@@ -163,143 +163,43 @@ const float sd_latent_rgb_proj[4][3] = {
163163 {-0 .178022f , -0 .200862f , -0 .678514f }};
164164float sd_latent_rgb_bias[3 ] = {-0 .017478f , -0 .055834f , -0 .105825f };
165165
166- void unpatchify_latents (ggml_tensor* latents, int patch_size, char * dst_buf) {
167- const int64_t N = latents->ne [3 ];
168- const int64_t C_in = latents->ne [2 ];
169- const int64_t H_in = latents->ne [1 ];
170- const int64_t W_in = latents->ne [0 ];
171166
172- const int64_t C_out = C_in / (patch_size * patch_size);
173- const int64_t H_out = H_in * patch_size;
174- const int64_t W_out = W_in * patch_size;
175-
176- const char * src_ptr = (char *)latents->data ;
177- size_t elem_size = latents->nb [0 ];
178-
179- bool alloc_dst_buf = dst_buf == nullptr ;
180- size_t dst_buf_size = latents->nb [3 ];
181- if (alloc_dst_buf) {
182- dst_buf = (char *)malloc (dst_buf_size);
183- }
184-
185- size_t dst_stride_w = elem_size;
186- size_t dst_stride_h = dst_stride_w * W_out;
187- size_t dst_stride_c = dst_stride_h * H_out;
188- size_t dst_stride_n = dst_stride_c * C_out;
189-
190- size_t dst_step_w = dst_stride_w * patch_size;
191- size_t dst_step_h = dst_stride_h * patch_size;
192-
193- for (int64_t n = 0 ; n < N; ++n) {
194- for (int64_t c = 0 ; c < C_in; ++c) {
195- int64_t c_out = c / (patch_size * patch_size);
196- int64_t rem = c % (patch_size * patch_size);
197- int64_t py = rem / patch_size;
198- int64_t px = rem % patch_size;
199-
200- char * dst_layer = dst_buf + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
201-
202- for (int64_t y = 0 ; y < H_in; ++y) {
203- char * dst_row = dst_layer + y * dst_step_h;
204-
205- for (int64_t x = 0 ; x < W_in; ++x) {
206- memcpy (dst_row + x * dst_step_w, src_ptr, elem_size);
207- src_ptr += elem_size;
208- }
209- }
210- }
211- }
212-
213- memcpy (latents->data , dst_buf, dst_buf_size);
214-
215- latents->ne [0 ] = W_out;
216- latents->ne [1 ] = H_out;
217- latents->ne [2 ] = C_out;
218-
219- latents->nb [0 ] = dst_stride_w;
220- latents->nb [1 ] = dst_stride_h;
221- latents->nb [2 ] = dst_stride_c;
222- latents->nb [3 ] = dst_stride_n;
223- if (alloc_dst_buf) {
224- free (dst_buf);
225- }
226- }
227-
228- void repatchify_latents (ggml_tensor* latents, int patch_size, char * dst_buf) {
229- const int64_t N = latents->ne [3 ];
230- const int64_t C_in = latents->ne [2 ];
231- const int64_t H_in = latents->ne [1 ];
232- const int64_t W_in = latents->ne [0 ];
233-
234- const int64_t C_out = C_in * patch_size * patch_size;
235- const int64_t H_out = H_in / patch_size;
236- const int64_t W_out = W_in / patch_size;
237-
238- const char * src_base = (char *)latents->data ;
239- const size_t elem_size = latents->nb [0 ];
240-
241- const size_t src_stride_w = latents->nb [0 ];
242- const size_t src_stride_h = latents->nb [1 ];
243- const size_t src_stride_c = latents->nb [2 ];
244- const size_t src_stride_n = latents->nb [3 ];
167+ void preview_latent_video (uint8_t * buffer, struct ggml_tensor * latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) {
168+ size_t buffer_head = 0 ;
245169
246- bool alloc_dst_buf = dst_buf == nullptr ;
247- size_t dst_buf_size = src_stride_n;
248- if (alloc_dst_buf) {
249- dst_buf = (char *)malloc (dst_buf_size);
170+ uint32_t latent_width = latents->ne [0 ];
171+ uint32_t latent_height = latents->ne [1 ];
172+ uint32_t dim = latents->ne [ggml_n_dims (latents) - 1 ];
173+ uint32_t frames = 1 ;
174+ if (ggml_n_dims (latents) == 4 ) {
175+ frames = latents->ne [2 ];
250176 }
251177
252- char * dst_ptr = dst_buf;
253-
254- const size_t src_step_h = src_stride_h * patch_size;
255- const size_t src_step_w = src_stride_w * patch_size;
178+ uint32_t rgb_width = latent_width * patch_size;
179+ uint32_t rgb_height = latent_height * patch_size;
256180
257- for (int64_t n = 0 ; n < N; ++n) {
258- for (int64_t c = 0 ; c < C_out; ++c) {
259- int64_t c_rem = c % (patch_size * patch_size);
260- int64_t c_in = c / (patch_size * patch_size);
261- int64_t py = c_rem / patch_size;
262- int64_t px = c_rem % patch_size;
181+ uint32_t unpatched_dim = dim / (patch_size * patch_size);
263182
264- const char * src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
265-
266- for (int64_t y = 0 ; y < H_out; ++y) {
267- const char * src_row = src_layer + y * src_step_h;
268-
269- for (int64_t x = 0 ; x < W_out; ++x) {
270- memcpy (dst_ptr, src_row + x * src_step_w, elem_size);
271- dst_ptr += elem_size;
183+ for (int k = 0 ; k < frames; k++) {
184+ for (int rgb_x = 0 ; rgb_x < rgb_width; rgb_x++) {
185+ for (int rgb_y = 0 ; rgb_y < rgb_height; rgb_y++) {
186+ int latent_x = rgb_x / patch_size;
187+ int latent_y = rgb_y / patch_size;
188+
189+ int channel_offset = 0 ;
190+ if (patch_size > 1 ) {
191+ channel_offset = ((rgb_y % patch_size) * patch_size + (rgb_x % patch_size));
272192 }
273- }
274- }
275- }
276-
277- memcpy (latents->data , dst_buf, dst_buf_size);
278-
279- latents->ne [0 ] = W_out;
280- latents->ne [1 ] = H_out;
281- latents->ne [2 ] = C_out;
282193
283- latents->nb [0 ] = elem_size;
284- latents->nb [1 ] = latents->nb [0 ] * W_out;
285- latents->nb [2 ] = latents->nb [1 ] * H_out;
286- latents->nb [3 ] = latents->nb [2 ] * C_out;
194+ size_t latent_id = (latent_x * latents->nb [0 ] + latent_y * latents->nb [1 ] + k * latents->nb [2 ]);
287195
288- if (alloc_dst_buf) {
289- free (dst_buf);
290- }
291- }
196+ // should be incremented by 1 for each pixel
197+ size_t pixel_id = k * rgb_width * rgb_height + rgb_y * rgb_width + rgb_x;
292198
293- void preview_latent_video (uint8_t * buffer, struct ggml_tensor * latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) {
294- size_t buffer_head = 0 ;
295- for (int k = 0 ; k < frames; k++) {
296- for (int j = 0 ; j < height; j++) {
297- for (int i = 0 ; i < width; i++) {
298- size_t latent_id = (i * latents->nb [0 ] + j * latents->nb [1 ] + k * latents->nb [2 ]);
299199 float r = 0 , g = 0 , b = 0 ;
300200 if (latent_rgb_proj != nullptr ) {
301- for (int d = 0 ; d < dim ; d++) {
302- float value = *(float *)((char *)latents->data + latent_id + d * latents->nb [ggml_n_dims (latents) - 1 ]);
201+ for (int d = 0 ; d < unpatched_dim ; d++) {
202+ float value = *(float *)((char *)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb [ggml_n_dims (latents) - 1 ]);
303203 r += value * latent_rgb_proj[d][0 ];
304204 g += value * latent_rgb_proj[d][1 ];
305205 b += value * latent_rgb_proj[d][2 ];
@@ -326,9 +226,9 @@ void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const fl
326226 g = g >= 0 ? g <= 1 ? g : 1 : 0 ;
327227 b = b >= 0 ? b <= 1 ? b : 1 : 0 ;
328228
329- buffer[buffer_head++ ] = (uint8_t )(r * 255 );
330- buffer[buffer_head++ ] = (uint8_t )(g * 255 );
331- buffer[buffer_head++ ] = (uint8_t )(b * 255 );
229+ buffer[pixel_id * 3 + 0 ] = (uint8_t )(r * 255 );
230+ buffer[pixel_id * 3 + 1 ] = (uint8_t )(g * 255 );
231+ buffer[pixel_id * 3 + 2 ] = (uint8_t )(b * 255 );
332232 }
333233 }
334234 }
0 commit comments