@@ -1343,63 +1343,14 @@ class StableDiffusionGGML {
13431343 }
13441344 if (patch_sz != 1 ) {
13451345 // unshuffle latents
1346- const int64_t N = latents->ne [3 ];
1347- const int64_t C_in = latents->ne [2 ];
1348- const int64_t H_in = latents->ne [1 ];
1349- const int64_t W_in = latents->ne [0 ];
1346+ std::vector<char > dst_buffer (latents->nb [GGML_MAX_DIMS-1 ]);
1347+ char * dst_buf = dst_buffer.data ();
13501348
1351- const int64_t C_out = C_in / (patch_sz * patch_sz);
1352- const int64_t H_out = H_in * patch_sz;
1353- const int64_t W_out = W_in * patch_sz;
1349+ unpatchify_latents (latents, patch_sz, dst_buf);
13541350
1355- const char * src_ptr = (char *)latents->data ;
1356- size_t elem_size = latents->nb [0 ];
1357-
1358- std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1359- char * dst_base = dst_buffer.data ();
1360-
1361- size_t dst_stride_w = elem_size;
1362- size_t dst_stride_h = dst_stride_w * W_out;
1363- size_t dst_stride_c = dst_stride_h * H_out;
1364- size_t dst_stride_n = dst_stride_c * C_out;
1365-
1366- size_t dst_step_w = dst_stride_w * patch_sz;
1367- size_t dst_step_h = dst_stride_h * patch_sz;
1368-
1369- for (int64_t n = 0 ; n < N; ++n) {
1370- for (int64_t c = 0 ; c < C_in; ++c) {
1371- int64_t c_out = c / (patch_sz * patch_sz);
1372- int64_t rem = c % (patch_sz * patch_sz);
1373- int64_t py = rem / patch_sz;
1374- int64_t px = rem % patch_sz;
1375-
1376- char * dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1377-
1378- for (int64_t y = 0 ; y < H_in; ++y) {
1379- char * dst_row = dst_layer + y * dst_step_h;
1380-
1381- for (int64_t x = 0 ; x < W_in; ++x) {
1382- memcpy (dst_row + x * dst_step_w, src_ptr, elem_size);
1383- src_ptr += elem_size;
1384- }
1385- }
1386- }
1387- }
1388-
1389- memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1390-
1391- latents->ne [0 ] = W_out;
1392- latents->ne [1 ] = H_out;
1393- latents->ne [2 ] = C_out;
1394-
1395- latents->nb [0 ] = dst_stride_w;
1396- latents->nb [1 ] = dst_stride_h;
1397- latents->nb [2 ] = dst_stride_c;
1398- latents->nb [3 ] = dst_stride_n;
1399-
1400- width = W_out;
1401- height = H_out;
1402- dim = C_out;
1351+ width = latents->ne [0 ];
1352+ height = latents->ne [1 ];
1353+ dim = latents->ne [ggml_n_dims (latents) - 1 ];
14031354 }
14041355 const float (*latent_rgb_proj)[channel] = nullptr ;
14051356 float * latent_rgb_bias = nullptr ;
@@ -1475,59 +1426,10 @@ class StableDiffusionGGML {
14751426
14761427 if (patch_sz != 1 ) {
14771428 // restore shuffled latents
1478- const int64_t N = latents->ne [3 ];
1479- const int64_t C_in = latents->ne [2 ];
1480- const int64_t H_in = latents->ne [1 ];
1481- const int64_t W_in = latents->ne [0 ];
1482-
1483- const int64_t C_out = C_in * patch_sz * patch_sz;
1484- const int64_t H_out = H_in / patch_sz;
1485- const int64_t W_out = W_in / patch_sz;
1486-
1487- const char * src_base = (char *)latents->data ;
1488- const size_t elem_size = latents->nb [0 ];
1489-
1490- const size_t src_stride_w = latents->nb [0 ];
1491- const size_t src_stride_h = latents->nb [1 ];
1492- const size_t src_stride_c = latents->nb [2 ];
1493- const size_t src_stride_n = latents->nb [3 ];
1494-
1495- std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1496- char * dst_ptr = dst_buffer.data ();
1497-
1498- const size_t src_step_h = src_stride_h * patch_sz;
1499- const size_t src_step_w = src_stride_w * patch_sz;
1500-
1501- for (int64_t n = 0 ; n < N; ++n) {
1502- for (int64_t c = 0 ; c < C_out; ++c) {
1503- int64_t c_rem = c % (patch_sz * patch_sz);
1504- int64_t c_in = c / (patch_sz * patch_sz);
1505- int64_t py = c_rem / patch_sz;
1506- int64_t px = c_rem % patch_sz;
1507-
1508- const char * src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1509-
1510- for (int64_t y = 0 ; y < H_out; ++y) {
1511- const char * src_row = src_layer + y * src_step_h;
1512-
1513- for (int64_t x = 0 ; x < W_out; ++x) {
1514- memcpy (dst_ptr, src_row + x * src_step_w, elem_size);
1515- dst_ptr += elem_size;
1516- }
1517- }
1518- }
1519- }
1520-
1521- memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1522-
1523- latents->ne [0 ] = W_out;
1524- latents->ne [1 ] = H_out;
1525- latents->ne [2 ] = C_out;
1429+ std::vector<char > dst_buffer (latents->nb [GGML_MAX_DIMS-1 ]);
1430+ char * dst_buf = dst_buffer.data ();
15261431
1527- latents->nb [0 ] = elem_size;
1528- latents->nb [1 ] = latents->nb [0 ] * W_out;
1529- latents->nb [2 ] = latents->nb [1 ] * H_out;
1530- latents->nb [3 ] = latents->nb [2 ] * C_out;
1432+ repatchify_latents (latents, patch_sz, dst_buf);
15311433 }
15321434 } else {
15331435 if (preview_mode == PREVIEW_VAE) {
0 commit comments