diff --git a/denoiser.hpp b/denoiser.hpp index 3b6be7552..96366ceef 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -671,6 +671,7 @@ static void sample_k_diffusion(sample_method_t method, ggml_context* work_ctx, ggml_tensor* x, std::vector sigmas, + int initial_step, std::shared_ptr rng, float eta) { size_t steps = sigmas.size() - 1; @@ -1248,12 +1249,13 @@ static void sample_k_diffusion(sample_method_t method, // - pred_sample_direction -> "direction pointing to // x_t" // - pred_prev_sample -> "x_t-1" - int timestep = - roundf(TIMESTEPS - - i * ((float)TIMESTEPS / steps)) - - 1; + int timestep = TIMESTEPS - 1 - + (int)roundf((initial_step + i) * + (TIMESTEPS / float(initial_step + steps))); // 1. get previous step value (=t-1) - int prev_timestep = timestep - TIMESTEPS / steps; + int prev_timestep = TIMESTEPS - 1 - + (int)roundf((initial_step + i + 1) * + (TIMESTEPS / float(initial_step + steps))); // The sigma here is chosen to cause the // CompVisDenoiser to produce t = timestep float sigma = compvis_sigmas[timestep]; @@ -1425,9 +1427,14 @@ static void sample_k_diffusion(sample_method_t method, // Analytic form for TCD timesteps int timestep = TIMESTEPS - 1 - (TIMESTEPS / original_steps) * - (int)floor(i * ((float)original_steps / steps)); + (int)floor((initial_step + i) * + ((float)original_steps / (initial_step + steps))); // 1. get previous step value - int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps)); + int prev_timestep = i >= steps - 1 ? 0 : + TIMESTEPS - 1 - + (TIMESTEPS / original_steps) * + (int)floor((initial_step + i + 1) * + ((float)original_steps / (initial_step + steps))); // Here timestep_s is tau_n' in Algorithm 4. The _s // notation appears to be that from C. Lu, // "DPM-Solver: A Fast ODE Solver for Diffusion diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 73065610d..f84d6c968 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1480,6 +1480,7 @@ class StableDiffusionGGML { int shifted_timestep, sample_method_t method, const std::vector& sigmas, + int initial_step, int start_merge_step, SDCondition id_cond, std::vector ref_latents = {}, @@ -1837,7 +1838,7 @@ class StableDiffusionGGML { return denoised; }; - sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta); + sample_k_diffusion(method, denoise, work_ctx, x, sigmas, initial_step, sampler_rng, eta); if (easycache_enabled) { size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0; @@ -2762,6 +2763,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int height, enum sample_method_t sample_method, const std::vector& sigmas, + int initial_step, int64_t seed, int batch_count, sd_image_t control_image, @@ -3056,6 +3058,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, shifted_timestep, sample_method, sigmas, + initial_step, start_merge_step, id_cond, ref_latents, @@ -3173,6 +3176,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_ctx->sd->get_image_seq_len(height, width), sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version); + int initial_step = 0; ggml_tensor* init_latent = nullptr; ggml_tensor* concat_latent = nullptr; @@ -3185,7 +3189,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g t_enc--; LOG_INFO("target t_enc is %zu steps", t_enc); std::vector sigma_sched; - sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end()); + initial_step = sample_steps - t_enc - 1; + sigma_sched.assign(sigmas.begin() + initial_step, sigmas.end()); sigmas = sigma_sched; ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); @@ -3373,6 +3378,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g height, sample_method, sigmas, + initial_step, seed, sd_img_gen_params->batch_count, sd_img_gen_params->control_image, @@ -3709,6 +3715,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_vid_gen_params->high_noise_sample_params.shifted_timestep, high_noise_sample_method, high_noise_sigmas, + 0, -1, {}, {}, @@ -3746,6 +3753,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_vid_gen_params->sample_params.shifted_timestep, sample_method, sigmas, + 0, -1, {}, {},