Skip to content

Commit 01af891

Browse files
committed
support Flux.2 patched latents for proj preview
1 parent 28973fc commit 01af891

File tree

1 file changed

+123
-2
lines changed

1 file changed

+123
-2
lines changed

stable-diffusion.cpp

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,8 +1315,72 @@ class StableDiffusionGGML {
13151315
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1];
13161316

13171317
if (preview_mode == PREVIEW_PROJ) {
1318-
const float(*latent_rgb_proj)[channel] = nullptr;
1319-
float* latent_rgb_bias = nullptr;
1318+
int64_t patch_sz = 1;
1319+
if (sd_version_is_flux2(version)) {
1320+
patch_sz = 2;
1321+
}
1322+
if (patch_sz != 1) {
1323+
// unshuffle latents
1324+
const int64_t N = latents->ne[3];
1325+
const int64_t C_in = latents->ne[2];
1326+
const int64_t H_in = latents->ne[1];
1327+
const int64_t W_in = latents->ne[0];
1328+
1329+
const int64_t C_out = C_in / (patch_sz * patch_sz);
1330+
const int64_t H_out = H_in * patch_sz;
1331+
const int64_t W_out = W_in * patch_sz;
1332+
1333+
const char* src_ptr = (char*)latents->data;
1334+
size_t elem_size = latents->nb[0];
1335+
1336+
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1337+
char* dst_base = dst_buffer.data();
1338+
1339+
size_t dst_stride_w = elem_size;
1340+
size_t dst_stride_h = dst_stride_w * W_out;
1341+
size_t dst_stride_c = dst_stride_h * H_out;
1342+
size_t dst_stride_n = dst_stride_c * C_out;
1343+
1344+
size_t dst_step_w = dst_stride_w * patch_sz;
1345+
size_t dst_step_h = dst_stride_h * patch_sz;
1346+
1347+
for (int64_t n = 0; n < N; ++n) {
1348+
for (int64_t c = 0; c < C_in; ++c) {
1349+
int64_t c_out = c / (patch_sz * patch_sz);
1350+
int64_t rem = c % (patch_sz * patch_sz);
1351+
int64_t py = rem / patch_sz;
1352+
int64_t px = rem % patch_sz;
1353+
1354+
char* dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1355+
1356+
for (int64_t y = 0; y < H_in; ++y) {
1357+
char* dst_row = dst_layer + y * dst_step_h;
1358+
1359+
for (int64_t x = 0; x < W_in; ++x) {
1360+
memcpy(dst_row + x * dst_step_w, src_ptr, elem_size);
1361+
src_ptr += elem_size;
1362+
}
1363+
}
1364+
}
1365+
}
1366+
1367+
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1368+
1369+
latents->ne[0] = W_out;
1370+
latents->ne[1] = H_out;
1371+
latents->ne[2] = C_out;
1372+
1373+
latents->nb[0] = dst_stride_w;
1374+
latents->nb[1] = dst_stride_h;
1375+
latents->nb[2] = dst_stride_c;
1376+
latents->nb[3] = dst_stride_n;
1377+
1378+
width = W_out;
1379+
height = H_out;
1380+
dim = C_out;
1381+
}
1382+
const float (*latent_rgb_proj)[channel] = nullptr;
1383+
float* latent_rgb_bias = nullptr;
13201384

13211385
if (dim == 48) {
13221386
if (sd_version_is_wan(version)) {
@@ -1386,6 +1450,63 @@ class StableDiffusionGGML {
13861450
step_callback(step, frames, images, is_noisy);
13871451
free(data);
13881452
free(images);
1453+
1454+
if (patch_sz != 1) {
1455+
// restore shuffled latents
1456+
const int64_t N = latents->ne[3];
1457+
const int64_t C_in = latents->ne[2];
1458+
const int64_t H_in = latents->ne[1];
1459+
const int64_t W_in = latents->ne[0];
1460+
1461+
const int64_t C_out = C_in * patch_sz * patch_sz;
1462+
const int64_t H_out = H_in / patch_sz;
1463+
const int64_t W_out = W_in / patch_sz;
1464+
1465+
const char* src_base = (char*)latents->data;
1466+
const size_t elem_size = latents->nb[0];
1467+
1468+
const size_t src_stride_w = latents->nb[0];
1469+
const size_t src_stride_h = latents->nb[1];
1470+
const size_t src_stride_c = latents->nb[2];
1471+
const size_t src_stride_n = latents->nb[3];
1472+
1473+
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1474+
char* dst_ptr = dst_buffer.data();
1475+
1476+
const size_t src_step_h = src_stride_h * patch_sz;
1477+
const size_t src_step_w = src_stride_w * patch_sz;
1478+
1479+
for (int64_t n = 0; n < N; ++n) {
1480+
for (int64_t c = 0; c < C_out; ++c) {
1481+
int64_t c_rem = c % (patch_sz * patch_sz);
1482+
int64_t c_in = c / (patch_sz * patch_sz);
1483+
int64_t py = c_rem / patch_sz;
1484+
int64_t px = c_rem % patch_sz;
1485+
1486+
const char* src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1487+
1488+
for (int64_t y = 0; y < H_out; ++y) {
1489+
const char* src_row = src_layer + y * src_step_h;
1490+
1491+
for (int64_t x = 0; x < W_out; ++x) {
1492+
memcpy(dst_ptr, src_row + x * src_step_w, elem_size);
1493+
dst_ptr += elem_size;
1494+
}
1495+
}
1496+
}
1497+
}
1498+
1499+
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1500+
1501+
latents->ne[0] = W_out;
1502+
latents->ne[1] = H_out;
1503+
latents->ne[2] = C_out;
1504+
1505+
latents->nb[0] = elem_size;
1506+
latents->nb[1] = latents->nb[0] * W_out;
1507+
latents->nb[2] = latents->nb[1] * H_out;
1508+
latents->nb[3] = latents->nb[2] * C_out;
1509+
}
13891510
} else {
13901511
if (preview_mode == PREVIEW_VAE) {
13911512
process_latent_out(latents);

0 commit comments

Comments
 (0)