diff --git a/latent-preview.h b/latent-preview.h index 97409a7d8..8354a35e0 100644 --- a/latent-preview.h +++ b/latent-preview.h @@ -91,6 +91,41 @@ const float flux_latent_rgb_proj[16][3] = { {-0.111849f, -0.055589f, -0.032361f}}; float flux_latent_rgb_bias[3] = {0.024600f, -0.006937f, -0.008089f}; +const float flux2_latent_rgb_proj[32][3] = { + {0.000736f, -0.008385f, -0.019710f}, + {-0.001352f, -0.016392f, 0.020693f}, + {-0.006376f, 0.002428f, 0.036736f}, + {0.039384f, 0.074167f, 0.119789f}, + {0.007464f, -0.005705f, -0.004734f}, + {-0.004086f, 0.005287f, -0.000409f}, + {-0.032835f, 0.050802f, -0.028120f}, + {-0.003158f, -0.000835f, 0.000406f}, + {-0.112840f, -0.084337f, -0.023083f}, + {0.001462f, -0.006656f, 0.000549f}, + {-0.009980f, -0.007480f, 0.009702f}, + {0.032540f, 0.000214f, -0.061388f}, + {0.011023f, 0.000694f, 0.007143f}, + {-0.001468f, -0.006723f, -0.001678f}, + {-0.005921f, -0.010320f, -0.003907f}, + {-0.028434f, 0.027584f, 0.018457f}, + {0.014349f, 0.011523f, 0.000441f}, + {0.009874f, 0.003081f, 0.001507f}, + {0.002218f, 0.005712f, 0.001563f}, + {0.053010f, -0.019844f, 0.008683f}, + {-0.002507f, 0.005384f, 0.000938f}, + {-0.002177f, -0.011366f, 0.003559f}, + {-0.000261f, 0.015121f, -0.003240f}, + {-0.003944f, -0.002083f, 0.005043f}, + {-0.009138f, 0.011336f, 0.003781f}, + {0.011429f, 0.003985f, -0.003855f}, + {0.010518f, -0.005586f, 0.010131f}, + {0.007883f, 0.002912f, -0.001473f}, + {-0.003318f, -0.003160f, 0.003684f}, + {-0.034560f, -0.008740f, 0.012996f}, + {0.000166f, 0.001079f, -0.012153f}, + {0.017772f, 0.000937f, -0.011953f}}; +float flux2_latent_rgb_bias[3] = {-0.028738f, -0.098463f, -0.107619f}; + // This one was taken straight from // https://github.com/Stability-AI/sd3.5/blob/8565799a3b41eb0c7ba976d18375f0f753f56402/sd3_impls.py#L288-L303 // (MiT Licence) @@ -128,16 +163,43 @@ const float sd_latent_rgb_proj[4][3] = { {-0.178022f, -0.200862f, -0.678514f}}; float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f}; -void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) { + +void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) { size_t buffer_head = 0; + + uint32_t latent_width = latents->ne[0]; + uint32_t latent_height = latents->ne[1]; + uint32_t dim = latents->ne[ggml_n_dims(latents) - 1]; + uint32_t frames = 1; + if (ggml_n_dims(latents) == 4) { + frames = latents->ne[2]; + } + + uint32_t rgb_width = latent_width * patch_size; + uint32_t rgb_height = latent_height * patch_size; + + uint32_t unpatched_dim = dim / (patch_size * patch_size); + for (int k = 0; k < frames; k++) { - for (int j = 0; j < height; j++) { - for (int i = 0; i < width; i++) { - size_t latent_id = (i * latents->nb[0] + j * latents->nb[1] + k * latents->nb[2]); + for (int rgb_x = 0; rgb_x < rgb_width; rgb_x++) { + for (int rgb_y = 0; rgb_y < rgb_height; rgb_y++) { + int latent_x = rgb_x / patch_size; + int latent_y = rgb_y / patch_size; + + int channel_offset = 0; + if (patch_size > 1) { + channel_offset = ((rgb_y % patch_size) * patch_size + (rgb_x % patch_size)); + } + + size_t latent_id = (latent_x * latents->nb[0] + latent_y * latents->nb[1] + k * latents->nb[2]); + + // should be incremented by 1 for each pixel + size_t pixel_id = k * rgb_width * rgb_height + rgb_y * rgb_width + rgb_x; + float r = 0, g = 0, b = 0; if (latent_rgb_proj != nullptr) { - for (int d = 0; d < dim; d++) { - float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[ggml_n_dims(latents) - 1]); + for (int d = 0; d < unpatched_dim; d++) { + float value = *(float*)((char*)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb[ggml_n_dims(latents) - 1]); r += value * latent_rgb_proj[d][0]; g += value * latent_rgb_proj[d][1]; b += value * latent_rgb_proj[d][2]; @@ -164,9 +226,9 @@ void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const fl g = g >= 0 ? g <= 1 ? g : 1 : 0; b = b >= 0 ? b <= 1 ? b : 1 : 0; - buffer[buffer_head++] = (uint8_t)(r * 255); - buffer[buffer_head++] = (uint8_t)(g * 255); - buffer[buffer_head++] = (uint8_t)(b * 255); + buffer[pixel_id * 3 + 0] = (uint8_t)(r * 255); + buffer[pixel_id * 3 + 1] = (uint8_t)(g * 255); + buffer[pixel_id * 3 + 2] = (uint8_t)(b * 255); } } } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 73065610d..ec1d38e64 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -406,8 +406,8 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -448,10 +448,10 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -460,10 +460,10 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -492,20 +492,20 @@ class StableDiffusionGGML { "", enable_vision); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else if (sd_version_is_z_image(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else { // SD1.x SD2.x SDXL if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, @@ -1337,10 +1337,17 @@ class StableDiffusionGGML { uint32_t dim = latents->ne[ggml_n_dims(latents) - 1]; if (preview_mode == PREVIEW_PROJ) { - const float(*latent_rgb_proj)[channel] = nullptr; - float* latent_rgb_bias = nullptr; + int64_t patch_sz = 1; + const float (*latent_rgb_proj)[channel] = nullptr; + float* latent_rgb_bias = nullptr; - if (dim == 48) { + if (dim == 128) { + if (sd_version_is_flux2(version)) { + latent_rgb_proj = flux2_latent_rgb_proj; + latent_rgb_bias = flux2_latent_rgb_bias; + patch_sz = 2; + } + } else if (dim == 48) { if (sd_version_is_wan(version)) { latent_rgb_proj = wan_22_latent_rgb_proj; latent_rgb_bias = wan_22_latent_rgb_bias; @@ -1349,6 +1356,11 @@ class StableDiffusionGGML { // unknown model return; } + } else if (dim == 32) { + if (sd_version_is_flux2(version)) { + latent_rgb_proj = flux2_latent_rgb_proj; + latent_rgb_bias = flux2_latent_rgb_bias; + } } else if (dim == 16) { // 16 channels VAE -> Flux or SD3 @@ -1393,12 +1405,15 @@ class StableDiffusionGGML { frames = latents->ne[2]; } - uint8_t* data = (uint8_t*)malloc(frames * width * height * channel * sizeof(uint8_t)); + uint32_t img_width = width * patch_sz; + uint32_t img_height = height * patch_sz; + + uint8_t* data = (uint8_t*)malloc(frames * img_width * img_height * channel * sizeof(uint8_t)); - preview_latent_video(data, latents, latent_rgb_proj, latent_rgb_bias, width, height, frames, dim); + preview_latent_video(data, latents, latent_rgb_proj, latent_rgb_bias, patch_sz); sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t)); for (int i = 0; i < frames; i++) { - images[i] = {width, height, channel, data + i * width * height * channel}; + images[i] = {img_width, img_height, channel, data + i * img_width * img_height * channel}; } step_callback(step, frames, images, is_noisy, step_callback_data); free(data); @@ -1942,12 +1957,12 @@ class StableDiffusionGGML { -0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f, 0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f}; latents_std_vec = { - 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, - 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, - 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, - 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, - 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, - 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; + 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, + 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, + 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, + 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, + 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, + 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; } else if (latent->ne[channel_dim] == 128) { // flux2 latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f, @@ -1967,22 +1982,22 @@ class StableDiffusionGGML { -0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f, -0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f}; latents_std_vec = { - 1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f, - 1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f, - 1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f, - 1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f, - 1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f, - 1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f, - 1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f, - 1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f, - 1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f, - 1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f, - 1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f, - 1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f, - 1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f, - 1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f, - 1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f, - 1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f}; + 1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f, + 1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f, + 1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f, + 1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f, + 1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f, + 1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f, + 1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f, + 1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f, + 1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f, + 1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f, + 1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f, + 1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f, + 1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f, + 1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f, + 1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f, + 1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f}; } } @@ -2094,12 +2109,12 @@ class StableDiffusionGGML { } ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) { - int64_t t0 = ggml_time_ms(); - ggml_tensor* result = nullptr; + int64_t t0 = ggml_time_ms(); + ggml_tensor* result = nullptr; const int vae_scale_factor = get_vae_scale_factor(); int W = x->ne[0] / vae_scale_factor; int H = x->ne[1] / vae_scale_factor; - int C = get_latent_channel(); + int C = get_latent_channel(); if (vae_tiling_params.enabled && !encode_video) { // TODO wan2.2 vae support? int ne2; @@ -2224,8 +2239,8 @@ class StableDiffusionGGML { const int vae_scale_factor = get_vae_scale_factor(); int64_t W = x->ne[0] * vae_scale_factor; int64_t H = x->ne[1] * vae_scale_factor; - int64_t C = 3; - ggml_tensor* result = nullptr; + int64_t C = 3; + ggml_tensor* result = nullptr; if (decode_video) { int T = x->ne[2]; if (sd_version_is_wan(version)) {