Skip to content

Commit 5ccb8d0

Browse files
committed
move latent shuffle logic to latents-preview.h
1 parent b23a989 commit 5ccb8d0

File tree

2 files changed

+136
-107
lines changed

2 files changed

+136
-107
lines changed

latent-preview.h

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,133 @@ const float sd_latent_rgb_proj[4][3] = {
163163
{-0.178022f, -0.200862f, -0.678514f}};
164164
float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f};
165165

166+
void unpatchify_latents(ggml_tensor* latents, int patch_size, char* dst_buf) {
167+
const int64_t N = latents->ne[3];
168+
const int64_t C_in = latents->ne[2];
169+
const int64_t H_in = latents->ne[1];
170+
const int64_t W_in = latents->ne[0];
171+
172+
const int64_t C_out = C_in / (patch_size * patch_size);
173+
const int64_t H_out = H_in * patch_size;
174+
const int64_t W_out = W_in * patch_size;
175+
176+
const char* src_ptr = (char*)latents->data;
177+
size_t elem_size = latents->nb[0];
178+
179+
bool alloc_dst_buf = dst_buf == nullptr;
180+
size_t dst_buf_size = latents->nb[3];
181+
if (alloc_dst_buf) {
182+
dst_buf = (char*)malloc(dst_buf_size);
183+
}
184+
185+
size_t dst_stride_w = elem_size;
186+
size_t dst_stride_h = dst_stride_w * W_out;
187+
size_t dst_stride_c = dst_stride_h * H_out;
188+
size_t dst_stride_n = dst_stride_c * C_out;
189+
190+
size_t dst_step_w = dst_stride_w * patch_size;
191+
size_t dst_step_h = dst_stride_h * patch_size;
192+
193+
for (int64_t n = 0; n < N; ++n) {
194+
for (int64_t c = 0; c < C_in; ++c) {
195+
int64_t c_out = c / (patch_size * patch_size);
196+
int64_t rem = c % (patch_size * patch_size);
197+
int64_t py = rem / patch_size;
198+
int64_t px = rem % patch_size;
199+
200+
char* dst_layer = dst_buf + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
201+
202+
for (int64_t y = 0; y < H_in; ++y) {
203+
char* dst_row = dst_layer + y * dst_step_h;
204+
205+
for (int64_t x = 0; x < W_in; ++x) {
206+
memcpy(dst_row + x * dst_step_w, src_ptr, elem_size);
207+
src_ptr += elem_size;
208+
}
209+
}
210+
}
211+
}
212+
213+
memcpy(latents->data, dst_buf, dst_buf_size);
214+
215+
latents->ne[0] = W_out;
216+
latents->ne[1] = H_out;
217+
latents->ne[2] = C_out;
218+
219+
latents->nb[0] = dst_stride_w;
220+
latents->nb[1] = dst_stride_h;
221+
latents->nb[2] = dst_stride_c;
222+
latents->nb[3] = dst_stride_n;
223+
if (alloc_dst_buf) {
224+
free(dst_buf);
225+
}
226+
}
227+
228+
void repatchify_latents(ggml_tensor* latents, int patch_size, char* dst_buf) {
229+
const int64_t N = latents->ne[3];
230+
const int64_t C_in = latents->ne[2];
231+
const int64_t H_in = latents->ne[1];
232+
const int64_t W_in = latents->ne[0];
233+
234+
const int64_t C_out = C_in * patch_size * patch_size;
235+
const int64_t H_out = H_in / patch_size;
236+
const int64_t W_out = W_in / patch_size;
237+
238+
const char* src_base = (char*)latents->data;
239+
const size_t elem_size = latents->nb[0];
240+
241+
const size_t src_stride_w = latents->nb[0];
242+
const size_t src_stride_h = latents->nb[1];
243+
const size_t src_stride_c = latents->nb[2];
244+
const size_t src_stride_n = latents->nb[3];
245+
246+
bool alloc_dst_buf = dst_buf == nullptr;
247+
size_t dst_buf_size = src_stride_n;
248+
if (alloc_dst_buf) {
249+
dst_buf = (char*)malloc(dst_buf_size);
250+
}
251+
252+
char* dst_ptr = dst_buf;
253+
254+
const size_t src_step_h = src_stride_h * patch_size;
255+
const size_t src_step_w = src_stride_w * patch_size;
256+
257+
for (int64_t n = 0; n < N; ++n) {
258+
for (int64_t c = 0; c < C_out; ++c) {
259+
int64_t c_rem = c % (patch_size * patch_size);
260+
int64_t c_in = c / (patch_size * patch_size);
261+
int64_t py = c_rem / patch_size;
262+
int64_t px = c_rem % patch_size;
263+
264+
const char* src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
265+
266+
for (int64_t y = 0; y < H_out; ++y) {
267+
const char* src_row = src_layer + y * src_step_h;
268+
269+
for (int64_t x = 0; x < W_out; ++x) {
270+
memcpy(dst_ptr, src_row + x * src_step_w, elem_size);
271+
dst_ptr += elem_size;
272+
}
273+
}
274+
}
275+
}
276+
277+
memcpy(latents->data, dst_buf, dst_buf_size);
278+
279+
latents->ne[0] = W_out;
280+
latents->ne[1] = H_out;
281+
latents->ne[2] = C_out;
282+
283+
latents->nb[0] = elem_size;
284+
latents->nb[1] = latents->nb[0] * W_out;
285+
latents->nb[2] = latents->nb[1] * H_out;
286+
latents->nb[3] = latents->nb[2] * C_out;
287+
288+
if (alloc_dst_buf) {
289+
free(dst_buf);
290+
}
291+
}
292+
166293
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) {
167294
size_t buffer_head = 0;
168295
for (int k = 0; k < frames; k++) {

stable-diffusion.cpp

Lines changed: 9 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,63 +1343,14 @@ class StableDiffusionGGML {
13431343
}
13441344
if (patch_sz != 1) {
13451345
// unshuffle latents
1346-
const int64_t N = latents->ne[3];
1347-
const int64_t C_in = latents->ne[2];
1348-
const int64_t H_in = latents->ne[1];
1349-
const int64_t W_in = latents->ne[0];
1346+
std::vector<char> dst_buffer(latents->nb[GGML_MAX_DIMS-1]);
1347+
char* dst_buf = dst_buffer.data();
13501348

1351-
const int64_t C_out = C_in / (patch_sz * patch_sz);
1352-
const int64_t H_out = H_in * patch_sz;
1353-
const int64_t W_out = W_in * patch_sz;
1349+
unpatchify_latents(latents, patch_sz, dst_buf);
13541350

1355-
const char* src_ptr = (char*)latents->data;
1356-
size_t elem_size = latents->nb[0];
1357-
1358-
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1359-
char* dst_base = dst_buffer.data();
1360-
1361-
size_t dst_stride_w = elem_size;
1362-
size_t dst_stride_h = dst_stride_w * W_out;
1363-
size_t dst_stride_c = dst_stride_h * H_out;
1364-
size_t dst_stride_n = dst_stride_c * C_out;
1365-
1366-
size_t dst_step_w = dst_stride_w * patch_sz;
1367-
size_t dst_step_h = dst_stride_h * patch_sz;
1368-
1369-
for (int64_t n = 0; n < N; ++n) {
1370-
for (int64_t c = 0; c < C_in; ++c) {
1371-
int64_t c_out = c / (patch_sz * patch_sz);
1372-
int64_t rem = c % (patch_sz * patch_sz);
1373-
int64_t py = rem / patch_sz;
1374-
int64_t px = rem % patch_sz;
1375-
1376-
char* dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1377-
1378-
for (int64_t y = 0; y < H_in; ++y) {
1379-
char* dst_row = dst_layer + y * dst_step_h;
1380-
1381-
for (int64_t x = 0; x < W_in; ++x) {
1382-
memcpy(dst_row + x * dst_step_w, src_ptr, elem_size);
1383-
src_ptr += elem_size;
1384-
}
1385-
}
1386-
}
1387-
}
1388-
1389-
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1390-
1391-
latents->ne[0] = W_out;
1392-
latents->ne[1] = H_out;
1393-
latents->ne[2] = C_out;
1394-
1395-
latents->nb[0] = dst_stride_w;
1396-
latents->nb[1] = dst_stride_h;
1397-
latents->nb[2] = dst_stride_c;
1398-
latents->nb[3] = dst_stride_n;
1399-
1400-
width = W_out;
1401-
height = H_out;
1402-
dim = C_out;
1351+
width = latents->ne[0];
1352+
height = latents->ne[1];
1353+
dim = latents->ne[ggml_n_dims(latents) - 1];
14031354
}
14041355
const float (*latent_rgb_proj)[channel] = nullptr;
14051356
float* latent_rgb_bias = nullptr;
@@ -1475,59 +1426,10 @@ class StableDiffusionGGML {
14751426

14761427
if (patch_sz != 1) {
14771428
// restore shuffled latents
1478-
const int64_t N = latents->ne[3];
1479-
const int64_t C_in = latents->ne[2];
1480-
const int64_t H_in = latents->ne[1];
1481-
const int64_t W_in = latents->ne[0];
1482-
1483-
const int64_t C_out = C_in * patch_sz * patch_sz;
1484-
const int64_t H_out = H_in / patch_sz;
1485-
const int64_t W_out = W_in / patch_sz;
1486-
1487-
const char* src_base = (char*)latents->data;
1488-
const size_t elem_size = latents->nb[0];
1489-
1490-
const size_t src_stride_w = latents->nb[0];
1491-
const size_t src_stride_h = latents->nb[1];
1492-
const size_t src_stride_c = latents->nb[2];
1493-
const size_t src_stride_n = latents->nb[3];
1494-
1495-
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1496-
char* dst_ptr = dst_buffer.data();
1497-
1498-
const size_t src_step_h = src_stride_h * patch_sz;
1499-
const size_t src_step_w = src_stride_w * patch_sz;
1500-
1501-
for (int64_t n = 0; n < N; ++n) {
1502-
for (int64_t c = 0; c < C_out; ++c) {
1503-
int64_t c_rem = c % (patch_sz * patch_sz);
1504-
int64_t c_in = c / (patch_sz * patch_sz);
1505-
int64_t py = c_rem / patch_sz;
1506-
int64_t px = c_rem % patch_sz;
1507-
1508-
const char* src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1509-
1510-
for (int64_t y = 0; y < H_out; ++y) {
1511-
const char* src_row = src_layer + y * src_step_h;
1512-
1513-
for (int64_t x = 0; x < W_out; ++x) {
1514-
memcpy(dst_ptr, src_row + x * src_step_w, elem_size);
1515-
dst_ptr += elem_size;
1516-
}
1517-
}
1518-
}
1519-
}
1520-
1521-
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1522-
1523-
latents->ne[0] = W_out;
1524-
latents->ne[1] = H_out;
1525-
latents->ne[2] = C_out;
1429+
std::vector<char> dst_buffer(latents->nb[GGML_MAX_DIMS-1]);
1430+
char* dst_buf = dst_buffer.data();
15261431

1527-
latents->nb[0] = elem_size;
1528-
latents->nb[1] = latents->nb[0] * W_out;
1529-
latents->nb[2] = latents->nb[1] * H_out;
1530-
latents->nb[3] = latents->nb[2] * C_out;
1432+
repatchify_latents(latents, patch_sz, dst_buf);
15311433
}
15321434
} else {
15331435
if (preview_mode == PREVIEW_VAE) {

0 commit comments

Comments
 (0)