= {};
for (const [id, schema] of Object.entries(schemas.pipelines)) {
+ // Extract VAE types from JSON schema if vae_type field exists
+ // Pydantic v2 represents enum fields using $ref to definitions
+ let vaeTypes: string[] | undefined = undefined;
+ const vaeTypeProperty = schema.config_schema?.properties?.vae_type;
+ if (vaeTypeProperty?.$ref && schema.config_schema?.$defs) {
+ const refPath = vaeTypeProperty.$ref;
+ const defName = refPath.split("/").pop();
+ const definition = schema.config_schema.$defs[defName || ""];
+ if (definition && Array.isArray(definition.enum)) {
+ vaeTypes = definition.enum as string[];
+ }
+ }
+
transformed[id] = {
name: schema.name,
about: schema.description,
@@ -41,12 +54,11 @@ export function usePipelines() {
supportsCacheManagement: schema.supports_cache_management,
supportsKvCacheBias: schema.supports_kv_cache_bias,
supportsQuantization: schema.supports_quantization,
- supportsVaeType: schema.supports_vae_type,
minDimension: schema.min_dimension,
recommendedQuantizationVramThreshold:
schema.recommended_quantization_vram_threshold ?? undefined,
modified: schema.modified,
- vaeTypes: schema.vae_types ?? undefined,
+ vaeTypes,
};
}
diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts
index 2f212dfa2..3d93d15c7 100644
--- a/frontend/src/lib/api.ts
+++ b/frontend/src/lib/api.ts
@@ -356,6 +356,8 @@ export interface PipelineSchemaProperty {
maximum?: number;
items?: unknown;
anyOf?: unknown[];
+ enum?: unknown[];
+ $ref?: string;
}
export interface PipelineConfigSchema {
@@ -363,6 +365,7 @@ export interface PipelineConfigSchema {
properties: Record;
required?: string[];
title?: string;
+ $defs?: Record;
}
// Mode-specific default overrides
@@ -399,12 +402,9 @@ export interface PipelineSchemaInfo {
supports_cache_management: boolean;
supports_kv_cache_bias: boolean;
supports_quantization: boolean;
- supports_vae_type: boolean;
min_dimension: number;
recommended_quantization_vram_threshold: number | null;
modified: boolean;
- // Available VAE types from config schema enum
- vae_types?: string[];
}
export interface PipelineSchemasResponse {
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 23baae1ce..b269a0f70 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -103,10 +103,9 @@ export interface PipelineInfo {
supportsCacheManagement?: boolean;
supportsKvCacheBias?: boolean;
supportsQuantization?: boolean;
- supportsVaeType?: boolean;
minDimension?: number;
recommendedQuantizationVramThreshold?: number | null;
- // Available VAE types from config schema enum
+ // Available VAE types from config schema enum (derived from vae_type field presence)
vaeTypes?: string[];
}
diff --git a/src/scope/core/pipelines/base_schema.py b/src/scope/core/pipelines/base_schema.py
index 6bd3bcd3e..e06c52dd0 100644
--- a/src/scope/core/pipelines/base_schema.py
+++ b/src/scope/core/pipelines/base_schema.py
@@ -160,7 +160,6 @@ class BasePipelineConfig(BaseModel):
supports_cache_management: ClassVar[bool] = False
supports_kv_cache_bias: ClassVar[bool] = False
supports_quantization: ClassVar[bool] = False
- supports_vae_type: ClassVar[bool] = False
min_dimension: ClassVar[int] = 1
# Whether this pipeline contains modifications based on the original project
modified: ClassVar[bool] = False
@@ -289,7 +288,6 @@ def get_schema_with_metadata(cls) -> dict[str, Any]:
metadata["supports_cache_management"] = cls.supports_cache_management
metadata["supports_kv_cache_bias"] = cls.supports_kv_cache_bias
metadata["supports_quantization"] = cls.supports_quantization
- metadata["supports_vae_type"] = cls.supports_vae_type
metadata["min_dimension"] = cls.min_dimension
metadata["recommended_quantization_vram_threshold"] = (
cls.recommended_quantization_vram_threshold
@@ -297,12 +295,6 @@ def get_schema_with_metadata(cls) -> dict[str, Any]:
metadata["modified"] = cls.modified
metadata["config_schema"] = cls.model_json_schema()
- # Extract VAE types from schema if pipeline supports it
- if cls.supports_vae_type:
- from .utils import VaeType
-
- metadata["vae_types"] = [vae_type.value for vae_type in VaeType]
-
# Include mode-specific defaults (excluding None values and the "default" flag)
mode_defaults = {}
for mode_name, mode_config in cls.modes.items():
diff --git a/src/scope/core/pipelines/krea_realtime_video/schema.py b/src/scope/core/pipelines/krea_realtime_video/schema.py
index 2dd6b627c..8fdfcf058 100644
--- a/src/scope/core/pipelines/krea_realtime_video/schema.py
+++ b/src/scope/core/pipelines/krea_realtime_video/schema.py
@@ -1,4 +1,7 @@
+from pydantic import Field
+
from ..base_schema import BasePipelineConfig, ModeDefaults
+from ..utils import VaeType
class KreaRealtimeVideoConfig(BasePipelineConfig):
@@ -16,7 +19,6 @@ class KreaRealtimeVideoConfig(BasePipelineConfig):
supports_cache_management = True
supports_kv_cache_bias = True
supports_quantization = True
- supports_vae_type = True
min_dimension = 16
modified = True
recommended_quantization_vram_threshold = 40.0
@@ -27,6 +29,10 @@ class KreaRealtimeVideoConfig(BasePipelineConfig):
height: int = 320
width: int = 576
denoising_steps: list[int] = [1000, 750, 500, 250]
+ vae_type: VaeType = Field(
+ default=VaeType.WAN,
+ description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
+ )
modes = {
"text": ModeDefaults(default=True),
diff --git a/src/scope/core/pipelines/longlive/schema.py b/src/scope/core/pipelines/longlive/schema.py
index 3605c76ca..c70c1e96a 100644
--- a/src/scope/core/pipelines/longlive/schema.py
+++ b/src/scope/core/pipelines/longlive/schema.py
@@ -1,4 +1,7 @@
+from pydantic import Field
+
from ..base_schema import BasePipelineConfig, ModeDefaults
+from ..utils import VaeType
class LongLiveConfig(BasePipelineConfig):
@@ -17,13 +20,16 @@ class LongLiveConfig(BasePipelineConfig):
supports_cache_management = True
supports_quantization = True
- supports_vae_type = True
min_dimension = 16
modified = True
height: int = 320
width: int = 576
denoising_steps: list[int] = [1000, 750, 500, 250]
+ vae_type: VaeType = Field(
+ default=VaeType.WAN,
+ description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
+ )
modes = {
"text": ModeDefaults(default=True),
diff --git a/src/scope/core/pipelines/streamdiffusionv2/schema.py b/src/scope/core/pipelines/streamdiffusionv2/schema.py
index 2d797ea4b..a356578e9 100644
--- a/src/scope/core/pipelines/streamdiffusionv2/schema.py
+++ b/src/scope/core/pipelines/streamdiffusionv2/schema.py
@@ -1,4 +1,7 @@
+from pydantic import Field
+
from ..base_schema import BasePipelineConfig, ModeDefaults
+from ..utils import VaeType
class StreamDiffusionV2Config(BasePipelineConfig):
@@ -17,7 +20,6 @@ class StreamDiffusionV2Config(BasePipelineConfig):
supports_cache_management = True
supports_quantization = True
- supports_vae_type = True
min_dimension = 16
modified = True
@@ -25,6 +27,10 @@ class StreamDiffusionV2Config(BasePipelineConfig):
noise_scale: float = 0.7
noise_controller: bool = True
input_size: int = 4
+ vae_type: VaeType = Field(
+ default=VaeType.WAN,
+ description="VAE type to use. 'wan' is the full VAE, 'lightvae' is 75% pruned (faster but lower quality).",
+ )
modes = {
"text": ModeDefaults(
From 70cbfbb9cad6c09a2fbaa3eb4be555945528e55e Mon Sep 17 00:00:00 2001
From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com>
Date: Thu, 4 Dec 2025 14:03:19 -0500
Subject: [PATCH 4/4] feat: tae
Signed-off-by: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com>
---
frontend/src/data/parameterMetadata.ts | 2 +-
src/scope/core/pipelines/utils.py | 1 +
.../core/pipelines/wan2_1/vae/__init__.py | 5 +-
src/scope/core/pipelines/wan2_1/vae/tae.py | 626 ++++++++++++++++++
4 files changed, 632 insertions(+), 2 deletions(-)
create mode 100644 src/scope/core/pipelines/wan2_1/vae/tae.py
diff --git a/frontend/src/data/parameterMetadata.ts b/frontend/src/data/parameterMetadata.ts
index 1d875f2fd..402821b0a 100644
--- a/frontend/src/data/parameterMetadata.ts
+++ b/frontend/src/data/parameterMetadata.ts
@@ -85,6 +85,6 @@ export const PARAMETER_METADATA: Record = {
vaeType: {
label: "VAE:",
tooltip:
- "VAE type to use for encoding/decoding. 'wan' is the full VAE with best quality. 'lightvae' is 75% pruned for faster performance but lower quality.",
+ "VAE type to use for encoding/decoding. 'wan' is the full VAE with best quality. 'lightvae' is 75% pruned for faster performance but lower quality. 'tae' is a tiny autoencoder for fast preview quality.",
},
};
diff --git a/src/scope/core/pipelines/utils.py b/src/scope/core/pipelines/utils.py
index f9978b8e7..d2ffbe103 100644
--- a/src/scope/core/pipelines/utils.py
+++ b/src/scope/core/pipelines/utils.py
@@ -19,6 +19,7 @@ class VaeType(str, Enum):
WAN = "wan"
LIGHTVAE = "lightvae"
+ TAE = "tae"
def load_state_dict(weights_path: str) -> dict:
diff --git a/src/scope/core/pipelines/wan2_1/vae/__init__.py b/src/scope/core/pipelines/wan2_1/vae/__init__.py
index 50bdb16e7..fae101e7e 100644
--- a/src/scope/core/pipelines/wan2_1/vae/__init__.py
+++ b/src/scope/core/pipelines/wan2_1/vae/__init__.py
@@ -21,6 +21,7 @@
from functools import partial
+from .tae import TAEWrapper
from .wan import WanVAEWrapper
# Registry mapping type names to VAE factory functions
@@ -28,6 +29,7 @@
VAE_REGISTRY: dict[str, type] = {
"wan": WanVAEWrapper,
"lightvae": partial(WanVAEWrapper, use_lightvae=True),
+ "tae": TAEWrapper,
}
DEFAULT_VAE_TYPE = "wan"
@@ -38,7 +40,7 @@ def create_vae(
model_name: str = "Wan2.1-T2V-1.3B",
vae_type: str | None = None,
vae_path: str | None = None,
-) -> WanVAEWrapper:
+) -> WanVAEWrapper | TAEWrapper:
"""Create VAE instance by type.
Args:
@@ -69,6 +71,7 @@ def create_vae(
__all__ = [
"WanVAEWrapper",
+ "TAEWrapper",
"create_vae",
"VAE_REGISTRY",
"DEFAULT_VAE_TYPE",
diff --git a/src/scope/core/pipelines/wan2_1/vae/tae.py b/src/scope/core/pipelines/wan2_1/vae/tae.py
new file mode 100644
index 000000000..ff1f2aa87
--- /dev/null
+++ b/src/scope/core/pipelines/wan2_1/vae/tae.py
@@ -0,0 +1,626 @@
+# Adapted from https://github.com/ModelTC/LightX2V/blob/main/lightx2v/models/video_encoders/hf/tae.py
+"""Tiny AutoEncoder (TAE) wrapper for Wan2.1 models.
+
+TAE is a lightweight alternative VAE architecture from the LightX2V project.
+Unlike WanVAE, TAE is a completely different architecture - a much smaller/faster
+model designed for quick encoding/decoding previews.
+
+Key differences from WanVAE:
+- Uses MemBlock for temporal memory (different from CausalConv3d caching)
+- Has TPool/TGrow blocks for temporal downsampling/upsampling
+- Much simpler architecture with 64 channels throughout encoder
+- Approximately 4x temporal upscaling in decoder (TGrow blocks expand frames)
+
+Streaming mode:
+- TAE supports streaming decode via parallel processing with persistent MemBlock memory
+- Each batch is processed in parallel (fast) while memory state is maintained across batches
+- This provides both speed AND temporal continuity for smooth streaming
+- First decode call has fewer output frames due to TGrow expansion and frame trimming (3 frames)
+"""
+
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from safetensors.torch import load_file
+
+# Note: TAE does NOT use WAN_VAE_LATENT_MEAN/STD - it has its own latent space
+
+# Default checkpoint filename for Wan 2.1 TAE
+DEFAULT_TAE_FILENAME = "taew2_1.pth"
+
+
+def _conv(n_in: int, n_out: int, **kwargs) -> nn.Conv2d:
+ """Create a 3x3 Conv2d with padding."""
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
+
+
+class _Clamp(nn.Module):
+ """Clamp activation using scaled tanh."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return torch.tanh(x / 3) * 3
+
+
+class _MemBlock(nn.Module):
+ """Memory block that combines current input with past state."""
+
+ def __init__(self, n_in: int, n_out: int, act_func: nn.Module):
+ super().__init__()
+ self.conv = nn.Sequential(
+ _conv(n_in * 2, n_out),
+ act_func,
+ _conv(n_out, n_out),
+ act_func,
+ _conv(n_out, n_out),
+ )
+ self.skip = (
+ nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
+ )
+ self.act = act_func
+
+ def forward(self, x: torch.Tensor, past: torch.Tensor) -> torch.Tensor:
+ return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
+
+
+class _TPool(nn.Module):
+ """Temporal pooling block that combines multiple frames."""
+
+ def __init__(self, n_f: int, stride: int):
+ super().__init__()
+ self.stride = stride
+ self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ _NT, C, H, W = x.shape
+ return self.conv(x.reshape(-1, self.stride * C, H, W))
+
+
+class _TGrow(nn.Module):
+ """Temporal growth block that expands to multiple frames."""
+
+ def __init__(self, n_f: int, stride: int):
+ super().__init__()
+ self.stride = stride
+ self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ _NT, C, H, W = x.shape
+ x = self.conv(x)
+ return x.reshape(-1, C, H, W)
+
+
+def _apply_model_parallel_streaming(
+ model: nn.Sequential,
+ x: torch.Tensor,
+ N: int,
+ initial_mem: list[torch.Tensor | None] | None = None,
+) -> tuple[torch.Tensor, list[torch.Tensor | None]]:
+ """Apply model in parallel mode with streaming memory support.
+
+ This processes all frames in parallel (fast) while maintaining temporal
+ continuity across batches by using initial memory from the previous batch.
+
+ Args:
+ model: nn.Sequential of blocks to apply
+ x: input data reshaped to (N*T, C, H, W)
+ N: batch size (for reshaping)
+ initial_mem: Initial memory values for each MemBlock (from previous batch).
+ If None, uses zeros for first batch.
+
+ Returns:
+ Tuple of (NTCHW output tensor, list of final memory values for next batch)
+ """
+ # Count MemBlocks for memory initialization
+ num_memblocks = sum(1 for b in model if isinstance(b, _MemBlock))
+
+ # Initialize memory list if not provided
+ if initial_mem is None:
+ initial_mem = [None] * num_memblocks
+
+ # Track which MemBlock we're at
+ mem_idx = 0
+ final_mem = []
+
+ for b in model:
+ if isinstance(b, _MemBlock):
+ NT, C, H, W = x.shape
+ T = NT // N
+ _x = x.reshape(N, T, C, H, W)
+
+ # Create memory: pad with initial_mem at t=0, then shift frames
+ if initial_mem[mem_idx] is not None:
+ # Use previous batch's last frame as initial memory
+ init_frame = initial_mem[mem_idx].reshape(N, 1, C, H, W)
+ mem = torch.cat([init_frame, _x[:, :-1]], dim=1).reshape(x.shape)
+ else:
+ # First batch - use zeros
+ mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(
+ x.shape
+ )
+
+ # Save last frame for next batch (input before processing)
+ final_mem.append(_x[:, -1:].reshape(N, C, H, W).clone())
+ mem_idx += 1
+
+ x = b(x, mem)
+ else:
+ x = b(x)
+
+ NT, C, H, W = x.shape
+ T = NT // N
+ return x.view(N, T, C, H, W), final_mem
+
+
+def _apply_model_with_memblocks(
+ model: nn.Sequential,
+ x: torch.Tensor,
+ parallel: bool = True,
+ show_progress_bar: bool = False,
+) -> torch.Tensor:
+ """Apply a sequential model with memblocks to the given input (batch mode).
+
+ Args:
+ model: nn.Sequential of blocks to apply
+ x: input data, of dimensions NTCHW
+ parallel: unused, kept for API compatibility (always uses parallel)
+ show_progress_bar: unused, kept for API compatibility
+
+ Returns:
+ NTCHW tensor of output data.
+ """
+ assert x.ndim == 5, f"_apply_model_with_memblocks: TAE expects NTCHW, got {x.ndim}D"
+ N, T, C, H, W = x.shape
+ x = x.reshape(N * T, C, H, W)
+ result, _ = _apply_model_parallel_streaming(model, x, N, initial_mem=None)
+ return result
+
+
+class _TAEModel(nn.Module):
+ """Tiny AutoEncoder model for Wan 2.1.
+
+ This is a lightweight VAE designed for quick previews. It uses a different
+ architecture than the standard WanVAE, with MemBlocks for temporal processing.
+
+ Supports two decode modes:
+ - Batch mode (decode_video): Process all frames at once
+ - Streaming mode (stream_decode): Process frames incrementally with persistent memory
+ """
+
+ def __init__(
+ self,
+ checkpoint_path: str | None = None,
+ decoder_time_upscale: tuple[bool, bool] = (True, True),
+ decoder_space_upscale: tuple[bool, bool, bool] = (True, True, True),
+ patch_size: int = 1,
+ latent_channels: int = 16,
+ ):
+ """Initialize TAE model.
+
+ Args:
+ checkpoint_path: Path to weight file (.pth or .safetensors)
+ decoder_time_upscale: Whether temporal upsampling is enabled for each block
+ decoder_space_upscale: Whether spatial upsampling is enabled for each block
+ patch_size: Input/output pixelshuffle patch-size (1 for Wan 2.1)
+ latent_channels: Number of latent channels (16 for Wan 2.1)
+ """
+ super().__init__()
+ self.patch_size = patch_size
+ self.latent_channels = latent_channels
+ self.image_channels = 3
+
+ # Wan 2.1 uses ReLU activation
+ act_func = nn.ReLU(inplace=True)
+
+ # Encoder: 64 channels throughout, simple architecture
+ self.encoder = nn.Sequential(
+ _conv(self.image_channels * self.patch_size**2, 64),
+ act_func,
+ _TPool(64, 2),
+ _conv(64, 64, stride=2, bias=False),
+ _MemBlock(64, 64, act_func),
+ _MemBlock(64, 64, act_func),
+ _MemBlock(64, 64, act_func),
+ _TPool(64, 2),
+ _conv(64, 64, stride=2, bias=False),
+ _MemBlock(64, 64, act_func),
+ _MemBlock(64, 64, act_func),
+ _MemBlock(64, 64, act_func),
+ _TPool(64, 1),
+ _conv(64, 64, stride=2, bias=False),
+ _MemBlock(64, 64, act_func),
+ _MemBlock(64, 64, act_func),
+ _MemBlock(64, 64, act_func),
+ _conv(64, self.latent_channels),
+ )
+
+ # Decoder with configurable upscaling
+ n_f = [256, 128, 64, 64]
+ self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
+ self._decoder_time_upscale = decoder_time_upscale
+
+ self.decoder = nn.Sequential(
+ _Clamp(),
+ _conv(self.latent_channels, n_f[0]),
+ act_func,
+ _MemBlock(n_f[0], n_f[0], act_func),
+ _MemBlock(n_f[0], n_f[0], act_func),
+ _MemBlock(n_f[0], n_f[0], act_func),
+ nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
+ _TGrow(n_f[0], 1),
+ _conv(n_f[0], n_f[1], bias=False),
+ _MemBlock(n_f[1], n_f[1], act_func),
+ _MemBlock(n_f[1], n_f[1], act_func),
+ _MemBlock(n_f[1], n_f[1], act_func),
+ nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
+ _TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
+ _conv(n_f[1], n_f[2], bias=False),
+ _MemBlock(n_f[2], n_f[2], act_func),
+ _MemBlock(n_f[2], n_f[2], act_func),
+ _MemBlock(n_f[2], n_f[2], act_func),
+ nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
+ _TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
+ _conv(n_f[2], n_f[3], bias=False),
+ act_func,
+ _conv(n_f[3], self.image_channels * self.patch_size**2),
+ )
+
+ # Streaming state for parallel streaming encode/decode
+ self._encoder_mem: list[torch.Tensor | None] | None = None
+ self._decoder_mem: list[torch.Tensor | None] | None = None
+ self._frames_output: int = 0 # Track output frames for trim handling
+
+ if checkpoint_path is not None:
+ ext = os.path.splitext(checkpoint_path)[1].lower()
+ if ext == ".pth":
+ state_dict = torch.load(
+ checkpoint_path, map_location="cpu", weights_only=True
+ )
+ elif ext == ".safetensors":
+ state_dict = load_file(checkpoint_path, device="cpu")
+ else:
+ raise ValueError(
+ f"_TAEModel.__init__: Unsupported checkpoint format: {ext}. "
+ "Supported: .pth, .safetensors"
+ )
+ self.load_state_dict(self._patch_tgrow_layers(state_dict))
+
+ def _patch_tgrow_layers(self, sd: dict) -> dict:
+ """Patch TGrow layers to use a smaller kernel if needed.
+
+ Args:
+ sd: state dict to patch
+
+ Returns:
+ Patched state dict
+ """
+ new_sd = self.state_dict()
+ for i, layer in enumerate(self.decoder):
+ if isinstance(layer, _TGrow):
+ key = f"decoder.{i}.conv.weight"
+ if sd[key].shape[0] > new_sd[key].shape[0]:
+ # Take the last-timestep output channels
+ sd[key] = sd[key][-new_sd[key].shape[0] :]
+ return sd
+
+ def clear_decode_state(self):
+ """Clear decoder streaming state for a new sequence."""
+ self._decoder_mem = None
+ self._frames_output = 0
+
+ def clear_encode_state(self):
+ """Clear encoder streaming state for a new sequence."""
+ self._encoder_mem = None
+
+ def stream_encode(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ """Encode frames in streaming mode with persistent memory.
+
+ This uses parallel processing within each batch for speed, while maintaining
+ MemBlock memory across batches for smooth temporal continuity at chunk
+ boundaries.
+
+ Unlike encode_video, this maintains state across calls.
+ Call clear_encode_state() before a new sequence.
+
+ Args:
+ x: input NTCHW RGB (C=3) tensor with values in [0, 1]
+
+ Returns:
+ NTCHW latent tensor with approximately Gaussian values
+ """
+ if self.patch_size > 1:
+ x = F.pixel_unshuffle(x, self.patch_size)
+ if x.shape[1] % 4 != 0:
+ # Pad at end to multiple of 4
+ n_pad = 4 - x.shape[1] % 4
+ padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
+ x = torch.cat([x, padding], 1)
+
+ N, T, C, H, W = x.shape
+ x_flat = x.reshape(N * T, C, H, W)
+
+ result, self._encoder_mem = _apply_model_parallel_streaming(
+ self.encoder,
+ x_flat,
+ N,
+ initial_mem=self._encoder_mem,
+ )
+
+ return result
+
+ def encode_video(
+ self,
+ x: torch.Tensor,
+ parallel: bool = True,
+ show_progress_bar: bool = False,
+ ) -> torch.Tensor:
+ """Encode a sequence of frames.
+
+ Args:
+ x: input NTCHW RGB (C=3) tensor with values in [0, 1]
+ parallel: if True, all frames processed at once (faster, more memory)
+ if False, frames processed sequentially (slower, O(1) memory)
+ show_progress_bar: if True, display tqdm progress bar
+
+ Returns:
+ NTCHW latent tensor with approximately Gaussian values
+ """
+ if self.patch_size > 1:
+ x = F.pixel_unshuffle(x, self.patch_size)
+ if x.shape[1] % 4 != 0:
+ # Pad at end to multiple of 4
+ n_pad = 4 - x.shape[1] % 4
+ padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
+ x = torch.cat([x, padding], 1)
+ return _apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
+
+ def decode_video(
+ self,
+ x: torch.Tensor,
+ parallel: bool = True,
+ show_progress_bar: bool = False,
+ ) -> torch.Tensor:
+ """Decode a sequence of frames (batch mode).
+
+ Args:
+ x: input NTCHW latent tensor with approximately Gaussian values
+ parallel: if True, all frames processed at once (faster, more memory)
+ if False, frames processed sequentially (slower, O(1) memory)
+ show_progress_bar: if True, display tqdm progress bar
+
+ Returns:
+ NTCHW RGB tensor with values clamped to [0, 1]
+ """
+ x = _apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
+ x = x.clamp_(0, 1)
+ if self.patch_size > 1:
+ x = F.pixel_shuffle(x, self.patch_size)
+ return x[:, self.frames_to_trim :]
+
+ def stream_decode(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ """Decode frames in streaming mode with persistent memory.
+
+ This uses parallel processing within each batch for speed, while maintaining
+ MemBlock memory across batches for smooth temporal continuity.
+
+ On the first batch, frames are processed sequentially (first frame separately,
+ then remaining frames) to match WanVAE warmup behavior for better temporal
+ consistency.
+
+ Unlike decode_video, this maintains state across calls.
+ Call clear_decode_state() before a new sequence.
+
+ Args:
+ x: input NTCHW latent tensor (typically 1-4 frames at a time)
+
+ Returns:
+ NTCHW RGB tensor with values in [0, 1].
+ First call returns fewer frames due to temporal trim.
+ """
+ N, T, C, H, W = x.shape
+
+ # First batch warmup: process first frame separately, then remaining frames
+ # This matches WanVAE's warmup behavior for better temporal consistency
+ if self._frames_output == 0:
+ # Clear decoder memory state for first batch
+ self._decoder_mem = None
+
+ # Process first frame separately
+ first_frame = x[:, :1, :, :, :] # [N, 1, C, H, W]
+ first_flat = first_frame.reshape(N * 1, C, H, W)
+ first_result, first_mem = _apply_model_parallel_streaming(
+ self.decoder,
+ first_flat,
+ N,
+ initial_mem=None, # Use zeros for first frame
+ )
+
+ # Process remaining frames if any
+ if T > 1:
+ remaining_frames = x[:, 1:, :, :, :] # [N, T-1, C, H, W]
+ remaining_flat = remaining_frames.reshape(N * (T - 1), C, H, W)
+ remaining_result, self._decoder_mem = _apply_model_parallel_streaming(
+ self.decoder,
+ remaining_flat,
+ N,
+ initial_mem=first_mem, # Use memory from first frame
+ )
+ # Concatenate first frame and remaining frames
+ result = torch.cat([first_result, remaining_result], dim=1)
+ else:
+ # Only one frame
+ result = first_result
+ self._decoder_mem = first_mem
+ else:
+ # Subsequent batches: use parallel processing with persistent memory
+ x_flat = x.reshape(N * T, C, H, W)
+ result, self._decoder_mem = _apply_model_parallel_streaming(
+ self.decoder,
+ x_flat,
+ N,
+ initial_mem=self._decoder_mem,
+ )
+
+ result = result.clamp_(0, 1)
+
+ if self.patch_size > 1:
+ result = F.pixel_shuffle(result, self.patch_size)
+
+ # Handle temporal trim - only trim on first output
+ if self._frames_output == 0 and result.shape[1] > self.frames_to_trim:
+ result = result[:, self.frames_to_trim :]
+
+ self._frames_output += result.shape[1]
+
+ return result
+
+
+class TAEWrapper(nn.Module):
+ """TAE wrapper with interface matching WanVAEWrapper.
+
+ This provides a consistent interface for the Tiny AutoEncoder that matches
+ the WanVAEWrapper's encode_to_latent/decode_to_pixel/clear_cache API.
+
+ Note: TAE is a lightweight preview encoder with its own latent space. It does
+ NOT use WanVAE's normalization constants - TAE produces approximately Gaussian
+ latents directly. Quality may be lower than WanVAE but encoding/decoding is faster.
+
+ Streaming mode (use_cache=True):
+ TAE maintains persistent MemBlock memory for smooth frame-by-frame streaming.
+ This is faster than batch mode for real-time applications since it processes
+ smaller chunks while maintaining temporal continuity.
+
+ Batch mode (use_cache=False):
+ Processes all frames at once without persistent state. Good for one-shot
+ encoding/decoding of complete videos.
+
+ Args:
+ model_dir: Base directory containing model files
+ model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B")
+ vae_path: Explicit path to TAE checkpoint (overrides model_dir/model_name)
+ """
+
+ def __init__(
+ self,
+ model_dir: str = "wan_models",
+ model_name: str = "Wan2.1-T2V-1.3B",
+ vae_path: str | None = None,
+ ):
+ super().__init__()
+
+ # Determine checkpoint path
+ if vae_path is None:
+ vae_path = os.path.join(model_dir, model_name, DEFAULT_TAE_FILENAME)
+
+ self.z_dim = 16
+
+ # Create TAE model
+ self.model = (
+ _TAEModel(
+ checkpoint_path=vae_path,
+ patch_size=1,
+ latent_channels=self.z_dim,
+ )
+ .eval()
+ .requires_grad_(False)
+ )
+
+ # Track state for streaming
+ self._first_batch = True
+
+ def encode_to_latent(
+ self,
+ pixel: torch.Tensor,
+ use_cache: bool = True,
+ feat_cache: list | None = None,
+ ) -> torch.Tensor:
+ """Encode video pixels to latents.
+
+ Args:
+ pixel: Input video tensor [batch, channels, frames, height, width]
+ use_cache: If True, use streaming encode with persistent memory.
+ If False, use batch encode (clears state).
+ feat_cache: Unused (kept for interface compatibility with WanVAEWrapper)
+
+ Returns:
+ Latent tensor [batch, frames, channels, height, width]
+
+ Note:
+ TAE produces approximately Gaussian latents directly without additional
+ normalization. The latent space is similar to but not identical to WanVAE.
+
+ In streaming mode (use_cache=True), TAE maintains MemBlock state across
+ calls for smooth temporal continuity at chunk boundaries.
+ """
+ # [batch, channels, frames, h, w] -> [batch, frames, channels, h, w] for TAE
+ pixel_ntchw = pixel.permute(0, 2, 1, 3, 4)
+
+ # Scale from [-1, 1] to [0, 1] range expected by TAE
+ pixel_ntchw = (pixel_ntchw + 1) / 2
+
+ if use_cache:
+ # Streaming mode - use parallel processing with persistent memory
+ if self._first_batch:
+ self.model.clear_encode_state()
+
+ latent = self.model.stream_encode(pixel_ntchw)
+ else:
+ # Batch mode - no persistent state
+ latent = self.model.encode_video(
+ pixel_ntchw, parallel=True, show_progress_bar=False
+ )
+
+ # Return in [batch, frames, channels, h, w] format
+ return latent
+
+ def decode_to_pixel(
+ self, latent: torch.Tensor, use_cache: bool = True
+ ) -> torch.Tensor:
+ """Decode latents to video pixels.
+
+ Args:
+ latent: Latent tensor [batch, frames, channels, height, width]
+ use_cache: If True, use streaming decode with persistent memory.
+ If False, use batch decode (clears state).
+
+ Returns:
+ Video tensor [batch, frames, channels, height, width] in range [-1, 1]
+
+ Note:
+ In streaming mode (use_cache=True), TAE maintains MemBlock state across
+ calls for smooth temporal continuity. Uses parallel processing within
+ each batch for speed. The first call may have fewer output frames due
+ to TGrow temporal expansion and frame trimming.
+ """
+ if use_cache:
+ # Streaming mode - use parallel processing with persistent memory
+ if self._first_batch:
+ self.model.clear_decode_state()
+ self._first_batch = False
+
+ output = self.model.stream_decode(latent)
+ else:
+ # Batch mode - no persistent state
+ output = self.model.decode_video(
+ latent, parallel=True, show_progress_bar=False
+ )
+
+ # Scale from [0, 1] to [-1, 1] range
+ output = output * 2 - 1
+ output = output.clamp_(-1, 1)
+
+ # Return in [batch, frames, channels, h, w] format
+ return output
+
+ def clear_cache(self):
+ """Clear state for next sequence."""
+ self._first_batch = True
+ self.model.clear_encode_state()
+ self.model.clear_decode_state()