diff --git a/frontend/src/hooks/usePipeline.ts b/frontend/src/hooks/usePipeline.ts index 6347cdcd..57e4b6e9 100644 --- a/frontend/src/hooks/usePipeline.ts +++ b/frontend/src/hooks/usePipeline.ts @@ -36,10 +36,12 @@ export function usePipeline(options: UsePipelineOptions = {}) { if (shownErrorRef.current !== errorMessage) { toast.error("Pipeline Error", { description: errorMessage, - duration: 5000, + duration: 10000, // Show longer for detailed errors }); shownErrorRef.current = errorMessage; } + // Ensure loading state is cleared when error occurs + setIsLoading(false); // Don't set error in state - it's shown as toast and cleared on backend setError(null); } else { @@ -54,7 +56,7 @@ export function usePipeline(options: UsePipelineOptions = {}) { if (shownErrorRef.current !== errorMessage) { toast.error("Pipeline Error", { description: errorMessage, - duration: 5000, + duration: 10000, // Show longer for detailed errors }); shownErrorRef.current = errorMessage; } @@ -91,10 +93,12 @@ export function usePipeline(options: UsePipelineOptions = {}) { if (shownErrorRef.current !== errorMessage) { toast.error("Pipeline Error", { description: errorMessage, - duration: 5000, + duration: 10000, // Show longer for detailed errors }); shownErrorRef.current = errorMessage; } + // Ensure loading state is cleared when error occurs + setIsLoading(false); // Don't set error in state - it's shown as toast and cleared on backend setError(null); } else { @@ -108,6 +112,7 @@ export function usePipeline(options: UsePipelineOptions = {}) { statusResponse.status === "error" ) { stopPolling(); + setIsLoading(false); // Ensure loading is stopped return; } } catch (err) { @@ -118,7 +123,7 @@ export function usePipeline(options: UsePipelineOptions = {}) { if (shownErrorRef.current !== errorMessage) { toast.error("Pipeline Error", { description: errorMessage, - duration: 5000, + duration: 10000, // Show longer for detailed errors }); shownErrorRef.current = errorMessage; } @@ -182,10 +187,11 @@ export function usePipeline(options: UsePipelineOptions = {}) { if (shownErrorRef.current !== errorMsg) { toast.error("Pipeline Error", { description: errorMsg, - duration: 5000, + duration: 10000, // Show longer for detailed errors }); shownErrorRef.current = errorMsg; } + setIsLoading(false); // Ensure loading is stopped on error reject(new Error(errorMsg)); } else { // Continue polling @@ -217,7 +223,7 @@ export function usePipeline(options: UsePipelineOptions = {}) { if (shownErrorRef.current !== errorMessage) { toast.error("Pipeline Error", { description: errorMessage, - duration: 5000, + duration: 10000, // Show longer for detailed errors }); shownErrorRef.current = errorMessage; } diff --git a/src/scope/server/frame_processor_proxy.py b/src/scope/server/frame_processor_proxy.py new file mode 100644 index 00000000..9ce5a8ee --- /dev/null +++ b/src/scope/server/frame_processor_proxy.py @@ -0,0 +1,163 @@ +"""FrameProcessor Proxy - Communicates with FrameProcessor in worker process.""" + +import logging +import multiprocessing as mp +import queue +import uuid + +import torch +from aiortc.mediastreams import VideoFrame + +from .pipeline_worker import WorkerCommand, WorkerResponse + +logger = logging.getLogger(__name__) + +# Constants +DEFAULT_TIMEOUT = 60 # seconds + + +class FrameProcessorProxy: + """Proxy object that communicates with FrameProcessor in worker process.""" + + def __init__( + self, + frame_processor_id: str, + command_queue: mp.Queue, + response_queue: mp.Queue, + ): + self.frame_processor_id = frame_processor_id + self._command_queue = command_queue + self._response_queue = response_queue + + def put(self, frame: VideoFrame) -> bool: + """Put a frame into the FrameProcessor buffer.""" + try: + # Serialize VideoFrame to dict for inter-process communication + frame_array = frame.to_ndarray(format="rgb24") + frame_data = {"array": frame_array} + + self._command_queue.put( + { + "command": WorkerCommand.PUT_FRAME.value, + "frame_processor_id": self.frame_processor_id, + "frame_data": frame_data, + } + ) + return True + except Exception as e: + logger.error(f"Error putting frame: {e}") + return False + + def get(self) -> torch.Tensor | None: + """Get a processed frame from the FrameProcessor.""" + try: + # Request a frame + self._command_queue.put( + { + "command": WorkerCommand.GET_FRAME.value, + "frame_processor_id": self.frame_processor_id, + } + ) + + # Wait for response with timeout + try: + response = self._response_queue.get(timeout=1.0) + + if response["status"] == WorkerResponse.FRAME.value: + frame_data = response.get("frame_data") + if frame_data and frame_data.get("__tensor__"): + # Deserialize tensor from numpy array + return torch.from_numpy(frame_data["data"]) + return None + elif response["status"] == WorkerResponse.RESULT.value: + # No frame available + return None + elif response["status"] == WorkerResponse.ERROR.value: + error_msg = response.get("error", "Unknown error") + logger.error(f"Error getting frame: {error_msg}") + return None + else: + logger.warning(f"Unexpected response status: {response['status']}") + return None + + except queue.Empty: + # Timeout - no frame available + return None + + except Exception as e: + logger.error(f"Error getting frame: {e}") + return None + + def get_current_pipeline_fps(self) -> float: + """Get the current dynamically calculated pipeline FPS.""" + try: + self._command_queue.put( + { + "command": WorkerCommand.GET_FPS.value, + "frame_processor_id": self.frame_processor_id, + } + ) + + try: + response = self._response_queue.get(timeout=DEFAULT_TIMEOUT) + + if response["status"] == WorkerResponse.RESULT.value: + return response.get("result", 30.0) + elif response["status"] == WorkerResponse.ERROR.value: + error_msg = response.get("error", "Unknown error") + logger.error(f"Error getting FPS: {error_msg}") + return 30.0 + else: + logger.warning(f"Unexpected response status: {response['status']}") + return 30.0 + + except queue.Empty: + logger.error("Timeout waiting for FPS response") + return 30.0 + + except Exception as e: + logger.error(f"Error getting FPS: {e}") + return 30.0 + + def update_parameters(self, parameters: dict): + """Update parameters that will be used in the next pipeline call.""" + try: + self._command_queue.put( + { + "command": WorkerCommand.UPDATE_PARAMETERS.value, + "frame_processor_id": self.frame_processor_id, + "parameters": parameters, + } + ) + return True + except Exception as e: + logger.error(f"Error updating parameters: {e}") + return False + + def start(self): + """Start the FrameProcessor (already started when created).""" + pass + + def stop(self): + """Stop and destroy the FrameProcessor.""" + try: + self._command_queue.put( + { + "command": WorkerCommand.DESTROY_FRAME_PROCESSOR.value, + "frame_processor_id": self.frame_processor_id, + } + ) + + # Wait for response + try: + response = self._response_queue.get(timeout=DEFAULT_TIMEOUT) + if response["status"] == WorkerResponse.SUCCESS.value: + logger.info(f"FrameProcessor {self.frame_processor_id} stopped") + else: + error_msg = response.get("error", "Unknown error") + logger.warning(f"Error stopping FrameProcessor: {error_msg}") + except queue.Empty: + logger.warning("Timeout waiting for FrameProcessor stop response") + + except Exception as e: + logger.error(f"Error stopping FrameProcessor: {e}") diff --git a/src/scope/server/pipeline_manager.py b/src/scope/server/pipeline_manager.py index 0f10b226..16cc7452 100644 --- a/src/scope/server/pipeline_manager.py +++ b/src/scope/server/pipeline_manager.py @@ -1,18 +1,26 @@ """Pipeline Manager for lazy loading and managing ML pipelines.""" import asyncio -import gc import logging +import multiprocessing as mp import os +import queue import threading +import time +import traceback from enum import Enum from typing import Any -import torch -from omegaconf import OmegaConf +from .pipeline_worker import WorkerCommand, WorkerResponse, pipeline_worker_process logger = logging.getLogger(__name__) +# Constants +PIPELINE_LOAD_TIMEOUT = 300 # 5 minutes +WORKER_SHUTDOWN_TIMEOUT = 5 # seconds +WORKER_TERMINATE_TIMEOUT = 3 # seconds +WORKER_KILL_TIMEOUT = 1 # seconds + class PipelineNotAvailableException(Exception): """Exception raised when pipeline is not available for processing.""" @@ -30,7 +38,7 @@ class PipelineStatus(Enum): class PipelineManager: - """Manager for ML pipeline lifecycle.""" + """Manager for ML pipeline lifecycle using separate process for GPU isolation.""" def __init__(self): self._status = PipelineStatus.NOT_LOADED @@ -40,6 +48,11 @@ def __init__(self): self._error_message = None self._lock = threading.RLock() # Single reentrant lock for all access + # Worker process management + self._worker_process = None + self._command_queue = None + self._response_queue = None + @property def status(self) -> PipelineStatus: """Get current pipeline status.""" @@ -56,13 +69,19 @@ def error_message(self) -> str | None: return self._error_message def get_pipeline(self): - """Get the loaded pipeline instance (thread-safe).""" + """Get the loaded pipeline instance (thread-safe). + + Note: Pipeline is now loaded in worker process. Use create_frame_processor() + to get a FrameProcessor that uses the pipeline directly in the worker process. + """ with self._lock: - if self._status != PipelineStatus.LOADED or self._pipeline is None: + if self._status != PipelineStatus.LOADED or self._worker_process is None: raise PipelineNotAvailableException( f"Pipeline not available. Status: {self._status.value}" ) - return self._pipeline + # Pipeline is in worker process, return a placeholder to indicate it's loaded + # Actual pipeline access is through FrameProcessor in worker process + return None def get_status_info(self) -> dict[str, Any]: """Get detailed status information (thread-safe). @@ -78,13 +97,9 @@ def get_status_info(self) -> dict[str, Any]: load_params = self._load_params # Capture loaded LoRA adapters if pipeline exposes them + # Note: With worker process, we can't directly access pipeline attributes + # This would require a worker command to query the pipeline state loaded_lora_adapters = None - if self._pipeline is not None and hasattr( - self._pipeline, "loaded_lora_adapters" - ): - loaded_lora_adapters = getattr( - self._pipeline, "loaded_lora_adapters", None - ) # If there's an error, clear it after capturing it # This ensures errors don't persist across page reloads @@ -172,34 +187,44 @@ def _load_pipeline_sync_wrapper( self._error_message = None # Release lock during slow loading operation - logger.info(f"Loading pipeline: {pipeline_id}") + logger.info(f"Loading pipeline in worker process: {pipeline_id}") try: - # Load the pipeline synchronously (we're already in executor thread) - pipeline = self._load_pipeline_implementation(pipeline_id, load_params) + # Start worker process and load pipeline + # This will raise RuntimeError if loading fails + self._start_worker_and_load_pipeline(pipeline_id, load_params) # Hold lock while updating state with loaded pipeline with self._lock: - self._pipeline = pipeline self._pipeline_id = pipeline_id self._load_params = load_params self._status = PipelineStatus.LOADED - logger.info(f"Pipeline {pipeline_id} loaded successfully") + logger.info(f"Pipeline {pipeline_id} loaded successfully in worker process") return True except Exception as e: - error_msg = f"Failed to load pipeline {pipeline_id}: {str(e)}" + # Capture full error message including traceback for better debugging + error_details = str(e) + # If the error already contains detailed info, use it; otherwise add traceback + if ( + "traceback" not in error_details.lower() + and "Worker process" not in error_details + ): + error_details = f"{error_details}\n{traceback.format_exc()}" + error_msg = f"Failed to load pipeline {pipeline_id}: {error_details}" logger.error(error_msg) # Hold lock while updating state with error with self._lock: self._status = PipelineStatus.ERROR self._error_message = error_msg - self._pipeline = None self._pipeline_id = None self._load_params = None + # Cleanup worker on failure + self._stop_worker() + return False def _apply_load_params( @@ -241,204 +266,178 @@ def _apply_load_params( config["_lora_merge_mode"] = lora_merge_mode def _unload_pipeline_unsafe(self): - """Unload the current pipeline. Must be called with lock held.""" - if self._pipeline: + """Unload the current pipeline. Must be called with lock held. + + This will kill the worker process to ensure proper VRAM cleanup. + """ + if self._pipeline_id: logger.info(f"Unloading pipeline: {self._pipeline_id}") - # Change status and pipeline atomically + # Stop the worker process (this ensures VRAM is cleaned up) + self._stop_worker() + + # Reset state self._status = PipelineStatus.NOT_LOADED self._pipeline = None self._pipeline_id = None self._load_params = None self._error_message = None - # Cleanup resources - gc.collect() - if torch.cuda.is_available(): - try: - torch.cuda.empty_cache() - torch.cuda.synchronize() - logger.info("CUDA cache cleared") - except Exception as e: - logger.warning(f"CUDA cleanup failed: {e}") - - def _load_pipeline_implementation( + def _start_worker_and_load_pipeline( self, pipeline_id: str, load_params: dict | None = None - ): - """Synchronous pipeline loading (runs in thread executor).""" - if pipeline_id == "streamdiffusionv2": - from scope.core.pipelines import ( - StreamDiffusionV2Pipeline, - ) - - from .models_config import get_model_file_path, get_models_dir - - models_dir = get_models_dir() - config = OmegaConf.create( - { - "model_dir": str(models_dir), - "generator_path": str( - get_model_file_path( - "StreamDiffusionV2/wan_causal_dmd_v2v/model.pt" - ) - ), - "text_encoder_path": str( - get_model_file_path( - "WanVideo_comfy/umt5-xxl-enc-fp8_e4m3fn.safetensors" - ) - ), - "tokenizer_path": str( - get_model_file_path("Wan2.1-T2V-1.3B/google/umt5-xxl") - ), - } - ) - - # Apply load parameters (resolution, seed, LoRAs) to config - self._apply_load_params( - config, - load_params, - default_height=512, - default_width=512, - default_seed=42, - ) + ) -> bool: + """Start worker process and load pipeline in it. - quantization = None - if load_params: - quantization = load_params.get("quantization", None) + Returns: + bool: True if successful, False otherwise - pipeline = StreamDiffusionV2Pipeline( - config, - quantization=quantization, - device=torch.device("cuda"), - dtype=torch.bfloat16, - ) - logger.info("StreamDiffusionV2 pipeline initialized") - return pipeline - - elif pipeline_id == "passthrough": - from scope.core.pipelines import PassthroughPipeline - - # Use load parameters for resolution, default to 512x512 - height = 512 - width = 512 - if load_params: - height = load_params.get("height", 512) - width = load_params.get("width", 512) - - pipeline = PassthroughPipeline( - height=height, - width=width, - device=torch.device("cuda"), - dtype=torch.bfloat16, - ) - logger.info("Passthrough pipeline initialized") - return pipeline + Raises: + RuntimeError: If worker process dies unexpectedly or fails to load pipeline + """ + # Stop any existing worker first + self._stop_worker() + + # Create communication queues with spawn context for better CUDA compatibility + # Using 'spawn' ensures a clean process without CUDA context issues + ctx = mp.get_context("spawn") + self._command_queue = ctx.Queue() + self._response_queue = ctx.Queue() + + # Start worker process with spawn context + self._worker_process = ctx.Process( + target=pipeline_worker_process, + args=(self._command_queue, self._response_queue), + daemon=False, # We want to control its lifecycle explicitly + ) + self._worker_process.start() - elif pipeline_id == "longlive": - from scope.core.pipelines import LongLivePipeline + logger.info(f"Started worker process (PID: {self._worker_process.pid})") - from .models_config import get_model_file_path, get_models_dir + # Send load command + self._command_queue.put( + { + "command": WorkerCommand.LOAD_PIPELINE.value, + "pipeline_id": pipeline_id, + "load_params": load_params, + } + ) - config = OmegaConf.create( - { - "model_dir": str(get_models_dir()), - "generator_path": str( - get_model_file_path("LongLive-1.3B/models/longlive_base.pt") - ), - "lora_path": str( - get_model_file_path("LongLive-1.3B/models/lora.pt") - ), - "text_encoder_path": str( - get_model_file_path( - "WanVideo_comfy/umt5-xxl-enc-fp8_e4m3fn.safetensors" + # Wait for response with timeout, checking if worker is still alive + start_time = time.time() + check_interval = 0.5 # Check worker status every 0.5 seconds + + while True: + # Check if worker process is still alive + if not self._worker_process.is_alive(): + # Worker died unexpectedly - get exit code + exit_code = self._worker_process.exitcode + if exit_code is not None and exit_code != 0: + # Process crashed (e.g., OOM, segmentation fault) + if exit_code == -9: # SIGKILL (often OOM killer) + error_msg = ( + f"Worker process was killed (likely out of memory). " + f"Exit code: {exit_code}. " + f"Pipeline loading failed for {pipeline_id}." ) - ), - "tokenizer_path": str( - get_model_file_path("Wan2.1-T2V-1.3B/google/umt5-xxl") - ), - } - ) - - # Apply load parameters (resolution, seed, LoRAs) to config - self._apply_load_params( - config, - load_params, - default_height=320, - default_width=576, - default_seed=42, - ) + elif exit_code < 0: + error_msg = ( + f"Worker process crashed with signal {abs(exit_code)}. " + f"Pipeline loading failed for {pipeline_id}." + ) + else: + error_msg = ( + f"Worker process exited unexpectedly with code {exit_code}. " + f"Pipeline loading failed for {pipeline_id}." + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + else: + # Process exited normally but we didn't get a response + error_msg = ( + f"Worker process exited before sending response. " + f"Pipeline loading failed for {pipeline_id}." + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + + # Check for timeout + elapsed = time.time() - start_time + if elapsed >= PIPELINE_LOAD_TIMEOUT: + error_msg = ( + f"Timeout waiting for pipeline load response after {PIPELINE_LOAD_TIMEOUT}s. " + f"Pipeline loading failed for {pipeline_id}." + ) + logger.error(error_msg) + raise RuntimeError(error_msg) - quantization = None - if load_params: - quantization = load_params.get("quantization", None) + # Try to get response with short timeout to allow periodic worker checks + try: + response = self._response_queue.get(timeout=check_interval) + + if response["status"] == WorkerResponse.SUCCESS.value: + logger.info( + f"Pipeline loaded successfully in worker: {response.get('message')}" + ) + return True + else: + # Worker sent an error response + error_msg = response.get("error", "Unknown error") + logger.error(f"Failed to load pipeline in worker: {error_msg}") + raise RuntimeError(f"Pipeline loading failed: {error_msg}") + + except queue.Empty: + # Timeout on queue.get - continue loop to check worker status + continue + + def _stop_worker(self): + """Stop the worker process if it's running. + + This ensures proper VRAM cleanup by killing the process. + """ + if self._worker_process is not None and self._worker_process.is_alive(): + logger.info(f"Stopping worker process (PID: {self._worker_process.pid})") - pipeline = LongLivePipeline( - config, - quantization=quantization, - device=torch.device("cuda"), - dtype=torch.bfloat16, - ) - logger.info("LongLive pipeline initialized") - return pipeline + # Try graceful shutdown first + try: + self._command_queue.put(None) # Shutdown signal + self._worker_process.join(timeout=WORKER_SHUTDOWN_TIMEOUT) + except Exception as e: + logger.warning(f"Error during graceful shutdown: {e}") - elif pipeline_id == "krea-realtime-video": - from scope.core.pipelines import ( - KreaRealtimeVideoPipeline, - ) + # Force terminate if still alive + if self._worker_process.is_alive(): + logger.warning( + "Worker process did not shut down gracefully, terminating..." + ) + self._worker_process.terminate() + self._worker_process.join(timeout=WORKER_TERMINATE_TIMEOUT) - from .models_config import get_model_file_path, get_models_dir + # Final kill if still alive + if self._worker_process.is_alive(): + logger.warning("Worker process did not terminate, killing...") + self._worker_process.kill() + self._worker_process.join(timeout=WORKER_KILL_TIMEOUT) - config = OmegaConf.create( - { - "model_dir": str(get_models_dir()), - "generator_path": str( - get_model_file_path( - "krea-realtime-video/krea-realtime-video-14b.safetensors" - ) - ), - "text_encoder_path": str( - get_model_file_path( - "WanVideo_comfy/umt5-xxl-enc-fp8_e4m3fn.safetensors" - ) - ), - "tokenizer_path": str( - get_model_file_path("Wan2.1-T2V-1.3B/google/umt5-xxl") - ), - "vae_path": str( - get_model_file_path("Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") - ), - } - ) + logger.info("Worker process stopped") - # Apply load parameters (resolution, seed, LoRAs) to config - self._apply_load_params( - config, - load_params, - default_height=512, - default_width=512, - default_seed=42, - ) + # Clean up queues + if self._command_queue is not None: + try: + self._command_queue.close() + self._command_queue.join_thread() + except Exception: + pass + self._command_queue = None - quantization = None - if load_params: - quantization = load_params.get("quantization", None) - - pipeline = KreaRealtimeVideoPipeline( - config, - quantization=quantization, - # Only compile diffusion model for hopper right now - compile=any( - x in torch.cuda.get_device_name(0).lower() - for x in ("h100", "hopper") - ), - device=torch.device("cuda"), - dtype=torch.bfloat16, - ) - logger.info("krea-realtime-video pipeline initialized") - return pipeline + if self._response_queue is not None: + try: + self._response_queue.close() + self._response_queue.join_thread() + except Exception: + pass + self._response_queue = None - else: - raise ValueError(f"Invalid pipeline ID: {pipeline_id}") + self._worker_process = None def unload_pipeline(self): """Unload the current pipeline (thread-safe).""" @@ -449,3 +448,45 @@ def is_loaded(self) -> bool: """Check if pipeline is loaded and ready (thread-safe).""" with self._lock: return self._status == PipelineStatus.LOADED + + def create_frame_processor(self, frame_processor_id: str, initial_parameters: dict = None): + """Create a FrameProcessor in the worker process (thread-safe). + + Args: + frame_processor_id: Unique identifier for this FrameProcessor + initial_parameters: Initial parameters for the FrameProcessor + + Returns: + FrameProcessorProxy instance + """ + from .frame_processor_proxy import FrameProcessorProxy + + with self._lock: + if self._status != PipelineStatus.LOADED or self._worker_process is None: + raise PipelineNotAvailableException( + f"Pipeline not available. Status: {self._status.value}" + ) + + # Send command to create FrameProcessor in worker + self._command_queue.put( + { + "command": WorkerCommand.CREATE_FRAME_PROCESSOR.value, + "frame_processor_id": frame_processor_id, + "initial_parameters": initial_parameters or {}, + } + ) + + # Wait for response + try: + response = self._response_queue.get(timeout=60) + if response["status"] == WorkerResponse.FRAME_PROCESSOR_CREATED.value: + return FrameProcessorProxy( + frame_processor_id=frame_processor_id, + command_queue=self._command_queue, + response_queue=self._response_queue, + ) + else: + error_msg = response.get("error", "Unknown error") + raise RuntimeError(f"Failed to create FrameProcessor: {error_msg}") + except queue.Empty: + raise RuntimeError("Timeout waiting for FrameProcessor creation") diff --git a/src/scope/server/pipeline_worker.py b/src/scope/server/pipeline_worker.py new file mode 100644 index 00000000..e871eb21 --- /dev/null +++ b/src/scope/server/pipeline_worker.py @@ -0,0 +1,882 @@ +"""Pipeline Worker - Runs pipeline and FrameProcessor in a separate process for proper VRAM cleanup.""" + +import logging +import multiprocessing as mp +import os +import queue +import threading +import time +import traceback +from collections import deque +from enum import Enum + +import torch +from omegaconf import OmegaConf + +# Configure logging for worker process +logger = logging.getLogger(__name__) + + +class WorkerCommand(Enum): + """Commands that can be sent to the worker process.""" + + LOAD_PIPELINE = "load_pipeline" + UNLOAD_PIPELINE = "unload_pipeline" + CREATE_FRAME_PROCESSOR = "create_frame_processor" + DESTROY_FRAME_PROCESSOR = "destroy_frame_processor" + PUT_FRAME = "put_frame" + GET_FRAME = "get_frame" + UPDATE_PARAMETERS = "update_parameters" + GET_FPS = "get_fps" + SHUTDOWN = "shutdown" + + +class WorkerResponse(Enum): + """Response types from worker process.""" + + SUCCESS = "success" + ERROR = "error" + PIPELINE_LOADED = "pipeline_loaded" + PIPELINE_NOT_LOADED = "pipeline_not_loaded" + RESULT = "result" + FRAME_PROCESSOR_CREATED = "frame_processor_created" + FRAME = "frame" + + +# Constants for FrameProcessor +OUTPUT_QUEUE_MAX_SIZE_FACTOR = 3 +MIN_FPS = 1.0 +MAX_FPS = 60.0 +DEFAULT_FPS = 30.0 +SLEEP_TIME = 0.01 + + +class WorkerFrameProcessor: + """FrameProcessor that runs in the worker process and uses pipeline directly.""" + + def __init__( + self, + pipeline, + max_output_queue_size: int = 8, + max_parameter_queue_size: int = 8, + max_buffer_size: int = 30, + initial_parameters: dict = None, + ): + self.pipeline = pipeline + + self.frame_buffer = deque(maxlen=max_buffer_size) + self.frame_buffer_lock = threading.Lock() + self.output_queue = queue.Queue(maxsize=max_output_queue_size) + + # Current parameters used by processing thread + self.parameters = initial_parameters or {} + # Queue for parameter updates from external threads + self.parameters_queue = queue.Queue(maxsize=max_parameter_queue_size) + + self.worker_thread: threading.Thread | None = None + self.shutdown_event = threading.Event() + self.running = False + + self.is_prepared = False + + # FPS tracking variables + self.processing_time_per_frame = deque(maxlen=2) + self.last_fps_update = time.time() + self.fps_update_interval = 0.5 + self.min_fps = MIN_FPS + self.max_fps = MAX_FPS + self.current_pipeline_fps = DEFAULT_FPS + self.fps_lock = threading.Lock() + + self.paused = False + + def start(self): + if self.running: + return + + self.running = True + self.shutdown_event.clear() + self.worker_thread = threading.Thread(target=self.worker_loop, daemon=True) + self.worker_thread.start() + + logger.info("WorkerFrameProcessor started") + + def stop(self): + if not self.running: + return + + self.running = False + self.shutdown_event.set() + + if self.worker_thread and self.worker_thread.is_alive(): + if threading.current_thread() != self.worker_thread: + self.worker_thread.join(timeout=5.0) + + while not self.output_queue.empty(): + try: + self.output_queue.get_nowait() + except queue.Empty: + break + + with self.frame_buffer_lock: + self.frame_buffer.clear() + + logger.info("WorkerFrameProcessor stopped") + + def put(self, frame_data: dict) -> bool: + """Put a frame into the buffer. frame_data is a serialized VideoFrame.""" + if not self.running: + return False + + # Deserialize frame from dict + frame_array = frame_data.get("array") + if frame_array is None: + return False + + with self.frame_buffer_lock: + # Store as dict for now, will convert to tensor when processing + self.frame_buffer.append(frame_data) + return True + + def get(self) -> dict | None: + """Get a processed frame. Returns serialized tensor data.""" + if not self.running: + return None + + try: + frame_tensor = self.output_queue.get_nowait() + # Serialize tensor to dict for inter-process communication + return {"__tensor__": True, "data": frame_tensor.cpu().numpy()} + except queue.Empty: + return None + + def get_current_pipeline_fps(self) -> float: + """Get the current dynamically calculated pipeline FPS""" + with self.fps_lock: + return self.current_pipeline_fps + + def _calculate_pipeline_fps(self, start_time: float, num_frames: int): + """Calculate FPS based on processing time and number of frames created""" + processing_time = time.time() - start_time + if processing_time <= 0 or num_frames <= 0: + return + + time_per_frame = processing_time / num_frames + self.processing_time_per_frame.append(time_per_frame) + + current_time = time.time() + if current_time - self.last_fps_update >= self.fps_update_interval: + if len(self.processing_time_per_frame) >= 1: + avg_time_per_frame = sum(self.processing_time_per_frame) / len( + self.processing_time_per_frame + ) + + with self.fps_lock: + current_fps = self.current_pipeline_fps + estimated_fps = ( + 1.0 / avg_time_per_frame if avg_time_per_frame > 0 else current_fps + ) + + estimated_fps = max(self.min_fps, min(self.max_fps, estimated_fps)) + with self.fps_lock: + self.current_pipeline_fps = estimated_fps + + self.last_fps_update = current_time + + def update_parameters(self, parameters: dict): + """Update parameters that will be used in the next pipeline call.""" + try: + self.parameters_queue.put_nowait(parameters) + except queue.Full: + logger.info("Parameter queue full, dropping parameter update") + return False + + def worker_loop(self): + logger.info("WorkerFrameProcessor worker thread started") + + while self.running and not self.shutdown_event.is_set(): + try: + self.process_chunk() + + except Exception as e: + if self._is_recoverable(e): + logger.error(f"Error in worker loop: {e}") + continue + else: + logger.error( + f"Non-recoverable error in worker loop: {e}, stopping frame processor" + ) + self.stop() + break + logger.info("WorkerFrameProcessor worker thread stopped") + + def process_chunk(self): + start_time = time.time() + try: + # Check if there are new parameters + try: + new_parameters = self.parameters_queue.get_nowait() + if new_parameters != self.parameters: + if ( + "prompts" in new_parameters + and "transition" not in new_parameters + and "transition" in self.parameters + ): + self.parameters.pop("transition", None) + + self.parameters = {**self.parameters, **new_parameters} + except queue.Empty: + pass + + # Pause or resume the processing + paused = self.parameters.pop("paused", None) + if paused is not None and paused != self.paused: + self.paused = paused + if self.paused: + self.shutdown_event.wait(SLEEP_TIME) + return + + reset_cache = self.parameters.pop("reset_cache", None) + lora_scales = self.parameters.pop("lora_scales", None) + + if reset_cache: + logger.info("Clearing output buffer queue due to reset_cache request") + while not self.output_queue.empty(): + try: + self.output_queue.get_nowait() + except queue.Empty: + break + + requirements = None + if hasattr(self.pipeline, "prepare"): + requirements = self.pipeline.prepare(**self.parameters) + + video_input = None + if requirements is not None: + current_chunk_size = requirements.input_size + with self.frame_buffer_lock: + if not self.frame_buffer or len(self.frame_buffer) < current_chunk_size: + self.shutdown_event.wait(SLEEP_TIME) + return + video_input = self.prepare_chunk(current_chunk_size) + + call_params = dict(self.parameters.items()) + call_params["init_cache"] = not self.is_prepared + if reset_cache is not None: + call_params["init_cache"] = reset_cache + + if lora_scales is not None: + call_params["lora_scales"] = lora_scales + + if video_input is not None: + call_params["video"] = video_input + + # Call pipeline directly - no proxy needed! + output = self.pipeline(**call_params) + + # Clear transition when complete + if "transition" in call_params and "transition" in self.parameters: + transition_active = False + if hasattr(self.pipeline, "state"): + transition_active = self.pipeline.state.get("_transition_active", False) + + transition = call_params.get("transition") + if not transition_active or transition is None: + self.parameters.pop("transition", None) + + processing_time = time.time() - start_time + num_frames = output.shape[0] + logger.debug( + f"Processed pipeline in {processing_time:.4f}s, {num_frames} frames" + ) + + # Normalize to [0, 255] and convert to uint8 + output = ( + (output * 255.0) + .clamp(0, 255) + .to(dtype=torch.uint8) + .contiguous() + .detach() + .cpu() + ) + + # Resize output queue to meet target max size + target_output_queue_max_size = num_frames * OUTPUT_QUEUE_MAX_SIZE_FACTOR + if self.output_queue.maxsize < target_output_queue_max_size: + logger.info( + f"Increasing output queue size to {target_output_queue_max_size}, current size {self.output_queue.maxsize}, num_frames {num_frames}" + ) + + old_queue = self.output_queue + self.output_queue = queue.Queue(maxsize=target_output_queue_max_size) + while not old_queue.empty(): + try: + frame = old_queue.get_nowait() + self.output_queue.put_nowait(frame) + except queue.Empty: + break + + for frame in output: + try: + self.output_queue.put_nowait(frame) + except queue.Full: + logger.warning("Output queue full, dropping processed frame") + self._calculate_pipeline_fps(start_time, num_frames) + continue + + self._calculate_pipeline_fps(start_time, num_frames) + except Exception as e: + if self._is_recoverable(e): + logger.error(f"Error processing chunk: {e}", exc_info=True) + else: + raise e + + self.is_prepared = True + + def prepare_chunk(self, chunk_size: int) -> list[torch.Tensor]: + """Sample frames uniformly from the buffer and convert them to tensors.""" + step = len(self.frame_buffer) / chunk_size + indices = [round(i * step) for i in range(chunk_size)] + video_frames_data = [self.frame_buffer[i] for i in indices] + + last_idx = indices[-1] + for _ in range(last_idx + 1): + self.frame_buffer.popleft() + + tensor_frames = [] + for frame_data in video_frames_data: + # Convert frame data to tensor + frame_array = frame_data.get("array") + if frame_array is not None: + tensor = torch.from_numpy(frame_array).float().unsqueeze(0) + tensor_frames.append(tensor) + + return tensor_frames + + @staticmethod + def _is_recoverable(error: Exception) -> bool: + """Check if an error is recoverable.""" + if isinstance(error, torch.cuda.OutOfMemoryError): + return False + return True + + +def _load_pipeline_implementation(pipeline_id: str, load_params: dict | None = None): + """Load a pipeline in the worker process. + + This is the same logic as in PipelineManager._load_pipeline_implementation + but runs in a separate process for proper VRAM isolation. + """ + if pipeline_id == "streamdiffusionv2": + from scope.core.pipelines import ( + StreamDiffusionV2Pipeline, + ) + + from scope.server.models_config import get_model_file_path, get_models_dir + + models_dir = get_models_dir() + config = OmegaConf.create( + { + "model_dir": str(models_dir), + "generator_path": str( + get_model_file_path( + "StreamDiffusionV2/wan_causal_dmd_v2v/model.pt" + ) + ), + "text_encoder_path": str( + get_model_file_path( + "WanVideo_comfy/umt5-xxl-enc-fp8_e4m3fn.safetensors" + ) + ), + "tokenizer_path": str( + get_model_file_path("Wan2.1-T2V-1.3B/google/umt5-xxl") + ), + } + ) + + # Apply load parameters (resolution, seed, LoRAs) to config + height = 512 + width = 512 + seed = 42 + loras = None + lora_merge_mode = "permanent_merge" + + if load_params: + height = load_params.get("height", 512) + width = load_params.get("width", 512) + seed = load_params.get("seed", 42) + loras = load_params.get("loras", None) + lora_merge_mode = load_params.get("lora_merge_mode", lora_merge_mode) + + config["height"] = height + config["width"] = width + config["seed"] = seed + if loras: + config["loras"] = loras + config["_lora_merge_mode"] = lora_merge_mode + + pipeline = StreamDiffusionV2Pipeline( + config, device=torch.device("cuda"), dtype=torch.bfloat16 + ) + logger.info("StreamDiffusionV2 pipeline initialized in worker process") + return pipeline + + elif pipeline_id == "passthrough": + from scope.core.pipelines import PassthroughPipeline + + # Use load parameters for resolution, default to 512x512 + height = 512 + width = 512 + if load_params: + height = load_params.get("height", 512) + width = load_params.get("width", 512) + + pipeline = PassthroughPipeline( + height=height, + width=width, + device=torch.device("cuda"), + dtype=torch.bfloat16, + ) + logger.info("Passthrough pipeline initialized in worker process") + return pipeline + + elif pipeline_id == "longlive": + from scope.core.pipelines import LongLivePipeline + + from scope.server.models_config import get_model_file_path, get_models_dir + + config = OmegaConf.create( + { + "model_dir": str(get_models_dir()), + "generator_path": str( + get_model_file_path("LongLive-1.3B/models/longlive_base.pt") + ), + "lora_path": str( + get_model_file_path("LongLive-1.3B/models/lora.pt") + ), + "text_encoder_path": str( + get_model_file_path( + "WanVideo_comfy/umt5-xxl-enc-fp8_e4m3fn.safetensors" + ) + ), + "tokenizer_path": str( + get_model_file_path("Wan2.1-T2V-1.3B/google/umt5-xxl") + ), + } + ) + + # Apply load parameters (resolution, seed, LoRAs) to config + height = 320 + width = 576 + seed = 42 + loras = None + lora_merge_mode = "permanent_merge" + + if load_params: + height = load_params.get("height", 320) + width = load_params.get("width", 576) + seed = load_params.get("seed", 42) + loras = load_params.get("loras", None) + lora_merge_mode = load_params.get("lora_merge_mode", lora_merge_mode) + + config["height"] = height + config["width"] = width + config["seed"] = seed + if loras: + config["loras"] = loras + config["_lora_merge_mode"] = lora_merge_mode + + pipeline = LongLivePipeline( + config, device=torch.device("cuda"), dtype=torch.bfloat16 + ) + logger.info("LongLive pipeline initialized in worker process") + return pipeline + + elif pipeline_id == "krea-realtime-video": + from scope.core.pipelines import ( + KreaRealtimeVideoPipeline, + ) + + from scope.server.models_config import get_model_file_path, get_models_dir + + config = OmegaConf.create( + { + "model_dir": str(get_models_dir()), + "generator_path": str( + get_model_file_path( + "krea-realtime-video/krea-realtime-video-14b.safetensors" + ) + ), + "text_encoder_path": str( + get_model_file_path( + "WanVideo_comfy/umt5-xxl-enc-fp8_e4m3fn.safetensors" + ) + ), + "tokenizer_path": str( + get_model_file_path("Wan2.1-T2V-1.3B/google/umt5-xxl") + ), + "vae_path": str( + get_model_file_path("Wan2.1-T2V-1.3B/Wan2.1_VAE.pth") + ), + } + ) + + # Apply load parameters (resolution, seed, LoRAs) to config + height = 512 + width = 512 + seed = 42 + loras = None + lora_merge_mode = "permanent_merge" + quantization = None + + if load_params: + height = load_params.get("height", 512) + width = load_params.get("width", 512) + seed = load_params.get("seed", 42) + loras = load_params.get("loras", None) + lora_merge_mode = load_params.get("lora_merge_mode", lora_merge_mode) + quantization = load_params.get("quantization", None) + + config["height"] = height + config["width"] = width + config["seed"] = seed + if loras: + config["loras"] = loras + config["_lora_merge_mode"] = lora_merge_mode + + pipeline = KreaRealtimeVideoPipeline( + config, + quantization=quantization, + # Only compile diffusion model for hopper right now + compile=any( + x in torch.cuda.get_device_name(0).lower() + for x in ("h100", "hopper") + ), + device=torch.device("cuda"), + dtype=torch.bfloat16, + ) + logger.info("krea-realtime-video pipeline initialized in worker process") + return pipeline + + else: + raise ValueError(f"Invalid pipeline ID: {pipeline_id}") + + +def pipeline_worker_process(command_queue: mp.Queue, response_queue: mp.Queue): + """Main worker process function that handles pipeline and FrameProcessor operations. + + This process runs in isolation and can be killed to ensure proper VRAM cleanup. + + Args: + command_queue: Queue for receiving commands from main process + response_queue: Queue for sending responses back to main process + """ + # Set up logging for worker process + logging.basicConfig( + level=logging.INFO, + format=f"%(asctime)s - [Worker-{os.getpid()}] - %(name)s - %(levelname)s - %(message)s", + ) + + logger.info(f"Pipeline worker process started (PID: {os.getpid()})") + + pipeline = None + pipeline_id = None + frame_processors: dict[str, WorkerFrameProcessor] = {} + + try: + while True: + try: + # Wait for commands from main process + command_data = command_queue.get() + + if command_data is None: + logger.info("Received shutdown signal") + break + + command = command_data.get("command") + + if command == WorkerCommand.LOAD_PIPELINE.value: + # Load pipeline + pipeline_id_to_load = command_data.get("pipeline_id") + load_params = command_data.get("load_params") + + logger.info( + f"Loading pipeline: {pipeline_id_to_load} with params: {load_params}" + ) + + try: + # Unload existing pipeline if any + if pipeline is not None: + logger.info(f"Unloading existing pipeline: {pipeline_id}") + # Stop all frame processors + for fp_id, fp in list(frame_processors.items()): + fp.stop() + del frame_processors[fp_id] + del pipeline + pipeline = None + pipeline_id = None + + # Load new pipeline + pipeline = _load_pipeline_implementation( + pipeline_id_to_load, load_params + ) + pipeline_id = pipeline_id_to_load + + response_queue.put( + { + "status": WorkerResponse.SUCCESS.value, + "message": f"Pipeline {pipeline_id} loaded successfully", + } + ) + logger.info(f"Pipeline {pipeline_id} loaded successfully") + + except Exception as e: + error_msg = ( + f"Failed to load pipeline: {str(e)}\n{traceback.format_exc()}" + ) + logger.error(error_msg) + response_queue.put( + {"status": WorkerResponse.ERROR.value, "error": error_msg} + ) + + elif command == WorkerCommand.CREATE_FRAME_PROCESSOR.value: + # Create a new FrameProcessor instance + if pipeline is None: + response_queue.put( + { + "status": WorkerResponse.ERROR.value, + "error": "Pipeline not loaded", + } + ) + continue + + try: + fp_id = command_data.get("frame_processor_id") + initial_parameters = command_data.get("initial_parameters", {}) + + if fp_id in frame_processors: + logger.warning(f"FrameProcessor {fp_id} already exists, stopping old one") + frame_processors[fp_id].stop() + + frame_processor = WorkerFrameProcessor( + pipeline=pipeline, + initial_parameters=initial_parameters, + ) + frame_processor.start() + frame_processors[fp_id] = frame_processor + + response_queue.put( + { + "status": WorkerResponse.FRAME_PROCESSOR_CREATED.value, + "frame_processor_id": fp_id, + } + ) + logger.info(f"Created FrameProcessor {fp_id}") + + except Exception as e: + error_msg = ( + f"Failed to create FrameProcessor: {str(e)}\n{traceback.format_exc()}" + ) + logger.error(error_msg) + response_queue.put( + {"status": WorkerResponse.ERROR.value, "error": error_msg} + ) + + elif command == WorkerCommand.DESTROY_FRAME_PROCESSOR.value: + # Destroy a FrameProcessor instance + fp_id = command_data.get("frame_processor_id") + if fp_id in frame_processors: + frame_processors[fp_id].stop() + del frame_processors[fp_id] + response_queue.put( + { + "status": WorkerResponse.SUCCESS.value, + "message": f"FrameProcessor {fp_id} destroyed", + } + ) + logger.info(f"Destroyed FrameProcessor {fp_id}") + else: + response_queue.put( + { + "status": WorkerResponse.ERROR.value, + "error": f"FrameProcessor {fp_id} not found", + } + ) + + elif command == WorkerCommand.PUT_FRAME.value: + # Put a frame into a FrameProcessor + fp_id = command_data.get("frame_processor_id") + frame_data = command_data.get("frame_data") + + if fp_id not in frame_processors: + response_queue.put( + { + "status": WorkerResponse.ERROR.value, + "error": f"FrameProcessor {fp_id} not found", + } + ) + continue + + try: + success = frame_processors[fp_id].put(frame_data) + # Don't send response for every frame to avoid queue buildup + # Only send response if there's an error + if not success: + response_queue.put( + { + "status": WorkerResponse.ERROR.value, + "error": "Failed to put frame", + } + ) + except Exception as e: + error_msg = f"Error putting frame: {str(e)}" + logger.error(error_msg) + response_queue.put( + {"status": WorkerResponse.ERROR.value, "error": error_msg} + ) + + elif command == WorkerCommand.GET_FRAME.value: + # Get a processed frame from a FrameProcessor + fp_id = command_data.get("frame_processor_id") + + if fp_id not in frame_processors: + response_queue.put( + { + "status": WorkerResponse.ERROR.value, + "error": f"FrameProcessor {fp_id} not found", + } + ) + continue + + try: + frame_data = frame_processors[fp_id].get() + if frame_data is not None: + response_queue.put( + { + "status": WorkerResponse.FRAME.value, + "frame_data": frame_data, + } + ) + else: + # No frame available - send empty response + response_queue.put( + { + "status": WorkerResponse.RESULT.value, + "result": None, + } + ) + except Exception as e: + error_msg = f"Error getting frame: {str(e)}" + logger.error(error_msg) + response_queue.put( + {"status": WorkerResponse.ERROR.value, "error": error_msg} + ) + + elif command == WorkerCommand.UPDATE_PARAMETERS.value: + # Update parameters for a FrameProcessor + fp_id = command_data.get("frame_processor_id") + parameters = command_data.get("parameters", {}) + + if fp_id not in frame_processors: + response_queue.put( + { + "status": WorkerResponse.ERROR.value, + "error": f"FrameProcessor {fp_id} not found", + } + ) + continue + + try: + frame_processors[fp_id].update_parameters(parameters) + # Don't send response for parameter updates to avoid queue buildup + except Exception as e: + error_msg = f"Error updating parameters: {str(e)}" + logger.error(error_msg) + response_queue.put( + {"status": WorkerResponse.ERROR.value, "error": error_msg} + ) + + elif command == WorkerCommand.GET_FPS.value: + # Get current FPS from a FrameProcessor + fp_id = command_data.get("frame_processor_id") + + if fp_id not in frame_processors: + response_queue.put( + { + "status": WorkerResponse.ERROR.value, + "error": f"FrameProcessor {fp_id} not found", + } + ) + continue + + try: + fps = frame_processors[fp_id].get_current_pipeline_fps() + response_queue.put( + { + "status": WorkerResponse.RESULT.value, + "result": fps, + } + ) + except Exception as e: + error_msg = f"Error getting FPS: {str(e)}" + logger.error(error_msg) + response_queue.put( + {"status": WorkerResponse.ERROR.value, "error": error_msg} + ) + + elif command == WorkerCommand.SHUTDOWN.value: + logger.info("Received shutdown command") + break + + except Exception as e: + error_msg = ( + f"Error processing command: {str(e)}\n{traceback.format_exc()}" + ) + logger.error(error_msg) + response_queue.put( + {"status": WorkerResponse.ERROR.value, "error": error_msg} + ) + + finally: + # Cleanup on exit + logger.info("Cleaning up worker process...") + # Stop all frame processors + for fp_id, fp in list(frame_processors.items()): + fp.stop() + frame_processors.clear() + if pipeline is not None: + del pipeline + + logger.info("Pipeline worker process shutting down") + + +def _serialize_tensors(obj): + """Serialize torch tensors for inter-process communication. + + For CUDA tensors, we move them to CPU first for serialization. + """ + if isinstance(obj, torch.Tensor): + # Move to CPU for serialization + return {"__tensor__": True, "data": obj.cpu().numpy()} + elif isinstance(obj, dict): + return {k: _serialize_tensors(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_serialize_tensors(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(_serialize_tensors(item) for item in obj) + else: + return obj + + +def _deserialize_tensors(obj): + """Deserialize torch tensors from inter-process communication.""" + if isinstance(obj, dict): + if obj.get("__tensor__"): + # Reconstruct tensor from numpy array + return torch.from_numpy(obj["data"]) + return {k: _deserialize_tensors(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_deserialize_tensors(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(_deserialize_tensors(item) for item in obj) + return obj diff --git a/src/scope/server/tracks.py b/src/scope/server/tracks.py index 128c6fa0..d8dad49f 100644 --- a/src/scope/server/tracks.py +++ b/src/scope/server/tracks.py @@ -3,12 +3,12 @@ import logging import threading import time +import uuid from aiortc import MediaStreamTrack from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE, MediaStreamError from av import VideoFrame -from .frame_processor import FrameProcessor from .pipeline_manager import PipelineManager logger = logging.getLogger(__name__) @@ -33,6 +33,7 @@ def __init__( self.frame_ptime = 1.0 / fps self.frame_processor = None + self.frame_processor_id = str(uuid.uuid4()) self.input_task = None self.input_task_running = False self._paused = False @@ -86,10 +87,10 @@ async def next_timestamp(self) -> tuple[int, fractions.Fraction]: def initialize_output_processing(self): if not self.frame_processor: - self.frame_processor = FrameProcessor( - pipeline_manager=self.pipeline_manager, + # Create FrameProcessor in worker process via PipelineManager + self.frame_processor = self.pipeline_manager.create_frame_processor( + frame_processor_id=self.frame_processor_id, initial_parameters=self.initial_parameters, - notification_callback=self.notification_callback, ) self.frame_processor.start() @@ -160,5 +161,6 @@ async def stop(self): if self.frame_processor is not None: self.frame_processor.stop() + self.frame_processor = None await super().stop()