@@ -1337,8 +1337,72 @@ class StableDiffusionGGML {
13371337 uint32_t dim = latents->ne [ggml_n_dims (latents) - 1 ];
13381338
13391339 if (preview_mode == PREVIEW_PROJ) {
1340- const float (*latent_rgb_proj)[channel] = nullptr ;
1341- float * latent_rgb_bias = nullptr ;
1340+ int64_t patch_sz = 1 ;
1341+ if (sd_version_is_flux2 (version)) {
1342+ patch_sz = 2 ;
1343+ }
1344+ if (patch_sz != 1 ) {
1345+ // unshuffle latents
1346+ const int64_t N = latents->ne [3 ];
1347+ const int64_t C_in = latents->ne [2 ];
1348+ const int64_t H_in = latents->ne [1 ];
1349+ const int64_t W_in = latents->ne [0 ];
1350+
1351+ const int64_t C_out = C_in / (patch_sz * patch_sz);
1352+ const int64_t H_out = H_in * patch_sz;
1353+ const int64_t W_out = W_in * patch_sz;
1354+
1355+ const char * src_ptr = (char *)latents->data ;
1356+ size_t elem_size = latents->nb [0 ];
1357+
1358+ std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1359+ char * dst_base = dst_buffer.data ();
1360+
1361+ size_t dst_stride_w = elem_size;
1362+ size_t dst_stride_h = dst_stride_w * W_out;
1363+ size_t dst_stride_c = dst_stride_h * H_out;
1364+ size_t dst_stride_n = dst_stride_c * C_out;
1365+
1366+ size_t dst_step_w = dst_stride_w * patch_sz;
1367+ size_t dst_step_h = dst_stride_h * patch_sz;
1368+
1369+ for (int64_t n = 0 ; n < N; ++n) {
1370+ for (int64_t c = 0 ; c < C_in; ++c) {
1371+ int64_t c_out = c / (patch_sz * patch_sz);
1372+ int64_t rem = c % (patch_sz * patch_sz);
1373+ int64_t py = rem / patch_sz;
1374+ int64_t px = rem % patch_sz;
1375+
1376+ char * dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1377+
1378+ for (int64_t y = 0 ; y < H_in; ++y) {
1379+ char * dst_row = dst_layer + y * dst_step_h;
1380+
1381+ for (int64_t x = 0 ; x < W_in; ++x) {
1382+ memcpy (dst_row + x * dst_step_w, src_ptr, elem_size);
1383+ src_ptr += elem_size;
1384+ }
1385+ }
1386+ }
1387+ }
1388+
1389+ memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1390+
1391+ latents->ne [0 ] = W_out;
1392+ latents->ne [1 ] = H_out;
1393+ latents->ne [2 ] = C_out;
1394+
1395+ latents->nb [0 ] = dst_stride_w;
1396+ latents->nb [1 ] = dst_stride_h;
1397+ latents->nb [2 ] = dst_stride_c;
1398+ latents->nb [3 ] = dst_stride_n;
1399+
1400+ width = W_out;
1401+ height = H_out;
1402+ dim = C_out;
1403+ }
1404+ const float (*latent_rgb_proj)[channel] = nullptr ;
1405+ float * latent_rgb_bias = nullptr ;
13421406
13431407 if (dim == 48 ) {
13441408 if (sd_version_is_wan (version)) {
@@ -1408,6 +1472,63 @@ class StableDiffusionGGML {
14081472 step_callback (step, frames, images, is_noisy, step_callback_data);
14091473 free (data);
14101474 free (images);
1475+
1476+ if (patch_sz != 1 ) {
1477+ // restore shuffled latents
1478+ const int64_t N = latents->ne [3 ];
1479+ const int64_t C_in = latents->ne [2 ];
1480+ const int64_t H_in = latents->ne [1 ];
1481+ const int64_t W_in = latents->ne [0 ];
1482+
1483+ const int64_t C_out = C_in * patch_sz * patch_sz;
1484+ const int64_t H_out = H_in / patch_sz;
1485+ const int64_t W_out = W_in / patch_sz;
1486+
1487+ const char * src_base = (char *)latents->data ;
1488+ const size_t elem_size = latents->nb [0 ];
1489+
1490+ const size_t src_stride_w = latents->nb [0 ];
1491+ const size_t src_stride_h = latents->nb [1 ];
1492+ const size_t src_stride_c = latents->nb [2 ];
1493+ const size_t src_stride_n = latents->nb [3 ];
1494+
1495+ std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1496+ char * dst_ptr = dst_buffer.data ();
1497+
1498+ const size_t src_step_h = src_stride_h * patch_sz;
1499+ const size_t src_step_w = src_stride_w * patch_sz;
1500+
1501+ for (int64_t n = 0 ; n < N; ++n) {
1502+ for (int64_t c = 0 ; c < C_out; ++c) {
1503+ int64_t c_rem = c % (patch_sz * patch_sz);
1504+ int64_t c_in = c / (patch_sz * patch_sz);
1505+ int64_t py = c_rem / patch_sz;
1506+ int64_t px = c_rem % patch_sz;
1507+
1508+ const char * src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1509+
1510+ for (int64_t y = 0 ; y < H_out; ++y) {
1511+ const char * src_row = src_layer + y * src_step_h;
1512+
1513+ for (int64_t x = 0 ; x < W_out; ++x) {
1514+ memcpy (dst_ptr, src_row + x * src_step_w, elem_size);
1515+ dst_ptr += elem_size;
1516+ }
1517+ }
1518+ }
1519+ }
1520+
1521+ memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1522+
1523+ latents->ne [0 ] = W_out;
1524+ latents->ne [1 ] = H_out;
1525+ latents->ne [2 ] = C_out;
1526+
1527+ latents->nb [0 ] = elem_size;
1528+ latents->nb [1 ] = latents->nb [0 ] * W_out;
1529+ latents->nb [2 ] = latents->nb [1 ] * H_out;
1530+ latents->nb [3 ] = latents->nb [2 ] * C_out;
1531+ }
14111532 } else {
14121533 if (preview_mode == PREVIEW_VAE) {
14131534 process_latent_out (latents);
0 commit comments