From 8d7e9730c3b02e0c55fc72f212b4b188e80f4642 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:03:19 -0500 Subject: [PATCH 1/4] feat: vae front end + lightvae Signed-off-by: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> --- frontend/src/components/SettingsPanel.tsx | 35 ++++++++++++++++ frontend/src/data/parameterMetadata.ts | 5 +++ frontend/src/hooks/useStreamState.ts | 26 +++++++++--- frontend/src/lib/api.ts | 21 ++++++++++ frontend/src/pages/StreamPage.tsx | 11 +++++ frontend/src/types/index.ts | 5 +++ .../pipelines/krea_realtime_video/pipeline.py | 14 ++++--- src/scope/core/pipelines/longlive/pipeline.py | 11 +++-- .../core/pipelines/wan2_1/vae/__init__.py | 29 ++++++++------ .../core/pipelines/wan2_1/vae/modules/vae.py | 33 ++++++++++++--- src/scope/core/pipelines/wan2_1/vae/wan.py | 40 ++++++++++++++++--- src/scope/server/app.py | 9 +++++ src/scope/server/pipeline_manager.py | 7 +++- src/scope/server/schema.py | 22 ++++++++++ 14 files changed, 228 insertions(+), 40 deletions(-) diff --git a/frontend/src/components/SettingsPanel.tsx b/frontend/src/components/SettingsPanel.tsx index b21d592bc..97ea92574 100644 --- a/frontend/src/components/SettingsPanel.tsx +++ b/frontend/src/components/SettingsPanel.tsx @@ -34,6 +34,7 @@ import type { SettingsState, InputMode, PipelineInfo, + VaeType, } from "../types"; import { LoRAManager } from "./LoRAManager"; @@ -87,6 +88,11 @@ interface SettingsPanelProps { onVaceEnabledChange?: (enabled: boolean) => void; vaceContextScale?: number; onVaceContextScaleChange?: (scale: number) => void; + // VAE type selection + vaeType?: VaeType; + onVaeTypeChange?: (vaeType: VaeType) => void; + // Available VAE types from backend registry + vaeTypes?: string[]; } export function SettingsPanel({ @@ -126,6 +132,9 @@ export function SettingsPanel({ onVaceEnabledChange, vaceContextScale = 1.0, onVaceContextScaleChange, + vaeType = "wan", + onVaeTypeChange, + vaeTypes = ["wan"], }: SettingsPanelProps) { // Local slider state management hooks const noiseScaleSlider = useLocalSliderValue(noiseScale, onNoiseScaleChange); @@ -717,6 +726,32 @@ export function SettingsPanel({ + +
+ + +
diff --git a/frontend/src/data/parameterMetadata.ts b/frontend/src/data/parameterMetadata.ts index 8f0be5adc..1d875f2fd 100644 --- a/frontend/src/data/parameterMetadata.ts +++ b/frontend/src/data/parameterMetadata.ts @@ -82,4 +82,9 @@ export const PARAMETER_METADATA: Record = { tooltip: "The configuration of the sender that will send video to Spout-compatible apps like TouchDesigner, Resolume, OBS.", }, + vaeType: { + label: "VAE:", + tooltip: + "VAE type to use for encoding/decoding. 'wan' is the full VAE with best quality. 'lightvae' is 75% pruned for faster performance but lower quality.", + }, }; diff --git a/frontend/src/hooks/useStreamState.ts b/frontend/src/hooks/useStreamState.ts index 5fa290775..0fbe5e991 100644 --- a/frontend/src/hooks/useStreamState.ts +++ b/frontend/src/hooks/useStreamState.ts @@ -10,6 +10,7 @@ import type { import { getHardwareInfo, getPipelineSchemas, + getVaeTypes, type HardwareInfoResponse, type PipelineSchemasResponse, } from "../lib/api"; @@ -160,14 +161,19 @@ export function useStreamState() { null ); - // Fetch pipeline schemas and hardware info on mount + // Store available VAE types from registry + const [vaeTypes, setVaeTypes] = useState(["wan"]); + + // Fetch pipeline schemas, hardware info, and VAE types on mount useEffect(() => { const fetchInitialData = async () => { try { - const [schemasResult, hardwareResult] = await Promise.allSettled([ - getPipelineSchemas(), - getHardwareInfo(), - ]); + const [schemasResult, hardwareResult, vaeTypesResult] = + await Promise.allSettled([ + getPipelineSchemas(), + getHardwareInfo(), + getVaeTypes(), + ]); if (schemasResult.status === "fulfilled") { const schemas = schemasResult.value; @@ -205,6 +211,15 @@ export function useStreamState() { hardwareResult.reason ); } + + if (vaeTypesResult.status === "fulfilled") { + setVaeTypes(vaeTypesResult.value.vae_types); + } else { + console.error( + "useStreamState: Failed to fetch VAE types:", + vaeTypesResult.reason + ); + } } catch (error) { console.error("useStreamState: Failed to fetch initial data:", error); } @@ -283,6 +298,7 @@ export function useStreamState() { promptData, hardwareInfo, pipelineSchemas, + vaeTypes, updateMetrics, updateStreamStatus, updateSettings, diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 50fa5b811..cccf1cde2 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -425,3 +425,24 @@ export const getPipelineSchemas = const result = await response.json(); return result; }; + +export interface VaeTypesResponse { + vae_types: string[]; +} + +export const getVaeTypes = async (): Promise => { + const response = await fetch("/api/v1/vae/types", { + method: "GET", + headers: { "Content-Type": "application/json" }, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `Get VAE types failed: ${response.status} ${response.statusText}: ${errorText}` + ); + } + + const result = await response.json(); + return result; +}; diff --git a/frontend/src/pages/StreamPage.tsx b/frontend/src/pages/StreamPage.tsx index 0de1ef2e6..bdfafff92 100644 --- a/frontend/src/pages/StreamPage.tsx +++ b/frontend/src/pages/StreamPage.tsx @@ -21,6 +21,7 @@ import type { LoRAConfig, LoraMergeStrategy, DownloadProgress, + VaeType, } from "../types"; import type { PromptItem, PromptTransition } from "../lib/api"; import { checkModelStatus, downloadPipelineModels } from "../lib/api"; @@ -78,6 +79,7 @@ export function StreamPage() { getDefaults, supportsNoiseControls, spoutAvailable, + vaeTypes, } = useStreamState(); // Prompt state - use unified default prompts based on mode @@ -453,6 +455,11 @@ export function StreamPage() { // Note: This setting requires pipeline reload, so we don't send parameter update here }; + const handleVaeTypeChange = (vaeType: VaeType) => { + updateSettings({ vaeType }); + // Note: This setting requires pipeline reload, so we don't send parameter update here + }; + const handleKvCacheAttentionBiasChange = (bias: number) => { updateSettings({ kvCacheAttentionBias: bias }); // Send KV cache attention bias update to backend @@ -725,6 +732,7 @@ export function StreamPage() { if (currentPipeline?.supportsQuantization) { loadParams.seed = settings.seed ?? 42; loadParams.quantization = settings.quantization ?? null; + loadParams.vae_type = settings.vaeType ?? "wan"; } // Add LoRA parameters if pipeline supports LoRA @@ -1117,6 +1125,9 @@ export function StreamPage() { onVaceEnabledChange={handleVaceEnabledChange} vaceContextScale={settings.vaceContextScale ?? 1.0} onVaceContextScaleChange={handleVaceContextScaleChange} + vaeType={settings.vaeType ?? "wan"} + onVaeTypeChange={handleVaeTypeChange} + vaeTypes={vaeTypes} /> diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index ba373389a..56ec47e82 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -4,6 +4,9 @@ export type PipelineId = string; // Input mode for pipeline operation export type InputMode = "text" | "video"; +// VAE type for model selection (dynamic from backend registry) +export type VaeType = string; + // WebRTC ICE server configuration export interface IceServerConfig { urls: string | string[]; @@ -73,6 +76,8 @@ export interface SettingsState { vaceEnabled?: boolean; refImages?: string[]; vaceContextScale?: number; + // VAE type selection + vaeType?: VaeType; } export interface PipelineInfo { diff --git a/src/scope/core/pipelines/krea_realtime_video/pipeline.py b/src/scope/core/pipelines/krea_realtime_video/pipeline.py index 89d8cfa72..efc2bc526 100644 --- a/src/scope/core/pipelines/krea_realtime_video/pipeline.py +++ b/src/scope/core/pipelines/krea_realtime_video/pipeline.py @@ -18,7 +18,7 @@ from ..utils import Quantization, load_model_config, validate_resolution from ..wan2_1.components import WanDiffusionWrapper, WanTextEncoderWrapper from ..wan2_1.lora.mixin import LoRAEnabledPipeline -from ..wan2_1.vae import WanVAEWrapper +from ..wan2_1.vae import create_vae from .modular_blocks import KreaRealtimeVideoBlocks from .schema import KreaRealtimeVideoConfig @@ -128,12 +128,16 @@ def __init__( # Move text encoder to target device but use dtype of weights text_encoder = text_encoder.to(device=device) - # Load vae + # Load VAE using create_vae factory (supports multiple VAE types) + vae_type = getattr(config, "vae_type", "wan") start = time.time() - vae = WanVAEWrapper( - model_name=base_model_name, model_dir=model_dir, vae_path=vae_path + vae = create_vae( + model_dir=model_dir, + model_name=base_model_name, + vae_type=vae_type, + vae_path=vae_path, ) - print(f"Loaded VAE in {time.time() - start:.3f}s") + print(f"Loaded VAE (type={vae_type}) in {time.time() - start:.3f}s") # Move VAE to target device and use target dtype vae = vae.to(device=device, dtype=dtype) diff --git a/src/scope/core/pipelines/longlive/pipeline.py b/src/scope/core/pipelines/longlive/pipeline.py index af0f8091f..823ab1063 100644 --- a/src/scope/core/pipelines/longlive/pipeline.py +++ b/src/scope/core/pipelines/longlive/pipeline.py @@ -20,7 +20,7 @@ from ..wan2_1.lora.mixin import LoRAEnabledPipeline from ..wan2_1.lora.strategies.module_targeted_lora import ModuleTargetedLoRAStrategy from ..wan2_1.vace import VACEEnabledPipeline -from ..wan2_1.vae import WanVAEWrapper +from ..wan2_1.vae import create_vae from .modular_blocks import LongLiveBlocks from .schema import LongLiveConfig @@ -145,10 +145,13 @@ def __init__( # Move text encoder to target device but use dtype of weights text_encoder = text_encoder.to(device=device) - # Load VAE using unified WanVAEWrapper + # Load VAE using create_vae factory (supports multiple VAE types) + vae_type = getattr(config, "vae_type", "wan") start = time.time() - vae = WanVAEWrapper(model_dir=model_dir, model_name=base_model_name) - print(f"Loaded VAE in {time.time() - start:.3f}s") + vae = create_vae( + model_dir=model_dir, model_name=base_model_name, vae_type=vae_type + ) + print(f"Loaded VAE (type={vae_type}) in {time.time() - start:.3f}s") # Move VAE to target device and use target dtype vae = vae.to(device=device, dtype=dtype) diff --git a/src/scope/core/pipelines/wan2_1/vae/__init__.py b/src/scope/core/pipelines/wan2_1/vae/__init__.py index c643470ae..ec2cf4797 100644 --- a/src/scope/core/pipelines/wan2_1/vae/__init__.py +++ b/src/scope/core/pipelines/wan2_1/vae/__init__.py @@ -1,28 +1,33 @@ """Wan2.1 VAE implementations. -This module provides a registry-based factory for VAE instantiation, -supporting multiple VAE types (currently WanVAEWrapper, with LightVAE planned). +This module provides a unified VAE interface through WanVAEWrapper, which supports +both the full WanVAE and the 75% pruned LightVAE via the `use_lightvae` parameter. Usage: from scope.core.pipelines.wan2_1.vae import create_vae - # Default (WanVAEWrapper) + # Default (full WanVAE) vae = create_vae(model_dir="wan_models") # Explicit type (for UI dropdown) vae = create_vae(model_dir="wan_models", vae_type="wan") + # LightVAE (75% pruned, faster but lower quality) + vae = create_vae(model_dir="wan_models", vae_type="lightvae") + # With explicit path override vae = create_vae(model_dir="wan_models", vae_path="/path/to/custom_vae.pth") """ +from functools import partial + from .wan import WanVAEWrapper -# Registry mapping type names to VAE classes +# Registry mapping type names to VAE factory functions # UI dropdowns will use these keys VAE_REGISTRY: dict[str, type] = { "wan": WanVAEWrapper, - # "lightvae": LightVAE, # Future: add when LightVAE is implemented + "lightvae": partial(WanVAEWrapper, use_lightvae=True), } DEFAULT_VAE_TYPE = "wan" @@ -39,27 +44,27 @@ def create_vae( Args: model_dir: Base model directory model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") - vae_type: VAE type from registry. Defaults to "wan". - This will be selectable via UI dropdown. + vae_type: VAE type ("wan" for full VAE, "lightvae" for 75% pruned). + Defaults to "wan". This is selectable via UI dropdown. vae_path: Optional explicit path override. If provided, takes precedence over model_dir/model_name path construction. Returns: - Initialized VAE instance + Initialized WanVAEWrapper instance Raises: - ValueError: If vae_type is not in registry + ValueError: If vae_type is not recognized """ vae_type = vae_type or DEFAULT_VAE_TYPE - vae_cls = VAE_REGISTRY.get(vae_type) - if vae_cls is None: + vae_factory = VAE_REGISTRY.get(vae_type) + if vae_factory is None: available = list(VAE_REGISTRY.keys()) raise ValueError( f"create_vae: Unknown VAE type '{vae_type}'. Available types: {available}" ) - return vae_cls(model_dir=model_dir, model_name=model_name, vae_path=vae_path) + return vae_factory(model_dir=model_dir, model_name=model_name, vae_path=vae_path) def list_vae_types() -> list[str]: diff --git a/src/scope/core/pipelines/wan2_1/vae/modules/vae.py b/src/scope/core/pipelines/wan2_1/vae/modules/vae.py index 8cbab16f2..2f3d1b0a0 100644 --- a/src/scope/core/pipelines/wan2_1/vae/modules/vae.py +++ b/src/scope/core/pipelines/wan2_1/vae/modules/vae.py @@ -305,6 +305,7 @@ def __init__( attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, + pruning_rate=0.0, ): super().__init__() self.dim = dim @@ -314,8 +315,9 @@ def __init__( self.attn_scales = attn_scales self.temperal_downsample = temperal_downsample - # dimensions + # dimensions (apply pruning to reduce channel dimensions) dims = [dim * u for u in [1] + dim_mult] + dims = [int(d * (1 - pruning_rate)) for d in dims] scale = 1.0 # init block @@ -419,6 +421,7 @@ def __init__( attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0, + pruning_rate=0.0, ): super().__init__() self.dim = dim @@ -428,8 +431,9 @@ def __init__( self.attn_scales = attn_scales self.temperal_upsample = temperal_upsample - # dimensions + # dimensions (apply pruning to reduce channel dimensions) dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + dims = [int(d * (1 - pruning_rate)) for d in dims] scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block @@ -544,6 +548,7 @@ def __init__( attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, + pruning_rate=0.0, ): super().__init__() self.dim = dim @@ -563,6 +568,7 @@ def __init__( attn_scales, self.temperal_downsample, dropout, + pruning_rate, ) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) @@ -574,6 +580,7 @@ def __init__( attn_scales, self.temperal_upsample, dropout, + pruning_rate, ) self.first_batch = True @@ -750,9 +757,16 @@ def clear_cache_encode(self): self._enc_feat_map = [None] * self._enc_conv_num -def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): +def _video_vae(pretrained_path=None, z_dim=None, device="cpu", pruning_rate=0.0, **kwargs): """ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + + Args: + pretrained_path: Path to checkpoint file (.pth or .safetensors) + z_dim: Latent dimension + device: Target device + pruning_rate: Channel pruning rate (0.0 = full VAE, 0.75 = LightVAE) + **kwargs: Additional model configuration """ # params cfg = dict( @@ -763,6 +777,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): attn_scales=[], temperal_downsample=[False, True, True], dropout=0.0, + pruning_rate=pruning_rate, ) cfg.update(**kwargs) @@ -770,9 +785,15 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): with torch.device("meta"): model = WanVAE_(**cfg) - # load checkpoint - logging.info(f"loading {pretrained_path}") - model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) + # load checkpoint (supports both .pth and .safetensors) + logging.info(f"_video_vae: loading {pretrained_path}") + if pretrained_path.endswith(".safetensors"): + from safetensors.torch import load_file + + state_dict = load_file(pretrained_path, device=str(device)) + else: + state_dict = torch.load(pretrained_path, map_location=device) + model.load_state_dict(state_dict, assign=True) return model diff --git a/src/scope/core/pipelines/wan2_1/vae/wan.py b/src/scope/core/pipelines/wan2_1/vae/wan.py index 8dc614b60..3bec689e6 100644 --- a/src/scope/core/pipelines/wan2_1/vae/wan.py +++ b/src/scope/core/pipelines/wan2_1/vae/wan.py @@ -1,14 +1,22 @@ -"""Unified Wan VAE wrapper with streaming and batch encoding/decoding.""" +"""Unified Wan VAE wrapper with streaming and batch encoding/decoding. + +This module provides a unified VAE wrapper that supports both the full WanVAE +and the 75% pruned LightVAE through a single interface with a `use_lightvae` parameter. +""" import os import torch from .constants import WAN_VAE_LATENT_MEAN, WAN_VAE_LATENT_STD -from .modules.vae import _video_vae +from .modules.vae import _video_vae, count_conv3d -# Default filename for standard Wan2.1 VAE checkpoint +# Default filenames for VAE checkpoints DEFAULT_VAE_FILENAME = "Wan2.1_VAE.pth" +LIGHTVAE_FILENAME = "lightvaew2_1.pth" + +# LightVAE pruning rate (75% of channels pruned) +LIGHTVAE_PRUNING_RATE = 0.75 class WanVAEWrapper(torch.nn.Module): @@ -16,6 +24,15 @@ class WanVAEWrapper(torch.nn.Module): This VAE supports both streaming (cached) and batch encoding/decoding modes. Normalization is always applied during encoding for consistent latent distributions. + + The wrapper can instantiate either the full WanVAE or the 75% pruned LightVAE + through the `use_lightvae` parameter. + + Args: + model_dir: Base directory containing model files + model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") + vae_path: Explicit path to VAE checkpoint (overrides model_dir/model_name) + use_lightvae: If True, use 75% pruned LightVAE (faster, lower quality) """ def __init__( @@ -23,12 +40,19 @@ def __init__( model_dir: str = "wan_models", model_name: str = "Wan2.1-T2V-1.3B", vae_path: str | None = None, + use_lightvae: bool = False, ): super().__init__() + # Determine pruning rate based on VAE type + pruning_rate = LIGHTVAE_PRUNING_RATE if use_lightvae else 0.0 + # Determine paths with priority: explicit vae_path > model_dir/model_name default if vae_path is None: - vae_path = os.path.join(model_dir, model_name, DEFAULT_VAE_FILENAME) + default_filename = ( + LIGHTVAE_FILENAME if use_lightvae else DEFAULT_VAE_FILENAME + ) + vae_path = os.path.join(model_dir, model_name, default_filename) self.register_buffer( "mean", torch.tensor(WAN_VAE_LATENT_MEAN, dtype=torch.float32) @@ -42,11 +66,15 @@ def __init__( _video_vae( pretrained_path=vae_path, z_dim=self.z_dim, + pruning_rate=pruning_rate, ) .eval() .requires_grad_(False) ) + # Cache encoder conv count for dynamic cache sizing + self._encoder_conv_count = count_conv3d(self.model.encoder) + def _get_scale(self, device: torch.device, dtype: torch.dtype) -> list: """Get normalization scale parameters on the correct device/dtype.""" return [ @@ -65,8 +93,8 @@ def _apply_encoding_normalization( return (latent - scale[0]) * scale[1] def _create_encoder_cache(self) -> list: - """Create a fresh encoder feature cache.""" - return [None] * 55 + """Create a fresh encoder feature cache with dynamic sizing.""" + return [None] * self._encoder_conv_count def encode_to_latent( self, diff --git a/src/scope/server/app.py b/src/scope/server/app.py index 6b79655ad..7a73abf19 100644 --- a/src/scope/server/app.py +++ b/src/scope/server/app.py @@ -52,6 +52,7 @@ PipelineLoadRequest, PipelineSchemasResponse, PipelineStatusResponse, + VaeTypesResponse, WebRTCOfferRequest, WebRTCOfferResponse, ) @@ -786,6 +787,14 @@ async def get_hardware_info(): raise HTTPException(status_code=500, detail=str(e)) from e +@app.get("/api/v1/vae/types", response_model=VaeTypesResponse) +async def get_vae_types(): + """Get available VAE types from the registry.""" + from scope.core.pipelines.wan2_1.vae import list_vae_types + + return VaeTypesResponse(vae_types=list_vae_types()) + + @app.get("/api/v1/logs/current") async def get_current_logs(): """Get the most recent application log file for bug reporting.""" diff --git a/src/scope/server/pipeline_manager.py b/src/scope/server/pipeline_manager.py index ef8873ec7..7137ea7d4 100644 --- a/src/scope/server/pipeline_manager.py +++ b/src/scope/server/pipeline_manager.py @@ -262,11 +262,11 @@ def _apply_load_params( default_width: int, default_seed: int = 42, ) -> None: - """Extract and apply common load parameters (resolution, seed, LoRAs) to config. + """Extract and apply common load parameters (resolution, seed, LoRAs, VAE type) to config. Args: config: Pipeline config dict to update - load_params: Load parameters dict (may contain height, width, seed, loras, lora_merge_mode) + load_params: Load parameters dict (may contain height, width, seed, loras, lora_merge_mode, vae_type) default_height: Default height if not in load_params default_width: Default width if not in load_params default_seed: Default seed if not in load_params @@ -276,6 +276,7 @@ def _apply_load_params( seed = default_seed loras = None lora_merge_mode = "permanent_merge" + vae_type = "wan" # Default VAE type if load_params: height = load_params.get("height", default_height) @@ -283,10 +284,12 @@ def _apply_load_params( seed = load_params.get("seed", default_seed) loras = load_params.get("loras", None) lora_merge_mode = load_params.get("lora_merge_mode", lora_merge_mode) + vae_type = load_params.get("vae_type", vae_type) config["height"] = height config["width"] = width config["seed"] = seed + config["vae_type"] = vae_type if loras: config["loras"] = loras # Pass merge_mode directly to mixin, not via config diff --git a/src/scope/server/schema.py b/src/scope/server/schema.py index 6cb6ed28e..aa937e8be 100644 --- a/src/scope/server/schema.py +++ b/src/scope/server/schema.py @@ -9,6 +9,10 @@ from scope.core.pipelines.longlive.schema import LongLiveConfig from scope.core.pipelines.streamdiffusionv2.schema import StreamDiffusionV2Config from scope.core.pipelines.utils import Quantization +from scope.core.pipelines.wan2_1.vae import DEFAULT_VAE_TYPE + +# VAE type literal based on available VAE types +VaeType = Literal["wan", "lightvae"] class HealthResponse(BaseModel): @@ -18,6 +22,12 @@ class HealthResponse(BaseModel): timestamp: str +class VaeTypesResponse(BaseModel): + """Response containing available VAE types from the registry.""" + + vae_types: list[str] + + class PromptItem(BaseModel): """Individual prompt with weight for blending.""" @@ -306,6 +316,10 @@ class StreamDiffusionV2LoadParams(LoRAEnabledLoadParams): default=True, description="Enable VACE (Video All-In-One Creation and Editing) support for reference image conditioning and structural guidance. When enabled, incoming video in V2V mode is routed to VACE for conditioning. When disabled, V2V uses faster regular encoding.", ) + vae_type: VaeType = Field( + default=DEFAULT_VAE_TYPE, + description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).", + ) class PassthroughLoadParams(PipelineLoadParams): @@ -345,6 +359,10 @@ class LongLiveLoadParams(LoRAEnabledLoadParams): default=True, description="Enable VACE (Video All-In-One Creation and Editing) support for reference image conditioning and structural guidance. When enabled, incoming video in V2V mode is routed to VACE for conditioning. When disabled, V2V uses faster regular encoding.", ) + vae_type: VaeType = Field( + default=DEFAULT_VAE_TYPE, + description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).", + ) class KreaRealtimeVideoLoadParams(LoRAEnabledLoadParams): @@ -374,6 +392,10 @@ class KreaRealtimeVideoLoadParams(LoRAEnabledLoadParams): default=Quantization.FP8_E4M3FN, description="Quantization method to use for diffusion model. If None, no quantization is applied.", ) + vae_type: VaeType = Field( + default=DEFAULT_VAE_TYPE, + description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).", + ) class PipelineLoadRequest(BaseModel): From a0700f0a72bd20706c60d3984d7176a22c6fcdb2 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Fri, 2 Jan 2026 15:57:47 -0500 Subject: [PATCH 2/4] refactor: vae to pipeline schema, vae dropdown moved, proper attribution Signed-off-by: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> --- frontend/src/components/SettingsPanel.tsx | 57 ++++++++++--------- frontend/src/hooks/usePipelines.ts | 2 + frontend/src/hooks/useStreamState.ts | 26 ++------- frontend/src/lib/api.ts | 24 +------- frontend/src/pages/StreamPage.tsx | 3 +- frontend/src/types/index.ts | 3 + src/scope/core/pipelines/base_schema.py | 8 +++ .../pipelines/krea_realtime_video/schema.py | 1 + src/scope/core/pipelines/longlive/schema.py | 1 + .../pipelines/streamdiffusionv2/schema.py | 1 + src/scope/core/pipelines/utils.py | 7 +++ .../core/pipelines/wan2_1/vae/__init__.py | 6 -- .../core/pipelines/wan2_1/vae/modules/vae.py | 1 + src/scope/server/app.py | 9 --- src/scope/server/schema.py | 11 +--- 15 files changed, 65 insertions(+), 95 deletions(-) diff --git a/frontend/src/components/SettingsPanel.tsx b/frontend/src/components/SettingsPanel.tsx index 97ea92574..82f0ee778 100644 --- a/frontend/src/components/SettingsPanel.tsx +++ b/frontend/src/components/SettingsPanel.tsx @@ -398,6 +398,37 @@ export function SettingsPanel({ )} + {/* VAE Type Selection */} + {pipelines?.[pipelineId]?.supportsVaeType && ( +
+
+ + +
+
+ )} + {currentPipeline?.supportsLoRA && (
- -
- - -
diff --git a/frontend/src/hooks/usePipelines.ts b/frontend/src/hooks/usePipelines.ts index e173a6d43..06a3e2799 100644 --- a/frontend/src/hooks/usePipelines.ts +++ b/frontend/src/hooks/usePipelines.ts @@ -41,10 +41,12 @@ export function usePipelines() { supportsCacheManagement: schema.supports_cache_management, supportsKvCacheBias: schema.supports_kv_cache_bias, supportsQuantization: schema.supports_quantization, + supportsVaeType: schema.supports_vae_type, minDimension: schema.min_dimension, recommendedQuantizationVramThreshold: schema.recommended_quantization_vram_threshold ?? undefined, modified: schema.modified, + vaeTypes: schema.vae_types ?? undefined, }; } diff --git a/frontend/src/hooks/useStreamState.ts b/frontend/src/hooks/useStreamState.ts index 0fbe5e991..5fa290775 100644 --- a/frontend/src/hooks/useStreamState.ts +++ b/frontend/src/hooks/useStreamState.ts @@ -10,7 +10,6 @@ import type { import { getHardwareInfo, getPipelineSchemas, - getVaeTypes, type HardwareInfoResponse, type PipelineSchemasResponse, } from "../lib/api"; @@ -161,19 +160,14 @@ export function useStreamState() { null ); - // Store available VAE types from registry - const [vaeTypes, setVaeTypes] = useState(["wan"]); - - // Fetch pipeline schemas, hardware info, and VAE types on mount + // Fetch pipeline schemas and hardware info on mount useEffect(() => { const fetchInitialData = async () => { try { - const [schemasResult, hardwareResult, vaeTypesResult] = - await Promise.allSettled([ - getPipelineSchemas(), - getHardwareInfo(), - getVaeTypes(), - ]); + const [schemasResult, hardwareResult] = await Promise.allSettled([ + getPipelineSchemas(), + getHardwareInfo(), + ]); if (schemasResult.status === "fulfilled") { const schemas = schemasResult.value; @@ -211,15 +205,6 @@ export function useStreamState() { hardwareResult.reason ); } - - if (vaeTypesResult.status === "fulfilled") { - setVaeTypes(vaeTypesResult.value.vae_types); - } else { - console.error( - "useStreamState: Failed to fetch VAE types:", - vaeTypesResult.reason - ); - } } catch (error) { console.error("useStreamState: Failed to fetch initial data:", error); } @@ -298,7 +283,6 @@ export function useStreamState() { promptData, hardwareInfo, pipelineSchemas, - vaeTypes, updateMetrics, updateStreamStatus, updateSettings, diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index cccf1cde2..2f212dfa2 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -399,9 +399,12 @@ export interface PipelineSchemaInfo { supports_cache_management: boolean; supports_kv_cache_bias: boolean; supports_quantization: boolean; + supports_vae_type: boolean; min_dimension: number; recommended_quantization_vram_threshold: number | null; modified: boolean; + // Available VAE types from config schema enum + vae_types?: string[]; } export interface PipelineSchemasResponse { @@ -425,24 +428,3 @@ export const getPipelineSchemas = const result = await response.json(); return result; }; - -export interface VaeTypesResponse { - vae_types: string[]; -} - -export const getVaeTypes = async (): Promise => { - const response = await fetch("/api/v1/vae/types", { - method: "GET", - headers: { "Content-Type": "application/json" }, - }); - - if (!response.ok) { - const errorText = await response.text(); - throw new Error( - `Get VAE types failed: ${response.status} ${response.statusText}: ${errorText}` - ); - } - - const result = await response.json(); - return result; -}; diff --git a/frontend/src/pages/StreamPage.tsx b/frontend/src/pages/StreamPage.tsx index bdfafff92..695ec89b3 100644 --- a/frontend/src/pages/StreamPage.tsx +++ b/frontend/src/pages/StreamPage.tsx @@ -79,7 +79,6 @@ export function StreamPage() { getDefaults, supportsNoiseControls, spoutAvailable, - vaeTypes, } = useStreamState(); // Prompt state - use unified default prompts based on mode @@ -1127,7 +1126,7 @@ export function StreamPage() { onVaceContextScaleChange={handleVaceContextScaleChange} vaeType={settings.vaeType ?? "wan"} onVaeTypeChange={handleVaeTypeChange} - vaeTypes={vaeTypes} + vaeTypes={pipelines?.[settings.pipelineId]?.vaeTypes ?? ["wan"]} /> diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 56ec47e82..23baae1ce 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -103,8 +103,11 @@ export interface PipelineInfo { supportsCacheManagement?: boolean; supportsKvCacheBias?: boolean; supportsQuantization?: boolean; + supportsVaeType?: boolean; minDimension?: number; recommendedQuantizationVramThreshold?: number | null; + // Available VAE types from config schema enum + vaeTypes?: string[]; } export interface DownloadProgress { diff --git a/src/scope/core/pipelines/base_schema.py b/src/scope/core/pipelines/base_schema.py index e06c52dd0..6bd3bcd3e 100644 --- a/src/scope/core/pipelines/base_schema.py +++ b/src/scope/core/pipelines/base_schema.py @@ -160,6 +160,7 @@ class BasePipelineConfig(BaseModel): supports_cache_management: ClassVar[bool] = False supports_kv_cache_bias: ClassVar[bool] = False supports_quantization: ClassVar[bool] = False + supports_vae_type: ClassVar[bool] = False min_dimension: ClassVar[int] = 1 # Whether this pipeline contains modifications based on the original project modified: ClassVar[bool] = False @@ -288,6 +289,7 @@ def get_schema_with_metadata(cls) -> dict[str, Any]: metadata["supports_cache_management"] = cls.supports_cache_management metadata["supports_kv_cache_bias"] = cls.supports_kv_cache_bias metadata["supports_quantization"] = cls.supports_quantization + metadata["supports_vae_type"] = cls.supports_vae_type metadata["min_dimension"] = cls.min_dimension metadata["recommended_quantization_vram_threshold"] = ( cls.recommended_quantization_vram_threshold @@ -295,6 +297,12 @@ def get_schema_with_metadata(cls) -> dict[str, Any]: metadata["modified"] = cls.modified metadata["config_schema"] = cls.model_json_schema() + # Extract VAE types from schema if pipeline supports it + if cls.supports_vae_type: + from .utils import VaeType + + metadata["vae_types"] = [vae_type.value for vae_type in VaeType] + # Include mode-specific defaults (excluding None values and the "default" flag) mode_defaults = {} for mode_name, mode_config in cls.modes.items(): diff --git a/src/scope/core/pipelines/krea_realtime_video/schema.py b/src/scope/core/pipelines/krea_realtime_video/schema.py index 6f3582c44..2dd6b627c 100644 --- a/src/scope/core/pipelines/krea_realtime_video/schema.py +++ b/src/scope/core/pipelines/krea_realtime_video/schema.py @@ -16,6 +16,7 @@ class KreaRealtimeVideoConfig(BasePipelineConfig): supports_cache_management = True supports_kv_cache_bias = True supports_quantization = True + supports_vae_type = True min_dimension = 16 modified = True recommended_quantization_vram_threshold = 40.0 diff --git a/src/scope/core/pipelines/longlive/schema.py b/src/scope/core/pipelines/longlive/schema.py index 609673481..3605c76ca 100644 --- a/src/scope/core/pipelines/longlive/schema.py +++ b/src/scope/core/pipelines/longlive/schema.py @@ -17,6 +17,7 @@ class LongLiveConfig(BasePipelineConfig): supports_cache_management = True supports_quantization = True + supports_vae_type = True min_dimension = 16 modified = True diff --git a/src/scope/core/pipelines/streamdiffusionv2/schema.py b/src/scope/core/pipelines/streamdiffusionv2/schema.py index d08e6de7b..2d797ea4b 100644 --- a/src/scope/core/pipelines/streamdiffusionv2/schema.py +++ b/src/scope/core/pipelines/streamdiffusionv2/schema.py @@ -17,6 +17,7 @@ class StreamDiffusionV2Config(BasePipelineConfig): supports_cache_management = True supports_quantization = True + supports_vae_type = True min_dimension = 16 modified = True diff --git a/src/scope/core/pipelines/utils.py b/src/scope/core/pipelines/utils.py index 7b3ed741b..f9978b8e7 100644 --- a/src/scope/core/pipelines/utils.py +++ b/src/scope/core/pipelines/utils.py @@ -14,6 +14,13 @@ class Quantization(str, Enum): FP8_E4M3FN = "fp8_e4m3fn" +class VaeType(str, Enum): + """VAE type enumeration.""" + + WAN = "wan" + LIGHTVAE = "lightvae" + + def load_state_dict(weights_path: str) -> dict: """Load weights with automatic format detection.""" if not os.path.exists(weights_path): diff --git a/src/scope/core/pipelines/wan2_1/vae/__init__.py b/src/scope/core/pipelines/wan2_1/vae/__init__.py index ec2cf4797..50bdb16e7 100644 --- a/src/scope/core/pipelines/wan2_1/vae/__init__.py +++ b/src/scope/core/pipelines/wan2_1/vae/__init__.py @@ -67,15 +67,9 @@ def create_vae( return vae_factory(model_dir=model_dir, model_name=model_name, vae_path=vae_path) -def list_vae_types() -> list[str]: - """Return list of available VAE types for UI dropdowns.""" - return list(VAE_REGISTRY.keys()) - - __all__ = [ "WanVAEWrapper", "create_vae", - "list_vae_types", "VAE_REGISTRY", "DEFAULT_VAE_TYPE", ] diff --git a/src/scope/core/pipelines/wan2_1/vae/modules/vae.py b/src/scope/core/pipelines/wan2_1/vae/modules/vae.py index 2f3d1b0a0..05ebcad44 100644 --- a/src/scope/core/pipelines/wan2_1/vae/modules/vae.py +++ b/src/scope/core/pipelines/wan2_1/vae/modules/vae.py @@ -1,4 +1,5 @@ # Modified from https://github.com/chenfengxu714/StreamdiffusionV2 +# Pruning functionality adapted from https://github.com/ModelTC/LightX2V # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging diff --git a/src/scope/server/app.py b/src/scope/server/app.py index 7a73abf19..6b79655ad 100644 --- a/src/scope/server/app.py +++ b/src/scope/server/app.py @@ -52,7 +52,6 @@ PipelineLoadRequest, PipelineSchemasResponse, PipelineStatusResponse, - VaeTypesResponse, WebRTCOfferRequest, WebRTCOfferResponse, ) @@ -787,14 +786,6 @@ async def get_hardware_info(): raise HTTPException(status_code=500, detail=str(e)) from e -@app.get("/api/v1/vae/types", response_model=VaeTypesResponse) -async def get_vae_types(): - """Get available VAE types from the registry.""" - from scope.core.pipelines.wan2_1.vae import list_vae_types - - return VaeTypesResponse(vae_types=list_vae_types()) - - @app.get("/api/v1/logs/current") async def get_current_logs(): """Get the most recent application log file for bug reporting.""" diff --git a/src/scope/server/schema.py b/src/scope/server/schema.py index aa937e8be..532853261 100644 --- a/src/scope/server/schema.py +++ b/src/scope/server/schema.py @@ -8,12 +8,9 @@ from scope.core.pipelines.krea_realtime_video.schema import KreaRealtimeVideoConfig from scope.core.pipelines.longlive.schema import LongLiveConfig from scope.core.pipelines.streamdiffusionv2.schema import StreamDiffusionV2Config -from scope.core.pipelines.utils import Quantization +from scope.core.pipelines.utils import Quantization, VaeType from scope.core.pipelines.wan2_1.vae import DEFAULT_VAE_TYPE -# VAE type literal based on available VAE types -VaeType = Literal["wan", "lightvae"] - class HealthResponse(BaseModel): """Health check response schema.""" @@ -22,12 +19,6 @@ class HealthResponse(BaseModel): timestamp: str -class VaeTypesResponse(BaseModel): - """Response containing available VAE types from the registry.""" - - vae_types: list[str] - - class PromptItem(BaseModel): """Individual prompt with weight for blending.""" From d9d5a267770332ca6aafc013ad6c95a74d0e6f89 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Mon, 5 Jan 2026 07:48:48 -0500 Subject: [PATCH 3/4] Refactor VAE type to use Pydantic enum field Signed-off-by: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> --- frontend/src/components/SettingsPanel.tsx | 2 +- frontend/src/hooks/usePipelines.ts | 16 ++++++++++++++-- frontend/src/lib/api.ts | 6 +++--- frontend/src/types/index.ts | 3 +-- src/scope/core/pipelines/base_schema.py | 8 -------- .../core/pipelines/krea_realtime_video/schema.py | 8 +++++++- src/scope/core/pipelines/longlive/schema.py | 8 +++++++- .../core/pipelines/streamdiffusionv2/schema.py | 8 +++++++- 8 files changed, 40 insertions(+), 19 deletions(-) diff --git a/frontend/src/components/SettingsPanel.tsx b/frontend/src/components/SettingsPanel.tsx index 82f0ee778..b08cead90 100644 --- a/frontend/src/components/SettingsPanel.tsx +++ b/frontend/src/components/SettingsPanel.tsx @@ -399,7 +399,7 @@ export function SettingsPanel({ )} {/* VAE Type Selection */} - {pipelines?.[pipelineId]?.supportsVaeType && ( + {vaeTypes && vaeTypes.length > 0 && (
= {}; for (const [id, schema] of Object.entries(schemas.pipelines)) { + // Extract VAE types from JSON schema if vae_type field exists + // Pydantic v2 represents enum fields using $ref to definitions + let vaeTypes: string[] | undefined = undefined; + const vaeTypeProperty = schema.config_schema?.properties?.vae_type; + if (vaeTypeProperty?.$ref && schema.config_schema?.$defs) { + const refPath = vaeTypeProperty.$ref; + const defName = refPath.split("/").pop(); + const definition = schema.config_schema.$defs[defName || ""]; + if (definition && Array.isArray(definition.enum)) { + vaeTypes = definition.enum as string[]; + } + } + transformed[id] = { name: schema.name, about: schema.description, @@ -41,12 +54,11 @@ export function usePipelines() { supportsCacheManagement: schema.supports_cache_management, supportsKvCacheBias: schema.supports_kv_cache_bias, supportsQuantization: schema.supports_quantization, - supportsVaeType: schema.supports_vae_type, minDimension: schema.min_dimension, recommendedQuantizationVramThreshold: schema.recommended_quantization_vram_threshold ?? undefined, modified: schema.modified, - vaeTypes: schema.vae_types ?? undefined, + vaeTypes, }; } diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 2f212dfa2..3d93d15c7 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -356,6 +356,8 @@ export interface PipelineSchemaProperty { maximum?: number; items?: unknown; anyOf?: unknown[]; + enum?: unknown[]; + $ref?: string; } export interface PipelineConfigSchema { @@ -363,6 +365,7 @@ export interface PipelineConfigSchema { properties: Record; required?: string[]; title?: string; + $defs?: Record; } // Mode-specific default overrides @@ -399,12 +402,9 @@ export interface PipelineSchemaInfo { supports_cache_management: boolean; supports_kv_cache_bias: boolean; supports_quantization: boolean; - supports_vae_type: boolean; min_dimension: number; recommended_quantization_vram_threshold: number | null; modified: boolean; - // Available VAE types from config schema enum - vae_types?: string[]; } export interface PipelineSchemasResponse { diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 23baae1ce..b269a0f70 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -103,10 +103,9 @@ export interface PipelineInfo { supportsCacheManagement?: boolean; supportsKvCacheBias?: boolean; supportsQuantization?: boolean; - supportsVaeType?: boolean; minDimension?: number; recommendedQuantizationVramThreshold?: number | null; - // Available VAE types from config schema enum + // Available VAE types from config schema enum (derived from vae_type field presence) vaeTypes?: string[]; } diff --git a/src/scope/core/pipelines/base_schema.py b/src/scope/core/pipelines/base_schema.py index 6bd3bcd3e..e06c52dd0 100644 --- a/src/scope/core/pipelines/base_schema.py +++ b/src/scope/core/pipelines/base_schema.py @@ -160,7 +160,6 @@ class BasePipelineConfig(BaseModel): supports_cache_management: ClassVar[bool] = False supports_kv_cache_bias: ClassVar[bool] = False supports_quantization: ClassVar[bool] = False - supports_vae_type: ClassVar[bool] = False min_dimension: ClassVar[int] = 1 # Whether this pipeline contains modifications based on the original project modified: ClassVar[bool] = False @@ -289,7 +288,6 @@ def get_schema_with_metadata(cls) -> dict[str, Any]: metadata["supports_cache_management"] = cls.supports_cache_management metadata["supports_kv_cache_bias"] = cls.supports_kv_cache_bias metadata["supports_quantization"] = cls.supports_quantization - metadata["supports_vae_type"] = cls.supports_vae_type metadata["min_dimension"] = cls.min_dimension metadata["recommended_quantization_vram_threshold"] = ( cls.recommended_quantization_vram_threshold @@ -297,12 +295,6 @@ def get_schema_with_metadata(cls) -> dict[str, Any]: metadata["modified"] = cls.modified metadata["config_schema"] = cls.model_json_schema() - # Extract VAE types from schema if pipeline supports it - if cls.supports_vae_type: - from .utils import VaeType - - metadata["vae_types"] = [vae_type.value for vae_type in VaeType] - # Include mode-specific defaults (excluding None values and the "default" flag) mode_defaults = {} for mode_name, mode_config in cls.modes.items(): diff --git a/src/scope/core/pipelines/krea_realtime_video/schema.py b/src/scope/core/pipelines/krea_realtime_video/schema.py index 2dd6b627c..8fdfcf058 100644 --- a/src/scope/core/pipelines/krea_realtime_video/schema.py +++ b/src/scope/core/pipelines/krea_realtime_video/schema.py @@ -1,4 +1,7 @@ +from pydantic import Field + from ..base_schema import BasePipelineConfig, ModeDefaults +from ..utils import VaeType class KreaRealtimeVideoConfig(BasePipelineConfig): @@ -16,7 +19,6 @@ class KreaRealtimeVideoConfig(BasePipelineConfig): supports_cache_management = True supports_kv_cache_bias = True supports_quantization = True - supports_vae_type = True min_dimension = 16 modified = True recommended_quantization_vram_threshold = 40.0 @@ -27,6 +29,10 @@ class KreaRealtimeVideoConfig(BasePipelineConfig): height: int = 320 width: int = 576 denoising_steps: list[int] = [1000, 750, 500, 250] + vae_type: VaeType = Field( + default=VaeType.WAN, + description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).", + ) modes = { "text": ModeDefaults(default=True), diff --git a/src/scope/core/pipelines/longlive/schema.py b/src/scope/core/pipelines/longlive/schema.py index 3605c76ca..c70c1e96a 100644 --- a/src/scope/core/pipelines/longlive/schema.py +++ b/src/scope/core/pipelines/longlive/schema.py @@ -1,4 +1,7 @@ +from pydantic import Field + from ..base_schema import BasePipelineConfig, ModeDefaults +from ..utils import VaeType class LongLiveConfig(BasePipelineConfig): @@ -17,13 +20,16 @@ class LongLiveConfig(BasePipelineConfig): supports_cache_management = True supports_quantization = True - supports_vae_type = True min_dimension = 16 modified = True height: int = 320 width: int = 576 denoising_steps: list[int] = [1000, 750, 500, 250] + vae_type: VaeType = Field( + default=VaeType.WAN, + description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).", + ) modes = { "text": ModeDefaults(default=True), diff --git a/src/scope/core/pipelines/streamdiffusionv2/schema.py b/src/scope/core/pipelines/streamdiffusionv2/schema.py index 2d797ea4b..a356578e9 100644 --- a/src/scope/core/pipelines/streamdiffusionv2/schema.py +++ b/src/scope/core/pipelines/streamdiffusionv2/schema.py @@ -1,4 +1,7 @@ +from pydantic import Field + from ..base_schema import BasePipelineConfig, ModeDefaults +from ..utils import VaeType class StreamDiffusionV2Config(BasePipelineConfig): @@ -17,7 +20,6 @@ class StreamDiffusionV2Config(BasePipelineConfig): supports_cache_management = True supports_quantization = True - supports_vae_type = True min_dimension = 16 modified = True @@ -25,6 +27,10 @@ class StreamDiffusionV2Config(BasePipelineConfig): noise_scale: float = 0.7 noise_controller: bool = True input_size: int = 4 + vae_type: VaeType = Field( + default=VaeType.WAN, + description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).", + ) modes = { "text": ModeDefaults( From 70cbfbb9cad6c09a2fbaa3eb4be555945528e55e Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:03:19 -0500 Subject: [PATCH 4/4] feat: tae Signed-off-by: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> --- frontend/src/data/parameterMetadata.ts | 2 +- src/scope/core/pipelines/utils.py | 1 + .../core/pipelines/wan2_1/vae/__init__.py | 5 +- src/scope/core/pipelines/wan2_1/vae/tae.py | 626 ++++++++++++++++++ 4 files changed, 632 insertions(+), 2 deletions(-) create mode 100644 src/scope/core/pipelines/wan2_1/vae/tae.py diff --git a/frontend/src/data/parameterMetadata.ts b/frontend/src/data/parameterMetadata.ts index 1d875f2fd..402821b0a 100644 --- a/frontend/src/data/parameterMetadata.ts +++ b/frontend/src/data/parameterMetadata.ts @@ -85,6 +85,6 @@ export const PARAMETER_METADATA: Record = { vaeType: { label: "VAE:", tooltip: - "VAE type to use for encoding/decoding. 'wan' is the full VAE with best quality. 'lightvae' is 75% pruned for faster performance but lower quality.", + "VAE type to use for encoding/decoding. 'wan' is the full VAE with best quality. 'lightvae' is 75% pruned for faster performance but lower quality. 'tae' is a tiny autoencoder for fast preview quality.", }, }; diff --git a/src/scope/core/pipelines/utils.py b/src/scope/core/pipelines/utils.py index f9978b8e7..d2ffbe103 100644 --- a/src/scope/core/pipelines/utils.py +++ b/src/scope/core/pipelines/utils.py @@ -19,6 +19,7 @@ class VaeType(str, Enum): WAN = "wan" LIGHTVAE = "lightvae" + TAE = "tae" def load_state_dict(weights_path: str) -> dict: diff --git a/src/scope/core/pipelines/wan2_1/vae/__init__.py b/src/scope/core/pipelines/wan2_1/vae/__init__.py index 50bdb16e7..fae101e7e 100644 --- a/src/scope/core/pipelines/wan2_1/vae/__init__.py +++ b/src/scope/core/pipelines/wan2_1/vae/__init__.py @@ -21,6 +21,7 @@ from functools import partial +from .tae import TAEWrapper from .wan import WanVAEWrapper # Registry mapping type names to VAE factory functions @@ -28,6 +29,7 @@ VAE_REGISTRY: dict[str, type] = { "wan": WanVAEWrapper, "lightvae": partial(WanVAEWrapper, use_lightvae=True), + "tae": TAEWrapper, } DEFAULT_VAE_TYPE = "wan" @@ -38,7 +40,7 @@ def create_vae( model_name: str = "Wan2.1-T2V-1.3B", vae_type: str | None = None, vae_path: str | None = None, -) -> WanVAEWrapper: +) -> WanVAEWrapper | TAEWrapper: """Create VAE instance by type. Args: @@ -69,6 +71,7 @@ def create_vae( __all__ = [ "WanVAEWrapper", + "TAEWrapper", "create_vae", "VAE_REGISTRY", "DEFAULT_VAE_TYPE", diff --git a/src/scope/core/pipelines/wan2_1/vae/tae.py b/src/scope/core/pipelines/wan2_1/vae/tae.py new file mode 100644 index 000000000..ff1f2aa87 --- /dev/null +++ b/src/scope/core/pipelines/wan2_1/vae/tae.py @@ -0,0 +1,626 @@ +# Adapted from https://github.com/ModelTC/LightX2V/blob/main/lightx2v/models/video_encoders/hf/tae.py +"""Tiny AutoEncoder (TAE) wrapper for Wan2.1 models. + +TAE is a lightweight alternative VAE architecture from the LightX2V project. +Unlike WanVAE, TAE is a completely different architecture - a much smaller/faster +model designed for quick encoding/decoding previews. + +Key differences from WanVAE: +- Uses MemBlock for temporal memory (different from CausalConv3d caching) +- Has TPool/TGrow blocks for temporal downsampling/upsampling +- Much simpler architecture with 64 channels throughout encoder +- Approximately 4x temporal upscaling in decoder (TGrow blocks expand frames) + +Streaming mode: +- TAE supports streaming decode via parallel processing with persistent MemBlock memory +- Each batch is processed in parallel (fast) while memory state is maintained across batches +- This provides both speed AND temporal continuity for smooth streaming +- First decode call has fewer output frames due to TGrow expansion and frame trimming (3 frames) +""" + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import load_file + +# Note: TAE does NOT use WAN_VAE_LATENT_MEAN/STD - it has its own latent space + +# Default checkpoint filename for Wan 2.1 TAE +DEFAULT_TAE_FILENAME = "taew2_1.pth" + + +def _conv(n_in: int, n_out: int, **kwargs) -> nn.Conv2d: + """Create a 3x3 Conv2d with padding.""" + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class _Clamp(nn.Module): + """Clamp activation using scaled tanh.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.tanh(x / 3) * 3 + + +class _MemBlock(nn.Module): + """Memory block that combines current input with past state.""" + + def __init__(self, n_in: int, n_out: int, act_func: nn.Module): + super().__init__() + self.conv = nn.Sequential( + _conv(n_in * 2, n_out), + act_func, + _conv(n_out, n_out), + act_func, + _conv(n_out, n_out), + ) + self.skip = ( + nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + ) + self.act = act_func + + def forward(self, x: torch.Tensor, past: torch.Tensor) -> torch.Tensor: + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + + +class _TPool(nn.Module): + """Temporal pooling block that combines multiple frames.""" + + def __init__(self, n_f: int, stride: int): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + + +class _TGrow(nn.Module): + """Temporal growth block that expands to multiple frames.""" + + def __init__(self, n_f: int, stride: int): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + + +def _apply_model_parallel_streaming( + model: nn.Sequential, + x: torch.Tensor, + N: int, + initial_mem: list[torch.Tensor | None] | None = None, +) -> tuple[torch.Tensor, list[torch.Tensor | None]]: + """Apply model in parallel mode with streaming memory support. + + This processes all frames in parallel (fast) while maintaining temporal + continuity across batches by using initial memory from the previous batch. + + Args: + model: nn.Sequential of blocks to apply + x: input data reshaped to (N*T, C, H, W) + N: batch size (for reshaping) + initial_mem: Initial memory values for each MemBlock (from previous batch). + If None, uses zeros for first batch. + + Returns: + Tuple of (NTCHW output tensor, list of final memory values for next batch) + """ + # Count MemBlocks for memory initialization + num_memblocks = sum(1 for b in model if isinstance(b, _MemBlock)) + + # Initialize memory list if not provided + if initial_mem is None: + initial_mem = [None] * num_memblocks + + # Track which MemBlock we're at + mem_idx = 0 + final_mem = [] + + for b in model: + if isinstance(b, _MemBlock): + NT, C, H, W = x.shape + T = NT // N + _x = x.reshape(N, T, C, H, W) + + # Create memory: pad with initial_mem at t=0, then shift frames + if initial_mem[mem_idx] is not None: + # Use previous batch's last frame as initial memory + init_frame = initial_mem[mem_idx].reshape(N, 1, C, H, W) + mem = torch.cat([init_frame, _x[:, :-1]], dim=1).reshape(x.shape) + else: + # First batch - use zeros + mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape( + x.shape + ) + + # Save last frame for next batch (input before processing) + final_mem.append(_x[:, -1:].reshape(N, C, H, W).clone()) + mem_idx += 1 + + x = b(x, mem) + else: + x = b(x) + + NT, C, H, W = x.shape + T = NT // N + return x.view(N, T, C, H, W), final_mem + + +def _apply_model_with_memblocks( + model: nn.Sequential, + x: torch.Tensor, + parallel: bool = True, + show_progress_bar: bool = False, +) -> torch.Tensor: + """Apply a sequential model with memblocks to the given input (batch mode). + + Args: + model: nn.Sequential of blocks to apply + x: input data, of dimensions NTCHW + parallel: unused, kept for API compatibility (always uses parallel) + show_progress_bar: unused, kept for API compatibility + + Returns: + NTCHW tensor of output data. + """ + assert x.ndim == 5, f"_apply_model_with_memblocks: TAE expects NTCHW, got {x.ndim}D" + N, T, C, H, W = x.shape + x = x.reshape(N * T, C, H, W) + result, _ = _apply_model_parallel_streaming(model, x, N, initial_mem=None) + return result + + +class _TAEModel(nn.Module): + """Tiny AutoEncoder model for Wan 2.1. + + This is a lightweight VAE designed for quick previews. It uses a different + architecture than the standard WanVAE, with MemBlocks for temporal processing. + + Supports two decode modes: + - Batch mode (decode_video): Process all frames at once + - Streaming mode (stream_decode): Process frames incrementally with persistent memory + """ + + def __init__( + self, + checkpoint_path: str | None = None, + decoder_time_upscale: tuple[bool, bool] = (True, True), + decoder_space_upscale: tuple[bool, bool, bool] = (True, True, True), + patch_size: int = 1, + latent_channels: int = 16, + ): + """Initialize TAE model. + + Args: + checkpoint_path: Path to weight file (.pth or .safetensors) + decoder_time_upscale: Whether temporal upsampling is enabled for each block + decoder_space_upscale: Whether spatial upsampling is enabled for each block + patch_size: Input/output pixelshuffle patch-size (1 for Wan 2.1) + latent_channels: Number of latent channels (16 for Wan 2.1) + """ + super().__init__() + self.patch_size = patch_size + self.latent_channels = latent_channels + self.image_channels = 3 + + # Wan 2.1 uses ReLU activation + act_func = nn.ReLU(inplace=True) + + # Encoder: 64 channels throughout, simple architecture + self.encoder = nn.Sequential( + _conv(self.image_channels * self.patch_size**2, 64), + act_func, + _TPool(64, 2), + _conv(64, 64, stride=2, bias=False), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _TPool(64, 2), + _conv(64, 64, stride=2, bias=False), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _TPool(64, 1), + _conv(64, 64, stride=2, bias=False), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _conv(64, self.latent_channels), + ) + + # Decoder with configurable upscaling + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 + self._decoder_time_upscale = decoder_time_upscale + + self.decoder = nn.Sequential( + _Clamp(), + _conv(self.latent_channels, n_f[0]), + act_func, + _MemBlock(n_f[0], n_f[0], act_func), + _MemBlock(n_f[0], n_f[0], act_func), + _MemBlock(n_f[0], n_f[0], act_func), + nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), + _TGrow(n_f[0], 1), + _conv(n_f[0], n_f[1], bias=False), + _MemBlock(n_f[1], n_f[1], act_func), + _MemBlock(n_f[1], n_f[1], act_func), + _MemBlock(n_f[1], n_f[1], act_func), + nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), + _TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), + _conv(n_f[1], n_f[2], bias=False), + _MemBlock(n_f[2], n_f[2], act_func), + _MemBlock(n_f[2], n_f[2], act_func), + _MemBlock(n_f[2], n_f[2], act_func), + nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), + _TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), + _conv(n_f[2], n_f[3], bias=False), + act_func, + _conv(n_f[3], self.image_channels * self.patch_size**2), + ) + + # Streaming state for parallel streaming encode/decode + self._encoder_mem: list[torch.Tensor | None] | None = None + self._decoder_mem: list[torch.Tensor | None] | None = None + self._frames_output: int = 0 # Track output frames for trim handling + + if checkpoint_path is not None: + ext = os.path.splitext(checkpoint_path)[1].lower() + if ext == ".pth": + state_dict = torch.load( + checkpoint_path, map_location="cpu", weights_only=True + ) + elif ext == ".safetensors": + state_dict = load_file(checkpoint_path, device="cpu") + else: + raise ValueError( + f"_TAEModel.__init__: Unsupported checkpoint format: {ext}. " + "Supported: .pth, .safetensors" + ) + self.load_state_dict(self._patch_tgrow_layers(state_dict)) + + def _patch_tgrow_layers(self, sd: dict) -> dict: + """Patch TGrow layers to use a smaller kernel if needed. + + Args: + sd: state dict to patch + + Returns: + Patched state dict + """ + new_sd = self.state_dict() + for i, layer in enumerate(self.decoder): + if isinstance(layer, _TGrow): + key = f"decoder.{i}.conv.weight" + if sd[key].shape[0] > new_sd[key].shape[0]: + # Take the last-timestep output channels + sd[key] = sd[key][-new_sd[key].shape[0] :] + return sd + + def clear_decode_state(self): + """Clear decoder streaming state for a new sequence.""" + self._decoder_mem = None + self._frames_output = 0 + + def clear_encode_state(self): + """Clear encoder streaming state for a new sequence.""" + self._encoder_mem = None + + def stream_encode( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Encode frames in streaming mode with persistent memory. + + This uses parallel processing within each batch for speed, while maintaining + MemBlock memory across batches for smooth temporal continuity at chunk + boundaries. + + Unlike encode_video, this maintains state across calls. + Call clear_encode_state() before a new sequence. + + Args: + x: input NTCHW RGB (C=3) tensor with values in [0, 1] + + Returns: + NTCHW latent tensor with approximately Gaussian values + """ + if self.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) + if x.shape[1] % 4 != 0: + # Pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + + N, T, C, H, W = x.shape + x_flat = x.reshape(N * T, C, H, W) + + result, self._encoder_mem = _apply_model_parallel_streaming( + self.encoder, + x_flat, + N, + initial_mem=self._encoder_mem, + ) + + return result + + def encode_video( + self, + x: torch.Tensor, + parallel: bool = True, + show_progress_bar: bool = False, + ) -> torch.Tensor: + """Encode a sequence of frames. + + Args: + x: input NTCHW RGB (C=3) tensor with values in [0, 1] + parallel: if True, all frames processed at once (faster, more memory) + if False, frames processed sequentially (slower, O(1) memory) + show_progress_bar: if True, display tqdm progress bar + + Returns: + NTCHW latent tensor with approximately Gaussian values + """ + if self.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) + if x.shape[1] % 4 != 0: + # Pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + return _apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar) + + def decode_video( + self, + x: torch.Tensor, + parallel: bool = True, + show_progress_bar: bool = False, + ) -> torch.Tensor: + """Decode a sequence of frames (batch mode). + + Args: + x: input NTCHW latent tensor with approximately Gaussian values + parallel: if True, all frames processed at once (faster, more memory) + if False, frames processed sequentially (slower, O(1) memory) + show_progress_bar: if True, display tqdm progress bar + + Returns: + NTCHW RGB tensor with values clamped to [0, 1] + """ + x = _apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar) + x = x.clamp_(0, 1) + if self.patch_size > 1: + x = F.pixel_shuffle(x, self.patch_size) + return x[:, self.frames_to_trim :] + + def stream_decode( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Decode frames in streaming mode with persistent memory. + + This uses parallel processing within each batch for speed, while maintaining + MemBlock memory across batches for smooth temporal continuity. + + On the first batch, frames are processed sequentially (first frame separately, + then remaining frames) to match WanVAE warmup behavior for better temporal + consistency. + + Unlike decode_video, this maintains state across calls. + Call clear_decode_state() before a new sequence. + + Args: + x: input NTCHW latent tensor (typically 1-4 frames at a time) + + Returns: + NTCHW RGB tensor with values in [0, 1]. + First call returns fewer frames due to temporal trim. + """ + N, T, C, H, W = x.shape + + # First batch warmup: process first frame separately, then remaining frames + # This matches WanVAE's warmup behavior for better temporal consistency + if self._frames_output == 0: + # Clear decoder memory state for first batch + self._decoder_mem = None + + # Process first frame separately + first_frame = x[:, :1, :, :, :] # [N, 1, C, H, W] + first_flat = first_frame.reshape(N * 1, C, H, W) + first_result, first_mem = _apply_model_parallel_streaming( + self.decoder, + first_flat, + N, + initial_mem=None, # Use zeros for first frame + ) + + # Process remaining frames if any + if T > 1: + remaining_frames = x[:, 1:, :, :, :] # [N, T-1, C, H, W] + remaining_flat = remaining_frames.reshape(N * (T - 1), C, H, W) + remaining_result, self._decoder_mem = _apply_model_parallel_streaming( + self.decoder, + remaining_flat, + N, + initial_mem=first_mem, # Use memory from first frame + ) + # Concatenate first frame and remaining frames + result = torch.cat([first_result, remaining_result], dim=1) + else: + # Only one frame + result = first_result + self._decoder_mem = first_mem + else: + # Subsequent batches: use parallel processing with persistent memory + x_flat = x.reshape(N * T, C, H, W) + result, self._decoder_mem = _apply_model_parallel_streaming( + self.decoder, + x_flat, + N, + initial_mem=self._decoder_mem, + ) + + result = result.clamp_(0, 1) + + if self.patch_size > 1: + result = F.pixel_shuffle(result, self.patch_size) + + # Handle temporal trim - only trim on first output + if self._frames_output == 0 and result.shape[1] > self.frames_to_trim: + result = result[:, self.frames_to_trim :] + + self._frames_output += result.shape[1] + + return result + + +class TAEWrapper(nn.Module): + """TAE wrapper with interface matching WanVAEWrapper. + + This provides a consistent interface for the Tiny AutoEncoder that matches + the WanVAEWrapper's encode_to_latent/decode_to_pixel/clear_cache API. + + Note: TAE is a lightweight preview encoder with its own latent space. It does + NOT use WanVAE's normalization constants - TAE produces approximately Gaussian + latents directly. Quality may be lower than WanVAE but encoding/decoding is faster. + + Streaming mode (use_cache=True): + TAE maintains persistent MemBlock memory for smooth frame-by-frame streaming. + This is faster than batch mode for real-time applications since it processes + smaller chunks while maintaining temporal continuity. + + Batch mode (use_cache=False): + Processes all frames at once without persistent state. Good for one-shot + encoding/decoding of complete videos. + + Args: + model_dir: Base directory containing model files + model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") + vae_path: Explicit path to TAE checkpoint (overrides model_dir/model_name) + """ + + def __init__( + self, + model_dir: str = "wan_models", + model_name: str = "Wan2.1-T2V-1.3B", + vae_path: str | None = None, + ): + super().__init__() + + # Determine checkpoint path + if vae_path is None: + vae_path = os.path.join(model_dir, model_name, DEFAULT_TAE_FILENAME) + + self.z_dim = 16 + + # Create TAE model + self.model = ( + _TAEModel( + checkpoint_path=vae_path, + patch_size=1, + latent_channels=self.z_dim, + ) + .eval() + .requires_grad_(False) + ) + + # Track state for streaming + self._first_batch = True + + def encode_to_latent( + self, + pixel: torch.Tensor, + use_cache: bool = True, + feat_cache: list | None = None, + ) -> torch.Tensor: + """Encode video pixels to latents. + + Args: + pixel: Input video tensor [batch, channels, frames, height, width] + use_cache: If True, use streaming encode with persistent memory. + If False, use batch encode (clears state). + feat_cache: Unused (kept for interface compatibility with WanVAEWrapper) + + Returns: + Latent tensor [batch, frames, channels, height, width] + + Note: + TAE produces approximately Gaussian latents directly without additional + normalization. The latent space is similar to but not identical to WanVAE. + + In streaming mode (use_cache=True), TAE maintains MemBlock state across + calls for smooth temporal continuity at chunk boundaries. + """ + # [batch, channels, frames, h, w] -> [batch, frames, channels, h, w] for TAE + pixel_ntchw = pixel.permute(0, 2, 1, 3, 4) + + # Scale from [-1, 1] to [0, 1] range expected by TAE + pixel_ntchw = (pixel_ntchw + 1) / 2 + + if use_cache: + # Streaming mode - use parallel processing with persistent memory + if self._first_batch: + self.model.clear_encode_state() + + latent = self.model.stream_encode(pixel_ntchw) + else: + # Batch mode - no persistent state + latent = self.model.encode_video( + pixel_ntchw, parallel=True, show_progress_bar=False + ) + + # Return in [batch, frames, channels, h, w] format + return latent + + def decode_to_pixel( + self, latent: torch.Tensor, use_cache: bool = True + ) -> torch.Tensor: + """Decode latents to video pixels. + + Args: + latent: Latent tensor [batch, frames, channels, height, width] + use_cache: If True, use streaming decode with persistent memory. + If False, use batch decode (clears state). + + Returns: + Video tensor [batch, frames, channels, height, width] in range [-1, 1] + + Note: + In streaming mode (use_cache=True), TAE maintains MemBlock state across + calls for smooth temporal continuity. Uses parallel processing within + each batch for speed. The first call may have fewer output frames due + to TGrow temporal expansion and frame trimming. + """ + if use_cache: + # Streaming mode - use parallel processing with persistent memory + if self._first_batch: + self.model.clear_decode_state() + self._first_batch = False + + output = self.model.stream_decode(latent) + else: + # Batch mode - no persistent state + output = self.model.decode_video( + latent, parallel=True, show_progress_bar=False + ) + + # Scale from [0, 1] to [-1, 1] range + output = output * 2 - 1 + output = output.clamp_(-1, 1) + + # Return in [batch, frames, channels, h, w] format + return output + + def clear_cache(self): + """Clear state for next sequence.""" + self._first_batch = True + self.model.clear_encode_state() + self.model.clear_decode_state()