Skip to content

Commit b23a989

Browse files
committed
support Flux.2 patched latents for proj preview
1 parent 58bdc8a commit b23a989

File tree

1 file changed

+123
-2
lines changed

1 file changed

+123
-2
lines changed

stable-diffusion.cpp

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,8 +1337,72 @@ class StableDiffusionGGML {
13371337
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1];
13381338

13391339
if (preview_mode == PREVIEW_PROJ) {
1340-
const float(*latent_rgb_proj)[channel] = nullptr;
1341-
float* latent_rgb_bias = nullptr;
1340+
int64_t patch_sz = 1;
1341+
if (sd_version_is_flux2(version)) {
1342+
patch_sz = 2;
1343+
}
1344+
if (patch_sz != 1) {
1345+
// unshuffle latents
1346+
const int64_t N = latents->ne[3];
1347+
const int64_t C_in = latents->ne[2];
1348+
const int64_t H_in = latents->ne[1];
1349+
const int64_t W_in = latents->ne[0];
1350+
1351+
const int64_t C_out = C_in / (patch_sz * patch_sz);
1352+
const int64_t H_out = H_in * patch_sz;
1353+
const int64_t W_out = W_in * patch_sz;
1354+
1355+
const char* src_ptr = (char*)latents->data;
1356+
size_t elem_size = latents->nb[0];
1357+
1358+
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1359+
char* dst_base = dst_buffer.data();
1360+
1361+
size_t dst_stride_w = elem_size;
1362+
size_t dst_stride_h = dst_stride_w * W_out;
1363+
size_t dst_stride_c = dst_stride_h * H_out;
1364+
size_t dst_stride_n = dst_stride_c * C_out;
1365+
1366+
size_t dst_step_w = dst_stride_w * patch_sz;
1367+
size_t dst_step_h = dst_stride_h * patch_sz;
1368+
1369+
for (int64_t n = 0; n < N; ++n) {
1370+
for (int64_t c = 0; c < C_in; ++c) {
1371+
int64_t c_out = c / (patch_sz * patch_sz);
1372+
int64_t rem = c % (patch_sz * patch_sz);
1373+
int64_t py = rem / patch_sz;
1374+
int64_t px = rem % patch_sz;
1375+
1376+
char* dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1377+
1378+
for (int64_t y = 0; y < H_in; ++y) {
1379+
char* dst_row = dst_layer + y * dst_step_h;
1380+
1381+
for (int64_t x = 0; x < W_in; ++x) {
1382+
memcpy(dst_row + x * dst_step_w, src_ptr, elem_size);
1383+
src_ptr += elem_size;
1384+
}
1385+
}
1386+
}
1387+
}
1388+
1389+
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1390+
1391+
latents->ne[0] = W_out;
1392+
latents->ne[1] = H_out;
1393+
latents->ne[2] = C_out;
1394+
1395+
latents->nb[0] = dst_stride_w;
1396+
latents->nb[1] = dst_stride_h;
1397+
latents->nb[2] = dst_stride_c;
1398+
latents->nb[3] = dst_stride_n;
1399+
1400+
width = W_out;
1401+
height = H_out;
1402+
dim = C_out;
1403+
}
1404+
const float (*latent_rgb_proj)[channel] = nullptr;
1405+
float* latent_rgb_bias = nullptr;
13421406

13431407
if (dim == 48) {
13441408
if (sd_version_is_wan(version)) {
@@ -1408,6 +1472,63 @@ class StableDiffusionGGML {
14081472
step_callback(step, frames, images, is_noisy, step_callback_data);
14091473
free(data);
14101474
free(images);
1475+
1476+
if (patch_sz != 1) {
1477+
// restore shuffled latents
1478+
const int64_t N = latents->ne[3];
1479+
const int64_t C_in = latents->ne[2];
1480+
const int64_t H_in = latents->ne[1];
1481+
const int64_t W_in = latents->ne[0];
1482+
1483+
const int64_t C_out = C_in * patch_sz * patch_sz;
1484+
const int64_t H_out = H_in / patch_sz;
1485+
const int64_t W_out = W_in / patch_sz;
1486+
1487+
const char* src_base = (char*)latents->data;
1488+
const size_t elem_size = latents->nb[0];
1489+
1490+
const size_t src_stride_w = latents->nb[0];
1491+
const size_t src_stride_h = latents->nb[1];
1492+
const size_t src_stride_c = latents->nb[2];
1493+
const size_t src_stride_n = latents->nb[3];
1494+
1495+
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1496+
char* dst_ptr = dst_buffer.data();
1497+
1498+
const size_t src_step_h = src_stride_h * patch_sz;
1499+
const size_t src_step_w = src_stride_w * patch_sz;
1500+
1501+
for (int64_t n = 0; n < N; ++n) {
1502+
for (int64_t c = 0; c < C_out; ++c) {
1503+
int64_t c_rem = c % (patch_sz * patch_sz);
1504+
int64_t c_in = c / (patch_sz * patch_sz);
1505+
int64_t py = c_rem / patch_sz;
1506+
int64_t px = c_rem % patch_sz;
1507+
1508+
const char* src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1509+
1510+
for (int64_t y = 0; y < H_out; ++y) {
1511+
const char* src_row = src_layer + y * src_step_h;
1512+
1513+
for (int64_t x = 0; x < W_out; ++x) {
1514+
memcpy(dst_ptr, src_row + x * src_step_w, elem_size);
1515+
dst_ptr += elem_size;
1516+
}
1517+
}
1518+
}
1519+
}
1520+
1521+
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1522+
1523+
latents->ne[0] = W_out;
1524+
latents->ne[1] = H_out;
1525+
latents->ne[2] = C_out;
1526+
1527+
latents->nb[0] = elem_size;
1528+
latents->nb[1] = latents->nb[0] * W_out;
1529+
latents->nb[2] = latents->nb[1] * H_out;
1530+
latents->nb[3] = latents->nb[2] * C_out;
1531+
}
14111532
} else {
14121533
if (preview_mode == PREVIEW_VAE) {
14131534
process_latent_out(latents);

0 commit comments

Comments
 (0)