@@ -1315,8 +1315,72 @@ class StableDiffusionGGML {
13151315 uint32_t dim = latents->ne [ggml_n_dims (latents) - 1 ];
13161316
13171317 if (preview_mode == PREVIEW_PROJ) {
1318- const float (*latent_rgb_proj)[channel] = nullptr ;
1319- float * latent_rgb_bias = nullptr ;
1318+ int64_t patch_sz = 1 ;
1319+ if (sd_version_is_flux2 (version)) {
1320+ patch_sz = 2 ;
1321+ }
1322+ if (patch_sz != 1 ) {
1323+ // unshuffle latents
1324+ const int64_t N = latents->ne [3 ];
1325+ const int64_t C_in = latents->ne [2 ];
1326+ const int64_t H_in = latents->ne [1 ];
1327+ const int64_t W_in = latents->ne [0 ];
1328+
1329+ const int64_t C_out = C_in / (patch_sz * patch_sz);
1330+ const int64_t H_out = H_in * patch_sz;
1331+ const int64_t W_out = W_in * patch_sz;
1332+
1333+ const char * src_ptr = (char *)latents->data ;
1334+ size_t elem_size = latents->nb [0 ];
1335+
1336+ std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1337+ char * dst_base = dst_buffer.data ();
1338+
1339+ size_t dst_stride_w = elem_size;
1340+ size_t dst_stride_h = dst_stride_w * W_out;
1341+ size_t dst_stride_c = dst_stride_h * H_out;
1342+ size_t dst_stride_n = dst_stride_c * C_out;
1343+
1344+ size_t dst_step_w = dst_stride_w * patch_sz;
1345+ size_t dst_step_h = dst_stride_h * patch_sz;
1346+
1347+ for (int64_t n = 0 ; n < N; ++n) {
1348+ for (int64_t c = 0 ; c < C_in; ++c) {
1349+ int64_t c_out = c / (patch_sz * patch_sz);
1350+ int64_t rem = c % (patch_sz * patch_sz);
1351+ int64_t py = rem / patch_sz;
1352+ int64_t px = rem % patch_sz;
1353+
1354+ char * dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1355+
1356+ for (int64_t y = 0 ; y < H_in; ++y) {
1357+ char * dst_row = dst_layer + y * dst_step_h;
1358+
1359+ for (int64_t x = 0 ; x < W_in; ++x) {
1360+ memcpy (dst_row + x * dst_step_w, src_ptr, elem_size);
1361+ src_ptr += elem_size;
1362+ }
1363+ }
1364+ }
1365+ }
1366+
1367+ memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1368+
1369+ latents->ne [0 ] = W_out;
1370+ latents->ne [1 ] = H_out;
1371+ latents->ne [2 ] = C_out;
1372+
1373+ latents->nb [0 ] = dst_stride_w;
1374+ latents->nb [1 ] = dst_stride_h;
1375+ latents->nb [2 ] = dst_stride_c;
1376+ latents->nb [3 ] = dst_stride_n;
1377+
1378+ width = W_out;
1379+ height = H_out;
1380+ dim = C_out;
1381+ }
1382+ const float (*latent_rgb_proj)[channel] = nullptr ;
1383+ float * latent_rgb_bias = nullptr ;
13201384
13211385 if (dim == 48 ) {
13221386 if (sd_version_is_wan (version)) {
@@ -1386,6 +1450,63 @@ class StableDiffusionGGML {
13861450 step_callback (step, frames, images, is_noisy);
13871451 free (data);
13881452 free (images);
1453+
1454+ if (patch_sz != 1 ) {
1455+ // restore shuffled latents
1456+ const int64_t N = latents->ne [3 ];
1457+ const int64_t C_in = latents->ne [2 ];
1458+ const int64_t H_in = latents->ne [1 ];
1459+ const int64_t W_in = latents->ne [0 ];
1460+
1461+ const int64_t C_out = C_in * patch_sz * patch_sz;
1462+ const int64_t H_out = H_in / patch_sz;
1463+ const int64_t W_out = W_in / patch_sz;
1464+
1465+ const char * src_base = (char *)latents->data ;
1466+ const size_t elem_size = latents->nb [0 ];
1467+
1468+ const size_t src_stride_w = latents->nb [0 ];
1469+ const size_t src_stride_h = latents->nb [1 ];
1470+ const size_t src_stride_c = latents->nb [2 ];
1471+ const size_t src_stride_n = latents->nb [3 ];
1472+
1473+ std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1474+ char * dst_ptr = dst_buffer.data ();
1475+
1476+ const size_t src_step_h = src_stride_h * patch_sz;
1477+ const size_t src_step_w = src_stride_w * patch_sz;
1478+
1479+ for (int64_t n = 0 ; n < N; ++n) {
1480+ for (int64_t c = 0 ; c < C_out; ++c) {
1481+ int64_t c_rem = c % (patch_sz * patch_sz);
1482+ int64_t c_in = c / (patch_sz * patch_sz);
1483+ int64_t py = c_rem / patch_sz;
1484+ int64_t px = c_rem % patch_sz;
1485+
1486+ const char * src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1487+
1488+ for (int64_t y = 0 ; y < H_out; ++y) {
1489+ const char * src_row = src_layer + y * src_step_h;
1490+
1491+ for (int64_t x = 0 ; x < W_out; ++x) {
1492+ memcpy (dst_ptr, src_row + x * src_step_w, elem_size);
1493+ dst_ptr += elem_size;
1494+ }
1495+ }
1496+ }
1497+ }
1498+
1499+ memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1500+
1501+ latents->ne [0 ] = W_out;
1502+ latents->ne [1 ] = H_out;
1503+ latents->ne [2 ] = C_out;
1504+
1505+ latents->nb [0 ] = elem_size;
1506+ latents->nb [1 ] = latents->nb [0 ] * W_out;
1507+ latents->nb [2 ] = latents->nb [1 ] * H_out;
1508+ latents->nb [3 ] = latents->nb [2 ] * C_out;
1509+ }
13891510 } else {
13901511 if (preview_mode == PREVIEW_VAE) {
13911512 process_latent_out (latents);
0 commit comments