@@ -1321,63 +1321,14 @@ class StableDiffusionGGML {
13211321 }
13221322 if (patch_sz != 1 ) {
13231323 // unshuffle latents
1324- const int64_t N = latents->ne [3 ];
1325- const int64_t C_in = latents->ne [2 ];
1326- const int64_t H_in = latents->ne [1 ];
1327- const int64_t W_in = latents->ne [0 ];
1324+ std::vector<char > dst_buffer (latents->nb [GGML_MAX_DIMS-1 ]);
1325+ char * dst_buf = dst_buffer.data ();
13281326
1329- const int64_t C_out = C_in / (patch_sz * patch_sz);
1330- const int64_t H_out = H_in * patch_sz;
1331- const int64_t W_out = W_in * patch_sz;
1327+ unpatchify_latents (latents, patch_sz, dst_buf);
13321328
1333- const char * src_ptr = (char *)latents->data ;
1334- size_t elem_size = latents->nb [0 ];
1335-
1336- std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1337- char * dst_base = dst_buffer.data ();
1338-
1339- size_t dst_stride_w = elem_size;
1340- size_t dst_stride_h = dst_stride_w * W_out;
1341- size_t dst_stride_c = dst_stride_h * H_out;
1342- size_t dst_stride_n = dst_stride_c * C_out;
1343-
1344- size_t dst_step_w = dst_stride_w * patch_sz;
1345- size_t dst_step_h = dst_stride_h * patch_sz;
1346-
1347- for (int64_t n = 0 ; n < N; ++n) {
1348- for (int64_t c = 0 ; c < C_in; ++c) {
1349- int64_t c_out = c / (patch_sz * patch_sz);
1350- int64_t rem = c % (patch_sz * patch_sz);
1351- int64_t py = rem / patch_sz;
1352- int64_t px = rem % patch_sz;
1353-
1354- char * dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1355-
1356- for (int64_t y = 0 ; y < H_in; ++y) {
1357- char * dst_row = dst_layer + y * dst_step_h;
1358-
1359- for (int64_t x = 0 ; x < W_in; ++x) {
1360- memcpy (dst_row + x * dst_step_w, src_ptr, elem_size);
1361- src_ptr += elem_size;
1362- }
1363- }
1364- }
1365- }
1366-
1367- memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1368-
1369- latents->ne [0 ] = W_out;
1370- latents->ne [1 ] = H_out;
1371- latents->ne [2 ] = C_out;
1372-
1373- latents->nb [0 ] = dst_stride_w;
1374- latents->nb [1 ] = dst_stride_h;
1375- latents->nb [2 ] = dst_stride_c;
1376- latents->nb [3 ] = dst_stride_n;
1377-
1378- width = W_out;
1379- height = H_out;
1380- dim = C_out;
1329+ width = latents->ne [0 ];
1330+ height = latents->ne [1 ];
1331+ dim = latents->ne [ggml_n_dims (latents) - 1 ];
13811332 }
13821333 const float (*latent_rgb_proj)[channel] = nullptr ;
13831334 float * latent_rgb_bias = nullptr ;
@@ -1453,59 +1404,10 @@ class StableDiffusionGGML {
14531404
14541405 if (patch_sz != 1 ) {
14551406 // restore shuffled latents
1456- const int64_t N = latents->ne [3 ];
1457- const int64_t C_in = latents->ne [2 ];
1458- const int64_t H_in = latents->ne [1 ];
1459- const int64_t W_in = latents->ne [0 ];
1460-
1461- const int64_t C_out = C_in * patch_sz * patch_sz;
1462- const int64_t H_out = H_in / patch_sz;
1463- const int64_t W_out = W_in / patch_sz;
1464-
1465- const char * src_base = (char *)latents->data ;
1466- const size_t elem_size = latents->nb [0 ];
1467-
1468- const size_t src_stride_w = latents->nb [0 ];
1469- const size_t src_stride_h = latents->nb [1 ];
1470- const size_t src_stride_c = latents->nb [2 ];
1471- const size_t src_stride_n = latents->nb [3 ];
1472-
1473- std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1474- char * dst_ptr = dst_buffer.data ();
1475-
1476- const size_t src_step_h = src_stride_h * patch_sz;
1477- const size_t src_step_w = src_stride_w * patch_sz;
1478-
1479- for (int64_t n = 0 ; n < N; ++n) {
1480- for (int64_t c = 0 ; c < C_out; ++c) {
1481- int64_t c_rem = c % (patch_sz * patch_sz);
1482- int64_t c_in = c / (patch_sz * patch_sz);
1483- int64_t py = c_rem / patch_sz;
1484- int64_t px = c_rem % patch_sz;
1485-
1486- const char * src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1487-
1488- for (int64_t y = 0 ; y < H_out; ++y) {
1489- const char * src_row = src_layer + y * src_step_h;
1490-
1491- for (int64_t x = 0 ; x < W_out; ++x) {
1492- memcpy (dst_ptr, src_row + x * src_step_w, elem_size);
1493- dst_ptr += elem_size;
1494- }
1495- }
1496- }
1497- }
1498-
1499- memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1500-
1501- latents->ne [0 ] = W_out;
1502- latents->ne [1 ] = H_out;
1503- latents->ne [2 ] = C_out;
1407+ std::vector<char > dst_buffer (latents->nb [GGML_MAX_DIMS-1 ]);
1408+ char * dst_buf = dst_buffer.data ();
15041409
1505- latents->nb [0 ] = elem_size;
1506- latents->nb [1 ] = latents->nb [0 ] * W_out;
1507- latents->nb [2 ] = latents->nb [1 ] * H_out;
1508- latents->nb [3 ] = latents->nb [2 ] * C_out;
1410+ repatchify_latents (latents, patch_sz, dst_buf);
15091411 }
15101412 } else {
15111413 if (preview_mode == PREVIEW_VAE) {
0 commit comments