Skip to content

Commit 56f134e

Browse files
committed
refactor preview_latent_video to support flux.2 patchified latents
1 parent 5ccb8d0 commit 56f134e

File tree

2 files changed

+86
-199
lines changed

2 files changed

+86
-199
lines changed

latent-preview.h

Lines changed: 28 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -163,143 +163,43 @@ const float sd_latent_rgb_proj[4][3] = {
163163
{-0.178022f, -0.200862f, -0.678514f}};
164164
float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f};
165165

166-
void unpatchify_latents(ggml_tensor* latents, int patch_size, char* dst_buf) {
167-
const int64_t N = latents->ne[3];
168-
const int64_t C_in = latents->ne[2];
169-
const int64_t H_in = latents->ne[1];
170-
const int64_t W_in = latents->ne[0];
171166

172-
const int64_t C_out = C_in / (patch_size * patch_size);
173-
const int64_t H_out = H_in * patch_size;
174-
const int64_t W_out = W_in * patch_size;
175-
176-
const char* src_ptr = (char*)latents->data;
177-
size_t elem_size = latents->nb[0];
178-
179-
bool alloc_dst_buf = dst_buf == nullptr;
180-
size_t dst_buf_size = latents->nb[3];
181-
if (alloc_dst_buf) {
182-
dst_buf = (char*)malloc(dst_buf_size);
183-
}
184-
185-
size_t dst_stride_w = elem_size;
186-
size_t dst_stride_h = dst_stride_w * W_out;
187-
size_t dst_stride_c = dst_stride_h * H_out;
188-
size_t dst_stride_n = dst_stride_c * C_out;
189-
190-
size_t dst_step_w = dst_stride_w * patch_size;
191-
size_t dst_step_h = dst_stride_h * patch_size;
192-
193-
for (int64_t n = 0; n < N; ++n) {
194-
for (int64_t c = 0; c < C_in; ++c) {
195-
int64_t c_out = c / (patch_size * patch_size);
196-
int64_t rem = c % (patch_size * patch_size);
197-
int64_t py = rem / patch_size;
198-
int64_t px = rem % patch_size;
199-
200-
char* dst_layer = dst_buf + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
201-
202-
for (int64_t y = 0; y < H_in; ++y) {
203-
char* dst_row = dst_layer + y * dst_step_h;
204-
205-
for (int64_t x = 0; x < W_in; ++x) {
206-
memcpy(dst_row + x * dst_step_w, src_ptr, elem_size);
207-
src_ptr += elem_size;
208-
}
209-
}
210-
}
211-
}
212-
213-
memcpy(latents->data, dst_buf, dst_buf_size);
214-
215-
latents->ne[0] = W_out;
216-
latents->ne[1] = H_out;
217-
latents->ne[2] = C_out;
218-
219-
latents->nb[0] = dst_stride_w;
220-
latents->nb[1] = dst_stride_h;
221-
latents->nb[2] = dst_stride_c;
222-
latents->nb[3] = dst_stride_n;
223-
if (alloc_dst_buf) {
224-
free(dst_buf);
225-
}
226-
}
227-
228-
void repatchify_latents(ggml_tensor* latents, int patch_size, char* dst_buf) {
229-
const int64_t N = latents->ne[3];
230-
const int64_t C_in = latents->ne[2];
231-
const int64_t H_in = latents->ne[1];
232-
const int64_t W_in = latents->ne[0];
233-
234-
const int64_t C_out = C_in * patch_size * patch_size;
235-
const int64_t H_out = H_in / patch_size;
236-
const int64_t W_out = W_in / patch_size;
237-
238-
const char* src_base = (char*)latents->data;
239-
const size_t elem_size = latents->nb[0];
240-
241-
const size_t src_stride_w = latents->nb[0];
242-
const size_t src_stride_h = latents->nb[1];
243-
const size_t src_stride_c = latents->nb[2];
244-
const size_t src_stride_n = latents->nb[3];
167+
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) {
168+
size_t buffer_head = 0;
245169

246-
bool alloc_dst_buf = dst_buf == nullptr;
247-
size_t dst_buf_size = src_stride_n;
248-
if (alloc_dst_buf) {
249-
dst_buf = (char*)malloc(dst_buf_size);
170+
uint32_t latent_width = latents->ne[0];
171+
uint32_t latent_height = latents->ne[1];
172+
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1];
173+
uint32_t frames = 1;
174+
if (ggml_n_dims(latents) == 4) {
175+
frames = latents->ne[2];
250176
}
251177

252-
char* dst_ptr = dst_buf;
253-
254-
const size_t src_step_h = src_stride_h * patch_size;
255-
const size_t src_step_w = src_stride_w * patch_size;
178+
uint32_t rgb_width = latent_width * patch_size;
179+
uint32_t rgb_height = latent_height * patch_size;
256180

257-
for (int64_t n = 0; n < N; ++n) {
258-
for (int64_t c = 0; c < C_out; ++c) {
259-
int64_t c_rem = c % (patch_size * patch_size);
260-
int64_t c_in = c / (patch_size * patch_size);
261-
int64_t py = c_rem / patch_size;
262-
int64_t px = c_rem % patch_size;
181+
uint32_t unpatched_dim = dim / (patch_size * patch_size);
263182

264-
const char* src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
265-
266-
for (int64_t y = 0; y < H_out; ++y) {
267-
const char* src_row = src_layer + y * src_step_h;
268-
269-
for (int64_t x = 0; x < W_out; ++x) {
270-
memcpy(dst_ptr, src_row + x * src_step_w, elem_size);
271-
dst_ptr += elem_size;
183+
for (int k = 0; k < frames; k++) {
184+
for (int rgb_x = 0; rgb_x < rgb_width; rgb_x++) {
185+
for (int rgb_y = 0; rgb_y < rgb_height; rgb_y++) {
186+
int latent_x = rgb_x / patch_size;
187+
int latent_y = rgb_y / patch_size;
188+
189+
int channel_offset = 0;
190+
if (patch_size > 1) {
191+
channel_offset = ((rgb_y % patch_size) * patch_size + (rgb_x % patch_size));
272192
}
273-
}
274-
}
275-
}
276-
277-
memcpy(latents->data, dst_buf, dst_buf_size);
278-
279-
latents->ne[0] = W_out;
280-
latents->ne[1] = H_out;
281-
latents->ne[2] = C_out;
282193

283-
latents->nb[0] = elem_size;
284-
latents->nb[1] = latents->nb[0] * W_out;
285-
latents->nb[2] = latents->nb[1] * H_out;
286-
latents->nb[3] = latents->nb[2] * C_out;
194+
size_t latent_id = (latent_x * latents->nb[0] + latent_y * latents->nb[1] + k * latents->nb[2]);
287195

288-
if (alloc_dst_buf) {
289-
free(dst_buf);
290-
}
291-
}
196+
// should be incremented by 1 for each pixel
197+
size_t pixel_id = k * rgb_width * rgb_height + rgb_y * rgb_width + rgb_x;
292198

293-
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) {
294-
size_t buffer_head = 0;
295-
for (int k = 0; k < frames; k++) {
296-
for (int j = 0; j < height; j++) {
297-
for (int i = 0; i < width; i++) {
298-
size_t latent_id = (i * latents->nb[0] + j * latents->nb[1] + k * latents->nb[2]);
299199
float r = 0, g = 0, b = 0;
300200
if (latent_rgb_proj != nullptr) {
301-
for (int d = 0; d < dim; d++) {
302-
float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[ggml_n_dims(latents) - 1]);
201+
for (int d = 0; d < unpatched_dim; d++) {
202+
float value = *(float*)((char*)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb[ggml_n_dims(latents) - 1]);
303203
r += value * latent_rgb_proj[d][0];
304204
g += value * latent_rgb_proj[d][1];
305205
b += value * latent_rgb_proj[d][2];
@@ -326,9 +226,9 @@ void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const fl
326226
g = g >= 0 ? g <= 1 ? g : 1 : 0;
327227
b = b >= 0 ? b <= 1 ? b : 1 : 0;
328228

329-
buffer[buffer_head++] = (uint8_t)(r * 255);
330-
buffer[buffer_head++] = (uint8_t)(g * 255);
331-
buffer[buffer_head++] = (uint8_t)(b * 255);
229+
buffer[pixel_id * 3 + 0] = (uint8_t)(r * 255);
230+
buffer[pixel_id * 3 + 1] = (uint8_t)(g * 255);
231+
buffer[pixel_id * 3 + 2] = (uint8_t)(b * 255);
332232
}
333233
}
334234
}

0 commit comments

Comments
 (0)