Skip to content
Open
266 changes: 255 additions & 11 deletions materializationengine/celery_worker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import datetime
import logging
import os
import signal
import sys
import threading
import time
import warnings
from typing import Any, Callable, Dict

import redis
from celery import signals
from celery.app.builtins import add_backend_cleanup_task
from celery.schedules import crontab
from celery.signals import after_setup_logger
Expand All @@ -22,6 +26,191 @@
celery_logger = get_task_logger(__name__)


_task_execution_count = 0
_shutdown_requested = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Task count is per-child, not per-worker with prefork

The module-level globals _task_execution_count and _shutdown_requested are process-local. With Celery's prefork pool (used in this project with concurrency up to 4), each child process maintains its own independent counter due to fork semantics. This means if worker_autoshutdown_max_tasks is set to 10 and concurrency is 4, each child would need to run 10 tasks individually before triggering shutdown - potentially 40 total tasks before any action, rather than the expected 10. The counting doesn't aggregate across worker children as the configuration name suggests.

Additional Locations (1)

Fix in Cursor Fix in Web

_last_task_time = None
_worker_start_time = None


def _request_worker_shutdown(delay_seconds: int, observed_count: int) -> None:
"""Delay and then terminate the worker process.

First attempts graceful shutdown via SIGTERM, then forces exit if needed.
In Kubernetes, the container must exit for the pod to terminate.
"""
# Delay slightly so task result propagation finishes
time.sleep(max(delay_seconds, 0))
celery_logger.info(
"Auto-shutdown: terminating worker PID %s after %s tasks",
os.getpid(),
observed_count,
)

# First try graceful shutdown via SIGTERM
try:
os.kill(os.getpid(), signal.SIGTERM)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: SIGTERM targets child process, not main worker with prefork

The auto-shutdown feature intends to terminate the worker after a configurable number of tasks, but with Celery's prefork pool (which this project uses), the task_postrun signal fires in the child process, not the main worker process. Calling os.kill(os.getpid(), signal.SIGTERM) from that context terminates only the child process, which the main worker will simply replace with a new one. The worker continues running indefinitely, defeating the feature's purpose. To properly shut down the entire worker, the code would need to signal the parent/main process rather than the current child process.

Fix in Cursor Fix in Web

# Give Celery a moment to handle the signal gracefully
time.sleep(3)
except Exception as exc:
celery_logger.warning("Failed to send SIGTERM: %s", exc)

# Force exit if still running (Kubernetes requires process to exit)
# os._exit() bypasses Python cleanup and immediately terminates the process
# This ensures the container/pod terminates even if Celery's warm shutdown doesn't exit
celery_logger.info("Forcing process exit to ensure container termination")
try:
os._exit(0) # Exit with success code
except Exception as exc:
# Last resort: use sys.exit() which might be caught but is better than nothing
celery_logger.error("Failed to force exit, using sys.exit(): %s", exc)
sys.exit(0)


def _auto_shutdown_handler(sender=None, **kwargs):
"""Trigger worker shutdown after configurable task count when enabled."""
if not celery.conf.get("worker_autoshutdown_enabled", False):
return

max_tasks = celery.conf.get("worker_autoshutdown_max_tasks", 1)
if max_tasks <= 0:
return

global _task_execution_count, _shutdown_requested, _last_task_time
if _shutdown_requested:
return

_task_execution_count += 1
_last_task_time = time.time() # Update last task time

if _task_execution_count < max_tasks:
return

_shutdown_requested = True
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Race condition in global shutdown state without synchronization

The global variables _task_execution_count and _shutdown_requested are accessed and modified without synchronization (e.g., a threading lock). When using Celery with eventlet, gevent, or threads pools where multiple tasks execute concurrently, two tasks could both pass the if _shutdown_requested: check before either sets it to True. This could cause multiple shutdown threads to be spawned and _task_execution_count to be incremented incorrectly. While the worker still shuts down, the logged task count may be wrong and duplicate shutdown attempts occur.

Additional Locations (1)

Fix in Cursor Fix in Web

delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2)
celery_logger.info(
"Auto-shutdown triggered after %s tasks; stopping consumer and terminating in %ss",
_task_execution_count,
delay,
)

# Immediately stop accepting new tasks by canceling the consumer
# This prevents the worker from picking up new tasks during the shutdown delay
try:
# Get queue name from config or use default
queue_name = celery.conf.get("task_default_queue") or celery.conf.get("task_routes", {}).get("*", {}).get("queue", "celery")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Queue name fallback logic fails with string task_routes

The fallback logic for determining queue_name assumes task_routes is a dict with pattern matching keys, but in this codebase task_routes is set to a string "materializationengine.task_router.TaskRouter" at line 167. If task_default_queue is not set or returns a falsy value, the expression celery.conf.get("task_routes", {}).get("*", {}) will raise AttributeError because strings don't have a .get() method. While this is caught by the outer try-except block, it causes the consumer cancellation feature to silently fail, allowing the worker to potentially accept new tasks during the shutdown delay period.

Fix in Cursor Fix in Web

# Get worker hostname from the task sender or use current worker's hostname
worker_hostname = None
if hasattr(sender, 'hostname'):
worker_hostname = sender.hostname
elif hasattr(celery, 'control'):
# Try to get hostname from current worker
try:
from celery import current_app
inspect = current_app.control.inspect()
active_workers = inspect.active() if inspect else {}
if active_workers:
worker_hostname = list(active_workers.keys())[0]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Consumer cancellation may target wrong worker in cluster

The code uses inspect.active() which returns all active workers across the entire Celery cluster, then selects the first worker with list(active_workers.keys())[0]. In a multi-worker environment, this could pick an arbitrary worker's hostname instead of the current worker, causing cancel_consumer to stop task consumption on a different worker than the one that reached its task limit or idle timeout.

Additional Locations (1)

Fix in Cursor Fix in Web

except Exception:
pass
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incorrect worker hostname attribute access on task sender

The sender parameter in task_postrun is the Task instance, which doesn't have a hostname attribute directly. The worker hostname is available at sender.request.hostname, not sender.hostname. The hasattr(sender, 'hostname') check will return False, always falling through to the fallback logic which calls inspect.active(). This fallback makes a network call to query all workers and picks the first one from the dictionary, which may not be the current worker, potentially canceling the consumer on a different worker.

Fix in Cursor Fix in Web


if worker_hostname and queue_name:
celery_logger.info("Canceling consumer for queue '%s' on worker '%s'", queue_name, worker_hostname)
celery.control.cancel_consumer(queue_name, destination=[worker_hostname])
else:
celery_logger.warning("Could not determine worker hostname or queue name for consumer cancellation")
except Exception as exc:
celery_logger.warning("Failed to cancel consumer during shutdown: %s", exc)

shutdown_thread = threading.Thread(
target=_request_worker_shutdown,
args=(delay, _task_execution_count),
daemon=True,
)
shutdown_thread.start()


def _monitor_idle_timeout():
"""Monitor worker idle time and shutdown if idle timeout exceeded."""
idle_timeout_seconds = celery.conf.get("worker_idle_timeout_seconds", 0)
if idle_timeout_seconds <= 0:
return # Idle timeout not enabled

check_interval = min(30, idle_timeout_seconds / 4) # Check every 30s or 1/4 of timeout, whichever is smaller

global _last_task_time, _worker_start_time, _shutdown_requested

while not _shutdown_requested:
time.sleep(check_interval)

if _shutdown_requested:
break

current_time = time.time()

# If we've processed at least one task, use last task time
# Otherwise, use worker start time
if _last_task_time is not None:
idle_duration = current_time - _last_task_time
reference_time = _last_task_time
elif _worker_start_time is not None:
idle_duration = current_time - _worker_start_time
reference_time = _worker_start_time
else:
continue # Haven't started yet

if idle_duration >= idle_timeout_seconds:
celery_logger.info(
"Idle timeout exceeded: worker has been idle for %.1f seconds (timeout: %d seconds). Shutting down.",
idle_duration,
idle_timeout_seconds,
)
_shutdown_requested = True
# Cancel consumer to stop accepting new tasks
try:
queue_name = celery.conf.get("task_default_queue") or celery.conf.get("task_routes", {}).get("*", {}).get("queue", "celery")
from celery import current_app
inspect = current_app.control.inspect()
active_workers = inspect.active() if inspect else {}
if active_workers:
worker_hostname = list(active_workers.keys())[0]
celery_logger.info("Canceling consumer for queue '%s' on worker '%s'", queue_name, worker_hostname)
celery.control.cancel_consumer(queue_name, destination=[worker_hostname])
except Exception as exc:
celery_logger.warning("Failed to cancel consumer during idle timeout shutdown: %s", exc)

# Shutdown after a short delay
delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2)
shutdown_thread = threading.Thread(
target=_request_worker_shutdown,
args=(delay, 0), # 0 tasks since we're shutting down due to idle timeout
daemon=True,
)
shutdown_thread.start()
break


def _worker_ready_handler(sender=None, **kwargs):
"""Handle worker ready signal - start idle timeout monitor if enabled."""
global _worker_start_time, _shutdown_requested
_worker_start_time = time.time()

idle_timeout_seconds = celery.conf.get("worker_idle_timeout_seconds", 0)
if idle_timeout_seconds > 0:
celery_logger.info(
"Worker idle timeout enabled: %d seconds. Worker will shutdown if idle for this duration.",
idle_timeout_seconds,
)
monitor_thread = threading.Thread(
target=_monitor_idle_timeout,
daemon=True,
)
monitor_thread.start()


signals.task_postrun.connect(_auto_shutdown_handler, weak=False)
signals.worker_ready.connect(_worker_ready_handler, weak=False)


def create_celery(app=None):
celery.conf.broker_url = app.config["CELERY_BROKER_URL"]
celery.conf.result_backend = app.config["CELERY_RESULT_BACKEND"]
Expand All @@ -32,6 +221,32 @@ def create_celery(app=None):
celery.conf.result_backend_transport_options = {
"master_name": app.config["MASTER_NAME"]
}

celery.conf.worker_autoshutdown_enabled = app.config.get(
"CELERY_WORKER_AUTOSHUTDOWN_ENABLED", False
)
celery.conf.worker_autoshutdown_max_tasks = app.config.get(
"CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS", 1
)
celery.conf.worker_autoshutdown_delay_seconds = app.config.get(
"CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS", 2
)
celery.conf.worker_idle_timeout_seconds = app.config.get(
"CELERY_WORKER_IDLE_TIMEOUT_SECONDS", 0
)

if celery.conf.worker_autoshutdown_enabled:
celery_logger.info(
"Worker auto-shutdown enabled: max_tasks=%s delay=%ss",
celery.conf.worker_autoshutdown_max_tasks,
celery.conf.worker_autoshutdown_delay_seconds,
)

if celery.conf.worker_idle_timeout_seconds > 0:
celery_logger.info(
"Worker idle timeout enabled: %s seconds",
celery.conf.worker_idle_timeout_seconds,
)
# Configure Celery and related loggers
log_level = app.config["LOGGING_LEVEL"]
celery_logger.setLevel(log_level)
Expand Down Expand Up @@ -102,6 +317,15 @@ def __call__(self, *args, **kwargs):
if os.environ.get("SLACK_WEBHOOK"):
celery.Task.on_failure = post_to_slack_on_task_failure

# Manually trigger setup_periodic_tasks to ensure it runs with the correct configuration
# The signal handler may have fired before beat_schedules was set, so we call it explicitly here
try:
setup_periodic_tasks(celery)
celery_logger.info("Manually triggered setup_periodic_tasks after create_celery")
except Exception as e:
celery_logger.warning(f"Error manually triggering setup_periodic_tasks: {e}")
# Don't fail if this doesn't work - the signal handler should still work

return celery


Expand Down Expand Up @@ -137,13 +361,25 @@ def days_till_next_month(date):

@celery.on_after_configure.connect
def setup_periodic_tasks(sender, **kwargs):
# remove expired task results in redis broker
sender.add_periodic_task(
crontab(hour=0, minute=0, day_of_week="*", day_of_month="*", month_of_year="*"),
add_backend_cleanup_task(celery),
name="Clean up back end results",
)

# Check if tasks are already registered to avoid duplicates
existing_schedule = sender.conf.get("beat_schedule", {})
cleanup_task_name = "Clean up back end results"

# Check if cleanup task is already registered
cleanup_already_registered = cleanup_task_name in existing_schedule if existing_schedule else False

# Add cleanup task only if not already registered
if not cleanup_already_registered:
# remove expired task results in redis broker
sender.add_periodic_task(
crontab(hour=0, minute=0, day_of_week="*", day_of_month="*", month_of_year="*"),
add_backend_cleanup_task(celery),
name=cleanup_task_name,
)
celery_logger.debug(f"Added cleanup task: {cleanup_task_name}")
else:
celery_logger.debug(f"Cleanup task '{cleanup_task_name}' already registered, skipping.")

# Try to get beat_schedules from celery.conf, fallback to BEAT_SCHEDULES if not found
beat_schedules = celery.conf.get("beat_schedules")
if not beat_schedules:
Expand All @@ -157,8 +393,9 @@ def setup_periodic_tasks(sender, **kwargs):
celery_logger.debug(f"beat_schedules type: {type(beat_schedules)}, length: {len(beat_schedules) if isinstance(beat_schedules, (list, dict)) else 'N/A'}")

if not beat_schedules:
celery_logger.info("No periodic tasks configured.")
celery_logger.debug("No periodic tasks configured yet (beat_schedules empty). Will retry after create_celery.")
return

try:
schedules = CeleryBeatSchema(many=True).load(beat_schedules)
except ValidationError as validation_error:
Expand All @@ -168,17 +405,24 @@ def setup_periodic_tasks(sender, **kwargs):
min_databases = sender.conf.get("MIN_DATABASES")
celery_logger.info(f"MIN_DATABASES: {min_databases}")
for schedule in schedules:
task_name = schedule.get("name")

# Check if this specific task is already registered
if existing_schedule and task_name and task_name in existing_schedule:
celery_logger.debug(f"Task '{task_name}' already registered, skipping duplicate registration.")
continue

try:
task = configure_task(schedule, min_databases)
sender.add_periodic_task(
create_crontab(schedule),
task,
name=schedule["name"],
name=task_name,
)
celery_logger.info(f"Added task: {schedule['name']}")
celery_logger.info(f"Added task: {task_name}")
except ConfigurationError as e:
celery_logger.error(
f"Error configuring task '{schedule.get('name', 'Unknown')}': {str(e)}"
f"Error configuring task '{task_name or 'Unknown'}': {str(e)}"
)


Expand Down
10 changes: 7 additions & 3 deletions materializationengine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ class BaseConfig:
MERGE_TABLES = True
AUTH_SERVICE_NAMESPACE = "datastack"

REDIS_HOST="localhost"
REDIS_PORT=6379
REDIS_PASSWORD=""
CELERY_WORKER_AUTOSHUTDOWN_ENABLED = False
CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS = 1
CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS = 2

REDIS_HOST = "localhost"
REDIS_PORT = 6379
REDIS_PASSWORD = ""
SESSION_TYPE = "redis"
PERMANENT_SESSION_LIFETIME = timedelta(hours=24)
SESSION_PREFIX = "annotation_upload_"
Expand Down
Loading