diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index e8ce9bf4..ee4051dc 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -1,11 +1,15 @@ import datetime import logging import os +import signal import sys +import threading +import time import warnings from typing import Any, Callable, Dict import redis +from celery import signals from celery.app.builtins import add_backend_cleanup_task from celery.schedules import crontab from celery.signals import after_setup_logger @@ -22,6 +26,191 @@ celery_logger = get_task_logger(__name__) +_task_execution_count = 0 +_shutdown_requested = False +_last_task_time = None +_worker_start_time = None + + +def _request_worker_shutdown(delay_seconds: int, observed_count: int) -> None: + """Delay and then terminate the worker process. + + First attempts graceful shutdown via SIGTERM, then forces exit if needed. + In Kubernetes, the container must exit for the pod to terminate. + """ + # Delay slightly so task result propagation finishes + time.sleep(max(delay_seconds, 0)) + celery_logger.info( + "Auto-shutdown: terminating worker PID %s after %s tasks", + os.getpid(), + observed_count, + ) + + # First try graceful shutdown via SIGTERM + try: + os.kill(os.getpid(), signal.SIGTERM) + # Give Celery a moment to handle the signal gracefully + time.sleep(3) + except Exception as exc: + celery_logger.warning("Failed to send SIGTERM: %s", exc) + + # Force exit if still running (Kubernetes requires process to exit) + # os._exit() bypasses Python cleanup and immediately terminates the process + # This ensures the container/pod terminates even if Celery's warm shutdown doesn't exit + celery_logger.info("Forcing process exit to ensure container termination") + try: + os._exit(0) # Exit with success code + except Exception as exc: + # Last resort: use sys.exit() which might be caught but is better than nothing + celery_logger.error("Failed to force exit, using sys.exit(): %s", exc) + sys.exit(0) + + +def _auto_shutdown_handler(sender=None, **kwargs): + """Trigger worker shutdown after configurable task count when enabled.""" + if not celery.conf.get("worker_autoshutdown_enabled", False): + return + + max_tasks = celery.conf.get("worker_autoshutdown_max_tasks", 1) + if max_tasks <= 0: + return + + global _task_execution_count, _shutdown_requested, _last_task_time + if _shutdown_requested: + return + + _task_execution_count += 1 + _last_task_time = time.time() # Update last task time + + if _task_execution_count < max_tasks: + return + + _shutdown_requested = True + delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2) + celery_logger.info( + "Auto-shutdown triggered after %s tasks; stopping consumer and terminating in %ss", + _task_execution_count, + delay, + ) + + # Immediately stop accepting new tasks by canceling the consumer + # This prevents the worker from picking up new tasks during the shutdown delay + try: + # Get queue name from config or use default + queue_name = celery.conf.get("task_default_queue") or celery.conf.get("task_routes", {}).get("*", {}).get("queue", "celery") + # Get worker hostname from the task sender or use current worker's hostname + worker_hostname = None + if hasattr(sender, 'hostname'): + worker_hostname = sender.hostname + elif hasattr(celery, 'control'): + # Try to get hostname from current worker + try: + from celery import current_app + inspect = current_app.control.inspect() + active_workers = inspect.active() if inspect else {} + if active_workers: + worker_hostname = list(active_workers.keys())[0] + except Exception: + pass + + if worker_hostname and queue_name: + celery_logger.info("Canceling consumer for queue '%s' on worker '%s'", queue_name, worker_hostname) + celery.control.cancel_consumer(queue_name, destination=[worker_hostname]) + else: + celery_logger.warning("Could not determine worker hostname or queue name for consumer cancellation") + except Exception as exc: + celery_logger.warning("Failed to cancel consumer during shutdown: %s", exc) + + shutdown_thread = threading.Thread( + target=_request_worker_shutdown, + args=(delay, _task_execution_count), + daemon=True, + ) + shutdown_thread.start() + + +def _monitor_idle_timeout(): + """Monitor worker idle time and shutdown if idle timeout exceeded.""" + idle_timeout_seconds = celery.conf.get("worker_idle_timeout_seconds", 0) + if idle_timeout_seconds <= 0: + return # Idle timeout not enabled + + check_interval = min(30, idle_timeout_seconds / 4) # Check every 30s or 1/4 of timeout, whichever is smaller + + global _last_task_time, _worker_start_time, _shutdown_requested + + while not _shutdown_requested: + time.sleep(check_interval) + + if _shutdown_requested: + break + + current_time = time.time() + + # If we've processed at least one task, use last task time + # Otherwise, use worker start time + if _last_task_time is not None: + idle_duration = current_time - _last_task_time + reference_time = _last_task_time + elif _worker_start_time is not None: + idle_duration = current_time - _worker_start_time + reference_time = _worker_start_time + else: + continue # Haven't started yet + + if idle_duration >= idle_timeout_seconds: + celery_logger.info( + "Idle timeout exceeded: worker has been idle for %.1f seconds (timeout: %d seconds). Shutting down.", + idle_duration, + idle_timeout_seconds, + ) + _shutdown_requested = True + # Cancel consumer to stop accepting new tasks + try: + queue_name = celery.conf.get("task_default_queue") or celery.conf.get("task_routes", {}).get("*", {}).get("queue", "celery") + from celery import current_app + inspect = current_app.control.inspect() + active_workers = inspect.active() if inspect else {} + if active_workers: + worker_hostname = list(active_workers.keys())[0] + celery_logger.info("Canceling consumer for queue '%s' on worker '%s'", queue_name, worker_hostname) + celery.control.cancel_consumer(queue_name, destination=[worker_hostname]) + except Exception as exc: + celery_logger.warning("Failed to cancel consumer during idle timeout shutdown: %s", exc) + + # Shutdown after a short delay + delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2) + shutdown_thread = threading.Thread( + target=_request_worker_shutdown, + args=(delay, 0), # 0 tasks since we're shutting down due to idle timeout + daemon=True, + ) + shutdown_thread.start() + break + + +def _worker_ready_handler(sender=None, **kwargs): + """Handle worker ready signal - start idle timeout monitor if enabled.""" + global _worker_start_time, _shutdown_requested + _worker_start_time = time.time() + + idle_timeout_seconds = celery.conf.get("worker_idle_timeout_seconds", 0) + if idle_timeout_seconds > 0: + celery_logger.info( + "Worker idle timeout enabled: %d seconds. Worker will shutdown if idle for this duration.", + idle_timeout_seconds, + ) + monitor_thread = threading.Thread( + target=_monitor_idle_timeout, + daemon=True, + ) + monitor_thread.start() + + +signals.task_postrun.connect(_auto_shutdown_handler, weak=False) +signals.worker_ready.connect(_worker_ready_handler, weak=False) + + def create_celery(app=None): celery.conf.broker_url = app.config["CELERY_BROKER_URL"] celery.conf.result_backend = app.config["CELERY_RESULT_BACKEND"] @@ -32,6 +221,32 @@ def create_celery(app=None): celery.conf.result_backend_transport_options = { "master_name": app.config["MASTER_NAME"] } + + celery.conf.worker_autoshutdown_enabled = app.config.get( + "CELERY_WORKER_AUTOSHUTDOWN_ENABLED", False + ) + celery.conf.worker_autoshutdown_max_tasks = app.config.get( + "CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS", 1 + ) + celery.conf.worker_autoshutdown_delay_seconds = app.config.get( + "CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS", 2 + ) + celery.conf.worker_idle_timeout_seconds = app.config.get( + "CELERY_WORKER_IDLE_TIMEOUT_SECONDS", 0 + ) + + if celery.conf.worker_autoshutdown_enabled: + celery_logger.info( + "Worker auto-shutdown enabled: max_tasks=%s delay=%ss", + celery.conf.worker_autoshutdown_max_tasks, + celery.conf.worker_autoshutdown_delay_seconds, + ) + + if celery.conf.worker_idle_timeout_seconds > 0: + celery_logger.info( + "Worker idle timeout enabled: %s seconds", + celery.conf.worker_idle_timeout_seconds, + ) # Configure Celery and related loggers log_level = app.config["LOGGING_LEVEL"] celery_logger.setLevel(log_level) @@ -102,6 +317,15 @@ def __call__(self, *args, **kwargs): if os.environ.get("SLACK_WEBHOOK"): celery.Task.on_failure = post_to_slack_on_task_failure + # Manually trigger setup_periodic_tasks to ensure it runs with the correct configuration + # The signal handler may have fired before beat_schedules was set, so we call it explicitly here + try: + setup_periodic_tasks(celery) + celery_logger.info("Manually triggered setup_periodic_tasks after create_celery") + except Exception as e: + celery_logger.warning(f"Error manually triggering setup_periodic_tasks: {e}") + # Don't fail if this doesn't work - the signal handler should still work + return celery @@ -137,13 +361,25 @@ def days_till_next_month(date): @celery.on_after_configure.connect def setup_periodic_tasks(sender, **kwargs): - # remove expired task results in redis broker - sender.add_periodic_task( - crontab(hour=0, minute=0, day_of_week="*", day_of_month="*", month_of_year="*"), - add_backend_cleanup_task(celery), - name="Clean up back end results", - ) - + # Check if tasks are already registered to avoid duplicates + existing_schedule = sender.conf.get("beat_schedule", {}) + cleanup_task_name = "Clean up back end results" + + # Check if cleanup task is already registered + cleanup_already_registered = cleanup_task_name in existing_schedule if existing_schedule else False + + # Add cleanup task only if not already registered + if not cleanup_already_registered: + # remove expired task results in redis broker + sender.add_periodic_task( + crontab(hour=0, minute=0, day_of_week="*", day_of_month="*", month_of_year="*"), + add_backend_cleanup_task(celery), + name=cleanup_task_name, + ) + celery_logger.debug(f"Added cleanup task: {cleanup_task_name}") + else: + celery_logger.debug(f"Cleanup task '{cleanup_task_name}' already registered, skipping.") + # Try to get beat_schedules from celery.conf, fallback to BEAT_SCHEDULES if not found beat_schedules = celery.conf.get("beat_schedules") if not beat_schedules: @@ -157,8 +393,9 @@ def setup_periodic_tasks(sender, **kwargs): celery_logger.debug(f"beat_schedules type: {type(beat_schedules)}, length: {len(beat_schedules) if isinstance(beat_schedules, (list, dict)) else 'N/A'}") if not beat_schedules: - celery_logger.info("No periodic tasks configured.") + celery_logger.debug("No periodic tasks configured yet (beat_schedules empty). Will retry after create_celery.") return + try: schedules = CeleryBeatSchema(many=True).load(beat_schedules) except ValidationError as validation_error: @@ -168,17 +405,24 @@ def setup_periodic_tasks(sender, **kwargs): min_databases = sender.conf.get("MIN_DATABASES") celery_logger.info(f"MIN_DATABASES: {min_databases}") for schedule in schedules: + task_name = schedule.get("name") + + # Check if this specific task is already registered + if existing_schedule and task_name and task_name in existing_schedule: + celery_logger.debug(f"Task '{task_name}' already registered, skipping duplicate registration.") + continue + try: task = configure_task(schedule, min_databases) sender.add_periodic_task( create_crontab(schedule), task, - name=schedule["name"], + name=task_name, ) - celery_logger.info(f"Added task: {schedule['name']}") + celery_logger.info(f"Added task: {task_name}") except ConfigurationError as e: celery_logger.error( - f"Error configuring task '{schedule.get('name', 'Unknown')}': {str(e)}" + f"Error configuring task '{task_name or 'Unknown'}': {str(e)}" ) diff --git a/materializationengine/config.py b/materializationengine/config.py index 4706a51a..9ee098a1 100644 --- a/materializationengine/config.py +++ b/materializationengine/config.py @@ -48,9 +48,13 @@ class BaseConfig: MERGE_TABLES = True AUTH_SERVICE_NAMESPACE = "datastack" - REDIS_HOST="localhost" - REDIS_PORT=6379 - REDIS_PASSWORD="" + CELERY_WORKER_AUTOSHUTDOWN_ENABLED = False + CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS = 1 + CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS = 2 + + REDIS_HOST = "localhost" + REDIS_PORT = 6379 + REDIS_PASSWORD = "" SESSION_TYPE = "redis" PERMANENT_SESSION_LIFETIME = timedelta(hours=24) SESSION_PREFIX = "annotation_upload_" diff --git a/materializationengine/database.py b/materializationengine/database.py index 3c38ea60..29ccdf26 100644 --- a/materializationengine/database.py +++ b/materializationengine/database.py @@ -130,20 +130,54 @@ def session_scope(self, database_name: str): def cleanup(self): """Cleanup any remaining sessions and dispose of engine pools.""" + self.shutdown() + + def shutdown(self): + """Shutdown all database connections and dispose of all engines. + + This method: + 1. Closes all active sessions from session factories + 2. Removes all scoped sessions + 3. Disposes of all engine connection pools + 4. Clears all cached engines and session factories + + Should be called when the application is shutting down or when + you want to ensure all database connections are closed. + """ + celery_logger.info("Shutting down DatabaseConnectionManager...") + + # First, close all active sessions from session factories for database_name, session_factory in list(self._session_factories.items()): try: + # Remove all scoped sessions (this closes them) session_factory.remove() + celery_logger.debug(f"Removed scoped sessions for {database_name}") except Exception as e: - celery_logger.error(f"Error cleaning up sessions for {database_name}: {e}") + celery_logger.warning(f"Error removing sessions for {database_name}: {e}") + # Then dispose of all engines (this closes all connections in the pool) for database_name, engine in list(self._engines.items()): try: + # Log pool status before disposal + pool = engine.pool + checked_out = pool.checkedout() + if checked_out > 0: + celery_logger.warning( + f"Disposing engine for {database_name} with {checked_out} " + f"checked-out connections" + ) + + # Dispose closes all connections in the pool engine.dispose() + celery_logger.debug(f"Disposed engine connection pool for {database_name}") except Exception as e: celery_logger.error(f"Error disposing engine for {database_name}: {e}") - + + # Clear all caches self._session_factories.clear() self._engines.clear() + + celery_logger.info("DatabaseConnectionManager shutdown complete") def log_pool_status(self, database_name: str): """Log current connection pool status.""" @@ -185,7 +219,77 @@ def _get_mat_client(self, database: str): return self._clients[database] def invalidate_cache(self): + """Invalidate the cache by clearing all client references. + + Note: This does NOT close database connections. Use shutdown() for that. + """ self._clients = {} + + def shutdown(self): + """Shutdown all database connections and clear the cache. + + This method: + 1. Closes all cached sessions in DynamicAnnotationInterface clients + 2. Disposes of underlying database engines if available + 3. Clears all cached clients + + Should be called when the application is shutting down or when + you want to ensure all database connections are closed. + """ + celery_logger.info("Shutting down DynamicMaterializationCache...") + + for database, client in list(self._clients.items()): + try: + # Close the cached session if it exists + if hasattr(client, 'database'): + # Close the session + if hasattr(client.database, 'close_session'): + try: + client.database.close_session() + celery_logger.debug(f"Closed session for {database}") + except Exception as e: + celery_logger.warning( + f"Error closing session for {database}: {e}" + ) + + # Dispose of the engine if it exists + if hasattr(client.database, 'engine'): + try: + engine = client.database.engine + if engine: + pool = engine.pool + checked_out = pool.checkedout() + if checked_out > 0: + celery_logger.warning( + f"Disposing engine for {database} with " + f"{checked_out} checked-out connections" + ) + engine.dispose() + celery_logger.debug(f"Disposed engine for {database}") + except Exception as e: + celery_logger.warning( + f"Error disposing engine for {database}: {e}" + ) + + # Also try to close any cached session directly + if hasattr(client.database, '_cached_session'): + try: + cached_session = client.database._cached_session + if cached_session: + cached_session.close() + celery_logger.debug(f"Closed cached session for {database}") + except Exception as e: + celery_logger.warning( + f"Error closing cached session for {database}: {e}" + ) + + except Exception as e: + celery_logger.error(f"Error shutting down client for {database}: {e}") + + # Clear all clients + self._clients.clear() + + celery_logger.info("DynamicMaterializationCache shutdown complete") dynamic_annotation_cache = DynamicMaterializationCache() diff --git a/materializationengine/shared_tasks.py b/materializationengine/shared_tasks.py index 49ff88d4..6d8238b4 100644 --- a/materializationengine/shared_tasks.py +++ b/materializationengine/shared_tasks.py @@ -488,12 +488,26 @@ def check_if_task_is_running(task_name: str, worker_name_prefix: str) -> bool: inspector = celery.control.inspect() active_tasks_dict = inspector.active() - workflow_active_tasks = next( - v for k, v in active_tasks_dict.items() if worker_name_prefix in k - ) + # Handle case where inspector.active() returns None (no workers available or disconnected) + if active_tasks_dict is None: + celery_logger.warning( + "Unable to inspect active tasks: no workers available or workers disconnected" + ) + return False - for active_task in workflow_active_tasks: - if task_name in active_task.values(): - celery_logger.info(f"Task {task_name} is running...") - return True + # Find active tasks from workers matching the prefix + matching_active_tasks = [ + v for k, v in active_tasks_dict.items() if worker_name_prefix in k + ] + + # If no matching workers found, return False + if not matching_active_tasks: + return False + + # Check all active tasks from matching workers for the task + for workflow_active_tasks in matching_active_tasks: + for active_task in workflow_active_tasks: + if task_name in active_task.values(): + celery_logger.info(f"Task {task_name} is running...") + return True return False diff --git a/materializationengine/workflows/periodic_database_removal.py b/materializationengine/workflows/periodic_database_removal.py index 8e5af9f4..3b0de47f 100644 --- a/materializationengine/workflows/periodic_database_removal.py +++ b/materializationengine/workflows/periodic_database_removal.py @@ -121,7 +121,7 @@ def remove_expired_databases(delete_threshold: int = 5, datastack: str = None) - ] dropped_dbs = [] - + celery_logger.info(f"Databases to delete: {databases_to_delete}") if len(databases) > delete_threshold: with engine.connect() as conn: conn.execution_options(isolation_level="AUTOCOMMIT")