diff --git a/src/scope/core/pipelines/longlive/pipeline.py b/src/scope/core/pipelines/longlive/pipeline.py index ea472085..380d2374 100644 --- a/src/scope/core/pipelines/longlive/pipeline.py +++ b/src/scope/core/pipelines/longlive/pipeline.py @@ -221,6 +221,12 @@ def _generate(self, **kwargs) -> torch.Tensor: if "vace_ref_images" not in kwargs: self.state.set("vace_ref_images", None) + # Clear extension mode frame images from state if not provided to prevent reuse on non-extension chunks + if "first_frame_image" not in kwargs: + self.state.set("first_frame_image", None) + if "last_frame_image" not in kwargs: + self.state.set("last_frame_image", None) + if self.state.get("denoising_step_list") is None: self.state.set("denoising_step_list", DEFAULT_DENOISING_STEP_LIST) diff --git a/src/scope/core/pipelines/longlive/test_vace.py b/src/scope/core/pipelines/longlive/test_vace.py index e78270ed..2d81c60f 100644 --- a/src/scope/core/pipelines/longlive/test_vace.py +++ b/src/scope/core/pipelines/longlive/test_vace.py @@ -2,9 +2,13 @@ Unified test script for LongLive pipeline with VACE integration. Supports multiple modes: -- R2V: Reference-to-Video generation using reference images +- R2V: Reference-to-Video generation using reference images (condition only) - Depth guidance: Structural guidance using depth maps - Inpainting: Masked video-to-video generation +- Extension: Temporal extension with reference frames in output + * firstframe: first_frame_image at start of first chunk, generate continuation + * lastframe: Generate lead-up, last_frame_image at end of last chunk + * firstlastframe: first_frame_image at start AND last_frame_image at end Modes can be combined: - R2V + Depth @@ -12,6 +16,11 @@ - Depth only - Inpainting only - R2V only +- Extension only + +Key distinction - R2V vs Extension: +- R2V (ref_images): Reference images condition the entire video but DON'T appear in output +- Extension (first_frame_image/last_frame_image): Reference frames ARE in output video Usage: Edit the CONFIG dictionary below to enable/disable modes and set paths. @@ -35,11 +44,12 @@ CONFIG = { # ===== MODE SELECTION ===== - "use_r2v": True, # Reference-to-Video: condition on reference images + "use_r2v": False, # Reference-to-Video: condition on reference images "use_depth": False, # Depth guidance: structural control via depth maps "use_inpainting": False, # Inpainting: masked video-to-video generation + "use_extension": True, # Extension mode: temporal generation (firstframe/lastframe/firstlastframe) # ===== INPUT PATHS ===== - # R2V: List of reference image paths + # R2V: List of reference image paths (condition entire video, don't appear in output) "ref_images": [ "frontend/public/assets/example.png", # path/to/image.png ], @@ -48,16 +58,21 @@ # Inpainting: Input video and mask video paths "input_video": "frontend/public/assets/test.mp4", # path/to/input_video.mp4 "mask_video": "vace_tests/circle_mask.mp4", # path/to/mask_video.mp4 + # Extension: Frame images (appear in output video as actual frames) + "first_frame_image": "frontend/public/assets/woman1.jpg", # For firstframe or firstlastframe modes + "last_frame_image": "frontend/public/assets/woman2.jpg", # For lastframe or firstlastframe modes + "extension_mode": "firstlastframe", # "firstframe", "lastframe", or "firstlastframe" # ===== GENERATION PARAMETERS ===== "prompt": None, # Set to override mode-specific prompts, or None to use defaults "prompt_r2v": "", # Default prompt for R2V mode "prompt_depth": "a cat walking towards the camera", # Default prompt for depth mode "prompt_inpainting": "a fireball", # Default prompt for inpainting mode - "num_chunks": 3, # Number of generation chunks + "prompt_extension": "", # Default prompt for extension mode + "num_chunks": 2, # Number of generation chunks "frames_per_chunk": 12, # Frames per chunk (12 = 3 latent * 4 temporal upsample) "height": 512, "width": 512, - "vace_context_scale": 0.7, # VACE conditioning strength (0.0-1.0) + "vace_context_scale": 1.5, # VACE conditioning strength # ===== INPAINTING SPECIFIC ===== "mask_threshold": 0.5, # Threshold for binarizing mask (0-1) "mask_value": 127, # Gray value for masked regions (0-255) @@ -416,18 +431,27 @@ def main(): use_r2v = config["use_r2v"] use_depth = config["use_depth"] use_inpainting = config["use_inpainting"] + use_extension = config["use_extension"] # Validate mode selection if use_depth and use_inpainting: raise ValueError("Cannot use both depth and inpainting modes simultaneously") - if not (use_r2v or use_depth or use_inpainting): + if use_extension and (use_depth or use_inpainting): + raise ValueError( + "Extension mode cannot be combined with depth or inpainting modes" + ) + + if not (use_r2v or use_depth or use_inpainting or use_extension): raise ValueError("At least one mode must be enabled") # Select appropriate prompt based on mode if config["prompt"] is not None: # User override prompt = config["prompt"] + elif use_extension: + # Extension mode + prompt = config["prompt_extension"] elif use_inpainting: # Inpainting takes priority prompt = config["prompt_inpainting"] @@ -442,6 +466,9 @@ def main(): print(f" R2V: {use_r2v}") print(f" Depth Guidance: {use_depth}") print(f" Inpainting: {use_inpainting}") + print(f" Extension: {use_extension}") + if use_extension: + print(f" Mode: {config['extension_mode']}") print(f" Prompt: '{prompt}'") print(f" Chunks: {config['num_chunks']} x {config['frames_per_chunk']} frames") print(f" Resolution: {config['height']}x{config['width']}") @@ -474,7 +501,7 @@ def main(): ), "lora_path": str(get_model_file_path("LongLive-1.3B/models/lora.pt")), "vace_path": vace_path - if (use_r2v or use_depth or use_inpainting) + if (use_r2v or use_depth or use_inpainting or use_extension) else None, "text_encoder_path": str( get_model_file_path( @@ -490,8 +517,8 @@ def main(): } ) - # Set vace_in_dim for depth/inpainting modes - if use_depth or use_inpainting: + # Set vace_in_dim for depth/inpainting/extension modes (all use masked encoding: 32 + 64 = 96 channels) + if use_depth or use_inpainting or use_extension: pipeline_config.model_config.base_model_kwargs = ( pipeline_config.model_config.base_model_kwargs or {} ) @@ -504,6 +531,8 @@ def main(): total_frames = config["num_chunks"] * config["frames_per_chunk"] ref_images = None + first_frame_image = None + last_frame_image = None depth_video = None input_video_tensor = None mask_tensor = None @@ -529,6 +558,35 @@ def main(): use_r2v = False print() + # Load frame images for Extension mode + if use_extension: + print("=== Preparing Extension Inputs ===") + extension_mode = config["extension_mode"] + + # Load first_frame_image if needed + if extension_mode in ["firstframe", "firstlastframe"]: + first_frame_path = resolve_path(config["first_frame_image"], project_root) + if first_frame_path.exists(): + first_frame_image = str(first_frame_path) + print(f" First frame image: {first_frame_path}") + else: + raise FileNotFoundError( + f"First frame image not found: {first_frame_path}" + ) + + # Load last_frame_image if needed + if extension_mode in ["lastframe", "firstlastframe"]: + last_frame_path = resolve_path(config["last_frame_image"], project_root) + if last_frame_path.exists(): + last_frame_image = str(last_frame_path) + print(f" Last frame image: {last_frame_path}") + else: + raise FileNotFoundError( + f"Last frame image not found: {last_frame_path}" + ) + + print() + # Load depth video if use_depth: print("=== Preparing Depth Inputs ===") @@ -631,11 +689,39 @@ def main(): "vace_context_scale": config["vace_context_scale"], } - # Add R2V reference images (first chunk only) + # Add R2V reference images (first chunk only, R2V conditions entire video) if use_r2v and is_first_chunk and ref_images: kwargs["vace_ref_images"] = ref_images print(f"Chunk {chunk_index}: Using {len(ref_images)} reference image(s)") + # Add Extension mode parameters + # Extension applies per-chunk: first chunk for firstframe/firstlastframe, last for lastframe/firstlastframe + if use_extension: + extension_mode = config["extension_mode"] + is_last_chunk = chunk_index == config["num_chunks"] - 1 + + # Determine if we should apply extension for this chunk + apply_extension = False + if extension_mode == "firstframe" and is_first_chunk: + apply_extension = True + elif extension_mode == "lastframe" and is_last_chunk: + apply_extension = True + elif extension_mode == "firstlastframe" and ( + is_first_chunk or is_last_chunk + ): + apply_extension = True + + if apply_extension: + kwargs["extension_mode"] = extension_mode + if first_frame_image is not None: + kwargs["first_frame_image"] = first_frame_image + if last_frame_image is not None: + kwargs["last_frame_image"] = last_frame_image + print( + f"Chunk {chunk_index}: Extension mode={extension_mode}, " + f"chunk={'first' if is_first_chunk else 'last' if is_last_chunk else 'middle'}" + ) + # Add depth guidance if use_depth: depth_frames_available = depth_video.shape[2] @@ -710,6 +796,47 @@ def main(): print(f"\nFinal output shape: {output_video.shape}") + # DIAGNOSTIC: Check final frames in extension mode + if use_extension: + print("\n=== Extension Mode Diagnostics ===") + extension_mode = config["extension_mode"] + print(f"Extension mode: {extension_mode}") + + if extension_mode == "firstframe" or extension_mode == "firstlastframe": + print( + f"First frame value range: [{output_video[0].min():.3f}, {output_video[0].max():.3f}]" + ) + print(f"First frame mean: {output_video[0].mean():.3f}") + + if extension_mode == "lastframe" or extension_mode == "firstlastframe": + last_frame_idx = output_video.shape[0] - 1 + print( + f"Last frame (index {last_frame_idx}) value range: [{output_video[last_frame_idx].min():.3f}, {output_video[last_frame_idx].max():.3f}]" + ) + print(f"Last frame mean: {output_video[last_frame_idx].mean():.3f}") + + # Compare with reference image if available + if extension_mode == "lastframe" and last_frame_image: + print(f"\nLoading reference image for comparison: {last_frame_image}") + from PIL import Image + + ref_img = Image.open(last_frame_image).convert("RGB") + ref_img_resized = ref_img.resize((config["width"], config["height"])) + ref_np = np.array(ref_img_resized).astype(np.float32) / 255.0 + print( + f"Reference image value range: [{ref_np.min():.3f}, {ref_np.max():.3f}]" + ) + print(f"Reference image mean: {ref_np.mean():.3f}") + + # Compare last frame with reference + last_frame_np = output_video[last_frame_idx].numpy() + diff = np.abs(last_frame_np - ref_np) + print("Absolute difference between last frame and reference:") + print(f" Mean: {diff.mean():.3f}") + print(f" Max: {diff.max():.3f}") + print(f" Min: {diff.min():.3f}") + print() + # Save output video output_video_np = output_video.contiguous().numpy() output_video_np = np.clip(output_video_np, 0.0, 1.0) @@ -721,6 +848,8 @@ def main(): mode_suffix.append("depth") if use_inpainting: mode_suffix.append("inpainting") + if use_extension: + mode_suffix.append(f"extension_{config['extension_mode']}") output_filename = f"output_{'_'.join(mode_suffix)}.mp4" output_path = output_dir / output_filename @@ -795,6 +924,15 @@ def main(): print("Used depth maps for structural guidance") if use_inpainting: print("Used spatial masks for inpainting control") + if use_extension: + extension_info = [] + if first_frame_image: + extension_info.append("first frame") + if last_frame_image: + extension_info.append("last frame") + print( + f"Used extension mode '{config['extension_mode']}' with {' and '.join(extension_info)}" + ) if __name__ == "__main__": diff --git a/src/scope/core/pipelines/longlive/test_vace_extension_scale.py b/src/scope/core/pipelines/longlive/test_vace_extension_scale.py new file mode 100644 index 00000000..e61b6a86 --- /dev/null +++ b/src/scope/core/pipelines/longlive/test_vace_extension_scale.py @@ -0,0 +1,507 @@ +""" +Extension mode test script with scaling VACE context scale. + +This script tests extension mode with a VACE context scale that scales from +minimum to maximum across N chunks: +- First chunk: Uses 'firstframe' mode with first_frame_image +- Subsequent chunks (1 to N-1): Uses 'lastframe' mode with last_frame_image, + with VACE context scale scaling from min to max + +Usage: + Edit the CONFIG dictionary below to set paths and parameters. + python -m scope.core.pipelines.longlive.test_vace_extension_scale +""" + +import time +from pathlib import Path + +import numpy as np +import torch +from diffusers.utils import export_to_video +from omegaconf import OmegaConf +from PIL import Image, ImageDraw, ImageFont + +from scope.core.config import get_model_file_path, get_models_dir + +from .pipeline import LongLivePipeline + +# ============================= CONFIGURATION ============================= + +CONFIG = { + # ===== INPUT PATHS ===== + "first_frame_image": "frontend/public/assets/example.png", # First frame reference + "last_frame_image": "frontend/public/assets/example1.png", # Last frame reference + # ===== GENERATION PARAMETERS ===== + "prompt": "", # Text prompt (can be empty for extension mode) + "num_chunks": 8, # Number of generation chunks + "frames_per_chunk": 12, # Frames per chunk (12 = 3 latent * 4 temporal upsample) + "height": 512, + "width": 512, + # ===== VACE CONTEXT SCALE PARAMETERS ===== + "vace_context_scale_min": 0.0, # Minimum VACE context scale (first chunk) + "vace_context_scale_max": 0.5, # Maximum VACE context scale (last chunk) + "interpolation_mode": "weak_middle", # Interpolation mode: "linear", "ease_in", "ease_out", "ease_in_out", "exponential", "logarithmic", "cosine", "strong_middle", "weak_middle" + # ===== OUTPUT ===== + "output_dir": "vace_tests/extension_scale", # path/to/output_dir +} + +# ========================= END CONFIGURATION ========================= + +# ============================= UTILITIES ============================= + + +def resolve_path(path_str: str, relative_to: Path) -> Path: + """Resolve path relative to a base directory or as absolute.""" + path = Path(path_str) + if path.is_absolute(): + return path + return (relative_to / path).resolve() + + +def add_vace_scale_overlay( + frames: np.ndarray, + vace_scales_per_chunk: list[float], + frames_per_chunk: int, +) -> np.ndarray: + """ + Add VACE context scale overlay to video frames. + + Args: + frames: Video frames [F, H, W, C] in [0, 1] + vace_scales_per_chunk: List of VACE scales, one per chunk + frames_per_chunk: Number of frames per chunk + + Returns: + Frames with overlay [F, H, W, C] in [0, 1] + """ + num_frames, height, width, channels = frames.shape + overlayed_frames = [] + + # Try to load a font, fall back to default if not available + try: + # Try to use a larger font if available + font = ImageFont.truetype("arial.ttf", size=24) + except OSError: + try: + font = ImageFont.truetype( + "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", size=24 + ) + except OSError: + # Fall back to default font + font = ImageFont.load_default() + + for frame_idx in range(num_frames): + # Determine which chunk this frame belongs to + chunk_idx = frame_idx // frames_per_chunk + if chunk_idx >= len(vace_scales_per_chunk): + chunk_idx = len(vace_scales_per_chunk) - 1 + + vace_scale = vace_scales_per_chunk[chunk_idx] + + # Convert frame to PIL Image + frame_uint8 = (frames[frame_idx] * 255).astype(np.uint8) + pil_image = Image.fromarray(frame_uint8) + + # Create draw context + draw = ImageDraw.Draw(pil_image) + + # Prepare text + text = f"VACE Scale: {vace_scale:.3f}" + chunk_text = f"Chunk: {chunk_idx}" + + # Get text bounding boxes + bbox = draw.textbbox((0, 0), text, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + + chunk_bbox = draw.textbbox((0, 0), chunk_text, font=font) + chunk_text_width = chunk_bbox[2] - chunk_bbox[0] + chunk_text_height = chunk_bbox[3] - chunk_bbox[1] + + # Position text at top-left with padding + padding = 10 + text_x = padding + text_y = padding + + # Draw semi-transparent background rectangles + overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + overlay_draw = ImageDraw.Draw(overlay) + + # Background for scale text + overlay_draw.rectangle( + [ + text_x - 5, + text_y - 5, + text_x + text_width + 5, + text_y + text_height + 5, + ], + fill=(0, 0, 0, 180), # Semi-transparent black + ) + + # Background for chunk text + overlay_draw.rectangle( + [ + text_x - 5, + text_y + text_height + 5, + text_x + chunk_text_width + 5, + text_y + text_height + chunk_text_height + 10, + ], + fill=(0, 0, 0, 180), # Semi-transparent black + ) + + # Composite overlay onto image + pil_image = Image.alpha_composite(pil_image.convert("RGBA"), overlay).convert( + "RGB" + ) + + # Redraw text on the composited image + draw = ImageDraw.Draw(pil_image) + draw.text((text_x, text_y), text, fill=(255, 255, 255), font=font) + draw.text( + (text_x, text_y + text_height + 5), + chunk_text, + fill=(255, 255, 255), + font=font, + ) + + # Convert back to numpy + frame_overlayed = np.array(pil_image).astype(np.float32) / 255.0 + overlayed_frames.append(frame_overlayed) + + return np.array(overlayed_frames) + + +def apply_interpolation(t: float, mode: str) -> float: + """ + Apply interpolation function to normalized time t [0, 1]. + + Args: + t: Normalized time (0 to 1) + mode: Interpolation mode string + + Returns: + Interpolated value [0, 1] + """ + if mode == "linear": + return t + elif mode == "ease_in": + # Quadratic ease-in: slow start, fast end + return t * t + elif mode == "ease_out": + # Quadratic ease-out: fast start, slow end + return 1 - (1 - t) * (1 - t) + elif mode == "ease_in_out": + # Quadratic ease-in-out: slow start and end, fast middle + if t < 0.5: + return 2 * t * t + else: + return 1 - 2 * (1 - t) * (1 - t) + elif mode == "exponential": + # Exponential: very slow start, very fast end + if t == 0: + return 0 + if t == 1: + return 1 + # Normalized exponential: (2^(10*t) - 1) / (2^10 - 1) + return (2 ** (10 * t) - 1) / (2**10 - 1) + elif mode == "logarithmic": + # Logarithmic: fast start, very slow end + if t == 0: + return 0 + if t == 1: + return 1 + # Normalized logarithmic: log(10*t + 1) / log(11) + return np.log(10 * t + 1) / np.log(11) + elif mode == "cosine": + # Cosine: smooth S-curve + return 1 - np.cos(t * np.pi / 2) + elif mode == "strong_middle": + # Strong middle: emphasizes middle values (higher in middle) + # Uses a curve that peaks in the middle while still reaching 1 at the end + if t == 0: + return 0 + if t == 1: + return 1 + # Use sin(πt/2) as base, then add a bell curve centered at 0.5 + # The bell curve is sin(πt) which peaks at 0.5 + base = np.sin(t * np.pi / 2) + bell = np.sin(t * np.pi) + # Blend: more bell in middle, ensuring we don't exceed bounds + # At t=0.5: base ≈ 0.707, bell = 1, so blend ≈ 0.707 + 0.2*1 = 0.907 + # At t=1: base = 1, bell = 0, so blend = 1 + return base + 0.2 * bell * (1 - base) + elif mode == "weak_middle": + # Weak middle: de-emphasizes middle values (lower in middle) + # Uses a curve that dips in the middle while still reaching 1 at the end + if t == 0: + return 0 + if t == 1: + return 1 + # Linear base with middle dip: t * (1 - 0.5*sin(πt)) + # sin(πt) is 0 at t=0,1 and peaks at t=0.5, so this creates a dip + return t * (1 - 0.5 * np.sin(t * np.pi)) + else: + raise ValueError(f"Unknown interpolation mode: {mode}") + + +def calculate_vace_scale( + chunk_index: int, + num_chunks: int, + scale_min: float, + scale_max: float, + interpolation_mode: str = "linear", +) -> float: + """ + Calculate VACE context scale for a given chunk. + + Scales from scale_min to scale_max using the specified interpolation mode. + + Args: + chunk_index: Current chunk index (0-based) + num_chunks: Total number of chunks + scale_min: Minimum scale value (first chunk) + scale_max: Maximum scale value (last chunk) + interpolation_mode: Interpolation mode ("linear", "ease_in", "ease_out", etc.) + + Returns: + VACE context scale for this chunk + """ + if num_chunks == 1: + return scale_max + + # Normalized time from 0 to 1 + t = chunk_index / (num_chunks - 1) + + # Apply interpolation function + t_interpolated = apply_interpolation(t, interpolation_mode) + + # Scale from min to max + scale = scale_min + t_interpolated * (scale_max - scale_min) + return scale + + +# ============================= MAIN ============================= + + +def main(): + print("=" * 80) + print(" LongLive Extension Mode - Scaling VACE Context Scale Test") + print("=" * 80) + + # Parse configuration + config = CONFIG + + print("\nConfiguration:") + print( + f" Extension mode: firstframe (chunk 0), then lastframe (chunks 1-{config['num_chunks'] - 1})" + ) + print(f" Prompt: '{config['prompt']}'") + print(f" Chunks: {config['num_chunks']} x {config['frames_per_chunk']} frames") + print(f" Resolution: {config['height']}x{config['width']}") + print( + f" VACE Scale: {config['vace_context_scale_min']} -> {config['vace_context_scale_max']}" + ) + print(f" Interpolation: {config['interpolation_mode']}") + + # Setup paths + script_dir = Path(__file__).parent + project_root = script_dir.parent.parent.parent.parent.parent + output_dir = resolve_path(config["output_dir"], script_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + print(f" Output: {output_dir}") + + # Setup device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f" Device: {device}\n") + + # Initialize pipeline + print("Initializing pipeline...") + + vace_path = str( + get_model_file_path("Wan2.1-VACE-1.3B/diffusion_pytorch_model.safetensors") + ) + + pipeline_config = OmegaConf.create( + { + "model_dir": str(get_models_dir()), + "generator_path": str( + get_model_file_path("LongLive-1.3B/models/longlive_base.pt") + ), + "lora_path": str(get_model_file_path("LongLive-1.3B/models/lora.pt")), + "vace_path": vace_path, + "text_encoder_path": str( + get_model_file_path( + "WanVideo_comfy/umt5-xxl-enc-fp8_e4m3fn.safetensors" + ) + ), + "tokenizer_path": str( + get_model_file_path("Wan2.1-T2V-1.3B/google/umt5-xxl") + ), + "model_config": OmegaConf.load(script_dir / "model.yaml"), + "height": config["height"], + "width": config["width"], + } + ) + + # Set vace_in_dim for extension mode (masked encoding: 32 + 64 = 96 channels) + pipeline_config.model_config.base_model_kwargs = ( + pipeline_config.model_config.base_model_kwargs or {} + ) + pipeline_config.model_config.base_model_kwargs["vace_in_dim"] = 96 + + pipeline = LongLivePipeline(pipeline_config, device=device, dtype=torch.bfloat16) + print("Pipeline ready\n") + + # Load frame images for Extension mode + print("=== Preparing Extension Inputs ===") + + # Load first_frame_image + first_frame_path = resolve_path(config["first_frame_image"], project_root) + if not first_frame_path.exists(): + raise FileNotFoundError(f"First frame image not found: {first_frame_path}") + first_frame_image = str(first_frame_path) + print(f" First frame image: {first_frame_path}") + + # Load last_frame_image + last_frame_path = resolve_path(config["last_frame_image"], project_root) + if not last_frame_path.exists(): + raise FileNotFoundError(f"Last frame image not found: {last_frame_path}") + last_frame_image = str(last_frame_path) + print(f" Last frame image: {last_frame_path}") + print() + + # Generate video + print("=== Generating Video ===") + outputs = [] + latency_measures = [] + fps_measures = [] + vace_scales_used = [] + + frames_per_chunk = config["frames_per_chunk"] + num_chunks = config["num_chunks"] + + for chunk_index in range(num_chunks): + start_time = time.time() + + # Determine if this is first chunk + is_first_chunk = chunk_index == 0 + + # Prepare pipeline kwargs + kwargs = { + "prompts": [{"text": config["prompt"], "weight": 100}], + } + + # First chunk: use first_frame_image with scale of 1.0 + if is_first_chunk: + kwargs["extension_mode"] = "firstframe" + kwargs["first_frame_image"] = first_frame_image + kwargs["vace_context_scale"] = 1.0 + vace_scales_used.append(1.0) + extension_info = "first" + else: + # Subsequent chunks: use last_frame_image with scaling VACE context scale + # Scale from min to max across chunks 1 to N-1 + # For chunk i (where i >= 1), scale from 0 to 1 across remaining chunks + num_subsequent_chunks = num_chunks - 1 + chunk_position = ( + chunk_index - 1 + ) # Position within subsequent chunks (0 to num_subsequent_chunks - 1) + + vace_scale = calculate_vace_scale( + chunk_position, + num_subsequent_chunks, + config["vace_context_scale_min"], + config["vace_context_scale_max"], + config["interpolation_mode"], + ) + vace_scales_used.append(vace_scale) + + kwargs["extension_mode"] = "lastframe" + kwargs["last_frame_image"] = last_frame_image + kwargs["vace_context_scale"] = vace_scale + extension_info = f"last (scale={vace_scale:.3f})" + + print( + f"Chunk {chunk_index}/{num_chunks - 1}: " + f"VACE scale={kwargs['vace_context_scale']:.3f}, " + f"frames={frames_per_chunk}, " + f"extension={extension_info}" + ) + + # Generate + output = pipeline(**kwargs) + + # Metrics + num_output_frames, _, _, _ = output.shape + latency = time.time() - start_time + fps = num_output_frames / latency + + print( + f" Generated {num_output_frames} frames, " + f"latency={latency:.2f}s, fps={fps:.2f}" + ) + + latency_measures.append(latency) + fps_measures.append(fps) + outputs.append(output.detach().cpu()) + + # Concatenate outputs + output_video = torch.concat(outputs) + + print(f"\nFinal output shape: {output_video.shape}") + + # Convert to numpy and clip + # output_video is already [F, H, W, C] from pipeline + output_video_np = output_video.contiguous().numpy() + output_video_np = np.clip(output_video_np, 0.0, 1.0) + + # Add VACE scale overlay + print("Adding VACE scale overlay to frames...") + output_video_np = add_vace_scale_overlay( + output_video_np, + vace_scales_used, + frames_per_chunk, + ) + + output_filename = ( + f"output_extension_scale_" + f"{config['vace_context_scale_min']:.2f}to{config['vace_context_scale_max']:.2f}_" + f"{config['interpolation_mode']}_" + f"{num_chunks}chunks.mp4" + ) + output_path = output_dir / output_filename + export_to_video(output_video_np, output_path, fps=16) + + print(f"\nSaved output: {output_path}") + + # Statistics + print("\n=== Performance Statistics ===") + print( + f"Latency - Avg: {sum(latency_measures) / len(latency_measures):.2f}s, " + f"Max: {max(latency_measures):.2f}s, " + f"Min: {min(latency_measures):.2f}s" + ) + print( + f"FPS - Avg: {sum(fps_measures) / len(fps_measures):.2f}, " + f"Max: {max(fps_measures):.2f}, " + f"Min: {min(fps_measures):.2f}" + ) + + print("\n=== VACE Context Scale Progression ===") + for chunk_idx, scale in enumerate(vace_scales_used): + print(f" Chunk {chunk_idx}: {scale:.4f}") + + print("\n" + "=" * 80) + print(" Test Complete") + print("=" * 80) + print(f"\nResults saved to: {output_dir}") + print(f"Main output: {output_filename}") + print( + f"VACE context scale scaled from {config['vace_context_scale_min']:.2f} " + f"to {config['vace_context_scale_max']:.2f} across {num_chunks} chunks" + ) + + +if __name__ == "__main__": + main() diff --git a/src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py b/src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py index b3f2a100..396a60a8 100644 --- a/src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py +++ b/src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py @@ -30,7 +30,12 @@ OutputParam, ) -from ..utils.encoding import load_and_prepare_reference_images +from ..utils.encoding import ( + load_and_prepare_reference_images, + vace_encode_frames, + vace_encode_masks, + vace_latent, +) logger = logging.getLogger(__name__) @@ -96,6 +101,16 @@ def inputs(self) -> list[InputParam]: type_hint=torch.Tensor | None, description="Spatial control masks [B, 1, F, H, W]: defines WHERE to apply conditioning (white=generate, black=preserve). Defaults to ones (all white) when None. Works with any vace_input_frames type.", ), + InputParam( + "first_frame_image", + default=None, + description="Path to first frame reference image for extension mode. When provided alone, enables 'firstframe' mode (ref at start, generate after). When provided with last_frame_image, enables 'firstlastframe' mode (refs at both ends).", + ), + InputParam( + "last_frame_image", + default=None, + description="Path to last frame reference image for extension mode. When provided alone, enables 'lastframe' mode (generate before, ref at end). When provided with first_frame_image, enables 'firstlastframe' mode (refs at both ends).", + ), InputParam( "height", type_hint=int, @@ -135,22 +150,40 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState vace_ref_images = block_state.vace_ref_images vace_input_frames = block_state.vace_input_frames + first_frame_image: str | None = block_state.first_frame_image + last_frame_image: str | None = block_state.last_frame_image current_start = block_state.current_start_frame - # If neither input is provided, skip VACE conditioning - if ( - vace_ref_images is None or len(vace_ref_images) == 0 - ) and vace_input_frames is None: + # If no inputs provided, skip VACE conditioning + has_ref_images = vace_ref_images is not None and len(vace_ref_images) > 0 + has_input_frames = vace_input_frames is not None + has_first_frame = first_frame_image is not None + has_last_frame = last_frame_image is not None + has_extension = has_first_frame or has_last_frame + + if not has_ref_images and not has_input_frames and not has_extension: block_state.vace_context = None block_state.vace_ref_images = None self.set_block_state(state, block_state) return components, state # Determine encoding path based on what's provided (implicit mode detection) - has_ref_images = vace_ref_images is not None and len(vace_ref_images) > 0 - has_input_frames = vace_input_frames is not None + if has_extension: + # Extension mode: Generate frames before/after reference frame(s) + # Mode is inferred from which frame images are provided + if has_first_frame and has_last_frame: + extension_mode = "firstlastframe" + elif has_first_frame: + extension_mode = "firstframe" + else: + extension_mode = "lastframe" - if has_input_frames: + block_state.vace_context, block_state.vace_ref_images = ( + self._encode_extension_mode( + components, block_state, current_start, extension_mode + ) + ) + elif has_input_frames: # Standard VACE path: conditioning input (depth, flow, pose, etc.) # with optional reference images block_state.vace_context, block_state.vace_ref_images = ( @@ -237,6 +270,162 @@ def _encode_reference_only(self, components, block_state, current_start): # Return original paths, not tensors, so they can be reused in subsequent chunks return vace_context, ref_image_paths + def _encode_extension_mode( + self, components, block_state, current_start, extension_mode: str + ): + """ + Encode VACE context with reference frames and dummy frames for temporal extension. + + Loads reference image based on extension_mode (inferred from provided images), + replicates it across a temporal group, fills remaining frames with zeros (dummy frames), + and encodes with masks indicating which frames to inpaint (1=dummy, 0=reference). + + Args: + extension_mode: Inferred mode ('firstframe', 'lastframe', or 'firstlastframe') + """ + first_frame_image = block_state.first_frame_image + last_frame_image = block_state.last_frame_image + + images_to_load = [] + if extension_mode == "firstframe": + images_to_load = [first_frame_image] + elif extension_mode == "lastframe": + images_to_load = [last_frame_image] + elif extension_mode == "firstlastframe": + if current_start == 0: + images_to_load = [first_frame_image] + else: + images_to_load = [last_frame_image] + + prepared_refs = load_and_prepare_reference_images( + images_to_load, + block_state.height, + block_state.width, + components.config.device, + ) + + if hasattr(components, "vace_vae") and components.vace_vae is not None: + vace_vae = components.vace_vae + else: + vace_vae = components.vae + + vae_dtype = next(vace_vae.parameters()).dtype + + num_frames = ( + components.config.num_frame_per_block + * components.config.vae_temporal_downsample_factor + ) + + ref_at_start = extension_mode == "firstframe" or ( + extension_mode == "firstlastframe" and current_start == 0 + ) + + frames, masks = self._build_extension_frames_and_masks( + prepared_refs=prepared_refs, + num_frames=num_frames, + temporal_group_size=components.config.vae_temporal_downsample_factor, + ref_at_start=ref_at_start, + device=components.config.device, + dtype=vae_dtype, + height=block_state.height, + width=block_state.width, + ) + + frames_to_encode = [frames] + masks_to_encode = [masks] + + z0 = vace_encode_frames( + vae=vace_vae, + frames=frames_to_encode, + ref_images=[None], + masks=masks_to_encode, + pad_to_96=False, + use_cache=False, + ) + + vae_stride = ( + components.config.vae_temporal_downsample_factor, + components.config.vae_spatial_downsample_factor, + components.config.vae_spatial_downsample_factor, + ) + m0 = vace_encode_masks( + masks=masks_to_encode, + ref_images=[None], + vae_stride=vae_stride, + ) + + vace_context = vace_latent(z0, m0) + + logger.info( + f"_encode_extension_mode: mode={extension_mode}, current_start={current_start}, " + f"num_frames={num_frames}, vace_context_shape={vace_context[0].shape}" + ) + + return vace_context, prepared_refs + + def _build_extension_frames_and_masks( + self, + prepared_refs: list[torch.Tensor], + num_frames: int, + temporal_group_size: int, + ref_at_start: bool, + device: torch.device, + dtype: torch.dtype, + height: int, + width: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Build frames and masks for extension mode with reference frame replication. + + Args: + prepared_refs: List of prepared reference images [C, 1, H, W] + num_frames: Total number of frames to generate + temporal_group_size: Number of frames in a temporal VAE group + ref_at_start: True for firstframe mode, False for lastframe mode + device: Target device + dtype: Target dtype for frames + height: Frame height + width: Frame width + + Returns: + Tuple of (frames, masks) where: + - frames: [C, F, H, W] tensor with reference frames and dummy frames + - masks: [1, F, H, W] tensor with 0s for reference frames, 1s for dummy frames + """ + num_ref_frames = temporal_group_size + num_dummy_frames = num_frames - num_ref_frames + + # Replicate reference across temporal group to prevent dilution during VAE encoding + ref_replicated = prepared_refs[0].repeat(1, num_ref_frames, 1, 1) + + # Create dummy frames (zeros = gray in normalized space) + dummy_frames = torch.zeros( + (3, num_dummy_frames, height, width), device=device, dtype=dtype + ) + + # Create masks: 0 for reference frames (keep), 1 for dummy frames (inpaint) + ref_masks = torch.zeros( + (1, num_ref_frames, height, width), + device=device, + dtype=torch.float32, + ) + dummy_masks = torch.ones( + (1, num_dummy_frames, height, width), + device=device, + dtype=torch.float32, + ) + + if ref_at_start: + # firstframe: [ref, ref, ref, zeros, zeros, ...] + frames = torch.cat([ref_replicated, dummy_frames], dim=1) + masks = torch.cat([ref_masks, dummy_masks], dim=1) + else: + # lastframe: [zeros, zeros, ..., ref, ref, ref] + frames = torch.cat([dummy_frames, ref_replicated], dim=1) + masks = torch.cat([dummy_masks, ref_masks], dim=1) + + return frames, masks + def _encode_with_conditioning(self, components, block_state, current_start): """ Encode VACE input using the standard VACE path, with optional reference images. diff --git a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py index 1bcfb35c..9a4608c6 100644 --- a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py +++ b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py @@ -233,8 +233,14 @@ def forward_vace( crossattn_cache, ): """Process VACE context to generate hints.""" + # Get target dtype from vace_patch_embedding parameters + target_dtype = next(self.vace_patch_embedding.parameters()).dtype + + # Convert all VACE context to model dtype first + vace_context_converted = [u.to(dtype=target_dtype) for u in vace_context] + # Embed VACE context - c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context_converted] c = [u.flatten(2).transpose(1, 2) for u in c] # Pad to seq_len diff --git a/src/scope/core/pipelines/wan2_1/vace/utils/encoding.py b/src/scope/core/pipelines/wan2_1/vace/utils/encoding.py index 66d187b1..5e4ee0ac 100644 --- a/src/scope/core/pipelines/wan2_1/vace/utils/encoding.py +++ b/src/scope/core/pipelines/wan2_1/vace/utils/encoding.py @@ -31,7 +31,10 @@ def vace_encode_frames( masks: Optional list of masks [B, 1, F, H, W] for masked video generation pad_to_96: Whether to pad to 96 channels (default True). Set False when masks will be added later. use_cache: Whether to use streaming encode cache for frames (default True). - Set False for one-off encoding (e.g., reference images only mode). + Only applies when masks=None. When masks are provided, caching is + handled automatically based on mask content: + - All-1s masks (conditioning mode): both inactive/reactive use cache + - Mixed masks (inpainting mode): inactive uses cache, reactive doesn't Returns: List of concatenated latents [ref_latents + frame_latents] @@ -58,14 +61,24 @@ def vace_encode_frames( masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks, strict=False)] reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks, strict=False)] + inactive_stacked = torch.stack(inactive, dim=0).to(dtype=vae_dtype) reactive_stacked = torch.stack(reactive, dim=0).to(dtype=vae_dtype) - # Use cache=True to ensure cache consistency for both inactive and reactive portions + + # Auto-detect mode based on mask content and handle caching appropriately: + # - Conditioning mode (mask all 1s): inactive=zeros, reactive=content → both use cache + # - Inpainting mode (mixed mask): both have content → reactive skips cache to + # avoid interference from inactive encoding + is_conditioning_mode = all((m > 0.5).all() for m in masks) + inactive_out = vae.encode_to_latent(inactive_stacked, use_cache=True) - reactive_out = vae.encode_to_latent(reactive_stacked, use_cache=True) - # Transpose [B, F, C, H, W] -> [B, C, F, H, W] and concatenate along channel dim + reactive_out = vae.encode_to_latent( + reactive_stacked, use_cache=is_conditioning_mode + ) + inactive_transposed = [lat.permute(1, 0, 2, 3) for lat in inactive_out] reactive_transposed = [lat.permute(1, 0, 2, 3) for lat in reactive_out] + latents = [ torch.cat((u, c), dim=0) for u, c in zip(inactive_transposed, reactive_transposed, strict=False)