Skip to content

Commit 615c6f3

Browse files
committed
move latent shuffle logic to latents-preview.h
1 parent 01af891 commit 615c6f3

File tree

2 files changed

+136
-107
lines changed

2 files changed

+136
-107
lines changed

latent-preview.h

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,133 @@ const float sd_latent_rgb_proj[4][3] = {
163163
{-0.178022f, -0.200862f, -0.678514f}};
164164
float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f};
165165

166+
void unpatchify_latents(ggml_tensor* latents, int patch_size, char* dst_buf) {
167+
const int64_t N = latents->ne[3];
168+
const int64_t C_in = latents->ne[2];
169+
const int64_t H_in = latents->ne[1];
170+
const int64_t W_in = latents->ne[0];
171+
172+
const int64_t C_out = C_in / (patch_size * patch_size);
173+
const int64_t H_out = H_in * patch_size;
174+
const int64_t W_out = W_in * patch_size;
175+
176+
const char* src_ptr = (char*)latents->data;
177+
size_t elem_size = latents->nb[0];
178+
179+
bool alloc_dst_buf = dst_buf == nullptr;
180+
size_t dst_buf_size = latents->nb[3];
181+
if (alloc_dst_buf) {
182+
dst_buf = (char*)malloc(dst_buf_size);
183+
}
184+
185+
size_t dst_stride_w = elem_size;
186+
size_t dst_stride_h = dst_stride_w * W_out;
187+
size_t dst_stride_c = dst_stride_h * H_out;
188+
size_t dst_stride_n = dst_stride_c * C_out;
189+
190+
size_t dst_step_w = dst_stride_w * patch_size;
191+
size_t dst_step_h = dst_stride_h * patch_size;
192+
193+
for (int64_t n = 0; n < N; ++n) {
194+
for (int64_t c = 0; c < C_in; ++c) {
195+
int64_t c_out = c / (patch_size * patch_size);
196+
int64_t rem = c % (patch_size * patch_size);
197+
int64_t py = rem / patch_size;
198+
int64_t px = rem % patch_size;
199+
200+
char* dst_layer = dst_buf + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
201+
202+
for (int64_t y = 0; y < H_in; ++y) {
203+
char* dst_row = dst_layer + y * dst_step_h;
204+
205+
for (int64_t x = 0; x < W_in; ++x) {
206+
memcpy(dst_row + x * dst_step_w, src_ptr, elem_size);
207+
src_ptr += elem_size;
208+
}
209+
}
210+
}
211+
}
212+
213+
memcpy(latents->data, dst_buf, dst_buf_size);
214+
215+
latents->ne[0] = W_out;
216+
latents->ne[1] = H_out;
217+
latents->ne[2] = C_out;
218+
219+
latents->nb[0] = dst_stride_w;
220+
latents->nb[1] = dst_stride_h;
221+
latents->nb[2] = dst_stride_c;
222+
latents->nb[3] = dst_stride_n;
223+
if (alloc_dst_buf) {
224+
free(dst_buf);
225+
}
226+
}
227+
228+
void repatchify_latents(ggml_tensor* latents, int patch_size, char* dst_buf) {
229+
const int64_t N = latents->ne[3];
230+
const int64_t C_in = latents->ne[2];
231+
const int64_t H_in = latents->ne[1];
232+
const int64_t W_in = latents->ne[0];
233+
234+
const int64_t C_out = C_in * patch_size * patch_size;
235+
const int64_t H_out = H_in / patch_size;
236+
const int64_t W_out = W_in / patch_size;
237+
238+
const char* src_base = (char*)latents->data;
239+
const size_t elem_size = latents->nb[0];
240+
241+
const size_t src_stride_w = latents->nb[0];
242+
const size_t src_stride_h = latents->nb[1];
243+
const size_t src_stride_c = latents->nb[2];
244+
const size_t src_stride_n = latents->nb[3];
245+
246+
bool alloc_dst_buf = dst_buf == nullptr;
247+
size_t dst_buf_size = src_stride_n;
248+
if (alloc_dst_buf) {
249+
dst_buf = (char*)malloc(dst_buf_size);
250+
}
251+
252+
char* dst_ptr = dst_buf;
253+
254+
const size_t src_step_h = src_stride_h * patch_size;
255+
const size_t src_step_w = src_stride_w * patch_size;
256+
257+
for (int64_t n = 0; n < N; ++n) {
258+
for (int64_t c = 0; c < C_out; ++c) {
259+
int64_t c_rem = c % (patch_size * patch_size);
260+
int64_t c_in = c / (patch_size * patch_size);
261+
int64_t py = c_rem / patch_size;
262+
int64_t px = c_rem % patch_size;
263+
264+
const char* src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
265+
266+
for (int64_t y = 0; y < H_out; ++y) {
267+
const char* src_row = src_layer + y * src_step_h;
268+
269+
for (int64_t x = 0; x < W_out; ++x) {
270+
memcpy(dst_ptr, src_row + x * src_step_w, elem_size);
271+
dst_ptr += elem_size;
272+
}
273+
}
274+
}
275+
}
276+
277+
memcpy(latents->data, dst_buf, dst_buf_size);
278+
279+
latents->ne[0] = W_out;
280+
latents->ne[1] = H_out;
281+
latents->ne[2] = C_out;
282+
283+
latents->nb[0] = elem_size;
284+
latents->nb[1] = latents->nb[0] * W_out;
285+
latents->nb[2] = latents->nb[1] * H_out;
286+
latents->nb[3] = latents->nb[2] * C_out;
287+
288+
if (alloc_dst_buf) {
289+
free(dst_buf);
290+
}
291+
}
292+
166293
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) {
167294
size_t buffer_head = 0;
168295
for (int k = 0; k < frames; k++) {

stable-diffusion.cpp

Lines changed: 9 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,63 +1321,14 @@ class StableDiffusionGGML {
13211321
}
13221322
if (patch_sz != 1) {
13231323
// unshuffle latents
1324-
const int64_t N = latents->ne[3];
1325-
const int64_t C_in = latents->ne[2];
1326-
const int64_t H_in = latents->ne[1];
1327-
const int64_t W_in = latents->ne[0];
1324+
std::vector<char> dst_buffer(latents->nb[GGML_MAX_DIMS-1]);
1325+
char* dst_buf = dst_buffer.data();
13281326

1329-
const int64_t C_out = C_in / (patch_sz * patch_sz);
1330-
const int64_t H_out = H_in * patch_sz;
1331-
const int64_t W_out = W_in * patch_sz;
1327+
unpatchify_latents(latents, patch_sz, dst_buf);
13321328

1333-
const char* src_ptr = (char*)latents->data;
1334-
size_t elem_size = latents->nb[0];
1335-
1336-
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1337-
char* dst_base = dst_buffer.data();
1338-
1339-
size_t dst_stride_w = elem_size;
1340-
size_t dst_stride_h = dst_stride_w * W_out;
1341-
size_t dst_stride_c = dst_stride_h * H_out;
1342-
size_t dst_stride_n = dst_stride_c * C_out;
1343-
1344-
size_t dst_step_w = dst_stride_w * patch_sz;
1345-
size_t dst_step_h = dst_stride_h * patch_sz;
1346-
1347-
for (int64_t n = 0; n < N; ++n) {
1348-
for (int64_t c = 0; c < C_in; ++c) {
1349-
int64_t c_out = c / (patch_sz * patch_sz);
1350-
int64_t rem = c % (patch_sz * patch_sz);
1351-
int64_t py = rem / patch_sz;
1352-
int64_t px = rem % patch_sz;
1353-
1354-
char* dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1355-
1356-
for (int64_t y = 0; y < H_in; ++y) {
1357-
char* dst_row = dst_layer + y * dst_step_h;
1358-
1359-
for (int64_t x = 0; x < W_in; ++x) {
1360-
memcpy(dst_row + x * dst_step_w, src_ptr, elem_size);
1361-
src_ptr += elem_size;
1362-
}
1363-
}
1364-
}
1365-
}
1366-
1367-
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1368-
1369-
latents->ne[0] = W_out;
1370-
latents->ne[1] = H_out;
1371-
latents->ne[2] = C_out;
1372-
1373-
latents->nb[0] = dst_stride_w;
1374-
latents->nb[1] = dst_stride_h;
1375-
latents->nb[2] = dst_stride_c;
1376-
latents->nb[3] = dst_stride_n;
1377-
1378-
width = W_out;
1379-
height = H_out;
1380-
dim = C_out;
1329+
width = latents->ne[0];
1330+
height = latents->ne[1];
1331+
dim = latents->ne[ggml_n_dims(latents) - 1];
13811332
}
13821333
const float (*latent_rgb_proj)[channel] = nullptr;
13831334
float* latent_rgb_bias = nullptr;
@@ -1453,59 +1404,10 @@ class StableDiffusionGGML {
14531404

14541405
if (patch_sz != 1) {
14551406
// restore shuffled latents
1456-
const int64_t N = latents->ne[3];
1457-
const int64_t C_in = latents->ne[2];
1458-
const int64_t H_in = latents->ne[1];
1459-
const int64_t W_in = latents->ne[0];
1460-
1461-
const int64_t C_out = C_in * patch_sz * patch_sz;
1462-
const int64_t H_out = H_in / patch_sz;
1463-
const int64_t W_out = W_in / patch_sz;
1464-
1465-
const char* src_base = (char*)latents->data;
1466-
const size_t elem_size = latents->nb[0];
1467-
1468-
const size_t src_stride_w = latents->nb[0];
1469-
const size_t src_stride_h = latents->nb[1];
1470-
const size_t src_stride_c = latents->nb[2];
1471-
const size_t src_stride_n = latents->nb[3];
1472-
1473-
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1474-
char* dst_ptr = dst_buffer.data();
1475-
1476-
const size_t src_step_h = src_stride_h * patch_sz;
1477-
const size_t src_step_w = src_stride_w * patch_sz;
1478-
1479-
for (int64_t n = 0; n < N; ++n) {
1480-
for (int64_t c = 0; c < C_out; ++c) {
1481-
int64_t c_rem = c % (patch_sz * patch_sz);
1482-
int64_t c_in = c / (patch_sz * patch_sz);
1483-
int64_t py = c_rem / patch_sz;
1484-
int64_t px = c_rem % patch_sz;
1485-
1486-
const char* src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1487-
1488-
for (int64_t y = 0; y < H_out; ++y) {
1489-
const char* src_row = src_layer + y * src_step_h;
1490-
1491-
for (int64_t x = 0; x < W_out; ++x) {
1492-
memcpy(dst_ptr, src_row + x * src_step_w, elem_size);
1493-
dst_ptr += elem_size;
1494-
}
1495-
}
1496-
}
1497-
}
1498-
1499-
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1500-
1501-
latents->ne[0] = W_out;
1502-
latents->ne[1] = H_out;
1503-
latents->ne[2] = C_out;
1407+
std::vector<char> dst_buffer(latents->nb[GGML_MAX_DIMS-1]);
1408+
char* dst_buf = dst_buffer.data();
15041409

1505-
latents->nb[0] = elem_size;
1506-
latents->nb[1] = latents->nb[0] * W_out;
1507-
latents->nb[2] = latents->nb[1] * H_out;
1508-
latents->nb[3] = latents->nb[2] * C_out;
1410+
repatchify_latents(latents, patch_sz, dst_buf);
15091411
}
15101412
} else {
15111413
if (preview_mode == PREVIEW_VAE) {

0 commit comments

Comments
 (0)