From c5cdb83d5d36ced7d5b247fe8c3404fbbd788bdb Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Tue, 25 Nov 2025 20:18:47 -0800 Subject: [PATCH 01/13] add worker shutdown options --- materializationengine/celery_worker.py | 76 ++++++++++++++++++++++++++ materializationengine/config.py | 4 ++ 2 files changed, 80 insertions(+) diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index e8ce9bf4..d3fbad14 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -1,11 +1,15 @@ import datetime import logging import os +import signal import sys +import threading +import time import warnings from typing import Any, Callable, Dict import redis +from celery import signals from celery.app.builtins import add_backend_cleanup_task from celery.schedules import crontab from celery.signals import after_setup_logger @@ -22,6 +26,61 @@ celery_logger = get_task_logger(__name__) +_task_execution_count = 0 +_shutdown_requested = False + + +def _request_worker_shutdown(delay_seconds: int, observed_count: int) -> None: + """Delay and then terminate the worker process.""" + # Delay slightly so task result propagation finishes + time.sleep(max(delay_seconds, 0)) + celery_logger.info( + "Auto-shutdown: terminating worker PID %s after %s tasks", + os.getpid(), + observed_count, + ) + try: + os.kill(os.getpid(), signal.SIGTERM) + except Exception as exc: # pragma: no cover - best-effort shutdown + celery_logger.error("Failed to terminate worker: %s", exc) + + +def _auto_shutdown_handler(sender=None, **kwargs): + """Trigger worker shutdown after configurable task count when enabled.""" + if not celery.conf.get("worker_autoshutdown_enabled", False): + return + + max_tasks = celery.conf.get("worker_autoshutdown_max_tasks", 1) + if max_tasks <= 0: + return + + global _task_execution_count, _shutdown_requested + if _shutdown_requested: + return + + _task_execution_count += 1 + + if _task_execution_count < max_tasks: + return + + _shutdown_requested = True + delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2) + celery_logger.info( + "Auto-shutdown triggered after %s tasks; terminating in %ss", + _task_execution_count, + delay, + ) + shutdown_thread = threading.Thread( + target=_request_worker_shutdown, + args=(delay, _task_execution_count), + daemon=True, + ) + shutdown_thread.start() + + +signals.task_postrun.connect(_auto_shutdown_handler, weak=False) + + def create_celery(app=None): celery.conf.broker_url = app.config["CELERY_BROKER_URL"] celery.conf.result_backend = app.config["CELERY_RESULT_BACKEND"] @@ -31,6 +90,23 @@ def create_celery(app=None): } celery.conf.result_backend_transport_options = { "master_name": app.config["MASTER_NAME"] + + celery.conf.worker_autoshutdown_enabled = app.config.get( + "CELERY_WORKER_AUTOSHUTDOWN_ENABLED", False + ) + celery.conf.worker_autoshutdown_max_tasks = app.config.get( + "CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS", 1 + ) + celery.conf.worker_autoshutdown_delay_seconds = app.config.get( + "CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS", 2 + ) + + if celery.conf.worker_autoshutdown_enabled: + celery_logger.info( + "Worker auto-shutdown enabled: max_tasks=%s delay=%ss", + celery.conf.worker_autoshutdown_max_tasks, + celery.conf.worker_autoshutdown_delay_seconds, + ) } # Configure Celery and related loggers log_level = app.config["LOGGING_LEVEL"] diff --git a/materializationengine/config.py b/materializationengine/config.py index 4706a51a..d3fc5157 100644 --- a/materializationengine/config.py +++ b/materializationengine/config.py @@ -48,6 +48,10 @@ class BaseConfig: MERGE_TABLES = True AUTH_SERVICE_NAMESPACE = "datastack" + CELERY_WORKER_AUTOSHUTDOWN_ENABLED = False + CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS = 1 + CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS = 2 + REDIS_HOST="localhost" REDIS_PORT=6379 REDIS_PASSWORD="" From 130193c24e9efd0c363c1b1114ead818dec25ff8 Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Wed, 26 Nov 2025 08:10:00 -0800 Subject: [PATCH 02/13] fix syntax --- materializationengine/celery_worker.py | 32 +++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index d3fbad14..591f9d64 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -90,24 +90,24 @@ def create_celery(app=None): } celery.conf.result_backend_transport_options = { "master_name": app.config["MASTER_NAME"] + } - celery.conf.worker_autoshutdown_enabled = app.config.get( - "CELERY_WORKER_AUTOSHUTDOWN_ENABLED", False - ) - celery.conf.worker_autoshutdown_max_tasks = app.config.get( - "CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS", 1 - ) - celery.conf.worker_autoshutdown_delay_seconds = app.config.get( - "CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS", 2 - ) + celery.conf.worker_autoshutdown_enabled = app.config.get( + "CELERY_WORKER_AUTOSHUTDOWN_ENABLED", False + ) + celery.conf.worker_autoshutdown_max_tasks = app.config.get( + "CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS", 1 + ) + celery.conf.worker_autoshutdown_delay_seconds = app.config.get( + "CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS", 2 + ) - if celery.conf.worker_autoshutdown_enabled: - celery_logger.info( - "Worker auto-shutdown enabled: max_tasks=%s delay=%ss", - celery.conf.worker_autoshutdown_max_tasks, - celery.conf.worker_autoshutdown_delay_seconds, - ) - } + if celery.conf.worker_autoshutdown_enabled: + celery_logger.info( + "Worker auto-shutdown enabled: max_tasks=%s delay=%ss", + celery.conf.worker_autoshutdown_max_tasks, + celery.conf.worker_autoshutdown_delay_seconds, + ) # Configure Celery and related loggers log_level = app.config["LOGGING_LEVEL"] celery_logger.setLevel(log_level) From d533b4004d2bcdddfc6bb6257cbf8dce0242035a Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Wed, 26 Nov 2025 08:10:26 -0800 Subject: [PATCH 03/13] Update materializationengine/config.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- materializationengine/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/materializationengine/config.py b/materializationengine/config.py index d3fc5157..9ee098a1 100644 --- a/materializationengine/config.py +++ b/materializationengine/config.py @@ -52,9 +52,9 @@ class BaseConfig: CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS = 1 CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS = 2 - REDIS_HOST="localhost" - REDIS_PORT=6379 - REDIS_PASSWORD="" + REDIS_HOST = "localhost" + REDIS_PORT = 6379 + REDIS_PASSWORD = "" SESSION_TYPE = "redis" PERMANENT_SESSION_LIFETIME = timedelta(hours=24) SESSION_PREFIX = "annotation_upload_" From 87364732e74ee52e997863fe9c0b2a292a87a490 Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Thu, 11 Dec 2025 08:48:10 -0800 Subject: [PATCH 04/13] trying to fix schedule registration --- materializationengine/celery_worker.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index 591f9d64..8d3c46ca 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -178,6 +178,15 @@ def __call__(self, *args, **kwargs): if os.environ.get("SLACK_WEBHOOK"): celery.Task.on_failure = post_to_slack_on_task_failure + # Manually trigger setup_periodic_tasks to ensure it runs with the correct configuration + # The signal handler may have fired before beat_schedules was set, so we call it explicitly here + try: + setup_periodic_tasks(celery) + celery_logger.info("Manually triggered setup_periodic_tasks after create_celery") + except Exception as e: + celery_logger.warning(f"Error manually triggering setup_periodic_tasks: {e}") + # Don't fail if this doesn't work - the signal handler should still work + return celery @@ -233,8 +242,18 @@ def setup_periodic_tasks(sender, **kwargs): celery_logger.debug(f"beat_schedules type: {type(beat_schedules)}, length: {len(beat_schedules) if isinstance(beat_schedules, (list, dict)) else 'N/A'}") if not beat_schedules: - celery_logger.info("No periodic tasks configured.") + celery_logger.debug("No periodic tasks configured yet (beat_schedules empty). Will retry after create_celery.") return + + # Check if tasks from beat_schedules are already registered to avoid duplicates + existing_schedule = sender.conf.get("beat_schedule", {}) + if existing_schedule: + # Check if any of our scheduled tasks are already registered + schedule_names = [s.get("name") for s in beat_schedules if isinstance(s, dict)] + already_registered = any(name in existing_schedule for name in schedule_names if name) + if already_registered: + celery_logger.debug("Periodic tasks already registered in beat_schedule, skipping duplicate registration.") + return try: schedules = CeleryBeatSchema(many=True).load(beat_schedules) except ValidationError as validation_error: From 93dc213537d510327eb528775a8ec06b3c2a486f Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Thu, 11 Dec 2025 11:44:13 -0800 Subject: [PATCH 05/13] trying a different shutdown option --- materializationengine/celery_worker.py | 31 +++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index 8d3c46ca..3e491c71 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -66,10 +66,39 @@ def _auto_shutdown_handler(sender=None, **kwargs): _shutdown_requested = True delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2) celery_logger.info( - "Auto-shutdown triggered after %s tasks; terminating in %ss", + "Auto-shutdown triggered after %s tasks; stopping consumer and terminating in %ss", _task_execution_count, delay, ) + + # Immediately stop accepting new tasks by canceling the consumer + # This prevents the worker from picking up new tasks during the shutdown delay + try: + # Get queue name from config or use default + queue_name = celery.conf.get("task_default_queue") or celery.conf.get("task_routes", {}).get("*", {}).get("queue", "celery") + # Get worker hostname from the task sender or use current worker's hostname + worker_hostname = None + if hasattr(sender, 'hostname'): + worker_hostname = sender.hostname + elif hasattr(celery, 'control'): + # Try to get hostname from current worker + try: + from celery import current_app + inspect = current_app.control.inspect() + active_workers = inspect.active() if inspect else {} + if active_workers: + worker_hostname = list(active_workers.keys())[0] + except Exception: + pass + + if worker_hostname and queue_name: + celery_logger.info("Canceling consumer for queue '%s' on worker '%s'", queue_name, worker_hostname) + celery.control.cancel_consumer(queue_name, destination=[worker_hostname]) + else: + celery_logger.warning("Could not determine worker hostname or queue name for consumer cancellation") + except Exception as exc: + celery_logger.warning("Failed to cancel consumer during shutdown: %s", exc) + shutdown_thread = threading.Thread( target=_request_worker_shutdown, args=(delay, _task_execution_count), From bd5ef8af60d16b081f9ee059c06438df783c37bc Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Thu, 11 Dec 2025 19:59:05 -0800 Subject: [PATCH 06/13] adding idle timeout --- materializationengine/celery_worker.py | 93 +++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index 3e491c71..953909db 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -28,6 +28,8 @@ _task_execution_count = 0 _shutdown_requested = False +_last_task_time = None +_worker_start_time = None def _request_worker_shutdown(delay_seconds: int, observed_count: int) -> None: @@ -54,11 +56,12 @@ def _auto_shutdown_handler(sender=None, **kwargs): if max_tasks <= 0: return - global _task_execution_count, _shutdown_requested + global _task_execution_count, _shutdown_requested, _last_task_time if _shutdown_requested: return _task_execution_count += 1 + _last_task_time = time.time() # Update last task time if _task_execution_count < max_tasks: return @@ -107,7 +110,86 @@ def _auto_shutdown_handler(sender=None, **kwargs): shutdown_thread.start() +def _monitor_idle_timeout(): + """Monitor worker idle time and shutdown if idle timeout exceeded.""" + idle_timeout_seconds = celery.conf.get("worker_idle_timeout_seconds", 0) + if idle_timeout_seconds <= 0: + return # Idle timeout not enabled + + check_interval = min(30, idle_timeout_seconds / 4) # Check every 30s or 1/4 of timeout, whichever is smaller + + global _last_task_time, _worker_start_time, _shutdown_requested + + while not _shutdown_requested: + time.sleep(check_interval) + + if _shutdown_requested: + break + + current_time = time.time() + + # If we've processed at least one task, use last task time + # Otherwise, use worker start time + if _last_task_time is not None: + idle_duration = current_time - _last_task_time + reference_time = _last_task_time + elif _worker_start_time is not None: + idle_duration = current_time - _worker_start_time + reference_time = _worker_start_time + else: + continue # Haven't started yet + + if idle_duration >= idle_timeout_seconds: + celery_logger.info( + "Idle timeout exceeded: worker has been idle for %.1f seconds (timeout: %d seconds). Shutting down.", + idle_duration, + idle_timeout_seconds, + ) + _shutdown_requested = True + # Cancel consumer to stop accepting new tasks + try: + queue_name = celery.conf.get("task_default_queue") or celery.conf.get("task_routes", {}).get("*", {}).get("queue", "celery") + from celery import current_app + inspect = current_app.control.inspect() + active_workers = inspect.active() if inspect else {} + if active_workers: + worker_hostname = list(active_workers.keys())[0] + celery_logger.info("Canceling consumer for queue '%s' on worker '%s'", queue_name, worker_hostname) + celery.control.cancel_consumer(queue_name, destination=[worker_hostname]) + except Exception as exc: + celery_logger.warning("Failed to cancel consumer during idle timeout shutdown: %s", exc) + + # Shutdown after a short delay + delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2) + shutdown_thread = threading.Thread( + target=_request_worker_shutdown, + args=(delay, 0), # 0 tasks since we're shutting down due to idle timeout + daemon=True, + ) + shutdown_thread.start() + break + + +def _worker_ready_handler(sender=None, **kwargs): + """Handle worker ready signal - start idle timeout monitor if enabled.""" + global _worker_start_time, _shutdown_requested + _worker_start_time = time.time() + + idle_timeout_seconds = celery.conf.get("worker_idle_timeout_seconds", 0) + if idle_timeout_seconds > 0: + celery_logger.info( + "Worker idle timeout enabled: %d seconds. Worker will shutdown if idle for this duration.", + idle_timeout_seconds, + ) + monitor_thread = threading.Thread( + target=_monitor_idle_timeout, + daemon=True, + ) + monitor_thread.start() + + signals.task_postrun.connect(_auto_shutdown_handler, weak=False) +signals.worker_ready.connect(_worker_ready_handler, weak=False) def create_celery(app=None): @@ -130,6 +212,9 @@ def create_celery(app=None): celery.conf.worker_autoshutdown_delay_seconds = app.config.get( "CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS", 2 ) + celery.conf.worker_idle_timeout_seconds = app.config.get( + "CELERY_WORKER_IDLE_TIMEOUT_SECONDS", 0 + ) if celery.conf.worker_autoshutdown_enabled: celery_logger.info( @@ -137,6 +222,12 @@ def create_celery(app=None): celery.conf.worker_autoshutdown_max_tasks, celery.conf.worker_autoshutdown_delay_seconds, ) + + if celery.conf.worker_idle_timeout_seconds > 0: + celery_logger.info( + "Worker idle timeout enabled: %s seconds", + celery.conf.worker_idle_timeout_seconds, + ) # Configure Celery and related loggers log_level = app.config["LOGGING_LEVEL"] celery_logger.setLevel(log_level) From ccaa28245447042a4d01582b62be399f56df945d Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Fri, 12 Dec 2025 04:50:48 -0800 Subject: [PATCH 07/13] make active tasks more robust to --- materializationengine/shared_tasks.py | 28 ++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/materializationengine/shared_tasks.py b/materializationengine/shared_tasks.py index 49ff88d4..6d8238b4 100644 --- a/materializationengine/shared_tasks.py +++ b/materializationengine/shared_tasks.py @@ -488,12 +488,26 @@ def check_if_task_is_running(task_name: str, worker_name_prefix: str) -> bool: inspector = celery.control.inspect() active_tasks_dict = inspector.active() - workflow_active_tasks = next( - v for k, v in active_tasks_dict.items() if worker_name_prefix in k - ) + # Handle case where inspector.active() returns None (no workers available or disconnected) + if active_tasks_dict is None: + celery_logger.warning( + "Unable to inspect active tasks: no workers available or workers disconnected" + ) + return False - for active_task in workflow_active_tasks: - if task_name in active_task.values(): - celery_logger.info(f"Task {task_name} is running...") - return True + # Find active tasks from workers matching the prefix + matching_active_tasks = [ + v for k, v in active_tasks_dict.items() if worker_name_prefix in k + ] + + # If no matching workers found, return False + if not matching_active_tasks: + return False + + # Check all active tasks from matching workers for the task + for workflow_active_tasks in matching_active_tasks: + for active_task in workflow_active_tasks: + if task_name in active_task.values(): + celery_logger.info(f"Task {task_name} is running...") + return True return False From 09877ae192194091f2e1a84da71ef46e0f386961 Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Sun, 14 Dec 2025 08:40:33 -0800 Subject: [PATCH 08/13] restructuring end of workflow logic to improve shutdown procedures --- materializationengine/database.py | 108 +++++++++++++++++- materializationengine/shared_tasks.py | 39 +++++++ .../workflows/ingest_new_annotations.py | 32 ++++-- .../workflows/update_database_workflow.py | 27 ++++- .../workflows/update_root_ids.py | 19 +-- 5 files changed, 199 insertions(+), 26 deletions(-) diff --git a/materializationengine/database.py b/materializationengine/database.py index 3c38ea60..29ccdf26 100644 --- a/materializationengine/database.py +++ b/materializationengine/database.py @@ -130,20 +130,54 @@ def session_scope(self, database_name: str): def cleanup(self): """Cleanup any remaining sessions and dispose of engine pools.""" + self.shutdown() + + def shutdown(self): + """Shutdown all database connections and dispose of all engines. + + This method: + 1. Closes all active sessions from session factories + 2. Removes all scoped sessions + 3. Disposes of all engine connection pools + 4. Clears all cached engines and session factories + + Should be called when the application is shutting down or when + you want to ensure all database connections are closed. + """ + celery_logger.info("Shutting down DatabaseConnectionManager...") + + # First, close all active sessions from session factories for database_name, session_factory in list(self._session_factories.items()): try: + # Remove all scoped sessions (this closes them) session_factory.remove() + celery_logger.debug(f"Removed scoped sessions for {database_name}") except Exception as e: - celery_logger.error(f"Error cleaning up sessions for {database_name}: {e}") + celery_logger.warning(f"Error removing sessions for {database_name}: {e}") + # Then dispose of all engines (this closes all connections in the pool) for database_name, engine in list(self._engines.items()): try: + # Log pool status before disposal + pool = engine.pool + checked_out = pool.checkedout() + if checked_out > 0: + celery_logger.warning( + f"Disposing engine for {database_name} with {checked_out} " + f"checked-out connections" + ) + + # Dispose closes all connections in the pool engine.dispose() + celery_logger.debug(f"Disposed engine connection pool for {database_name}") except Exception as e: celery_logger.error(f"Error disposing engine for {database_name}: {e}") - + + # Clear all caches self._session_factories.clear() self._engines.clear() + + celery_logger.info("DatabaseConnectionManager shutdown complete") def log_pool_status(self, database_name: str): """Log current connection pool status.""" @@ -185,7 +219,77 @@ def _get_mat_client(self, database: str): return self._clients[database] def invalidate_cache(self): + """Invalidate the cache by clearing all client references. + + Note: This does NOT close database connections. Use shutdown() for that. + """ self._clients = {} + + def shutdown(self): + """Shutdown all database connections and clear the cache. + + This method: + 1. Closes all cached sessions in DynamicAnnotationInterface clients + 2. Disposes of underlying database engines if available + 3. Clears all cached clients + + Should be called when the application is shutting down or when + you want to ensure all database connections are closed. + """ + celery_logger.info("Shutting down DynamicMaterializationCache...") + + for database, client in list(self._clients.items()): + try: + # Close the cached session if it exists + if hasattr(client, 'database'): + # Close the session + if hasattr(client.database, 'close_session'): + try: + client.database.close_session() + celery_logger.debug(f"Closed session for {database}") + except Exception as e: + celery_logger.warning( + f"Error closing session for {database}: {e}" + ) + + # Dispose of the engine if it exists + if hasattr(client.database, 'engine'): + try: + engine = client.database.engine + if engine: + pool = engine.pool + checked_out = pool.checkedout() + if checked_out > 0: + celery_logger.warning( + f"Disposing engine for {database} with " + f"{checked_out} checked-out connections" + ) + engine.dispose() + celery_logger.debug(f"Disposed engine for {database}") + except Exception as e: + celery_logger.warning( + f"Error disposing engine for {database}: {e}" + ) + + # Also try to close any cached session directly + if hasattr(client.database, '_cached_session'): + try: + cached_session = client.database._cached_session + if cached_session: + cached_session.close() + celery_logger.debug(f"Closed cached session for {database}") + except Exception as e: + celery_logger.warning( + f"Error closing cached session for {database}: {e}" + ) + + except Exception as e: + celery_logger.error(f"Error shutting down client for {database}: {e}") + + # Clear all clients + self._clients.clear() + + celery_logger.info("DynamicMaterializationCache shutdown complete") dynamic_annotation_cache = DynamicMaterializationCache() diff --git a/materializationengine/shared_tasks.py b/materializationengine/shared_tasks.py index 6d8238b4..4f7e239d 100644 --- a/materializationengine/shared_tasks.py +++ b/materializationengine/shared_tasks.py @@ -115,6 +115,45 @@ def fin(self, *args, **kwargs): def workflow_complete(self, workflow_name): return f"{workflow_name} completed successfully" + +@celery.task(name="orchestration:cleanup_and_shutdown", acks_late=True, bind=True) +def cleanup_and_shutdown(self, *args, **kwargs): + """Final cleanup task that ensures all resources are closed before shutdown. + + This task should be the last task in a workflow chain to ensure: + - Database connection pools are closed + - All resources are properly released + - The worker can safely shut down after this task completes + + Returns: + bool: Always returns True to indicate successful cleanup + """ + celery_logger.info("Performing final cleanup before shutdown") + + # Shutdown all database connections + try: + from materializationengine.database import db_manager, dynamic_annotation_cache + + # Shutdown DatabaseConnectionManager (closes all SQLAlchemy engines and sessions) + try: + db_manager.shutdown() + celery_logger.info("DatabaseConnectionManager shutdown complete") + except Exception as e: + celery_logger.warning(f"Error shutting down DatabaseConnectionManager: {e}") + + # Shutdown DynamicMaterializationCache (closes all DynamicAnnotationInterface connections) + try: + dynamic_annotation_cache.shutdown() + celery_logger.info("DynamicMaterializationCache shutdown complete") + except Exception as e: + celery_logger.warning(f"Error shutting down DynamicMaterializationCache: {e}") + + celery_logger.info("All database connections closed - cleanup complete") + except Exception as e: + celery_logger.error(f"Error during cleanup: {e}") + + return True + def get_materialization_info( datastack_info: dict, analysis_version: int = None, diff --git a/materializationengine/workflows/ingest_new_annotations.py b/materializationengine/workflows/ingest_new_annotations.py index e45ad4ad..95ca0704 100644 --- a/materializationengine/workflows/ingest_new_annotations.py +++ b/materializationengine/workflows/ingest_new_annotations.py @@ -608,7 +608,7 @@ def set_root_id_to_none_task( def ingest_new_annotations_workflow(mat_metadata: dict): """Celery workflow to ingest new annotations. In addition, it will create missing segmentation data table if it does not exist. - Returns celery chain primitive. + Returns celery chain primitive without executing it. Workflow: - Create linked segmentation table if not exists @@ -622,17 +622,26 @@ def ingest_new_annotations_workflow(mat_metadata: dict): annotation_chunks (List[int]): list of annotation primary key ids Returns: - chain: chain of celery tasks + chain: chain of celery tasks (not executed - returns signature only) """ - celery_logger.info("Ingesting new annotations...") + celery_logger.info("Preparing ingest new annotations workflow...") + + # Skip large tables if mat_metadata["row_count"] >= 1_000_000: + celery_logger.info(f"Skipping table with {mat_metadata['row_count']} rows (>= 1,000,000)") return fin.si() + + # Generate chunks synchronously (lightweight operation) annotation_chunks = generate_chunked_model_ids(mat_metadata) - table_created = create_missing_segmentation_table(mat_metadata) - if table_created: - celery_logger.info(f'Table created: {mat_metadata["segmentation_table_name"]}') - + + if not annotation_chunks: + celery_logger.info("No annotation chunks to process") + return fin.si() + + # Build the workflow chain - create table first, then process chunks + # The create_missing_segmentation_table task will be executed as part of the chain ingest_workflow = chain( + create_missing_segmentation_table.si(mat_metadata), chord( [ chain( @@ -642,10 +651,11 @@ def ingest_new_annotations_workflow(mat_metadata: dict): ], fin.si(), ) - ).apply_async() - tasks_completed = monitor_workflow_state(ingest_workflow) - if tasks_completed: - return fin.si() + ) + + # Return the chain signature - don't execute it here + # The caller will execute it as part of a larger chain + return ingest_workflow @celery.task(name="workflow:create_missing_segmentation_table", acks_late=True) diff --git a/materializationengine/workflows/update_database_workflow.py b/materializationengine/workflows/update_database_workflow.py index 19eba84a..28c8e3b7 100644 --- a/materializationengine/workflows/update_database_workflow.py +++ b/materializationengine/workflows/update_database_workflow.py @@ -8,9 +8,9 @@ from materializationengine.celery_init import celery from materializationengine.shared_tasks import ( get_materialization_info, - monitor_workflow_state, workflow_complete, fin, + cleanup_and_shutdown, ) from materializationengine.task import LockedTask from materializationengine.utils import get_config_param @@ -96,12 +96,27 @@ def update_database_workflow(self, datastack_info: dict, **kwargs): else: update_live_database_workflow.append(fin.si()) + # Build the complete workflow chain + # The workflow functions now return task signatures (not executed) + # So we can chain them together and execute once run_update_database_workflow = chain( - *update_live_database_workflow, workflow_complete.si("update_root_ids") - ).apply_async(kwargs={"Datastack": datastack_info["datastack"]}) + *update_live_database_workflow, + workflow_complete.si("update_root_ids"), + cleanup_and_shutdown.si(), # Final cleanup task to close resources + ) + + # Execute the entire chain asynchronously + # All tasks in the chain will be tracked by Celery + # task_postrun will fire for each task, including the cleanup task + celery_logger.info("Executing update database workflow chain") + result = run_update_database_workflow.apply_async( + kwargs={"Datastack": datastack_info["datastack"]} + ) + + # Return the result ID - the chain will execute asynchronously + # The worker will track all tasks in the chain + celery_logger.info(f"Workflow chain started with root task ID: {result.id}") + return True except Exception as e: celery_logger.error(f"An error has occurred: {e}") raise e - tasks_completed = monitor_workflow_state(run_update_database_workflow) - if tasks_completed: - return True diff --git a/materializationengine/workflows/update_root_ids.py b/materializationengine/workflows/update_root_ids.py index aa497278..16080ef1 100644 --- a/materializationengine/workflows/update_root_ids.py +++ b/materializationengine/workflows/update_root_ids.py @@ -50,7 +50,7 @@ def expired_root_id_workflow(datastack_info: dict, **kwargs): def update_root_ids_workflow(mat_metadata: dict): """Celery workflow that updates expired root ids in a - segmentation table. + segmentation table. Returns celery chain primitive without executing it. Workflow: - Lookup supervoxel id associated with expired root id @@ -65,18 +65,23 @@ def update_root_ids_workflow(mat_metadata: dict): chunked_roots (List[int]): chunks of expired root ids to lookup Returns: - chain: chain of celery tasks + chain: chain of celery tasks (not executed - returns signature only) """ - celery_logger.info("Setup expired root id workflow...") + celery_logger.info("Preparing expired root id workflow...") + + # Generate chunks synchronously (lightweight operation) if mat_metadata.get("lookup_all_root_ids"): chunked_ids = generate_chunked_model_ids(mat_metadata) else: chunked_ids = get_expired_root_ids_from_pcg(mat_metadata) if not chunked_ids: + celery_logger.info("No expired root IDs to process") return fin.si() + # Build the workflow chain - don't execute it here + # The caller will execute it as part of a larger chain update_root_workflow = chain( chord( [ @@ -86,10 +91,10 @@ def update_root_ids_workflow(mat_metadata: dict): fin.si(), ), update_metadata.si(mat_metadata), - ).apply_async() - tasks_completed = monitor_workflow_state(update_root_workflow) - if tasks_completed: - return workflow_complete.si("update_root_ids_workflow") + ) + + # Return the chain signature - don't execute it here + return update_root_workflow def get_expired_root_ids_from_pcg(mat_metadata: dict, expired_chunk_size: int = 100): From e1d2e94edf43b61cd0cd337f4815c1ab4562628a Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Sun, 14 Dec 2025 12:31:55 -0800 Subject: [PATCH 09/13] trying to fix shutdown with completion logic on child tasks --- materializationengine/celery_worker.py | 160 ++++++++++++++++++++++++- 1 file changed, 155 insertions(+), 5 deletions(-) diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index 953909db..b975d53b 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -47,8 +47,72 @@ def _request_worker_shutdown(delay_seconds: int, observed_count: int) -> None: celery_logger.error("Failed to terminate worker: %s", exc) +def _has_active_workflow_tasks(completed_task_id: str, task_name: str = None) -> bool: + """Check if there are other active tasks that might be part of a workflow. + + When a task spawns a chain/chord, it returns immediately but the child tasks + continue running. This function checks if there are other active tasks that + might be part of the workflow. + + We add a small delay to handle the race condition where chain tasks haven't + started yet when the parent task completes. + + Args: + completed_task_id: The task ID that just completed + task_name: Optional task name for logging + + Returns: + True if there are other active tasks (likely part of a workflow), False otherwise + """ + try: + import time + # Small delay to allow chain tasks to start + # This handles the race condition where apply_async() is called but + # the child tasks haven't been picked up by workers yet + time.sleep(0.5) + + from celery import current_app + inspect = current_app.control.inspect() + + # Get active tasks for all workers + active_tasks = inspect.active() + if not active_tasks: + return False + + # Check all workers (tasks in a chain might be on different workers) + total_other_tasks = 0 + for worker_name, tasks in active_tasks.items(): + # Filter out the task that just completed + other_tasks = [t for t in tasks if t.get('id') != completed_task_id] + total_other_tasks += len(other_tasks) + + if total_other_tasks > 0: + # There are other active tasks - this might be a workflow + celery_logger.debug( + "Found %d other active task(s) across workers (completed: %s, task: %s)", + total_other_tasks, + completed_task_id, + task_name or "unknown", + ) + return True + + return False + except Exception as e: + # If we can't check, assume no other tasks (safer to count than to skip) + celery_logger.debug(f"Could not check for active workflow tasks: {e}") + return False + + def _auto_shutdown_handler(sender=None, **kwargs): - """Trigger worker shutdown after configurable task count when enabled.""" + """Trigger worker shutdown after configurable task count when enabled. + + This handler uses a robust workflow detection mechanism: + - When a task completes, it checks if there are other active tasks on any worker + - If other tasks are active, it assumes this task started a workflow and skips counting + - Only tasks with no other active tasks are counted (workflow completion) + - This works for any workflow structure (chains, chords, groups) without hardcoding task names + - A small delay allows chain tasks to start before checking (handles race conditions) + """ if not celery.conf.get("worker_autoshutdown_enabled", False): return @@ -60,8 +124,43 @@ def _auto_shutdown_handler(sender=None, **kwargs): if _shutdown_requested: return + # Get task ID and name from kwargs (task_postrun provides task_id) + task_id = kwargs.get('task_id') + task_name = None + if sender: + if hasattr(sender, 'name'): + task_name = sender.name + elif hasattr(sender, 'task') and hasattr(sender.task, 'name'): + task_name = sender.task.name + + # Always update last task time when ANY task completes (workflow or child) + # This ensures idle timeout doesn't incorrectly trigger while workflows are running + _last_task_time = time.time() + + # If we have a task_id, check if there are other active tasks + # Tasks that spawn workflows (chains/chords) will have other tasks running + # We skip counting workflow starter tasks and only count when all tasks complete + if task_id: + has_other_tasks = _has_active_workflow_tasks(task_id, task_name) + if has_other_tasks: + celery_logger.debug( + "Skipping auto-shutdown count for %s (task_id: %s) - other tasks still active", + task_name or "unknown", + task_id, + ) + # Don't count this task, but we already updated _last_task_time above + return + + # This task has no active children, so it's a workflow completion _task_execution_count += 1 - _last_task_time = time.time() # Update last task time + + celery_logger.debug( + "Task completed (no active children): %s (task_id: %s, count: %d/%d)", + task_name or "unknown", + task_id or "unknown", + _task_execution_count, + max_tasks, + ) if _task_execution_count < max_tasks: return @@ -69,8 +168,9 @@ def _auto_shutdown_handler(sender=None, **kwargs): _shutdown_requested = True delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2) celery_logger.info( - "Auto-shutdown triggered after %s tasks; stopping consumer and terminating in %ss", + "Auto-shutdown triggered after %s tasks (last task: %s); stopping consumer and terminating in %ss", _task_execution_count, + task_name or "unknown", delay, ) @@ -111,7 +211,16 @@ def _auto_shutdown_handler(sender=None, **kwargs): def _monitor_idle_timeout(): - """Monitor worker idle time and shutdown if idle timeout exceeded.""" + """Monitor worker idle time and shutdown if idle timeout exceeded. + + This function checks both: + 1. Time since last task completion + 2. Whether there are any active tasks (to avoid shutting down during workflows) + + Only shuts down if BOTH conditions are met: + - Idle timeout exceeded + - No active tasks on any worker + """ idle_timeout_seconds = celery.conf.get("worker_idle_timeout_seconds", 0) if idle_timeout_seconds <= 0: return # Idle timeout not enabled @@ -140,8 +249,49 @@ def _monitor_idle_timeout(): continue # Haven't started yet if idle_duration >= idle_timeout_seconds: + # Before shutting down, verify there are no active tasks + # This prevents shutdown during workflows where tasks are waiting for children + try: + from celery import current_app + inspect = current_app.control.inspect() + active_tasks = inspect.active() + + has_active_tasks = False + if active_tasks: + # Check if there are any active tasks on any worker + for worker_name, tasks in active_tasks.items(): + if tasks: + has_active_tasks = True + celery_logger.debug( + "Idle timeout check: found %d active task(s) on worker %s, " + "not shutting down yet", + len(tasks), + worker_name, + ) + break + + if has_active_tasks: + # There are active tasks - don't shutdown, but continue monitoring + celery_logger.debug( + "Idle timeout exceeded (%.1f seconds) but active tasks detected, " + "continuing to monitor", + idle_duration, + ) + continue + + except Exception as exc: + # If we can't check for active tasks, be conservative and don't shutdown + celery_logger.warning( + "Could not check for active tasks during idle timeout: %s. " + "Not shutting down to be safe.", + exc, + ) + continue + + # Idle timeout exceeded AND no active tasks - safe to shutdown celery_logger.info( - "Idle timeout exceeded: worker has been idle for %.1f seconds (timeout: %d seconds). Shutting down.", + "Idle timeout exceeded: worker has been idle for %.1f seconds (timeout: %d seconds) " + "and no active tasks detected. Shutting down.", idle_duration, idle_timeout_seconds, ) From f02ae1607529736b5761a30ece3bffa183b03dd2 Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Sun, 14 Dec 2025 14:01:49 -0800 Subject: [PATCH 10/13] trying new shutdown options --- materializationengine/celery_worker.py | 160 +----------------- materializationengine/shared_tasks.py | 39 ----- .../workflows/ingest_new_annotations.py | 32 ++-- .../workflows/update_database_workflow.py | 27 +-- .../workflows/update_root_ids.py | 19 +-- 5 files changed, 29 insertions(+), 248 deletions(-) diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index b975d53b..953909db 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -47,72 +47,8 @@ def _request_worker_shutdown(delay_seconds: int, observed_count: int) -> None: celery_logger.error("Failed to terminate worker: %s", exc) -def _has_active_workflow_tasks(completed_task_id: str, task_name: str = None) -> bool: - """Check if there are other active tasks that might be part of a workflow. - - When a task spawns a chain/chord, it returns immediately but the child tasks - continue running. This function checks if there are other active tasks that - might be part of the workflow. - - We add a small delay to handle the race condition where chain tasks haven't - started yet when the parent task completes. - - Args: - completed_task_id: The task ID that just completed - task_name: Optional task name for logging - - Returns: - True if there are other active tasks (likely part of a workflow), False otherwise - """ - try: - import time - # Small delay to allow chain tasks to start - # This handles the race condition where apply_async() is called but - # the child tasks haven't been picked up by workers yet - time.sleep(0.5) - - from celery import current_app - inspect = current_app.control.inspect() - - # Get active tasks for all workers - active_tasks = inspect.active() - if not active_tasks: - return False - - # Check all workers (tasks in a chain might be on different workers) - total_other_tasks = 0 - for worker_name, tasks in active_tasks.items(): - # Filter out the task that just completed - other_tasks = [t for t in tasks if t.get('id') != completed_task_id] - total_other_tasks += len(other_tasks) - - if total_other_tasks > 0: - # There are other active tasks - this might be a workflow - celery_logger.debug( - "Found %d other active task(s) across workers (completed: %s, task: %s)", - total_other_tasks, - completed_task_id, - task_name or "unknown", - ) - return True - - return False - except Exception as e: - # If we can't check, assume no other tasks (safer to count than to skip) - celery_logger.debug(f"Could not check for active workflow tasks: {e}") - return False - - def _auto_shutdown_handler(sender=None, **kwargs): - """Trigger worker shutdown after configurable task count when enabled. - - This handler uses a robust workflow detection mechanism: - - When a task completes, it checks if there are other active tasks on any worker - - If other tasks are active, it assumes this task started a workflow and skips counting - - Only tasks with no other active tasks are counted (workflow completion) - - This works for any workflow structure (chains, chords, groups) without hardcoding task names - - A small delay allows chain tasks to start before checking (handles race conditions) - """ + """Trigger worker shutdown after configurable task count when enabled.""" if not celery.conf.get("worker_autoshutdown_enabled", False): return @@ -124,43 +60,8 @@ def _auto_shutdown_handler(sender=None, **kwargs): if _shutdown_requested: return - # Get task ID and name from kwargs (task_postrun provides task_id) - task_id = kwargs.get('task_id') - task_name = None - if sender: - if hasattr(sender, 'name'): - task_name = sender.name - elif hasattr(sender, 'task') and hasattr(sender.task, 'name'): - task_name = sender.task.name - - # Always update last task time when ANY task completes (workflow or child) - # This ensures idle timeout doesn't incorrectly trigger while workflows are running - _last_task_time = time.time() - - # If we have a task_id, check if there are other active tasks - # Tasks that spawn workflows (chains/chords) will have other tasks running - # We skip counting workflow starter tasks and only count when all tasks complete - if task_id: - has_other_tasks = _has_active_workflow_tasks(task_id, task_name) - if has_other_tasks: - celery_logger.debug( - "Skipping auto-shutdown count for %s (task_id: %s) - other tasks still active", - task_name or "unknown", - task_id, - ) - # Don't count this task, but we already updated _last_task_time above - return - - # This task has no active children, so it's a workflow completion _task_execution_count += 1 - - celery_logger.debug( - "Task completed (no active children): %s (task_id: %s, count: %d/%d)", - task_name or "unknown", - task_id or "unknown", - _task_execution_count, - max_tasks, - ) + _last_task_time = time.time() # Update last task time if _task_execution_count < max_tasks: return @@ -168,9 +69,8 @@ def _auto_shutdown_handler(sender=None, **kwargs): _shutdown_requested = True delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2) celery_logger.info( - "Auto-shutdown triggered after %s tasks (last task: %s); stopping consumer and terminating in %ss", + "Auto-shutdown triggered after %s tasks; stopping consumer and terminating in %ss", _task_execution_count, - task_name or "unknown", delay, ) @@ -211,16 +111,7 @@ def _auto_shutdown_handler(sender=None, **kwargs): def _monitor_idle_timeout(): - """Monitor worker idle time and shutdown if idle timeout exceeded. - - This function checks both: - 1. Time since last task completion - 2. Whether there are any active tasks (to avoid shutting down during workflows) - - Only shuts down if BOTH conditions are met: - - Idle timeout exceeded - - No active tasks on any worker - """ + """Monitor worker idle time and shutdown if idle timeout exceeded.""" idle_timeout_seconds = celery.conf.get("worker_idle_timeout_seconds", 0) if idle_timeout_seconds <= 0: return # Idle timeout not enabled @@ -249,49 +140,8 @@ def _monitor_idle_timeout(): continue # Haven't started yet if idle_duration >= idle_timeout_seconds: - # Before shutting down, verify there are no active tasks - # This prevents shutdown during workflows where tasks are waiting for children - try: - from celery import current_app - inspect = current_app.control.inspect() - active_tasks = inspect.active() - - has_active_tasks = False - if active_tasks: - # Check if there are any active tasks on any worker - for worker_name, tasks in active_tasks.items(): - if tasks: - has_active_tasks = True - celery_logger.debug( - "Idle timeout check: found %d active task(s) on worker %s, " - "not shutting down yet", - len(tasks), - worker_name, - ) - break - - if has_active_tasks: - # There are active tasks - don't shutdown, but continue monitoring - celery_logger.debug( - "Idle timeout exceeded (%.1f seconds) but active tasks detected, " - "continuing to monitor", - idle_duration, - ) - continue - - except Exception as exc: - # If we can't check for active tasks, be conservative and don't shutdown - celery_logger.warning( - "Could not check for active tasks during idle timeout: %s. " - "Not shutting down to be safe.", - exc, - ) - continue - - # Idle timeout exceeded AND no active tasks - safe to shutdown celery_logger.info( - "Idle timeout exceeded: worker has been idle for %.1f seconds (timeout: %d seconds) " - "and no active tasks detected. Shutting down.", + "Idle timeout exceeded: worker has been idle for %.1f seconds (timeout: %d seconds). Shutting down.", idle_duration, idle_timeout_seconds, ) diff --git a/materializationengine/shared_tasks.py b/materializationengine/shared_tasks.py index 4f7e239d..6d8238b4 100644 --- a/materializationengine/shared_tasks.py +++ b/materializationengine/shared_tasks.py @@ -115,45 +115,6 @@ def fin(self, *args, **kwargs): def workflow_complete(self, workflow_name): return f"{workflow_name} completed successfully" - -@celery.task(name="orchestration:cleanup_and_shutdown", acks_late=True, bind=True) -def cleanup_and_shutdown(self, *args, **kwargs): - """Final cleanup task that ensures all resources are closed before shutdown. - - This task should be the last task in a workflow chain to ensure: - - Database connection pools are closed - - All resources are properly released - - The worker can safely shut down after this task completes - - Returns: - bool: Always returns True to indicate successful cleanup - """ - celery_logger.info("Performing final cleanup before shutdown") - - # Shutdown all database connections - try: - from materializationengine.database import db_manager, dynamic_annotation_cache - - # Shutdown DatabaseConnectionManager (closes all SQLAlchemy engines and sessions) - try: - db_manager.shutdown() - celery_logger.info("DatabaseConnectionManager shutdown complete") - except Exception as e: - celery_logger.warning(f"Error shutting down DatabaseConnectionManager: {e}") - - # Shutdown DynamicMaterializationCache (closes all DynamicAnnotationInterface connections) - try: - dynamic_annotation_cache.shutdown() - celery_logger.info("DynamicMaterializationCache shutdown complete") - except Exception as e: - celery_logger.warning(f"Error shutting down DynamicMaterializationCache: {e}") - - celery_logger.info("All database connections closed - cleanup complete") - except Exception as e: - celery_logger.error(f"Error during cleanup: {e}") - - return True - def get_materialization_info( datastack_info: dict, analysis_version: int = None, diff --git a/materializationengine/workflows/ingest_new_annotations.py b/materializationengine/workflows/ingest_new_annotations.py index 95ca0704..e45ad4ad 100644 --- a/materializationengine/workflows/ingest_new_annotations.py +++ b/materializationengine/workflows/ingest_new_annotations.py @@ -608,7 +608,7 @@ def set_root_id_to_none_task( def ingest_new_annotations_workflow(mat_metadata: dict): """Celery workflow to ingest new annotations. In addition, it will create missing segmentation data table if it does not exist. - Returns celery chain primitive without executing it. + Returns celery chain primitive. Workflow: - Create linked segmentation table if not exists @@ -622,26 +622,17 @@ def ingest_new_annotations_workflow(mat_metadata: dict): annotation_chunks (List[int]): list of annotation primary key ids Returns: - chain: chain of celery tasks (not executed - returns signature only) + chain: chain of celery tasks """ - celery_logger.info("Preparing ingest new annotations workflow...") - - # Skip large tables + celery_logger.info("Ingesting new annotations...") if mat_metadata["row_count"] >= 1_000_000: - celery_logger.info(f"Skipping table with {mat_metadata['row_count']} rows (>= 1,000,000)") return fin.si() - - # Generate chunks synchronously (lightweight operation) annotation_chunks = generate_chunked_model_ids(mat_metadata) - - if not annotation_chunks: - celery_logger.info("No annotation chunks to process") - return fin.si() - - # Build the workflow chain - create table first, then process chunks - # The create_missing_segmentation_table task will be executed as part of the chain + table_created = create_missing_segmentation_table(mat_metadata) + if table_created: + celery_logger.info(f'Table created: {mat_metadata["segmentation_table_name"]}') + ingest_workflow = chain( - create_missing_segmentation_table.si(mat_metadata), chord( [ chain( @@ -651,11 +642,10 @@ def ingest_new_annotations_workflow(mat_metadata: dict): ], fin.si(), ) - ) - - # Return the chain signature - don't execute it here - # The caller will execute it as part of a larger chain - return ingest_workflow + ).apply_async() + tasks_completed = monitor_workflow_state(ingest_workflow) + if tasks_completed: + return fin.si() @celery.task(name="workflow:create_missing_segmentation_table", acks_late=True) diff --git a/materializationengine/workflows/update_database_workflow.py b/materializationengine/workflows/update_database_workflow.py index 28c8e3b7..19eba84a 100644 --- a/materializationengine/workflows/update_database_workflow.py +++ b/materializationengine/workflows/update_database_workflow.py @@ -8,9 +8,9 @@ from materializationengine.celery_init import celery from materializationengine.shared_tasks import ( get_materialization_info, + monitor_workflow_state, workflow_complete, fin, - cleanup_and_shutdown, ) from materializationengine.task import LockedTask from materializationengine.utils import get_config_param @@ -96,27 +96,12 @@ def update_database_workflow(self, datastack_info: dict, **kwargs): else: update_live_database_workflow.append(fin.si()) - # Build the complete workflow chain - # The workflow functions now return task signatures (not executed) - # So we can chain them together and execute once run_update_database_workflow = chain( - *update_live_database_workflow, - workflow_complete.si("update_root_ids"), - cleanup_and_shutdown.si(), # Final cleanup task to close resources - ) - - # Execute the entire chain asynchronously - # All tasks in the chain will be tracked by Celery - # task_postrun will fire for each task, including the cleanup task - celery_logger.info("Executing update database workflow chain") - result = run_update_database_workflow.apply_async( - kwargs={"Datastack": datastack_info["datastack"]} - ) - - # Return the result ID - the chain will execute asynchronously - # The worker will track all tasks in the chain - celery_logger.info(f"Workflow chain started with root task ID: {result.id}") - return True + *update_live_database_workflow, workflow_complete.si("update_root_ids") + ).apply_async(kwargs={"Datastack": datastack_info["datastack"]}) except Exception as e: celery_logger.error(f"An error has occurred: {e}") raise e + tasks_completed = monitor_workflow_state(run_update_database_workflow) + if tasks_completed: + return True diff --git a/materializationengine/workflows/update_root_ids.py b/materializationengine/workflows/update_root_ids.py index 16080ef1..aa497278 100644 --- a/materializationengine/workflows/update_root_ids.py +++ b/materializationengine/workflows/update_root_ids.py @@ -50,7 +50,7 @@ def expired_root_id_workflow(datastack_info: dict, **kwargs): def update_root_ids_workflow(mat_metadata: dict): """Celery workflow that updates expired root ids in a - segmentation table. Returns celery chain primitive without executing it. + segmentation table. Workflow: - Lookup supervoxel id associated with expired root id @@ -65,23 +65,18 @@ def update_root_ids_workflow(mat_metadata: dict): chunked_roots (List[int]): chunks of expired root ids to lookup Returns: - chain: chain of celery tasks (not executed - returns signature only) + chain: chain of celery tasks """ - celery_logger.info("Preparing expired root id workflow...") - - # Generate chunks synchronously (lightweight operation) + celery_logger.info("Setup expired root id workflow...") if mat_metadata.get("lookup_all_root_ids"): chunked_ids = generate_chunked_model_ids(mat_metadata) else: chunked_ids = get_expired_root_ids_from_pcg(mat_metadata) if not chunked_ids: - celery_logger.info("No expired root IDs to process") return fin.si() - # Build the workflow chain - don't execute it here - # The caller will execute it as part of a larger chain update_root_workflow = chain( chord( [ @@ -91,10 +86,10 @@ def update_root_ids_workflow(mat_metadata: dict): fin.si(), ), update_metadata.si(mat_metadata), - ) - - # Return the chain signature - don't execute it here - return update_root_workflow + ).apply_async() + tasks_completed = monitor_workflow_state(update_root_workflow) + if tasks_completed: + return workflow_complete.si("update_root_ids_workflow") def get_expired_root_ids_from_pcg(mat_metadata: dict, expired_chunk_size: int = 100): From cefef79d28496612d010279504d40356fb6ca358 Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Sun, 14 Dec 2025 14:50:16 -0800 Subject: [PATCH 11/13] trying to improve shutdown --- materializationengine/celery_worker.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index 953909db..ebe8cd30 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -33,7 +33,11 @@ def _request_worker_shutdown(delay_seconds: int, observed_count: int) -> None: - """Delay and then terminate the worker process.""" + """Delay and then terminate the worker process. + + First attempts graceful shutdown via SIGTERM, then forces exit if needed. + In Kubernetes, the container must exit for the pod to terminate. + """ # Delay slightly so task result propagation finishes time.sleep(max(delay_seconds, 0)) celery_logger.info( @@ -41,10 +45,25 @@ def _request_worker_shutdown(delay_seconds: int, observed_count: int) -> None: os.getpid(), observed_count, ) + + # First try graceful shutdown via SIGTERM try: os.kill(os.getpid(), signal.SIGTERM) - except Exception as exc: # pragma: no cover - best-effort shutdown - celery_logger.error("Failed to terminate worker: %s", exc) + # Give Celery a moment to handle the signal gracefully + time.sleep(3) + except Exception as exc: + celery_logger.warning("Failed to send SIGTERM: %s", exc) + + # Force exit if still running (Kubernetes requires process to exit) + # os._exit() bypasses Python cleanup and immediately terminates the process + # This ensures the container/pod terminates even if Celery's warm shutdown doesn't exit + celery_logger.info("Forcing process exit to ensure container termination") + try: + os._exit(0) # Exit with success code + except Exception as exc: + # Last resort: use sys.exit() which might be caught but is better than nothing + celery_logger.error("Failed to force exit, using sys.exit(): %s", exc) + sys.exit(0) def _auto_shutdown_handler(sender=None, **kwargs): From f13d4bb7e1fe867d0159f6c167f4e02ae999ef6c Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Mon, 15 Dec 2025 07:04:49 -0800 Subject: [PATCH 12/13] adding database delete logging --- materializationengine/workflows/periodic_database_removal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/materializationengine/workflows/periodic_database_removal.py b/materializationengine/workflows/periodic_database_removal.py index 8e5af9f4..3b0de47f 100644 --- a/materializationengine/workflows/periodic_database_removal.py +++ b/materializationengine/workflows/periodic_database_removal.py @@ -121,7 +121,7 @@ def remove_expired_databases(delete_threshold: int = 5, datastack: str = None) - ] dropped_dbs = [] - + celery_logger.info(f"Databases to delete: {databases_to_delete}") if len(databases) > delete_threshold: with engine.connect() as conn: conn.execution_options(isolation_level="AUTOCOMMIT") From 7ce830b1b35ce0d46ff5ff1f42216a08c2db2c7d Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Fri, 19 Dec 2025 08:36:04 -0800 Subject: [PATCH 13/13] making task registration more robust to duplication --- materializationengine/celery_worker.py | 48 ++++++++++++++++---------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index ebe8cd30..ee4051dc 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -361,13 +361,25 @@ def days_till_next_month(date): @celery.on_after_configure.connect def setup_periodic_tasks(sender, **kwargs): - # remove expired task results in redis broker - sender.add_periodic_task( - crontab(hour=0, minute=0, day_of_week="*", day_of_month="*", month_of_year="*"), - add_backend_cleanup_task(celery), - name="Clean up back end results", - ) - + # Check if tasks are already registered to avoid duplicates + existing_schedule = sender.conf.get("beat_schedule", {}) + cleanup_task_name = "Clean up back end results" + + # Check if cleanup task is already registered + cleanup_already_registered = cleanup_task_name in existing_schedule if existing_schedule else False + + # Add cleanup task only if not already registered + if not cleanup_already_registered: + # remove expired task results in redis broker + sender.add_periodic_task( + crontab(hour=0, minute=0, day_of_week="*", day_of_month="*", month_of_year="*"), + add_backend_cleanup_task(celery), + name=cleanup_task_name, + ) + celery_logger.debug(f"Added cleanup task: {cleanup_task_name}") + else: + celery_logger.debug(f"Cleanup task '{cleanup_task_name}' already registered, skipping.") + # Try to get beat_schedules from celery.conf, fallback to BEAT_SCHEDULES if not found beat_schedules = celery.conf.get("beat_schedules") if not beat_schedules: @@ -384,15 +396,6 @@ def setup_periodic_tasks(sender, **kwargs): celery_logger.debug("No periodic tasks configured yet (beat_schedules empty). Will retry after create_celery.") return - # Check if tasks from beat_schedules are already registered to avoid duplicates - existing_schedule = sender.conf.get("beat_schedule", {}) - if existing_schedule: - # Check if any of our scheduled tasks are already registered - schedule_names = [s.get("name") for s in beat_schedules if isinstance(s, dict)] - already_registered = any(name in existing_schedule for name in schedule_names if name) - if already_registered: - celery_logger.debug("Periodic tasks already registered in beat_schedule, skipping duplicate registration.") - return try: schedules = CeleryBeatSchema(many=True).load(beat_schedules) except ValidationError as validation_error: @@ -402,17 +405,24 @@ def setup_periodic_tasks(sender, **kwargs): min_databases = sender.conf.get("MIN_DATABASES") celery_logger.info(f"MIN_DATABASES: {min_databases}") for schedule in schedules: + task_name = schedule.get("name") + + # Check if this specific task is already registered + if existing_schedule and task_name and task_name in existing_schedule: + celery_logger.debug(f"Task '{task_name}' already registered, skipping duplicate registration.") + continue + try: task = configure_task(schedule, min_databases) sender.add_periodic_task( create_crontab(schedule), task, - name=schedule["name"], + name=task_name, ) - celery_logger.info(f"Added task: {schedule['name']}") + celery_logger.info(f"Added task: {task_name}") except ConfigurationError as e: celery_logger.error( - f"Error configuring task '{schedule.get('name', 'Unknown')}': {str(e)}" + f"Error configuring task '{task_name or 'Unknown'}': {str(e)}" )