From ec79e47ed7e04d29e872e12c6420049adc488eaa Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 18:01:16 -0800 Subject: [PATCH 01/23] Update __init__.py,condition.py scheduler.py --- distributed/__init__.py | 1 + distributed/condition.py | 197 ++++++++++ distributed/scheduler.py | 826 ++++++++++----------------------------- 3 files changed, 400 insertions(+), 624 deletions(-) create mode 100644 distributed/condition.py diff --git a/distributed/__init__.py b/distributed/__init__.py index 3f075b977c..091c14e0eb 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -146,3 +146,4 @@ "widgets", "worker_client", ] +from distributed.condition import Condition diff --git a/distributed/condition.py b/distributed/condition.py new file mode 100644 index 0000000000..ad31630815 --- /dev/null +++ b/distributed/condition.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import asyncio +import logging +import uuid +from collections import defaultdict +from contextlib import suppress + +from distributed.utils import log_errors, wait_for, TimeoutError +from distributed.utils import SyncMethodMixin +from distributed.worker import get_client + +logger = logging.getLogger(__name__) + + +class ConditionExtension: + """Scheduler extension for managing distributed Conditions""" + + def __init__(self, scheduler): + self.scheduler = scheduler + # {condition_name: asyncio.Condition} + self._conditions = {} + # {condition_name: set of waiter_ids} + self._waiters = defaultdict(set) + + self.scheduler.handlers.update( + { + "condition_wait": self.wait, + "condition_notify": self.notify, + "condition_acquire": self.acquire, + "condition_release": self.release, + } + ) + + def _get_condition(self, name): + if name not in self._conditions: + self._conditions[name] = asyncio.Condition() + return self._conditions[name] + + @log_errors + async def acquire(self, name=None, id=None): + """Acquire the underlying lock""" + condition = self._get_condition(name) + await condition.acquire() + return True + + @log_errors + async def release(self, name=None, id=None): + """Release the underlying lock""" + if name not in self._conditions: + return False + condition = self._conditions[name] + condition.release() + return True + + @log_errors + async def wait(self, name=None, id=None, timeout=None): + """Wait on condition""" + condition = self._get_condition(name) + self._waiters[name].add(id) + + try: + if timeout: + await asyncio.wait_for(condition.wait(), timeout=timeout) + else: + await condition.wait() + return True + except asyncio.TimeoutError: + return False + finally: + self._waiters[name].discard(id) + # Cleanup if no waiters + if not self._waiters[name]: + with suppress(KeyError): + del self._waiters[name] + with suppress(KeyError): + del self._conditions[name] + + @log_errors + def notify(self, name=None, n=1): + """Notify n waiters""" + if name not in self._conditions: + return 0 + condition = self._conditions[name] + condition.notify(n=n) + return min(n, len(self._waiters.get(name, []))) + + @log_errors + def notify_all(self, name=None): + """Notify all waiters""" + if name not in self._conditions: + return 0 + condition = self._conditions[name] + count = len(self._waiters.get(name, [])) + condition.notify_all() + return count + + +class Condition(SyncMethodMixin): + """Distributed Condition Variable + + Mimics asyncio.Condition API. Allows coordination between + distributed workers using wait/notify pattern. + + Examples + -------- + >>> from distributed import Condition + >>> condition = Condition('my-condition') + >>> async with condition: + ... await condition.wait() # Wait for notification + + >>> # In another worker/client + >>> condition = Condition('my-condition') + >>> async with condition: + ... condition.notify() # Wake one waiter + """ + + def __init__(self, name=None, scheduler_rpc=None, loop=None): + self._scheduler = scheduler_rpc + self._loop = loop + self.name = name or f"condition-{uuid.uuid4().hex}" + self.id = uuid.uuid4().hex + self._locked = False + + def _get_scheduler_rpc(self): + if self._scheduler: + return self._scheduler + try: + client = get_client() + return client.scheduler + except ValueError: + from distributed.worker import get_worker + + worker = get_worker() + return worker.scheduler + + async def acquire(self): + """Acquire underlying lock""" + scheduler = self._get_scheduler_rpc() + result = await scheduler.condition_acquire(name=self.name, id=self.id) + self._locked = result + return result + + async def release(self): + """Release underlying lock""" + if not self._locked: + raise RuntimeError("Cannot release un-acquired lock") + scheduler = self._get_scheduler_rpc() + await scheduler.condition_release(name=self.name, id=self.id) + self._locked = False + + async def wait(self, timeout=None): + """Wait until notified + + Must be called while lock is held. Releases lock and waits + for notify(), then reacquires lock before returning. + """ + if not self._locked: + raise RuntimeError("Cannot wait on un-acquired condition") + + scheduler = self._get_scheduler_rpc() + result = await scheduler.condition_wait(name=self.name, id=self.id, timeout=timeout) + return result + + async def notify(self, n=1): + """Wake up one or more waiters""" + if not self._locked: + raise RuntimeError("Cannot notify on un-acquired condition") + scheduler = self._get_scheduler_rpc() + return await scheduler.condition_notify(name=self.name, n=n) + + async def notify_all(self): + """Wake up all waiters""" + if not self._locked: + raise RuntimeError("Cannot notify on un-acquired condition") + scheduler = self._get_scheduler_rpc() + return await scheduler.condition_notify_all(name=self.name) + + def locked(self): + """Return True if lock is held""" + return self._locked + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.release() + + def __enter__(self): + return self.sync(self.__aenter__) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.sync(self.__aexit__, exc_type, exc_val, exc_tb) + + def __repr__(self): + return f"" diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ea5775aea6..e9b4bec324 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -144,6 +144,7 @@ scatter_to_workers, ) from distributed.variable import VariableExtension +from distributed.condition import ConditionExtension if TYPE_CHECKING: from typing import TypeAlias, TypeVar @@ -181,9 +182,7 @@ logger = logging.getLogger(__name__) LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -DEFAULT_DATA_SIZE = parse_bytes( - dask.config.get("distributed.scheduler.default-data-size") -) +DEFAULT_DATA_SIZE = parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) STIMULUS_ID_UNSET = "" DEFAULT_EXTENSIONS = { @@ -194,6 +193,7 @@ "variables": VariableExtension, "semaphores": SemaphoreExtension, "events": EventExtension, + "conditions": ConditionExtension, "amm": ActiveMemoryManagerExtension, "memory_sampler": MemorySamplerExtension, "shuffle": ShuffleSchedulerPlugin, @@ -407,8 +407,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: return { k: getattr(self, k) for k in dir(self) - if not k.startswith("_") - and k not in {"sum", "managed_in_memory", "managed_spilled"} + if not k.startswith("_") and k not in {"sum", "managed_in_memory", "managed_spilled"} } @@ -580,9 +579,7 @@ def __hash__(self) -> int: return self._hash def __eq__(self, other: object) -> bool: - return self is other or ( - isinstance(other, WorkerState) and other.server_id == self.server_id - ) + return self is other or (isinstance(other, WorkerState) and other.server_id == self.server_id) @property def has_what(self) -> Set[TaskState]: @@ -833,9 +830,7 @@ def _dec_needs_replica(self, ts: TaskState) -> None: nbytes = ts.get_nbytes() # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift self._network_occ -= min(nbytes, self._network_occ) - self.scheduler._network_occ_global -= min( - nbytes, self.scheduler._network_occ_global - ) + self.scheduler._network_occ_global -= min(nbytes, self.scheduler._network_occ_global) def add_replica(self, ts: TaskState) -> None: """The worker acquired a replica of task""" @@ -848,18 +843,14 @@ def add_replica(self, ts: TaskState) -> None: del self.needs_what[ts] # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift self._network_occ -= min(nbytes, self._network_occ) - self.scheduler._network_occ_global -= min( - nbytes, self.scheduler._network_occ_global - ) + self.scheduler._network_occ_global -= min(nbytes, self.scheduler._network_occ_global) ts.who_has.add(self) self.nbytes += nbytes self._has_what[ts] = None @property def occupancy(self) -> float: - return self._occupancy_cache or self.scheduler._calc_occupancy( - self.task_prefix_count, self._network_occ - ) + return self._occupancy_cache or self.scheduler._calc_occupancy(self.task_prefix_count, self._network_occ) @dataclasses.dataclass @@ -921,9 +912,7 @@ def __repr__(self) -> str: return ( f"" ) @@ -981,10 +970,7 @@ def all_durations(self) -> defaultdict[str, float]: """Cumulative duration of all completed actions of tasks belonging to this collection, by action""" return defaultdict( float, - { - action: duration_us / 1e6 - for action, duration_us in self._all_durations_us.items() - }, + {action: duration_us / 1e6 for action, duration_us in self._all_durations_us.items()}, ) @property @@ -1089,13 +1075,7 @@ def active_states(self) -> dict[TaskStateState, int]: def __repr__(self) -> str: return ( - "<" - + self.name - + ": " - + ", ".join( - "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v - ) - + ">" + "<" + self.name + ": " + ", ".join("%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v) + ">" ) @@ -1185,9 +1165,7 @@ def __repr__(self) -> str: "<" + (self.name or "no-group") + ": " - + ", ".join( - "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v - ) + + ", ".join("%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v) + ">" ) @@ -1215,8 +1193,7 @@ def done(self) -> bool: recomputed. """ return all( - count == 0 or state in {"memory", "erred", "released", "forgotten"} - for state, count in self.states.items() + count == 0 or state in {"memory", "erred", "released", "forgotten"} for state, count in self.states.items() ) @@ -1775,15 +1752,9 @@ def __init__( self.resources = resources self.saturated = set() self.tasks = tasks - self.replicated_tasks = { - ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1 - } - self.computations = deque( - maxlen=dask.config.get("distributed.diagnostics.computations.max-history") - ) - self.erred_tasks = deque( - maxlen=dask.config.get("distributed.diagnostics.erred-tasks.max-history") - ) + self.replicated_tasks = {ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1} + self.computations = deque(maxlen=dask.config.get("distributed.diagnostics.computations.max-history")) + self.erred_tasks = deque(maxlen=dask.config.get("distributed.diagnostics.erred-tasks.max-history")) self.task_groups = {} self.task_prefixes = {} self.task_metadata = {} @@ -1796,61 +1767,38 @@ def __init__( self.workers = workers self._task_prefix_count_global = defaultdict(int) self._network_occ_global = 0 - self.running = { - ws for ws in self.workers.values() if ws.status == Status.running - } + self.running = {ws for ws in self.workers.values() if ws.status == Status.running} self.plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} - self.transition_log = deque( - maxlen=dask.config.get("distributed.admin.low-level-log-length") - ) + self.transition_log = deque(maxlen=dask.config.get("distributed.admin.low-level-log-length")) self.transition_counter = 0 self._idle_transition_counter = 0 self.transition_counter_max = transition_counter_max # Variables from dask.config, cached by __init__ for performance - self.UNKNOWN_TASK_DURATION = parse_timedelta( - dask.config.get("distributed.scheduler.unknown-task-duration") - ) + self.UNKNOWN_TASK_DURATION = parse_timedelta(dask.config.get("distributed.scheduler.unknown-task-duration")) self.MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( dask.config.get("distributed.worker.memory.recent-to-old-time") ) - self.MEMORY_REBALANCE_MEASURE = dask.config.get( - "distributed.worker.memory.rebalance.measure" - ) - self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get( - "distributed.worker.memory.rebalance.sender-min" - ) - self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( - "distributed.worker.memory.rebalance.recipient-max" - ) + self.MEMORY_REBALANCE_MEASURE = dask.config.get("distributed.worker.memory.rebalance.measure") + self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get("distributed.worker.memory.rebalance.sender-min") + self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get("distributed.worker.memory.rebalance.recipient-max") self.MEMORY_REBALANCE_HALF_GAP = ( - dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") - / 2.0 + dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") / 2.0 ) - self.WORKER_SATURATION = dask.config.get( - "distributed.scheduler.worker-saturation" - ) + self.WORKER_SATURATION = dask.config.get("distributed.scheduler.worker-saturation") if self.WORKER_SATURATION == "inf": # Special case necessary because there's no way to parse a float infinity # from a DASK_* environment variable self.WORKER_SATURATION = math.inf - if ( - not isinstance(self.WORKER_SATURATION, (int, float)) - or self.WORKER_SATURATION <= 0 - ): + if not isinstance(self.WORKER_SATURATION, (int, float)) or self.WORKER_SATURATION <= 0: raise ValueError( # pragma: nocover - "`distributed.scheduler.worker-saturation` must be a float > 0; got " - + repr(self.WORKER_SATURATION) + "`distributed.scheduler.worker-saturation` must be a float > 0; got " + repr(self.WORKER_SATURATION) ) - self.rootish_tg_threshold = dask.config.get( - "distributed.scheduler.rootish-taskgroup" - ) - self.rootish_tg_dependencies_threshold = dask.config.get( - "distributed.scheduler.rootish-taskgroup-dependencies" - ) + self.rootish_tg_threshold = dask.config.get("distributed.scheduler.rootish-taskgroup") + self.rootish_tg_dependencies_threshold = dask.config.get("distributed.scheduler.rootish-taskgroup-dependencies") @abstractmethod def log_event(self, topic: str | Collection[str], msg: Any) -> None: ... @@ -1984,9 +1932,7 @@ def _calc_occupancy( # State Transitions # ##################### - def _transition( - self, key: Key, finish: TaskStateState, stimulus_id: str, **kwargs: Any - ) -> RecsMsgs: + def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwargs: Any) -> RecsMsgs: """Transition a key from its current state to the finish state Examples @@ -2032,15 +1978,11 @@ def _transition( func = self._TRANSITIONS_TABLE.get((start, finish)) if func is not None: - recommendations, client_msgs, worker_msgs = func( - self, key, stimulus_id, **kwargs - ) + recommendations, client_msgs, worker_msgs = func(self, key, stimulus_id, **kwargs) elif "released" not in (start, finish): assert not kwargs, (kwargs, start, finish) - a_recs, a_cmsgs, a_wmsgs = self._transition( - key, "released", stimulus_id - ) + a_recs, a_cmsgs, a_wmsgs = self._transition(key, "released", stimulus_id) v = a_recs.get(key, finish) # The inner rec has higher priority? Is that always desired? @@ -2070,16 +2012,10 @@ def _transition( stimulus_id = STIMULUS_ID_UNSET actual_finish = ts._state - self.transition_log.append( - Transition( - key, start, actual_finish, recommendations, stimulus_id, time() - ) - ) + self.transition_log.append(Transition(key, start, actual_finish, recommendations, stimulus_id, time())) if self.validate: if stimulus_id == STIMULUS_ID_UNSET: - raise RuntimeError( - "stimulus_id not set during Scheduler transition" - ) + raise RuntimeError("stimulus_id not set during Scheduler transition") logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -2096,9 +2032,7 @@ def _transition( self.tasks[ts.key] = ts for plugin in list(self.plugins.values()): try: - plugin.transition( - key, start, actual_finish, stimulus_id=stimulus_id, **kwargs - ) + plugin.transition(key, start, actual_finish, stimulus_id=stimulus_id, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts.state == "forgotten": @@ -2282,9 +2216,7 @@ def _transition_queued_erred( traceback_text=traceback_text, ) - def decide_worker_rootish_queuing_disabled( - self, ts: TaskState - ) -> WorkerState | None: + def decide_worker_rootish_queuing_disabled(self, ts: TaskState) -> WorkerState | None: """Pick a worker for a runnable root-ish task, without queuing. This attempts to schedule sibling tasks on the same worker, reducing future data @@ -2315,25 +2247,16 @@ def decide_worker_rootish_queuing_disabled( tg = ts.group lws = tg.last_worker - if ( - lws - and tg.last_worker_tasks_left - and lws.status == Status.running - and self.workers.get(lws.address) is lws - ): + if lws and tg.last_worker_tasks_left and lws.status == Status.running and self.workers.get(lws.address) is lws: ws = lws else: # Last-used worker is full, unknown, retiring, or paused; # pick a new worker for the next few tasks ws = min(pool, key=partial(self.worker_objective, ts)) - tg.last_worker_tasks_left = math.floor( - (len(tg) / self.total_nthreads) * ws.nthreads - ) + tg.last_worker_tasks_left = math.floor((len(tg) / self.total_nthreads) * ws.nthreads) # Record `last_worker`, or clear it on the final task - tg.last_worker = ( - ws if tg.states["released"] + tg.states["waiting"] > 1 else None - ) + tg.last_worker = ws if tg.states["released"] + tg.states["waiting"] > 1 else None tg.last_worker_tasks_left -= 1 if self.validate and ws is not None: @@ -2574,9 +2497,7 @@ def _transition_processing_memory( recommendations: Recs = {} client_msgs: Msgs = {} - self._add_to_memory( - ts, ws, recommendations, client_msgs, type=type, typename=typename - ) + self._add_to_memory(ts, ws, recommendations, client_msgs, type=type, typename=typename) if self.validate: assert not ts.processing_on @@ -2596,9 +2517,7 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: ws.actors.discard(ts) if ts.who_wants: ts.exception_blame = ts - ts.exception = Serialized( - *serialize(RuntimeError("Worker holding Actor was lost")) - ) + ts.exception = Serialized(*serialize(RuntimeError("Worker holding Actor was lost"))) return {ts.key: "erred"}, {}, {} # don't try to recreate recommendations: Recs = {} @@ -2625,9 +2544,7 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: recommendations[key] = "forgotten" elif ts.has_lost_dependencies: recommendations[key] = "forgotten" - elif (ts.who_wants or ts.waiters) and not any( - dts.state == "erred" for dts in ts.dependencies - ): + elif (ts.who_wants or ts.waiters) and not any(dts.state == "erred" for dts in ts.dependencies): recommendations[key] = "waiting" for dts in ts.waiters or (): @@ -3007,9 +2924,7 @@ def _transition_memory_erred(self, key: Key, stimulus_id: str) -> RecsMsgs: if not dts.who_has: dts.exception_blame = ts recommendations[dts.key] = "erred" - exception = Serialized( - *serialize(RuntimeError("Worker holding Actor was lost")) - ) + exception = Serialized(*serialize(RuntimeError("Worker holding Actor was lost"))) report_msg = { "op": "task-erred", "key": key, @@ -3122,14 +3037,9 @@ def _transition_released_forgotten(self, key: Key, stimulus_id: str) -> RecsMsgs ("released", "erred"): _transition_released_erred, } - def story( - self, *keys_or_tasks_or_stimuli: Key | TaskState | str - ) -> list[Transition]: + def story(self, *keys_or_tasks_or_stimuli: Key | TaskState | str) -> list[Transition]: """Get all transitions that touch one of the input keys or stimulus_id's""" - keys_or_stimuli = { - key.key if isinstance(key, TaskState) else key - for key in keys_or_tasks_or_stimuli - } + keys_or_stimuli = {key.key if isinstance(key, TaskState) else key for key in keys_or_tasks_or_stimuli} return scheduler_story(keys_or_stimuli, self.transition_log) ############################## @@ -3205,14 +3115,9 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None: else: self.idle_task_count.discard(ws) - def is_unoccupied( - self, ws: WorkerState, occupancy: float, nprocessing: int - ) -> bool: + def is_unoccupied(self, ws: WorkerState, occupancy: float, nprocessing: int) -> bool: nthreads = ws.nthreads - return ( - nprocessing < nthreads - or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 - ) + return nprocessing < nthreads or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: """ @@ -3388,9 +3293,7 @@ def _validate_ready(self, ts: TaskState) -> None: assert ts not in self.queued assert all(dts.who_has for dts in ts.dependencies) - def _add_to_processing( - self, ts: TaskState, ws: WorkerState, stimulus_id: str - ) -> RecsMsgs: + def _add_to_processing(self, ts: TaskState, ws: WorkerState, stimulus_id: str) -> RecsMsgs: """Set a task as processing on a worker and return the worker messages to send""" if self.validate: self._validate_ready(ts) @@ -3408,11 +3311,7 @@ def _add_to_processing( ws.actors.add(ts) ndep_bytes = sum(dts.nbytes for dts in ts.dependencies) - if ( - ws.memory_limit - and ndep_bytes > ws.memory_limit - and dask.config.get("distributed.worker.memory.terminate") - ): + if ws.memory_limit and ndep_bytes > ws.memory_limit and dask.config.get("distributed.worker.memory.terminate"): # Note # ---- # This is a crude safety system, only meant to prevent order-of-magnitude @@ -3620,10 +3519,7 @@ def _task_to_msg(self, ts: TaskState) -> dict[str, Any]: "run_id": ts.run_id, "priority": ts.priority, "stimulus_id": f"compute-task-{time()}", - "who_has": { - dts.key: tuple(ws.address for ws in (dts.who_has or ())) - for dts in ts.dependencies - }, + "who_has": {dts.key: tuple(ws.address for ws in (dts.who_has or ())) for dts in ts.dependencies}, "nbytes": {dts.key: dts.nbytes for dts in ts.dependencies}, "run_spec": ToPickle(ts.run_spec), "resource_restrictions": ts.resource_restrictions, @@ -3795,16 +3691,10 @@ def __init__( self.services = {} self.scheduler_file = scheduler_file - self.worker_ttl = parse_timedelta( - worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") - ) - self.idle_timeout = parse_timedelta( - idle_timeout or dask.config.get("distributed.scheduler.idle-timeout") - ) + self.worker_ttl = parse_timedelta(worker_ttl or dask.config.get("distributed.scheduler.worker-ttl")) + self.idle_timeout = parse_timedelta(idle_timeout or dask.config.get("distributed.scheduler.idle-timeout")) self.idle_since = time() - self.no_workers_timeout = parse_timedelta( - dask.config.get("distributed.scheduler.no-workers-timeout") - ) + self.no_workers_timeout = parse_timedelta(dask.config.get("distributed.scheduler.no-workers-timeout")) self._no_workers_since = None self.time_started = self.idle_since # compatibility for dask-gateway @@ -3852,24 +3742,17 @@ def __init__( except ImportError: show_dashboard = False http_server_modules.append("distributed.http.scheduler.missing_bokeh") - routes = get_handlers( - server=self, modules=http_server_modules, prefix=http_prefix - ) + routes = get_handlers(server=self, modules=http_server_modules, prefix=http_prefix) self.start_http_server(routes, dashboard_address, default_port=8787) self.jupyter = jupyter if show_dashboard: - distributed.dashboard.scheduler.connect( - self.http_application, self.http_server, self, prefix=http_prefix - ) + distributed.dashboard.scheduler.connect(self.http_application, self.http_server, self, prefix=http_prefix) scheduler = self if self.jupyter: try: from jupyter_server.serverapp import ServerApp except ImportError: - raise ImportError( - "In order to use the Dask jupyter option you " - "need to have jupyterlab installed" - ) + raise ImportError("In order to use the Dask jupyter option you need to have jupyterlab installed") from traitlets.config import Config """HTTP handler to shut down the Jupyter server. @@ -3917,9 +3800,7 @@ async def post(self) -> None: argv=[], ) self._jupyter_server_application = j - shutdown_app = tornado.web.Application( - [(r"/jupyter/api/shutdown", ShutdownHandler)] - ) + shutdown_app = tornado.web.Application([(r"/jupyter/api/shutdown", ShutdownHandler)]) shutdown_app.settings = j.web_app.settings self.http_application.add_application(shutdown_app) self.http_application.add_application(j.web_app) @@ -4139,8 +4020,7 @@ def identity(self, n_workers: int = -1) -> dict[str, Any]: "total_threads": self.total_nthreads, "total_memory": self.total_memory, "workers": { - worker.address: worker.identity() - for worker in itertools.islice(self.workers.values(), n_workers) + worker.address: worker.identity() for worker in itertools.islice(self.workers.values(), n_workers) }, } return d @@ -4198,10 +4078,7 @@ async def get_cluster_state( workers_future.cancel() # Convert any RPC errors to strings - worker_states = { - k: repr(v) if isinstance(v, Exception) else v - for k, v in worker_states.items() - } + worker_states = {k: repr(v) if isinstance(v, Exception) else v for k, v in worker_states.items()} return { "scheduler": scheduler_state, @@ -4217,9 +4094,7 @@ async def dump_cluster_state_to_url( **storage_options: dict[str, Any], ) -> None: "Write a cluster state dump to an fsspec-compatible URL." - await cluster_dump.write_state( - partial(self.get_cluster_state, exclude), url, format, **storage_options - ) + await cluster_dump.write_state(partial(self.get_cluster_state, exclude), url, format, **storage_options) def get_worker_service_addr( self, worker: str, service_name: str, protocol: bool = False @@ -4287,9 +4162,7 @@ async def start_unsafe(self) -> Self: # formatting dashboard link can fail if distributed.dashboard.link # refers to non-existent env vars. except KeyError as e: - logger.warning( - f"Failed to format dashboard link, unknown value: {e}" - ) + logger.warning(f"Failed to format dashboard link, unknown value: {e}") link = f":{server.port}" else: link = f"{listen_ip}:{server.port}" @@ -4315,9 +4188,7 @@ def del_scheduler_file() -> None: await self.listen("tcp://localhost:0") os.environ["DASK_SCHEDULER_ADDRESS"] = self.listeners[-1].contact_address - await asyncio.gather( - *[plugin.start(self) for plugin in list(self.plugins.values())] - ) + await asyncio.gather(*[plugin.start(self) for plugin in list(self.plugins.values())]) self.start_periodic_callbacks() @@ -4349,15 +4220,11 @@ async def log_errors(func: Callable) -> None: except Exception: logger.exception("Plugin call failed during scheduler.close") - await asyncio.gather( - *[log_errors(plugin.before_close) for plugin in list(self.plugins.values())] - ) + await asyncio.gather(*[log_errors(plugin.before_close) for plugin in list(self.plugins.values())]) await self.preloads.teardown() - await asyncio.gather( - *[log_errors(plugin.close) for plugin in list(self.plugins.values())] - ) + await asyncio.gather(*[log_errors(plugin.close) for plugin in list(self.plugins.values())]) for pc in self.periodic_callbacks.values(): pc.stop() @@ -4430,25 +4297,21 @@ def heartbeat_worker( dh["last-seen"] = local_now frac = 1 / len(self.workers) - self.bandwidth = ( - self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac - ) + self.bandwidth = self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac for other, (bw, count) in metrics["bandwidth"]["workers"].items(): if (address, other) not in self.bandwidth_workers: self.bandwidth_workers[address, other] = bw / count else: alpha = (1 - frac) ** count - self.bandwidth_workers[address, other] = self.bandwidth_workers[ - address, other - ] * alpha + bw * (1 - alpha) + self.bandwidth_workers[address, other] = self.bandwidth_workers[address, other] * alpha + bw * ( + 1 - alpha + ) for typ, (bw, count) in metrics["bandwidth"]["types"].items(): if typ not in self.bandwidth_types: self.bandwidth_types[typ] = bw / count else: alpha = (1 - frac) ** count - self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( - 1 - alpha - ) + self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * (1 - alpha) ws.last_seen = local_now if executing is not None: @@ -4479,9 +4342,7 @@ def heartbeat_worker( # ws._nbytes is updated at a different time and sizeof() may not be accurate, # so size may be (temporarily) negative; floor it to zero. - size = max( - 0, metrics["memory"] - ws.nbytes + metrics["spilled_bytes"]["memory"] - ) + size = max(0, metrics["memory"] - ws.nbytes + metrics["spilled_bytes"]["memory"]) ws._memory_unmanaged_history.append((local_now, size)) if not memory_unmanaged_old: @@ -4626,9 +4487,7 @@ async def add_worker( logger.exception(exc, exc_info=exc) if ws.status == Status.running: - self.transitions( - self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id - ) + self.transitions(self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) logger.info("Register worker addr: %s name: %s", ws.address, ws.name) @@ -4812,9 +4671,7 @@ def _create_taskstate_from_graph( # _generate_taskstates is not the only thing that calls new_task(). A # TaskState may have also been created by client_desires_keys or scatter, # and only later gained a run_spec. - span_annotations = spans_ext.observe_tasks( - touched_tasks, span_metadata=span_metadata, code=code - ) + span_annotations = spans_ext.observe_tasks(touched_tasks, span_metadata=span_metadata, code=code) # In case of TaskGroup collision, spans may have changed # FIXME: Is this used anywhere besides tests? if span_annotations: @@ -4920,9 +4777,7 @@ async def update_graph( }, client=client, ) - self.client_releases_keys( - keys=keys, client=client, stimulus_id=stimulus_id - ) + self.client_releases_keys(keys=keys, client=client, stimulus_id=stimulus_id) evt_msg = { "action": "update-graph", "stimulus_id": stimulus_id, @@ -4955,8 +4810,7 @@ async def update_graph( "start_timestamp_seconds": start, "materialization_duration_seconds": materialization_done - start, "ordering_duration_seconds": materialization_done - ordering_done, - "state_initialization_duration_seconds": ordering_done - - task_state_created, + "state_initialization_duration_seconds": ordering_done - task_state_created, "duration_seconds": task_state_created - start, } ) @@ -5209,9 +5063,7 @@ def _set_priorities( ) if self.validate and istask(ts.run_spec): - assert isinstance(ts.priority, tuple) and all( - isinstance(el, (int, float)) for el in ts.priority - ) + assert isinstance(ts.priority, tuple) and all(isinstance(el, (int, float)) for el in ts.priority) def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """Respond to an event which may have opened spots on worker threadpools @@ -5230,10 +5082,7 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """ if not self.queued: return - slots_available = sum( - _task_slots_available(ws, self.WORKER_SATURATION) - for ws in self.idle_task_count - ) + slots_available = sum(_task_slots_available(ws, self.WORKER_SATURATION) for ws in self.idle_task_count) if slots_available == 0: return @@ -5255,9 +5104,7 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: assert qts.state == "processing" assert not self.queued or self.queued.peek() != qts - def stimulus_task_finished( - self, key: Key, worker: str, stimulus_id: str, run_id: int, **kwargs: Any - ) -> RecsMsgs: + def stimulus_task_finished(self, key: Key, worker: str, stimulus_id: str, run_id: int, **kwargs: Any) -> RecsMsgs: """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s[%d] %s", key, run_id, worker) @@ -5268,8 +5115,7 @@ def stimulus_task_finished( ts = self.tasks.get(key) if ts is None or ts.state in ("released", "queued", "no-worker"): logger.debug( - "Received already computed task, worker: %s, state: %s" - ", key: %s, who_has: %s", + "Received already computed task, worker: %s, state: %s, key: %s, who_has: %s", worker, ts.state if ts else "forgotten", key, @@ -5284,7 +5130,7 @@ def stimulus_task_finished( ] elif ts.state == "erred": logger.debug( - "Received already erred task, worker: %s" ", key: %s", + "Received already erred task, worker: %s, key: %s", worker, key, ) @@ -5361,9 +5207,7 @@ def stimulus_task_erred( **kwargs, ) - def stimulus_retry( - self, keys: Collection[Key], client: str | None = None - ) -> tuple[Key, ...]: + def stimulus_retry(self, keys: Collection[Key], client: str | None = None) -> tuple[Key, ...]: logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -5441,14 +5285,10 @@ async def remove_worker( ws = self.workers[address] - logger.info( - f"Remove worker addr: {ws.address} name: {ws.name} ({stimulus_id=})" - ) + logger.info(f"Remove worker addr: {ws.address} name: {ws.name} ({stimulus_id=})") if close: with suppress(AttributeError, CommClosedError): - self.stream_comms[address].send( - {"op": "close", "reason": "scheduler-remove-worker"} - ) + self.stream_comms[address].send({"op": "close", "reason": "scheduler-remove-worker"}) self.remove_resources(address) @@ -5502,8 +5342,7 @@ async def remove_worker( ) recommendations.update(r) logger.error( - "Task %s marked as failed because %d workers died" - " while trying to run it", + "Task %s marked as failed because %d workers died while trying to run it", ts.key, ts.suspicious, ) @@ -5556,9 +5395,7 @@ async def remove_worker( for plugin in list(self.plugins.values()): try: try: - result = plugin.remove_worker( - scheduler=self, worker=address, stimulus_id=stimulus_id - ) + result = plugin.remove_worker(scheduler=self, worker=address, stimulus_id=stimulus_id) except TypeError: parameters = inspect.signature(plugin.remove_worker).parameters if "stimulus_id" not in parameters and not any( @@ -5590,13 +5427,9 @@ async def remove_worker_from_events() -> None: if address not in self.workers: self._broker.truncate(address) - cleanup_delay = parse_timedelta( - dask.config.get("distributed.scheduler.events-cleanup-delay") - ) + cleanup_delay = parse_timedelta(dask.config.get("distributed.scheduler.events-cleanup-delay")) - self._ongoing_background_tasks.call_later( - cleanup_delay, remove_worker_from_events - ) + self._ongoing_background_tasks.call_later(cleanup_delay, remove_worker_from_events) logger.debug("Removed worker %s", ws) for w in self.workers: @@ -5611,9 +5444,7 @@ async def remove_worker_from_events() -> None: return "OK" - def stimulus_cancel( - self, keys: Collection[Key], client: str, force: bool, reason: str, msg: str - ) -> None: + def stimulus_cancel(self, keys: Collection[Key], client: str, force: bool, reason: str, msg: str) -> None: """Stop execution on a list of keys""" logger.info("Client %s requests to cancel %d keys", client, len(keys)) self.log_event(client, {"action": "cancel", "count": len(keys), "force": force}) @@ -5675,9 +5506,7 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None: if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - def client_releases_keys( - self, keys: Collection[Key], client: str, stimulus_id: str | None = None - ) -> None: + def client_releases_keys(self, keys: Collection[Key], client: str, stimulus_id: str | None = None) -> None: """Remove keys from client desired list""" stimulus_id = stimulus_id or f"client-releases-keys-{time()}" if not isinstance(keys, list): @@ -5736,9 +5565,7 @@ def validate_queued(self, key: Key) -> None: assert not ts.waiting_on assert not ts.who_has assert not ts.processing_on - assert not ( - ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions - ) + assert not (ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions) for dts in ts.dependencies: assert dts.who_has assert ts in (dts.waiters or ()) @@ -5764,9 +5591,7 @@ def validate_memory(self, key: Key) -> None: assert ts not in self.unrunnable assert ts not in self.queued for dts in ts.dependents: - assert (dts in (ts.waiters or ())) == ( - dts.state in ("waiting", "queued", "processing", "no-worker") - ) + assert (dts in (ts.waiters or ())) == (dts.state in ("waiting", "queued", "processing", "no-worker")) assert ts not in (dts.waiting_on or ()) def validate_no_worker(self, key: Key) -> None: @@ -5797,9 +5622,7 @@ def validate_key(self, key: Key, ts: TaskState | None = None) -> None: try: func = getattr(self, "validate_" + ts.state.replace("-", "_")) except AttributeError: - logger.error( - "self.validate_%s not found", ts.state.replace("-", "_") - ) + logger.error("self.validate_%s not found", ts.state.replace("-", "_")) else: func(key) except Exception as e: @@ -5865,9 +5688,9 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert task_prefix_counts.keys() == self._task_prefix_count_global.keys() for name, global_count in self._task_prefix_count_global.items(): - assert ( - task_prefix_counts[name] == global_count - ), f"{name}: {task_prefix_counts[name]} (wss), {global_count} (global)" + assert task_prefix_counts[name] == global_count, ( + f"{name}: {task_prefix_counts[name]} (wss), {global_count} (global)" + ) for ws in self.running: assert ws.status == Status.running @@ -5890,10 +5713,7 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert cs.client_key == c a = {w: ws.nbytes for w, ws in self.workers.items()} - b = { - w: sum(ts.get_nbytes() for ts in ws.has_what) - for w, ws in self.workers.items() - } + b = {w: sum(ts.get_nbytes() for ts in ws.has_what) for w, ws in self.workers.items()} assert a == b, (a, b) if self.transition_counter_max: @@ -5903,9 +5723,7 @@ def validate_state(self, allow_overlap: bool = False) -> None: # Manage Messages # ################### - def report( - self, msg: dict, ts: TaskState | None = None, client: str | None = None - ) -> None: + def report(self, msg: dict, ts: TaskState | None = None, client: str | None = None) -> None: """ Publish updates to all listening Queues and Comms @@ -5927,9 +5745,7 @@ def report( # Notify clients interested in key (including `client`) # Note that, if report() was called by update_graph(), `client` won't be in # ts.who_wants yet. - client_keys = [ - cs.client_key for cs in ts.who_wants or () if cs.client_key != client - ] + client_keys = [cs.client_key for cs in ts.who_wants or () if cs.client_key != client] if client is not None: client_keys.append(client) @@ -5942,13 +5758,9 @@ def report( # logger.debug("Scheduler sends message to client %s: %s", k, msg) except CommClosedError: if self.status == Status.running: - logger.critical( - "Closed comm %r while trying to write %s", c, msg, exc_info=True - ) + logger.critical("Closed comm %r while trying to write %s", c, msg, exc_info=True) - async def add_client( - self, comm: Comm, client: str, versions: dict[str, Any] - ) -> None: + async def add_client(self, comm: Comm, client: str, versions: dict[str, Any]) -> None: """Add client to network We listen to all future messages from this Comm. @@ -6026,13 +5838,9 @@ async def remove_client_from_events() -> None: if client not in self.clients: self._broker.truncate(client) - cleanup_delay = parse_timedelta( - dask.config.get("distributed.scheduler.events-cleanup-delay") - ) + cleanup_delay = parse_timedelta(dask.config.get("distributed.scheduler.events-cleanup-delay")) if not self._ongoing_background_tasks.closed: - self._ongoing_background_tasks.call_later( - cleanup_delay, remove_client_from_events - ) + self._ongoing_background_tasks.call_later(cleanup_delay, remove_client_from_events) def send_task_to_worker(self, worker: str, ts: TaskState) -> None: """Send a single computational task to a worker""" @@ -6050,17 +5858,13 @@ def send_task_to_worker(self, worker: str, ts: TaskState) -> None: def handle_uncaught_error(self, **msg: Any) -> None: logger.exception(clean_exception(**msg)[1]) - def handle_task_finished( - self, key: Key, worker: str, stimulus_id: str, **msg: Any - ) -> None: + def handle_task_finished(self, key: Key, worker: str, stimulus_id: str, **msg: Any) -> None: if worker not in self.workers: return if self.validate: self.validate_key(key) - r: tuple = self.stimulus_task_finished( - key=key, worker=worker, stimulus_id=stimulus_id, **msg - ) + r: tuple = self.stimulus_task_finished(key=key, worker=worker, stimulus_id=stimulus_id, **msg) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) @@ -6099,9 +5903,7 @@ def handle_long_running( duration accounting as if the task has stopped. """ if worker not in self.workers: - logger.debug( - "Received long-running signal from unknown worker %s. Ignoring.", worker - ) + logger.debug("Received long-running signal from unknown worker %s. Ignoring.", worker) return if key not in self.tasks: @@ -6139,9 +5941,7 @@ def handle_long_running( self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) - def handle_worker_status_change( - self, status: str | Status, worker: str | WorkerState, stimulus_id: str - ) -> None: + def handle_worker_status_change(self, status: str | Status, worker: str | WorkerState, stimulus_id: str) -> None: ws = self.workers.get(worker) if isinstance(worker, str) else worker if not ws: return @@ -6164,9 +5964,7 @@ def handle_worker_status_change( if ws.status == Status.running: self.running.add(ws) self.check_idle_saturated(ws) - self.transitions( - self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id - ) + self.transitions(self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) else: self.running.discard(ws) @@ -6175,9 +5973,7 @@ def handle_worker_status_change( self.saturated.discard(ws) self._refresh_no_workers_since() - def handle_request_refresh_who_has( - self, keys: Iterable[Key], worker: str, stimulus_id: str - ) -> None: + def handle_request_refresh_who_has(self, keys: Iterable[Key], worker: str, stimulus_id: str) -> None: """Request from a Worker to refresh the who_has for some keys. Not to be confused with scheduler.who_has, which is a dedicated comm RPC request from a Client. @@ -6226,9 +6022,7 @@ async def handle_worker(self, comm: Comm, worker: str) -> None: finally: if worker in self.stream_comms: worker_comm.abort() - await self.remove_worker( - worker, stimulus_id=f"handle-worker-cleanup-{time()}" - ) + await self.remove_worker(worker, stimulus_id=f"handle-worker-cleanup-{time()}") def add_plugin( self, @@ -6289,9 +6083,7 @@ def remove_plugin(self, name: str | None = None) -> None: try: del self.plugins[name] except KeyError: - raise ValueError( - f"Could not find plugin {name!r} among the current scheduler plugins" - ) + raise ValueError(f"Could not find plugin {name!r} among the current scheduler plugins") async def register_scheduler_plugin( self, @@ -6353,9 +6145,7 @@ def client_send(self, client: str, msg: dict) -> None: c.send(msg) except CommClosedError: if self.status == Status.running: - logger.critical( - "Closed comm %r while trying to write %s", c, msg, exc_info=True - ) + logger.critical("Closed comm %r while trying to write %s", c, msg, exc_info=True) def send_all(self, client_msgs: Msgs, worker_msgs: Msgs) -> None: """Send messages to client and workers""" @@ -6433,14 +6223,10 @@ async def scatter( n = len(workers) if broadcast is True else broadcast await self.replicate(keys=keys, workers=workers, n=n) - self.log_event( - [client, "all"], {"action": "scatter", "client": client, "count": len(data)} - ) + self.log_event([client, "all"], {"action": "scatter", "client": client, "count": len(data)}) return keys - async def gather( - self, keys: Collection[Key], serializers: list[str] | None = None - ) -> dict[Key, object]: + async def gather(self, keys: Collection[Key], serializers: list[str] | None = None) -> dict[Key, object]: """Collect data from workers to the scheduler""" data = {} missing_keys = list(keys) @@ -6461,9 +6247,7 @@ async def gather( missing_keys, new_failed_keys, new_missing_workers, - ) = await gather_from_workers( - who_has, rpc=self.rpc, serializers=serializers - ) + ) = await gather_from_workers(who_has, rpc=self.rpc, serializers=serializers) data.update(new_data) failed_keys += new_failed_keys missing_workers.update(new_missing_workers) @@ -6473,10 +6257,7 @@ async def gather( if not failed_keys: return {"status": "OK", "data": data} - failed_states = { - key: self.tasks[key].state if key in self.tasks else "forgotten" - for key in failed_keys - } + failed_states = {key: self.tasks[key].state if key in self.tasks else "forgotten" for key in failed_keys} logger.error("Couldn't gather keys: %s", failed_states) return {"status": "error", "keys": list(failed_keys)} @@ -6586,24 +6367,16 @@ async def restart_workers( workers = list(set(workers).intersection(self.workers)) logger.info(f"Restarting {len(workers)} workers: {workers} ({stimulus_id=}") - nanny_workers = { - addr: self.workers[addr].nanny - for addr in workers - if self.workers[addr].nanny - } + nanny_workers = {addr: self.workers[addr].nanny for addr in workers if self.workers[addr].nanny} # Close non-Nanny workers. We have no way to restart them, so we just let them # go, and assume a deployment system is going to restart them for us. no_nanny_workers = [addr for addr in workers if addr not in nanny_workers] if no_nanny_workers: logger.warning( - f"Workers {no_nanny_workers} do not use a nanny and will be terminated " - "without restarting them" + f"Workers {no_nanny_workers} do not use a nanny and will be terminated without restarting them" ) await asyncio.gather( - *( - self.remove_worker(address=addr, stimulus_id=stimulus_id) - for addr in no_nanny_workers - ) + *(self.remove_worker(address=addr, stimulus_id=stimulus_id) for addr in no_nanny_workers) ) out: dict[str, Literal["OK", "removed", "timed out"]] out = {addr: "removed" for addr in no_nanny_workers} @@ -6613,9 +6386,7 @@ async def restart_workers( async with contextlib.AsyncExitStack() as stack: nannies = await asyncio.gather( *( - stack.enter_async_context( - rpc(nanny_address, connection_args=self.connection_args) - ) + stack.enter_async_context(rpc(nanny_address, connection_args=self.connection_args)) for nanny_address in nanny_workers.values() ) ) @@ -6651,16 +6422,8 @@ async def restart_workers( raise resp if bad_nannies: - logger.error( - f"Workers {list(bad_nannies)} did not shut down within {timeout}s; " - "force closing" - ) - await asyncio.gather( - *( - self.remove_worker(addr, stimulus_id=stimulus_id) - for addr in bad_nannies - ) - ) + logger.error(f"Workers {list(bad_nannies)} did not shut down within {timeout}s; force closing") + await asyncio.gather(*(self.remove_worker(addr, stimulus_id=stimulus_id) for addr in bad_nannies)) if on_error == "raise": raise TimeoutError( f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not " @@ -6669,15 +6432,10 @@ async def restart_workers( if client: self.log_event(client, {"action": "restart-workers", "workers": workers}) - self.log_event( - "all", {"action": "restart-workers", "workers": workers, "client": client} - ) + self.log_event("all", {"action": "restart-workers", "workers": workers, "client": client}) if not wait_for_workers: - logger.info( - "Workers restart finished (did not wait for new workers) " - f"({stimulus_id=}" - ) + logger.info(f"Workers restart finished (did not wait for new workers) ({stimulus_id=}") return out # NOTE: if new (unrelated) workers join while we're waiting, we may return @@ -6746,9 +6504,7 @@ async def broadcast( ERROR = object() - reuse_broadcast_comm = dask.config.get( - "distributed.scheduler.reuse-broadcast-comm", False - ) + reuse_broadcast_comm = dask.config.get("distributed.scheduler.reuse-broadcast-comm", False) close = not reuse_broadcast_comm async def send_message(addr: str) -> Any: @@ -6756,9 +6512,7 @@ async def send_message(addr: str) -> Any: comm = await self.rpc.connect(addr) comm.name = "Scheduler Broadcast" try: - resp = await send_recv( - comm, close=close, serializers=serializers, **msg - ) + resp = await send_recv(comm, close=close, serializers=serializers, **msg) finally: self.rpc.reuse(addr, comm) return resp @@ -6774,8 +6528,7 @@ async def send_message(addr: str) -> Any: return ERROR else: raise ValueError( - "on_error must be 'raise', 'return', 'return_pickle', " - f"or 'ignore'; got {on_error!r}" + f"on_error must be 'raise', 'return', 'return_pickle', or 'ignore'; got {on_error!r}" ) results = await All([send_message(address) for address in addresses]) @@ -6791,9 +6544,7 @@ async def proxy( d = await self.broadcast(msg=msg, workers=[worker], serializers=serializers) return d[worker] - async def gather_on_worker( - self, worker_address: str, who_has: dict[Key, list[str]] - ) -> set: + async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[str]]) -> set: """Peer-to-peer copy of keys from multiple workers to a single worker Parameters @@ -6809,15 +6560,12 @@ async def gather_on_worker( set of keys that failed to be copied """ try: - result = await retry_operation( - self.rpc(addr=worker_address).gather, who_has=who_has - ) + result = await retry_operation(self.rpc(addr=worker_address).gather, who_has=who_has) except OSError as e: # This can happen e.g. if the worker is going through controlled shutdown; # it doesn't necessarily mean that it went unexpectedly missing logger.warning( - f"Communication with worker {worker_address} failed during " - f"replication: {e.__class__.__name__}: {e}" + f"Communication with worker {worker_address} failed during replication: {e.__class__.__name__}: {e}" ) return set(who_has) @@ -6832,9 +6580,7 @@ async def gather_on_worker( elif result["status"] == "partial-fail": keys_failed = set(result["keys"]) keys_ok = who_has.keys() - keys_failed - logger.warning( - f"Worker {worker_address} failed to acquire keys: {result['keys']}" - ) + logger.warning(f"Worker {worker_address} failed to acquire keys: {result['keys']}") else: # pragma: nocover raise ValueError(f"Unexpected message from {worker_address}: {result}") @@ -6847,9 +6593,7 @@ async def gather_on_worker( return keys_failed - async def delete_worker_data( - self, worker_address: str, keys: Collection[Key], stimulus_id: str - ) -> None: + async def delete_worker_data(self, worker_address: str, keys: Collection[Key], stimulus_id: str) -> None: """Delete data from a worker and update the corresponding worker/task states Parameters @@ -6869,8 +6613,7 @@ async def delete_worker_data( # This can happen e.g. if the worker is going through controlled shutdown; # it doesn't necessarily mean that it went unexpectedly missing logger.warning( - f"Communication with worker {worker_address} failed during " - f"replication: {e.__class__.__name__}: {e}" + f"Communication with worker {worker_address} failed during replication: {e.__class__.__name__}: {e}" ) return @@ -6976,9 +6719,7 @@ async def rebalance( keys = set(keys) # unless already a set-like if not keys: return {"status": "OK"} - missing_data = [ - k for k in keys if k not in self.tasks or not self.tasks[k].who_has - ] + missing_data = [k for k in keys if k not in self.tasks or not self.tasks[k].who_has] if missing_data: return {"status": "partial-fail", "keys": missing_data} @@ -7053,9 +6794,7 @@ def _rebalance_find_msgs( # unmanaged memory that appeared over the last 30 seconds # (distributed.worker.memory.recent-to-old-time). # This lets us ignore temporary spikes caused by task heap usage. - memory_by_worker = [ - (ws, getattr(ws.memory, self.MEMORY_REBALANCE_MEASURE)) for ws in workers - ] + memory_by_worker = [(ws, getattr(ws.memory, self.MEMORY_REBALANCE_MEASURE)) for ws in workers] mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker) for ws, ws_memory in memory_by_worker: @@ -7068,18 +6807,12 @@ def _rebalance_find_msgs( sender_min = 0.0 recipient_max = math.inf - if ( - ws._has_what - and ws_memory >= mean_memory + half_gap - and ws_memory >= sender_min - ): + if ws._has_what and ws_memory >= mean_memory + half_gap and ws_memory >= sender_min: # This may send the worker below sender_min (by design) snd_bytes_max = mean_memory - ws_memory # negative snd_bytes_min = snd_bytes_max + half_gap # negative # See definition of senders above - senders.append( - (snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what)) - ) + senders.append((snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what))) elif ws_memory < mean_memory - half_gap and ws_memory < recipient_max: # This may send the worker above recipient_max (by design) rec_bytes_max = ws_memory - mean_memory # negative @@ -7197,9 +6930,7 @@ async def _rebalance_move_data( FIXME this method is not robust when the cluster is not idle. """ # {recipient address: {key: [sender address, ...]}} - to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict( - lambda: defaultdict(list) - ) + to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict(lambda: defaultdict(list)) for snd_ws, rec_ws, ts in msgs: to_recipients[rec_ws.address][ts.key].append(snd_ws.address) failed_keys_by_recipient = dict( @@ -7221,9 +6952,7 @@ async def _rebalance_move_data( to_senders[snd_ws.address].append(ts.key) # Note: this never raises exceptions - await asyncio.gather( - *(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items()) - ) + await asyncio.gather(*(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items())) for r, v in to_recipients.items(): self.log_event(r, {"action": "rebalance", "who_has": v}) @@ -7302,17 +7031,13 @@ async def replicate( assert ts.who_has is not None del_candidates = tuple(ts.who_has & workers) if len(del_candidates) > n: - for ws in random.sample( - del_candidates, len(del_candidates) - n - ): + for ws in random.sample(del_candidates, len(del_candidates) - n): del_worker_tasks[ws].add(ts) # Note: this never raises exceptions await asyncio.gather( *[ - self.delete_worker_data( - ws.address, [t.key for t in tasks], stimulus_id - ) + self.delete_worker_data(ws.address, [t.key for t in tasks], stimulus_id) for ws, tasks in del_worker_tasks.items() ] ) @@ -7337,9 +7062,7 @@ async def replicate( assert count > 0 for ws in random.sample(tuple(workers - ts.who_has), count): - gathers[ws.address][ts.key] = [ - wws.address for wws in ts.who_has - ] + gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has] await asyncio.gather( *( @@ -7595,8 +7318,7 @@ async def retire_workers( raise TypeError("names and workers are mutually exclusive") if (names is not None or workers is not None) and kwargs: raise TypeError( - "Parameters for workers_to_close() are mutually exclusive with " - f"names and workers: {kwargs}" + f"Parameters for workers_to_close() are mutually exclusive with names and workers: {kwargs}" ) stimulus_id = stimulus_id or f"retire-workers-{time()}" @@ -7616,24 +7338,16 @@ async def retire_workers( stimulus_id, workers, ) - wss = { - self.workers[address] - for address in workers - if address in self.workers - } + wss = {self.workers[address] for address in workers if address in self.workers} else: - wss = { - self.workers[address] for address in self.workers_to_close(**kwargs) - } + wss = {self.workers[address] for address in self.workers_to_close(**kwargs)} if not wss: return {} stop_amm = False amm: ActiveMemoryManagerExtension | None = self.extensions.get("amm") if not amm or not amm.running: - amm = ActiveMemoryManagerExtension( - self, policies=set(), register=False, start=True, interval=2.0 - ) + amm = ActiveMemoryManagerExtension(self, policies=set(), register=False, start=True, interval=2.0) stop_amm = True try: @@ -7645,9 +7359,7 @@ async def retire_workers( # Change Worker.status to closing_gracefully. Immediately set # the same on the scheduler to prevent race conditions. prev_status = ws.status - self.handle_worker_status_change( - Status.closing_gracefully, ws, stimulus_id - ) + self.handle_worker_status_change(Status.closing_gracefully, ws, stimulus_id) # FIXME: We should send a message to the nanny first; # eventually workers won't be able to close their own nannies. self.stream_comms[ws.address].send( @@ -7738,14 +7450,10 @@ async def _track_retire_worker( ) return ws.address, "no-recipients", ws.identity() - logger.debug( - f"All unique keys on worker {ws.address!r} have been replicated elsewhere" - ) + logger.debug(f"All unique keys on worker {ws.address!r} have been replicated elsewhere") if remove: - await self.remove_worker( - ws.address, expected=True, close=close, stimulus_id=stimulus_id - ) + await self.remove_worker(ws.address, expected=True, close=close, stimulus_id=stimulus_id) elif close: self.close_worker(ws.address) @@ -7887,9 +7595,7 @@ async def feed( if teardown: teardown(self, state) # type: ignore - def log_worker_event( - self, worker: str, topic: str | Collection[str], msg: Any - ) -> None: + def log_worker_event(self, worker: str, topic: str | Collection[str], msg: Any) -> None: if isinstance(msg, dict) and worker != topic: msg["worker"] = worker self.log_event(topic, msg) @@ -7902,46 +7608,25 @@ def subscribe_worker_status(self, comm: Comm) -> dict[str, Any]: del v["last_seen"] return ident - def get_processing( - self, workers: Iterable[str] | None = None - ) -> dict[str, list[Key]]: + def get_processing(self, workers: Iterable[str] | None = None) -> dict[str, list[Key]]: if workers is not None: workers = set(map(self.coerce_address, workers)) return {w: [ts.key for ts in self.workers[w].processing] for w in workers} else: - return { - w: [ts.key for ts in ws.processing] for w, ws in self.workers.items() - } + return {w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()} def get_who_has(self, keys: Iterable[Key] | None = None) -> dict[Key, list[str]]: if keys is not None: return { - key: ( - [ws.address for ws in self.tasks[key].who_has or ()] - if key in self.tasks - else [] - ) - for key in keys + key: ([ws.address for ws in self.tasks[key].who_has or ()] if key in self.tasks else []) for key in keys } else: - return { - key: [ws.address for ws in ts.who_has or ()] - for key, ts in self.tasks.items() - } + return {key: [ws.address for ws in ts.who_has or ()] for key, ts in self.tasks.items()} - def get_has_what( - self, workers: Iterable[str] | None = None - ) -> dict[str, list[Key]]: + def get_has_what(self, workers: Iterable[str] | None = None) -> dict[str, list[Key]]: if workers is not None: workers = map(self.coerce_address, workers) - return { - w: ( - [ts.key for ts in self.workers[w].has_what] - if w in self.workers - else [] - ) - for w in workers - } + return {w: ([ts.key for ts in self.workers[w].has_what] if w in self.workers else []) for w in workers} else: return {w: [ts.key for ts in ws.has_what] for w, ws in self.workers.items()} @@ -7952,13 +7637,9 @@ def get_ncores(self, workers: Iterable[str] | None = None) -> dict[str, int]: else: return {w: ws.nthreads for w, ws in self.workers.items()} - def get_ncores_running( - self, workers: Iterable[str] | None = None - ) -> dict[str, int]: + def get_ncores_running(self, workers: Iterable[str] | None = None) -> dict[str, int]: ncores = self.get_ncores(workers=workers) - return { - w: n for w, n in ncores.items() if self.workers[w].status == Status.running - } + return {w: n for w, n in ncores.items() if self.workers[w].status == Status.running} async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, Any]: workers: dict[str, list[Key] | None] @@ -7985,9 +7666,7 @@ async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, A if not workers: return {} - results = await asyncio.gather( - *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) - ) + results = await asyncio.gather(*(self.rpc(w).call_stack(keys=v) for w, v in workers.items())) response = {w: r for w, r in zip(workers, results) if r} return response @@ -8028,9 +7707,7 @@ async def benchmark_hardware(self) -> dict[str, dict[str, float]]: # implementing logic based on IP addresses would not necessarily help. # Randomize the connections to even out the mean measures. random.shuffle(workers) - futures = [ - self.rpc(a).benchmark_network(address=b) for a, b in partition(2, workers) - ] + futures = [self.rpc(a).benchmark_network(address=b) for a, b in partition(2, workers)] responses = await asyncio.gather(*futures) for d in responses: @@ -8039,17 +7716,12 @@ async def benchmark_hardware(self) -> dict[str, dict[str, float]]: result = {} for mode in out: - result[mode] = { - size: sum(durations) / len(durations) - for size, durations in out[mode].items() - } + result[mode] = {size: sum(durations) / len(durations) for size, durations in out[mode].items()} return result @log_errors - def get_nbytes( - self, keys: Iterable[Key] | None = None, summary: bool = True - ) -> dict[Key, int]: + def get_nbytes(self, keys: Iterable[Key] | None = None, summary: bool = True) -> dict[Key, int]: if keys is not None: result = {k: self.tasks[k].nbytes for k in keys} else: @@ -8135,9 +7807,7 @@ def get_task_prefix_states(self) -> dict[str, dict[str, int]]: return state def get_task_status(self, keys: Iterable[Key]) -> dict[Key, TaskStateState | None]: - return { - key: (self.tasks[key].state if key in self.tasks else None) for key in keys - } + return {key: (self.tasks[key].state if key in self.tasks else None) for key in keys} def get_task_stream( self, @@ -8160,14 +7830,11 @@ def start_task_metadata(self, name: str) -> None: def stop_task_metadata(self, name: str | None = None) -> dict: plugins = [ - p - for p in list(self.plugins.values()) - if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name + p for p in list(self.plugins.values()) if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name ] if len(plugins) != 1: raise ValueError( - "Expected to find exactly one CollectTaskMetaDataPlugin " - f"with name {name} but found {len(plugins)}." + f"Expected to find exactly one CollectTaskMetaDataPlugin with name {name} but found {len(plugins)}." ) plugin = plugins[0] @@ -8192,14 +7859,10 @@ async def register_worker_plugin( self.worker_plugins[name] = plugin - responses = await self.broadcast( - msg=dict(op="plugin-add", plugin=plugin, name=name) - ) + responses = await self.broadcast(msg=dict(op="plugin-add", plugin=plugin, name=name)) return responses - async def unregister_worker_plugin( - self, comm: None, name: str - ) -> dict[str, ErrorMessage | OKMessage]: + async def unregister_worker_plugin(self, comm: None, name: str) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.worker_plugins.pop(name) @@ -8231,27 +7894,21 @@ async def register_nanny_plugin( async with self._starting_nannies_cond: if self._starting_nannies: logger.info("Waiting for Nannies to start %s", self._starting_nannies) - await self._starting_nannies_cond.wait_for( - lambda: not self._starting_nannies - ) + await self._starting_nannies_cond.wait_for(lambda: not self._starting_nannies) responses = await self.broadcast( msg=dict(op="plugin_add", plugin=plugin, name=name), nanny=True, ) return responses - async def unregister_nanny_plugin( - self, comm: None, name: str - ) -> dict[str, ErrorMessage | OKMessage]: + async def unregister_nanny_plugin(self, comm: None, name: str) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.nanny_plugins.pop(name) except KeyError: raise ValueError(f"The nanny plugin {name} does not exist") - responses = await self.broadcast( - msg=dict(op="plugin_remove", name=name), nanny=True - ) + responses = await self.broadcast(msg=dict(op="plugin_remove", name=name), nanny=True) return responses def transition( @@ -8276,9 +7933,7 @@ def transition( -------- Scheduler.transitions: transitive version of this function """ - recommendations, client_msgs, worker_msgs = self._transition( - key, finish, stimulus_id, **kwargs - ) + recommendations, client_msgs, worker_msgs = self._transition(key, finish, stimulus_id, **kwargs) self.send_all(client_msgs, worker_msgs) return recommendations @@ -8301,9 +7956,7 @@ async def get_story(self, keys_or_stimuli: Iterable[Key | str]) -> list[Transiti """ return self.story(*keys_or_stimuli) - def _reschedule( - self, key: Key, worker: str | None = None, *, stimulus_id: str - ) -> None: + def _reschedule(self, key: Key, worker: str | None = None, *, stimulus_id: str) -> None: """Reschedule a task. This function should only be used when the task has already been released in @@ -8315,8 +7968,7 @@ def _reschedule( ts = self.tasks[key] except KeyError: logger.warning( - f"Attempting to reschedule task {key!r}, which was not " - "found on the scheduler. Aborting reschedule." + f"Attempting to reschedule task {key!r}, which was not found on the scheduler. Aborting reschedule." ) return if ts.state != "processing": @@ -8331,9 +7983,7 @@ def _reschedule( # Utility functions # ##################### - def add_resources( - self, worker: str, resources: dict | None = None - ) -> Literal["OK"]: + def add_resources(self, worker: str, resources: dict | None = None) -> Literal["OK"]: ws = self.workers[worker] if resources: ws.resources.update(resources) @@ -8415,10 +8065,7 @@ async def get_profile( ) results = await asyncio.gather( - *( - self.rpc(w).profile(start=start, stop=stop, key=key, server=server) - for w in workers - ), + *(self.rpc(w).profile(start=start, stop=stop, key=key, server=server) for w in workers), return_exceptions=True, ) @@ -8438,9 +8085,7 @@ async def get_profile_metadata( stop: float | None = None, profile_cycle_interval: str | float | None = None, ) -> dict[str, Any]: - dt = profile_cycle_interval or dask.config.get( - "distributed.worker.profile.cycle" - ) + dt = profile_cycle_interval or dask.config.get("distributed.worker.profile.cycle") dt = parse_timedelta(dt, default="ms") if workers is None: @@ -8463,9 +8108,7 @@ async def get_profile_metadata( ) ] - keys: dict[Key, list[list]] = { - k: [] for v in results for t, d in v["keys"] for k in d - } + keys: dict[Key, list[list]] = {k: [] for v in results for t, d in v["keys"] for k in d} groups1 = [v["keys"] for v in results] groups2 = list(merge_sorted(*groups1, key=first)) @@ -8482,9 +8125,7 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} - async def performance_report( - self, start: float, last_count: int, code: str = "", mode: str | None = None - ) -> str: + async def performance_report(self, start: float, last_count: int, code: str = "", mode: str | None = None) -> str: stop = time() # Profiles compute_d, scheduler_d, workers_d = await asyncio.gather( @@ -8501,9 +8142,7 @@ def profile_to_figure(state: object) -> object: figure, source = profile.plot_figure(data, sizing_mode="stretch_both") return figure - compute, scheduler, workers = map( - profile_to_figure, (compute_d, scheduler_d, workers_d) - ) + compute, scheduler, workers = map(profile_to_figure, (compute_d, scheduler_d, workers_d)) del compute_d, scheduler_d, workers_d # Task stream @@ -8587,16 +8226,10 @@ def profile_to_figure(state: object) -> object: html = TabPanel(child=html, title="Summary") compute = TabPanel(child=compute, title="Worker Profile (compute)") workers = TabPanel(child=workers, title="Worker Profile (administrative)") - scheduler = TabPanel( - child=scheduler, title="Scheduler Profile (administrative)" - ) + scheduler = TabPanel(child=scheduler, title="Scheduler Profile (administrative)") task_stream = TabPanel(child=task_stream, title="Task Stream") - bandwidth_workers = TabPanel( - child=bandwidth_workers.root, title="Bandwidth (Workers)" - ) - bandwidth_types = TabPanel( - child=bandwidth_types.root, title="Bandwidth (Types)" - ) + bandwidth_workers = TabPanel(child=bandwidth_workers.root, title="Bandwidth (Workers)") + bandwidth_types = TabPanel(child=bandwidth_types.root, title="Bandwidth (Types)") system = TabPanel(child=sysmon.root, title="System") logs = TabPanel(child=logs.root, title="Scheduler Logs") @@ -8620,9 +8253,7 @@ def profile_to_figure(state: object) -> object: with tmpfile(extension=".html") as fn: output_file(filename=fn, title="Dask Performance Report", mode=mode) - template_directory = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates" - ) + template_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates") template_environment = get_env() template_environment.loader.searchpath.append(template_directory) template = template_environment.get_template("performance_report.html") @@ -8633,12 +8264,8 @@ def profile_to_figure(state: object) -> object: return data - async def get_worker_logs( - self, n: int | None = None, workers: list | None = None, nanny: bool = False - ) -> dict: - results = await self.broadcast( - msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny - ) + async def get_worker_logs(self, n: int | None = None, workers: list | None = None, nanny: bool = False) -> dict: + results = await self.broadcast(msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny) return results def log_event(self, topic: str | Collection[str], msg: Any) -> None: @@ -8675,16 +8302,11 @@ def get_events( ) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]: return self._broker.get_events(topic) - async def get_worker_monitor_info( - self, recent: bool = False, starts: dict | None = None - ) -> dict: + async def get_worker_monitor_info(self, recent: bool = False, starts: dict | None = None) -> dict: if starts is None: starts = {} results = await asyncio.gather( - *( - self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) - for w in self.workers - ) + *(self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) for w in self.workers) ) return dict(zip(self.workers, results)) @@ -8737,11 +8359,7 @@ def check_idle(self) -> float | None: self.idle_since = None return None - if ( - self.queued - or self.unrunnable - or any(ws.processing for ws in self.workers.values()) - ): + if self.queued or self.unrunnable or any(ws.processing for ws in self.workers.values()): self.idle_since = None return None @@ -8750,9 +8368,7 @@ def check_idle(self) -> float | None: return self.idle_since if self.jupyter: - last_activity = ( - self._jupyter_server_application.web_app.last_activity().timestamp() - ) + last_activity = self._jupyter_server_application.web_app.last_activity().timestamp() if last_activity > self.idle_since: self.idle_since = last_activity return self.idle_since @@ -8764,16 +8380,11 @@ def check_idle(self) -> float | None: "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - self._ongoing_background_tasks.call_soon( - self.close, reason="idle-timeout-exceeded" - ) + self._ongoing_background_tasks.call_soon(self.close, reason="idle-timeout-exceeded") return self.idle_since def _check_no_workers(self) -> None: - if ( - self.status in (Status.closing, Status.closed) - or self.no_workers_timeout is None - ): + if self.status in (Status.closing, Status.closed) or self.no_workers_timeout is None: return now = monotonic() @@ -8783,15 +8394,9 @@ def _check_no_workers(self) -> None: self._refresh_no_workers_since(now) - affected = self._check_unrunnable_task_timeouts( - now, recommendations=recommendations, stimulus_id=stimulus_id - ) + affected = self._check_unrunnable_task_timeouts(now, recommendations=recommendations, stimulus_id=stimulus_id) - affected.update( - self._check_queued_task_timeouts( - now, recommendations=recommendations, stimulus_id=stimulus_id - ) - ) + affected.update(self._check_queued_task_timeouts(now, recommendations=recommendations, stimulus_id=stimulus_id)) self.transitions(recommendations, stimulus_id=stimulus_id) if affected: self.log_event( @@ -8799,9 +8404,7 @@ def _check_no_workers(self) -> None: {"action": "no-workers-timeout-exceeded", "keys": affected}, ) - def _check_unrunnable_task_timeouts( - self, timestamp: float, recommendations: Recs, stimulus_id: str - ) -> set[Key]: + def _check_unrunnable_task_timeouts(self, timestamp: float, recommendations: Recs, stimulus_id: str) -> set[Key]: assert self.no_workers_timeout unsatisfied = [] no_workers = [] @@ -8810,10 +8413,7 @@ def _check_unrunnable_task_timeouts( # unrunnable is insertion-ordered, which means that unrunnable_since will # be monotonically increasing in this loop. break - if ( - self._no_workers_since is None - or self._no_workers_since >= unrunnable_since - ): + if self._no_workers_since is None or self._no_workers_since >= unrunnable_since: unsatisfied.append(ts) else: no_workers.append(ts) @@ -8839,18 +8439,13 @@ def _check_unrunnable_task_timeouts( ) recommendations.update(r) logger.error( - "Task %s marked as failed because it timed out waiting " - "for its restrictions to become satisfied.", + "Task %s marked as failed because it timed out waiting for its restrictions to become satisfied.", ts.key, ) - self._fail_tasks_after_no_workers_timeout( - no_workers, recommendations, stimulus_id - ) + self._fail_tasks_after_no_workers_timeout(no_workers, recommendations, stimulus_id) return {ts.key for ts in concat([unsatisfied, no_workers])} - def _check_queued_task_timeouts( - self, timestamp: float, recommendations: Recs, stimulus_id: str - ) -> set[Key]: + def _check_queued_task_timeouts(self, timestamp: float, recommendations: Recs, stimulus_id: str) -> set[Key]: assert self.no_workers_timeout if self._no_workers_since is None: @@ -8859,9 +8454,7 @@ def _check_queued_task_timeouts( if timestamp <= self._no_workers_since + self.no_workers_timeout: return set() affected = list(self.queued) - self._fail_tasks_after_no_workers_timeout( - affected, recommendations, stimulus_id - ) + self._fail_tasks_after_no_workers_timeout(affected, recommendations, stimulus_id) return {ts.key for ts in affected} def _fail_tasks_after_no_workers_timeout( @@ -8885,8 +8478,7 @@ def _fail_tasks_after_no_workers_timeout( ) recommendations.update(r) logger.error( - "Task %s marked as failed because it timed out waiting " - "without any running workers.", + "Task %s marked as failed because it timed out waiting without any running workers.", ts.key, ) @@ -8961,9 +8553,7 @@ def adaptive_target(self, target_duration: float | None = None) -> int: to_close = self.workers_to_close() return len(self.workers) - len(to_close) - def request_acquire_replicas( - self, addr: str, keys: Iterable[Key], *, stimulus_id: str - ) -> None: + def request_acquire_replicas(self, addr: str, keys: Iterable[Key], *, stimulus_id: str) -> None: """Asynchronously ask a worker to acquire a replica of the listed keys from other workers. This is a fire-and-forget operation which offers no feedback for success or failure, and is intended for housekeeping and not for computation. @@ -8985,9 +8575,7 @@ def request_acquire_replicas( }, ) - def request_remove_replicas( - self, addr: str, keys: list[Key], *, stimulus_id: str - ) -> None: + def request_remove_replicas(self, addr: str, keys: list[Key], *, stimulus_id: str) -> None: """Asynchronously ask a worker to discard its replica of the listed keys. This must never be used to destroy the last replica of a key. This is a fire-and-forget operation, intended for housekeeping and not for computation. @@ -9171,13 +8759,7 @@ def validate_task_state(ts: TaskState) -> None: if ts.run_spec: # was computed assert ts.type assert isinstance(ts.type, str) - assert not any( - [ - ts in dts.waiting_on - for dts in ts.dependents - if dts.waiting_on is not None - ] - ) + assert not any([ts in dts.waiting_on for dts in ts.dependents if dts.waiting_on is not None]) for ws in ts.who_has: assert ts in ws.has_what, ( "not in who_has' has_what", @@ -9280,9 +8862,7 @@ def heartbeat_interval(n: int) -> float: def _task_slots_available(ws: WorkerState, saturation_factor: float) -> int: """Number of tasks that can be sent to this worker without oversaturating it""" assert not math.isinf(saturation_factor) - return max(math.ceil(saturation_factor * ws.nthreads), 1) - ( - len(ws.processing) - len(ws.long_running) - ) + return max(math.ceil(saturation_factor * ws.nthreads), 1) - (len(ws.processing) - len(ws.long_running)) def _worker_full(ws: WorkerState, saturation_factor: float) -> bool: @@ -9326,9 +8906,7 @@ def __init__( resource_restrictions: dict[str, float], timeout: float, ): - super().__init__( - task, host_restrictions, worker_restrictions, resource_restrictions, timeout - ) + super().__init__(task, host_restrictions, worker_restrictions, resource_restrictions, timeout) @property def task(self) -> Key: From 9b9e0a842ee01d5e190b663f2c5d464e87a75f8b Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 18:02:06 -0800 Subject: [PATCH 02/23] Update --- distributed/tests/test_condition.py | 418 ++++++++++++++++++++++++++++ 1 file changed, 418 insertions(+) create mode 100644 distributed/tests/test_condition.py diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py new file mode 100644 index 0000000000..ea8e051796 --- /dev/null +++ b/distributed/tests/test_condition.py @@ -0,0 +1,418 @@ +import asyncio +import pytest + +from distributed import Condition, Client, wait +from distributed.utils_test import gen_cluster, inc +from distributed.metrics import time + + +@gen_cluster(client=True) +async def test_condition_acqui re_release(c, s, a, b): + """Test basic lock acquire/release""" + condition = Condition("test-lock") + + assert not condition.locked() + await condition.acquire() + assert condition.locked() + await condition.release() + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_context_manager(c, s, a, b): + """Test context manager interface""" + condition = Condition("test-context") + + assert not condition.locked() + async with condition: + assert condition.locked() + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_wait_notify(c, s, a, b): + """Test basic wait/notify""" + condition = Condition("test-notify") + results = [] + + async def waiter(): + async with condition: + results.append("waiting") + await condition.wait() + results.append("notified") + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + results.append("notifying") + condition.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["waiting", "notifying", "notified"] + + +@gen_cluster(client=True) +async def test_condition_notify_all(c, s, a, b): + """Test notify_all wakes all waiters""" + condition = Condition("test-notify-all") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify_all() + + await asyncio.gather( + waiter(1), waiter(2), waiter(3), notifier() + ) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_notify_n(c, s, a, b): + """Test notify with specific count""" + condition = Condition("test-notify-n") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify(n=2) # Wake only 2 waiters + await asyncio.sleep(0.2) + async with condition: + condition.notify() # Wake remaining waiter + + await asyncio.gather( + waiter(1), waiter(2), waiter(3), notifier() + ) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_wait_timeout(c, s, a, b): + """Test wait with timeout""" + condition = Condition("test-timeout") + + start = time() + async with condition: + result = await condition.wait(timeout=0.5) + elapsed = time() - start + + assert result is False + assert 0.4 < elapsed < 0.7 + + +@gen_cluster(client=True) +async def test_condition_wait_timeout_then_notify(c, s, a, b): + """Test that timeout doesn't prevent subsequent notifications""" + condition = Condition("test-timeout-notify") + results = [] + + async def waiter(): + async with condition: + result = await condition.wait(timeout=0.2) + results.append(f"timeout: {result}") + + async with condition: + result = await condition.wait() + results.append(f"notified: {result}") + + async def notifier(): + await asyncio.sleep(0.5) + async with condition: + condition.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["timeout: False", "notified: True"] + + +@gen_cluster(client=True) +async def test_condition_error_without_lock(c, s, a, b): + """Test errors when calling wait/notify without holding lock""" + condition = Condition("test-error") + + with pytest.raises(RuntimeError, match="without holding the lock"): + await condition.wait() + + with pytest.raises(RuntimeError, match="Cannot notify"): + await condition.notify() + + with pytest.raises(RuntimeError, match="Cannot notify"): + await condition.notify_all() + + +@gen_cluster(client=True) +async def test_condition_error_release_without_acquire(c, s, a, b): + """Test error when releasing without acquiring""" + condition = Condition("test-release-error") + + with pytest.raises(RuntimeError, match="Cannot release"): + await condition.release() + + +@gen_cluster(client=True) +async def test_condition_producer_consumer(c, s, a, b): + """Test classic producer-consumer pattern""" + condition = Condition("prod-cons") + queue = [] + + async def producer(): + for i in range(5): + await asyncio.sleep(0.1) + async with condition: + queue.append(i) + condition.notify() + + async def consumer(): + results = [] + for _ in range(5): + async with condition: + while not queue: + await condition.wait() + results.append(queue.pop(0)) + return results + + prod_task = asyncio.create_task(producer()) + cons_task = asyncio.create_task(consumer()) + + await prod_task + results = await cons_task + + assert results == [0, 1, 2, 3, 4] + + +@gen_cluster(client=True) +async def test_condition_multiple_producers_consumers(c, s, a, b): + """Test multiple producers and consumers""" + condition = Condition("multi-prod-cons") + queue = [] + + async def producer(start): + for i in range(start, start + 3): + await asyncio.sleep(0.05) + async with condition: + queue.append(i) + condition.notify() + + async def consumer(): + results = [] + for _ in range(3): + async with condition: + while not queue: + await condition.wait() + results.append(queue.pop(0)) + return results + + results = await asyncio.gather( + producer(0), producer(10), + consumer(), consumer() + ) + + # Last two results are from consumers + consumed = results[2] + results[3] + assert sorted(consumed) == [0, 1, 2, 10, 11, 12] + + +@gen_cluster(client=True) +async def test_condition_from_worker(c, s, a, b): + """Test condition accessed from worker tasks""" + def wait_on_condition(name): + from distributed import Condition + import asyncio + + async def _wait(): + condition = Condition(name) + async with condition: + await condition.wait() + return "worker_notified" + + from distributed.worker import get_worker + worker = get_worker() + return worker.loop.run_until_complete(_wait()) + + def notify_condition(name): + from distributed import Condition + import asyncio + + async def _notify(): + await asyncio.sleep(0.2) + condition = Condition(name) + async with condition: + condition.notify() + return "notified" + + from distributed.worker import get_worker + worker = get_worker() + return worker.loop.run_until_complete(_notify()) + + name = "worker-condition" + f1 = c.submit(wait_on_condition, name, workers=[a.address]) + f2 = c.submit(notify_condition, name, workers=[b.address]) + + results = await c.gather([f1, f2]) + assert results == ["worker_notified", "notified"] + + +@gen_cluster(client=True) +async def test_condition_same_name_different_instances(c, s, a, b): + """Test that multiple instances with same name share state""" + name = "shared-condition" + cond1 = Condition(name) + cond2 = Condition(name) + + results = [] + + async def waiter(): + async with cond1: + results.append("waiting") + await cond1.wait() + results.append("notified") + + async def notifier(): + await asyncio.sleep(0.2) + async with cond2: + results.append("notifying") + cond2.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["waiting", "notifying", "notified"] + + +@gen_cluster(client=True) +async def test_condition_unique_names_independent(c, s, a, b): + """Test conditions with different names are independent""" + cond1 = Condition("cond-1") + cond2 = Condition("cond-2") + + async with cond1: + assert cond1.locked() + assert not cond2.locked() + + async with cond2: + assert not cond1.locked() + assert cond2.locked() + + +@gen_cluster(client=True) +async def test_condition_cleanup(c, s, a, b): + """Test that condition state is cleaned up after use""" + condition = Condition("cleanup-test") + + # Check initial state + assert "cleanup-test" not in s.extensions["conditions"]._lock_holders + assert "cleanup-test" not in s.extensions["conditions"]._waiters + + # Use condition + async with condition: + condition.notify() + + # State should be cleaned up + await asyncio.sleep(0.1) + assert "cleanup-test" not in s.extensions["conditions"]._lock_holders + + +@gen_cluster(client=True) +async def test_condition_barrier_pattern(c, s, a, b): + """Test barrier synchronization pattern""" + condition = Condition("barrier") + arrived = [] + n_workers = 3 + + async def worker(i): + async with condition: + arrived.append(i) + if len(arrived) < n_workers: + await condition.wait() + else: + condition.notify_all() + return f"worker-{i}-done" + + results = await asyncio.gather( + worker(0), worker(1), worker(2) + ) + + assert sorted(results) == ["worker-0-done", "worker-1-done", "worker-2-done"] + assert len(arrived) == 3 + + +def test_condition_sync_interface(client): + """Test synchronous interface via SyncMethodMixin""" + condition = Condition("sync-test") + results = [] + + def worker(): + with condition: + results.append("locked") + results.append("released") + + worker() + assert results == ["locked", "released"] + + +@gen_cluster(client=True) +async def test_condition_multiple_notify_calls(c, s, a, b): + """Test multiple notify calls in sequence""" + condition = Condition("multi-notify") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify() + await asyncio.sleep(0.1) + async with condition: + condition.notify() + await asyncio.sleep(0.1) + async with condition: + condition.notify() + + await asyncio.gather( + waiter(1), waiter(2), waiter(3), notifier() + ) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_predicate_loop(c, s, a, b): + """Test typical predicate-based wait loop pattern""" + condition = Condition("predicate") + state = {"value": 0, "target": 5} + + async def waiter(): + async with condition: + while state["value"] < state["target"]: + await condition.wait() + return state["value"] + + async def updater(): + for i in range(1, 6): + await asyncio.sleep(0.1) + async with condition: + state["value"] = i + condition.notify_all() + + result, _ = await asyncio.gather(waiter(), updater()) + assert result == 5 + + +@gen_cluster(client=True) +async def test_condition_repr(c, s, a, b): + """Test string representation""" + condition = Condition("test-repr") + assert "test-repr" in repr(condition) + assert "Condition" in repr(condition) From ce5ea1822bc9cc3fb24a0d921cbc61ce0c72c41c Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 18:10:36 -0800 Subject: [PATCH 03/23] Update scheduler.py --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e9b4bec324..590a84a924 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -193,8 +193,8 @@ "variables": VariableExtension, "semaphores": SemaphoreExtension, "events": EventExtension, - "conditions": ConditionExtension, "amm": ActiveMemoryManagerExtension, + "conditions": ConditionExtension, "memory_sampler": MemorySamplerExtension, "shuffle": ShuffleSchedulerPlugin, "spans": SpansSchedulerExtension, From eb7046ac61c27de78bba48ad8b5f06b3f5be9c1f Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 18:12:41 -0800 Subject: [PATCH 04/23] Update scheduler.py --- distributed/scheduler.py | 824 +++++++++++++++++++++++++++++---------- 1 file changed, 624 insertions(+), 200 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 590a84a924..f2ad5080a8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -182,7 +182,9 @@ logger = logging.getLogger(__name__) LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -DEFAULT_DATA_SIZE = parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) +DEFAULT_DATA_SIZE = parse_bytes( + dask.config.get("distributed.scheduler.default-data-size") +) STIMULUS_ID_UNSET = "" DEFAULT_EXTENSIONS = { @@ -407,7 +409,8 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: return { k: getattr(self, k) for k in dir(self) - if not k.startswith("_") and k not in {"sum", "managed_in_memory", "managed_spilled"} + if not k.startswith("_") + and k not in {"sum", "managed_in_memory", "managed_spilled"} } @@ -579,7 +582,9 @@ def __hash__(self) -> int: return self._hash def __eq__(self, other: object) -> bool: - return self is other or (isinstance(other, WorkerState) and other.server_id == self.server_id) + return self is other or ( + isinstance(other, WorkerState) and other.server_id == self.server_id + ) @property def has_what(self) -> Set[TaskState]: @@ -830,7 +835,9 @@ def _dec_needs_replica(self, ts: TaskState) -> None: nbytes = ts.get_nbytes() # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift self._network_occ -= min(nbytes, self._network_occ) - self.scheduler._network_occ_global -= min(nbytes, self.scheduler._network_occ_global) + self.scheduler._network_occ_global -= min( + nbytes, self.scheduler._network_occ_global + ) def add_replica(self, ts: TaskState) -> None: """The worker acquired a replica of task""" @@ -843,14 +850,18 @@ def add_replica(self, ts: TaskState) -> None: del self.needs_what[ts] # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift self._network_occ -= min(nbytes, self._network_occ) - self.scheduler._network_occ_global -= min(nbytes, self.scheduler._network_occ_global) + self.scheduler._network_occ_global -= min( + nbytes, self.scheduler._network_occ_global + ) ts.who_has.add(self) self.nbytes += nbytes self._has_what[ts] = None @property def occupancy(self) -> float: - return self._occupancy_cache or self.scheduler._calc_occupancy(self.task_prefix_count, self._network_occ) + return self._occupancy_cache or self.scheduler._calc_occupancy( + self.task_prefix_count, self._network_occ + ) @dataclasses.dataclass @@ -912,7 +923,9 @@ def __repr__(self) -> str: return ( f"" ) @@ -970,7 +983,10 @@ def all_durations(self) -> defaultdict[str, float]: """Cumulative duration of all completed actions of tasks belonging to this collection, by action""" return defaultdict( float, - {action: duration_us / 1e6 for action, duration_us in self._all_durations_us.items()}, + { + action: duration_us / 1e6 + for action, duration_us in self._all_durations_us.items() + }, ) @property @@ -1075,7 +1091,13 @@ def active_states(self) -> dict[TaskStateState, int]: def __repr__(self) -> str: return ( - "<" + self.name + ": " + ", ".join("%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v) + ">" + "<" + + self.name + + ": " + + ", ".join( + "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v + ) + + ">" ) @@ -1165,7 +1187,9 @@ def __repr__(self) -> str: "<" + (self.name or "no-group") + ": " - + ", ".join("%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v) + + ", ".join( + "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v + ) + ">" ) @@ -1193,7 +1217,8 @@ def done(self) -> bool: recomputed. """ return all( - count == 0 or state in {"memory", "erred", "released", "forgotten"} for state, count in self.states.items() + count == 0 or state in {"memory", "erred", "released", "forgotten"} + for state, count in self.states.items() ) @@ -1752,9 +1777,15 @@ def __init__( self.resources = resources self.saturated = set() self.tasks = tasks - self.replicated_tasks = {ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1} - self.computations = deque(maxlen=dask.config.get("distributed.diagnostics.computations.max-history")) - self.erred_tasks = deque(maxlen=dask.config.get("distributed.diagnostics.erred-tasks.max-history")) + self.replicated_tasks = { + ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1 + } + self.computations = deque( + maxlen=dask.config.get("distributed.diagnostics.computations.max-history") + ) + self.erred_tasks = deque( + maxlen=dask.config.get("distributed.diagnostics.erred-tasks.max-history") + ) self.task_groups = {} self.task_prefixes = {} self.task_metadata = {} @@ -1767,38 +1798,61 @@ def __init__( self.workers = workers self._task_prefix_count_global = defaultdict(int) self._network_occ_global = 0 - self.running = {ws for ws in self.workers.values() if ws.status == Status.running} + self.running = { + ws for ws in self.workers.values() if ws.status == Status.running + } self.plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} - self.transition_log = deque(maxlen=dask.config.get("distributed.admin.low-level-log-length")) + self.transition_log = deque( + maxlen=dask.config.get("distributed.admin.low-level-log-length") + ) self.transition_counter = 0 self._idle_transition_counter = 0 self.transition_counter_max = transition_counter_max # Variables from dask.config, cached by __init__ for performance - self.UNKNOWN_TASK_DURATION = parse_timedelta(dask.config.get("distributed.scheduler.unknown-task-duration")) + self.UNKNOWN_TASK_DURATION = parse_timedelta( + dask.config.get("distributed.scheduler.unknown-task-duration") + ) self.MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( dask.config.get("distributed.worker.memory.recent-to-old-time") ) - self.MEMORY_REBALANCE_MEASURE = dask.config.get("distributed.worker.memory.rebalance.measure") - self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get("distributed.worker.memory.rebalance.sender-min") - self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get("distributed.worker.memory.rebalance.recipient-max") + self.MEMORY_REBALANCE_MEASURE = dask.config.get( + "distributed.worker.memory.rebalance.measure" + ) + self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get( + "distributed.worker.memory.rebalance.sender-min" + ) + self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( + "distributed.worker.memory.rebalance.recipient-max" + ) self.MEMORY_REBALANCE_HALF_GAP = ( - dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") / 2.0 + dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") + / 2.0 ) - self.WORKER_SATURATION = dask.config.get("distributed.scheduler.worker-saturation") + self.WORKER_SATURATION = dask.config.get( + "distributed.scheduler.worker-saturation" + ) if self.WORKER_SATURATION == "inf": # Special case necessary because there's no way to parse a float infinity # from a DASK_* environment variable self.WORKER_SATURATION = math.inf - if not isinstance(self.WORKER_SATURATION, (int, float)) or self.WORKER_SATURATION <= 0: + if ( + not isinstance(self.WORKER_SATURATION, (int, float)) + or self.WORKER_SATURATION <= 0 + ): raise ValueError( # pragma: nocover - "`distributed.scheduler.worker-saturation` must be a float > 0; got " + repr(self.WORKER_SATURATION) + "`distributed.scheduler.worker-saturation` must be a float > 0; got " + + repr(self.WORKER_SATURATION) ) - self.rootish_tg_threshold = dask.config.get("distributed.scheduler.rootish-taskgroup") - self.rootish_tg_dependencies_threshold = dask.config.get("distributed.scheduler.rootish-taskgroup-dependencies") + self.rootish_tg_threshold = dask.config.get( + "distributed.scheduler.rootish-taskgroup" + ) + self.rootish_tg_dependencies_threshold = dask.config.get( + "distributed.scheduler.rootish-taskgroup-dependencies" + ) @abstractmethod def log_event(self, topic: str | Collection[str], msg: Any) -> None: ... @@ -1932,7 +1986,9 @@ def _calc_occupancy( # State Transitions # ##################### - def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwargs: Any) -> RecsMsgs: + def _transition( + self, key: Key, finish: TaskStateState, stimulus_id: str, **kwargs: Any + ) -> RecsMsgs: """Transition a key from its current state to the finish state Examples @@ -1978,11 +2034,15 @@ def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwar func = self._TRANSITIONS_TABLE.get((start, finish)) if func is not None: - recommendations, client_msgs, worker_msgs = func(self, key, stimulus_id, **kwargs) + recommendations, client_msgs, worker_msgs = func( + self, key, stimulus_id, **kwargs + ) elif "released" not in (start, finish): assert not kwargs, (kwargs, start, finish) - a_recs, a_cmsgs, a_wmsgs = self._transition(key, "released", stimulus_id) + a_recs, a_cmsgs, a_wmsgs = self._transition( + key, "released", stimulus_id + ) v = a_recs.get(key, finish) # The inner rec has higher priority? Is that always desired? @@ -2012,10 +2072,16 @@ def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwar stimulus_id = STIMULUS_ID_UNSET actual_finish = ts._state - self.transition_log.append(Transition(key, start, actual_finish, recommendations, stimulus_id, time())) + self.transition_log.append( + Transition( + key, start, actual_finish, recommendations, stimulus_id, time() + ) + ) if self.validate: if stimulus_id == STIMULUS_ID_UNSET: - raise RuntimeError("stimulus_id not set during Scheduler transition") + raise RuntimeError( + "stimulus_id not set during Scheduler transition" + ) logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -2032,7 +2098,9 @@ def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwar self.tasks[ts.key] = ts for plugin in list(self.plugins.values()): try: - plugin.transition(key, start, actual_finish, stimulus_id=stimulus_id, **kwargs) + plugin.transition( + key, start, actual_finish, stimulus_id=stimulus_id, **kwargs + ) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts.state == "forgotten": @@ -2216,7 +2284,9 @@ def _transition_queued_erred( traceback_text=traceback_text, ) - def decide_worker_rootish_queuing_disabled(self, ts: TaskState) -> WorkerState | None: + def decide_worker_rootish_queuing_disabled( + self, ts: TaskState + ) -> WorkerState | None: """Pick a worker for a runnable root-ish task, without queuing. This attempts to schedule sibling tasks on the same worker, reducing future data @@ -2247,16 +2317,25 @@ def decide_worker_rootish_queuing_disabled(self, ts: TaskState) -> WorkerState | tg = ts.group lws = tg.last_worker - if lws and tg.last_worker_tasks_left and lws.status == Status.running and self.workers.get(lws.address) is lws: + if ( + lws + and tg.last_worker_tasks_left + and lws.status == Status.running + and self.workers.get(lws.address) is lws + ): ws = lws else: # Last-used worker is full, unknown, retiring, or paused; # pick a new worker for the next few tasks ws = min(pool, key=partial(self.worker_objective, ts)) - tg.last_worker_tasks_left = math.floor((len(tg) / self.total_nthreads) * ws.nthreads) + tg.last_worker_tasks_left = math.floor( + (len(tg) / self.total_nthreads) * ws.nthreads + ) # Record `last_worker`, or clear it on the final task - tg.last_worker = ws if tg.states["released"] + tg.states["waiting"] > 1 else None + tg.last_worker = ( + ws if tg.states["released"] + tg.states["waiting"] > 1 else None + ) tg.last_worker_tasks_left -= 1 if self.validate and ws is not None: @@ -2497,7 +2576,9 @@ def _transition_processing_memory( recommendations: Recs = {} client_msgs: Msgs = {} - self._add_to_memory(ts, ws, recommendations, client_msgs, type=type, typename=typename) + self._add_to_memory( + ts, ws, recommendations, client_msgs, type=type, typename=typename + ) if self.validate: assert not ts.processing_on @@ -2517,7 +2598,9 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: ws.actors.discard(ts) if ts.who_wants: ts.exception_blame = ts - ts.exception = Serialized(*serialize(RuntimeError("Worker holding Actor was lost"))) + ts.exception = Serialized( + *serialize(RuntimeError("Worker holding Actor was lost")) + ) return {ts.key: "erred"}, {}, {} # don't try to recreate recommendations: Recs = {} @@ -2544,7 +2627,9 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: recommendations[key] = "forgotten" elif ts.has_lost_dependencies: recommendations[key] = "forgotten" - elif (ts.who_wants or ts.waiters) and not any(dts.state == "erred" for dts in ts.dependencies): + elif (ts.who_wants or ts.waiters) and not any( + dts.state == "erred" for dts in ts.dependencies + ): recommendations[key] = "waiting" for dts in ts.waiters or (): @@ -2924,7 +3009,9 @@ def _transition_memory_erred(self, key: Key, stimulus_id: str) -> RecsMsgs: if not dts.who_has: dts.exception_blame = ts recommendations[dts.key] = "erred" - exception = Serialized(*serialize(RuntimeError("Worker holding Actor was lost"))) + exception = Serialized( + *serialize(RuntimeError("Worker holding Actor was lost")) + ) report_msg = { "op": "task-erred", "key": key, @@ -3037,9 +3124,14 @@ def _transition_released_forgotten(self, key: Key, stimulus_id: str) -> RecsMsgs ("released", "erred"): _transition_released_erred, } - def story(self, *keys_or_tasks_or_stimuli: Key | TaskState | str) -> list[Transition]: + def story( + self, *keys_or_tasks_or_stimuli: Key | TaskState | str + ) -> list[Transition]: """Get all transitions that touch one of the input keys or stimulus_id's""" - keys_or_stimuli = {key.key if isinstance(key, TaskState) else key for key in keys_or_tasks_or_stimuli} + keys_or_stimuli = { + key.key if isinstance(key, TaskState) else key + for key in keys_or_tasks_or_stimuli + } return scheduler_story(keys_or_stimuli, self.transition_log) ############################## @@ -3115,9 +3207,14 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None: else: self.idle_task_count.discard(ws) - def is_unoccupied(self, ws: WorkerState, occupancy: float, nprocessing: int) -> bool: + def is_unoccupied( + self, ws: WorkerState, occupancy: float, nprocessing: int + ) -> bool: nthreads = ws.nthreads - return nprocessing < nthreads or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 + return ( + nprocessing < nthreads + or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 + ) def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: """ @@ -3293,7 +3390,9 @@ def _validate_ready(self, ts: TaskState) -> None: assert ts not in self.queued assert all(dts.who_has for dts in ts.dependencies) - def _add_to_processing(self, ts: TaskState, ws: WorkerState, stimulus_id: str) -> RecsMsgs: + def _add_to_processing( + self, ts: TaskState, ws: WorkerState, stimulus_id: str + ) -> RecsMsgs: """Set a task as processing on a worker and return the worker messages to send""" if self.validate: self._validate_ready(ts) @@ -3311,7 +3410,11 @@ def _add_to_processing(self, ts: TaskState, ws: WorkerState, stimulus_id: str) - ws.actors.add(ts) ndep_bytes = sum(dts.nbytes for dts in ts.dependencies) - if ws.memory_limit and ndep_bytes > ws.memory_limit and dask.config.get("distributed.worker.memory.terminate"): + if ( + ws.memory_limit + and ndep_bytes > ws.memory_limit + and dask.config.get("distributed.worker.memory.terminate") + ): # Note # ---- # This is a crude safety system, only meant to prevent order-of-magnitude @@ -3519,7 +3622,10 @@ def _task_to_msg(self, ts: TaskState) -> dict[str, Any]: "run_id": ts.run_id, "priority": ts.priority, "stimulus_id": f"compute-task-{time()}", - "who_has": {dts.key: tuple(ws.address for ws in (dts.who_has or ())) for dts in ts.dependencies}, + "who_has": { + dts.key: tuple(ws.address for ws in (dts.who_has or ())) + for dts in ts.dependencies + }, "nbytes": {dts.key: dts.nbytes for dts in ts.dependencies}, "run_spec": ToPickle(ts.run_spec), "resource_restrictions": ts.resource_restrictions, @@ -3691,10 +3797,16 @@ def __init__( self.services = {} self.scheduler_file = scheduler_file - self.worker_ttl = parse_timedelta(worker_ttl or dask.config.get("distributed.scheduler.worker-ttl")) - self.idle_timeout = parse_timedelta(idle_timeout or dask.config.get("distributed.scheduler.idle-timeout")) + self.worker_ttl = parse_timedelta( + worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") + ) + self.idle_timeout = parse_timedelta( + idle_timeout or dask.config.get("distributed.scheduler.idle-timeout") + ) self.idle_since = time() - self.no_workers_timeout = parse_timedelta(dask.config.get("distributed.scheduler.no-workers-timeout")) + self.no_workers_timeout = parse_timedelta( + dask.config.get("distributed.scheduler.no-workers-timeout") + ) self._no_workers_since = None self.time_started = self.idle_since # compatibility for dask-gateway @@ -3742,17 +3854,24 @@ def __init__( except ImportError: show_dashboard = False http_server_modules.append("distributed.http.scheduler.missing_bokeh") - routes = get_handlers(server=self, modules=http_server_modules, prefix=http_prefix) + routes = get_handlers( + server=self, modules=http_server_modules, prefix=http_prefix + ) self.start_http_server(routes, dashboard_address, default_port=8787) self.jupyter = jupyter if show_dashboard: - distributed.dashboard.scheduler.connect(self.http_application, self.http_server, self, prefix=http_prefix) + distributed.dashboard.scheduler.connect( + self.http_application, self.http_server, self, prefix=http_prefix + ) scheduler = self if self.jupyter: try: from jupyter_server.serverapp import ServerApp except ImportError: - raise ImportError("In order to use the Dask jupyter option you need to have jupyterlab installed") + raise ImportError( + "In order to use the Dask jupyter option you " + "need to have jupyterlab installed" + ) from traitlets.config import Config """HTTP handler to shut down the Jupyter server. @@ -3800,7 +3919,9 @@ async def post(self) -> None: argv=[], ) self._jupyter_server_application = j - shutdown_app = tornado.web.Application([(r"/jupyter/api/shutdown", ShutdownHandler)]) + shutdown_app = tornado.web.Application( + [(r"/jupyter/api/shutdown", ShutdownHandler)] + ) shutdown_app.settings = j.web_app.settings self.http_application.add_application(shutdown_app) self.http_application.add_application(j.web_app) @@ -4020,7 +4141,8 @@ def identity(self, n_workers: int = -1) -> dict[str, Any]: "total_threads": self.total_nthreads, "total_memory": self.total_memory, "workers": { - worker.address: worker.identity() for worker in itertools.islice(self.workers.values(), n_workers) + worker.address: worker.identity() + for worker in itertools.islice(self.workers.values(), n_workers) }, } return d @@ -4078,7 +4200,10 @@ async def get_cluster_state( workers_future.cancel() # Convert any RPC errors to strings - worker_states = {k: repr(v) if isinstance(v, Exception) else v for k, v in worker_states.items()} + worker_states = { + k: repr(v) if isinstance(v, Exception) else v + for k, v in worker_states.items() + } return { "scheduler": scheduler_state, @@ -4094,7 +4219,9 @@ async def dump_cluster_state_to_url( **storage_options: dict[str, Any], ) -> None: "Write a cluster state dump to an fsspec-compatible URL." - await cluster_dump.write_state(partial(self.get_cluster_state, exclude), url, format, **storage_options) + await cluster_dump.write_state( + partial(self.get_cluster_state, exclude), url, format, **storage_options + ) def get_worker_service_addr( self, worker: str, service_name: str, protocol: bool = False @@ -4162,7 +4289,9 @@ async def start_unsafe(self) -> Self: # formatting dashboard link can fail if distributed.dashboard.link # refers to non-existent env vars. except KeyError as e: - logger.warning(f"Failed to format dashboard link, unknown value: {e}") + logger.warning( + f"Failed to format dashboard link, unknown value: {e}" + ) link = f":{server.port}" else: link = f"{listen_ip}:{server.port}" @@ -4188,7 +4317,9 @@ def del_scheduler_file() -> None: await self.listen("tcp://localhost:0") os.environ["DASK_SCHEDULER_ADDRESS"] = self.listeners[-1].contact_address - await asyncio.gather(*[plugin.start(self) for plugin in list(self.plugins.values())]) + await asyncio.gather( + *[plugin.start(self) for plugin in list(self.plugins.values())] + ) self.start_periodic_callbacks() @@ -4220,11 +4351,15 @@ async def log_errors(func: Callable) -> None: except Exception: logger.exception("Plugin call failed during scheduler.close") - await asyncio.gather(*[log_errors(plugin.before_close) for plugin in list(self.plugins.values())]) + await asyncio.gather( + *[log_errors(plugin.before_close) for plugin in list(self.plugins.values())] + ) await self.preloads.teardown() - await asyncio.gather(*[log_errors(plugin.close) for plugin in list(self.plugins.values())]) + await asyncio.gather( + *[log_errors(plugin.close) for plugin in list(self.plugins.values())] + ) for pc in self.periodic_callbacks.values(): pc.stop() @@ -4297,21 +4432,25 @@ def heartbeat_worker( dh["last-seen"] = local_now frac = 1 / len(self.workers) - self.bandwidth = self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + self.bandwidth = ( + self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + ) for other, (bw, count) in metrics["bandwidth"]["workers"].items(): if (address, other) not in self.bandwidth_workers: self.bandwidth_workers[address, other] = bw / count else: alpha = (1 - frac) ** count - self.bandwidth_workers[address, other] = self.bandwidth_workers[address, other] * alpha + bw * ( - 1 - alpha - ) + self.bandwidth_workers[address, other] = self.bandwidth_workers[ + address, other + ] * alpha + bw * (1 - alpha) for typ, (bw, count) in metrics["bandwidth"]["types"].items(): if typ not in self.bandwidth_types: self.bandwidth_types[typ] = bw / count else: alpha = (1 - frac) ** count - self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * (1 - alpha) + self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( + 1 - alpha + ) ws.last_seen = local_now if executing is not None: @@ -4342,7 +4481,9 @@ def heartbeat_worker( # ws._nbytes is updated at a different time and sizeof() may not be accurate, # so size may be (temporarily) negative; floor it to zero. - size = max(0, metrics["memory"] - ws.nbytes + metrics["spilled_bytes"]["memory"]) + size = max( + 0, metrics["memory"] - ws.nbytes + metrics["spilled_bytes"]["memory"] + ) ws._memory_unmanaged_history.append((local_now, size)) if not memory_unmanaged_old: @@ -4487,7 +4628,9 @@ async def add_worker( logger.exception(exc, exc_info=exc) if ws.status == Status.running: - self.transitions(self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id) + self.transitions( + self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id + ) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) logger.info("Register worker addr: %s name: %s", ws.address, ws.name) @@ -4671,7 +4814,9 @@ def _create_taskstate_from_graph( # _generate_taskstates is not the only thing that calls new_task(). A # TaskState may have also been created by client_desires_keys or scatter, # and only later gained a run_spec. - span_annotations = spans_ext.observe_tasks(touched_tasks, span_metadata=span_metadata, code=code) + span_annotations = spans_ext.observe_tasks( + touched_tasks, span_metadata=span_metadata, code=code + ) # In case of TaskGroup collision, spans may have changed # FIXME: Is this used anywhere besides tests? if span_annotations: @@ -4777,7 +4922,9 @@ async def update_graph( }, client=client, ) - self.client_releases_keys(keys=keys, client=client, stimulus_id=stimulus_id) + self.client_releases_keys( + keys=keys, client=client, stimulus_id=stimulus_id + ) evt_msg = { "action": "update-graph", "stimulus_id": stimulus_id, @@ -4810,7 +4957,8 @@ async def update_graph( "start_timestamp_seconds": start, "materialization_duration_seconds": materialization_done - start, "ordering_duration_seconds": materialization_done - ordering_done, - "state_initialization_duration_seconds": ordering_done - task_state_created, + "state_initialization_duration_seconds": ordering_done + - task_state_created, "duration_seconds": task_state_created - start, } ) @@ -5063,7 +5211,9 @@ def _set_priorities( ) if self.validate and istask(ts.run_spec): - assert isinstance(ts.priority, tuple) and all(isinstance(el, (int, float)) for el in ts.priority) + assert isinstance(ts.priority, tuple) and all( + isinstance(el, (int, float)) for el in ts.priority + ) def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """Respond to an event which may have opened spots on worker threadpools @@ -5082,7 +5232,10 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """ if not self.queued: return - slots_available = sum(_task_slots_available(ws, self.WORKER_SATURATION) for ws in self.idle_task_count) + slots_available = sum( + _task_slots_available(ws, self.WORKER_SATURATION) + for ws in self.idle_task_count + ) if slots_available == 0: return @@ -5104,7 +5257,9 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: assert qts.state == "processing" assert not self.queued or self.queued.peek() != qts - def stimulus_task_finished(self, key: Key, worker: str, stimulus_id: str, run_id: int, **kwargs: Any) -> RecsMsgs: + def stimulus_task_finished( + self, key: Key, worker: str, stimulus_id: str, run_id: int, **kwargs: Any + ) -> RecsMsgs: """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s[%d] %s", key, run_id, worker) @@ -5115,7 +5270,8 @@ def stimulus_task_finished(self, key: Key, worker: str, stimulus_id: str, run_id ts = self.tasks.get(key) if ts is None or ts.state in ("released", "queued", "no-worker"): logger.debug( - "Received already computed task, worker: %s, state: %s, key: %s, who_has: %s", + "Received already computed task, worker: %s, state: %s" + ", key: %s, who_has: %s", worker, ts.state if ts else "forgotten", key, @@ -5130,7 +5286,7 @@ def stimulus_task_finished(self, key: Key, worker: str, stimulus_id: str, run_id ] elif ts.state == "erred": logger.debug( - "Received already erred task, worker: %s, key: %s", + "Received already erred task, worker: %s" ", key: %s", worker, key, ) @@ -5207,7 +5363,9 @@ def stimulus_task_erred( **kwargs, ) - def stimulus_retry(self, keys: Collection[Key], client: str | None = None) -> tuple[Key, ...]: + def stimulus_retry( + self, keys: Collection[Key], client: str | None = None + ) -> tuple[Key, ...]: logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -5285,10 +5443,14 @@ async def remove_worker( ws = self.workers[address] - logger.info(f"Remove worker addr: {ws.address} name: {ws.name} ({stimulus_id=})") + logger.info( + f"Remove worker addr: {ws.address} name: {ws.name} ({stimulus_id=})" + ) if close: with suppress(AttributeError, CommClosedError): - self.stream_comms[address].send({"op": "close", "reason": "scheduler-remove-worker"}) + self.stream_comms[address].send( + {"op": "close", "reason": "scheduler-remove-worker"} + ) self.remove_resources(address) @@ -5342,7 +5504,8 @@ async def remove_worker( ) recommendations.update(r) logger.error( - "Task %s marked as failed because %d workers died while trying to run it", + "Task %s marked as failed because %d workers died" + " while trying to run it", ts.key, ts.suspicious, ) @@ -5395,7 +5558,9 @@ async def remove_worker( for plugin in list(self.plugins.values()): try: try: - result = plugin.remove_worker(scheduler=self, worker=address, stimulus_id=stimulus_id) + result = plugin.remove_worker( + scheduler=self, worker=address, stimulus_id=stimulus_id + ) except TypeError: parameters = inspect.signature(plugin.remove_worker).parameters if "stimulus_id" not in parameters and not any( @@ -5427,9 +5592,13 @@ async def remove_worker_from_events() -> None: if address not in self.workers: self._broker.truncate(address) - cleanup_delay = parse_timedelta(dask.config.get("distributed.scheduler.events-cleanup-delay")) + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") + ) - self._ongoing_background_tasks.call_later(cleanup_delay, remove_worker_from_events) + self._ongoing_background_tasks.call_later( + cleanup_delay, remove_worker_from_events + ) logger.debug("Removed worker %s", ws) for w in self.workers: @@ -5444,7 +5613,9 @@ async def remove_worker_from_events() -> None: return "OK" - def stimulus_cancel(self, keys: Collection[Key], client: str, force: bool, reason: str, msg: str) -> None: + def stimulus_cancel( + self, keys: Collection[Key], client: str, force: bool, reason: str, msg: str + ) -> None: """Stop execution on a list of keys""" logger.info("Client %s requests to cancel %d keys", client, len(keys)) self.log_event(client, {"action": "cancel", "count": len(keys), "force": force}) @@ -5506,7 +5677,9 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None: if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - def client_releases_keys(self, keys: Collection[Key], client: str, stimulus_id: str | None = None) -> None: + def client_releases_keys( + self, keys: Collection[Key], client: str, stimulus_id: str | None = None + ) -> None: """Remove keys from client desired list""" stimulus_id = stimulus_id or f"client-releases-keys-{time()}" if not isinstance(keys, list): @@ -5565,7 +5738,9 @@ def validate_queued(self, key: Key) -> None: assert not ts.waiting_on assert not ts.who_has assert not ts.processing_on - assert not (ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions) + assert not ( + ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions + ) for dts in ts.dependencies: assert dts.who_has assert ts in (dts.waiters or ()) @@ -5591,7 +5766,9 @@ def validate_memory(self, key: Key) -> None: assert ts not in self.unrunnable assert ts not in self.queued for dts in ts.dependents: - assert (dts in (ts.waiters or ())) == (dts.state in ("waiting", "queued", "processing", "no-worker")) + assert (dts in (ts.waiters or ())) == ( + dts.state in ("waiting", "queued", "processing", "no-worker") + ) assert ts not in (dts.waiting_on or ()) def validate_no_worker(self, key: Key) -> None: @@ -5622,7 +5799,9 @@ def validate_key(self, key: Key, ts: TaskState | None = None) -> None: try: func = getattr(self, "validate_" + ts.state.replace("-", "_")) except AttributeError: - logger.error("self.validate_%s not found", ts.state.replace("-", "_")) + logger.error( + "self.validate_%s not found", ts.state.replace("-", "_") + ) else: func(key) except Exception as e: @@ -5688,9 +5867,9 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert task_prefix_counts.keys() == self._task_prefix_count_global.keys() for name, global_count in self._task_prefix_count_global.items(): - assert task_prefix_counts[name] == global_count, ( - f"{name}: {task_prefix_counts[name]} (wss), {global_count} (global)" - ) + assert ( + task_prefix_counts[name] == global_count + ), f"{name}: {task_prefix_counts[name]} (wss), {global_count} (global)" for ws in self.running: assert ws.status == Status.running @@ -5713,7 +5892,10 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert cs.client_key == c a = {w: ws.nbytes for w, ws in self.workers.items()} - b = {w: sum(ts.get_nbytes() for ts in ws.has_what) for w, ws in self.workers.items()} + b = { + w: sum(ts.get_nbytes() for ts in ws.has_what) + for w, ws in self.workers.items() + } assert a == b, (a, b) if self.transition_counter_max: @@ -5723,7 +5905,9 @@ def validate_state(self, allow_overlap: bool = False) -> None: # Manage Messages # ################### - def report(self, msg: dict, ts: TaskState | None = None, client: str | None = None) -> None: + def report( + self, msg: dict, ts: TaskState | None = None, client: str | None = None + ) -> None: """ Publish updates to all listening Queues and Comms @@ -5745,7 +5929,9 @@ def report(self, msg: dict, ts: TaskState | None = None, client: str | None = No # Notify clients interested in key (including `client`) # Note that, if report() was called by update_graph(), `client` won't be in # ts.who_wants yet. - client_keys = [cs.client_key for cs in ts.who_wants or () if cs.client_key != client] + client_keys = [ + cs.client_key for cs in ts.who_wants or () if cs.client_key != client + ] if client is not None: client_keys.append(client) @@ -5758,9 +5944,13 @@ def report(self, msg: dict, ts: TaskState | None = None, client: str | None = No # logger.debug("Scheduler sends message to client %s: %s", k, msg) except CommClosedError: if self.status == Status.running: - logger.critical("Closed comm %r while trying to write %s", c, msg, exc_info=True) + logger.critical( + "Closed comm %r while trying to write %s", c, msg, exc_info=True + ) - async def add_client(self, comm: Comm, client: str, versions: dict[str, Any]) -> None: + async def add_client( + self, comm: Comm, client: str, versions: dict[str, Any] + ) -> None: """Add client to network We listen to all future messages from this Comm. @@ -5838,9 +6028,13 @@ async def remove_client_from_events() -> None: if client not in self.clients: self._broker.truncate(client) - cleanup_delay = parse_timedelta(dask.config.get("distributed.scheduler.events-cleanup-delay")) + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") + ) if not self._ongoing_background_tasks.closed: - self._ongoing_background_tasks.call_later(cleanup_delay, remove_client_from_events) + self._ongoing_background_tasks.call_later( + cleanup_delay, remove_client_from_events + ) def send_task_to_worker(self, worker: str, ts: TaskState) -> None: """Send a single computational task to a worker""" @@ -5858,13 +6052,17 @@ def send_task_to_worker(self, worker: str, ts: TaskState) -> None: def handle_uncaught_error(self, **msg: Any) -> None: logger.exception(clean_exception(**msg)[1]) - def handle_task_finished(self, key: Key, worker: str, stimulus_id: str, **msg: Any) -> None: + def handle_task_finished( + self, key: Key, worker: str, stimulus_id: str, **msg: Any + ) -> None: if worker not in self.workers: return if self.validate: self.validate_key(key) - r: tuple = self.stimulus_task_finished(key=key, worker=worker, stimulus_id=stimulus_id, **msg) + r: tuple = self.stimulus_task_finished( + key=key, worker=worker, stimulus_id=stimulus_id, **msg + ) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) @@ -5903,7 +6101,9 @@ def handle_long_running( duration accounting as if the task has stopped. """ if worker not in self.workers: - logger.debug("Received long-running signal from unknown worker %s. Ignoring.", worker) + logger.debug( + "Received long-running signal from unknown worker %s. Ignoring.", worker + ) return if key not in self.tasks: @@ -5941,7 +6141,9 @@ def handle_long_running( self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) - def handle_worker_status_change(self, status: str | Status, worker: str | WorkerState, stimulus_id: str) -> None: + def handle_worker_status_change( + self, status: str | Status, worker: str | WorkerState, stimulus_id: str + ) -> None: ws = self.workers.get(worker) if isinstance(worker, str) else worker if not ws: return @@ -5964,7 +6166,9 @@ def handle_worker_status_change(self, status: str | Status, worker: str | Worker if ws.status == Status.running: self.running.add(ws) self.check_idle_saturated(ws) - self.transitions(self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id) + self.transitions( + self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id + ) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) else: self.running.discard(ws) @@ -5973,7 +6177,9 @@ def handle_worker_status_change(self, status: str | Status, worker: str | Worker self.saturated.discard(ws) self._refresh_no_workers_since() - def handle_request_refresh_who_has(self, keys: Iterable[Key], worker: str, stimulus_id: str) -> None: + def handle_request_refresh_who_has( + self, keys: Iterable[Key], worker: str, stimulus_id: str + ) -> None: """Request from a Worker to refresh the who_has for some keys. Not to be confused with scheduler.who_has, which is a dedicated comm RPC request from a Client. @@ -6022,7 +6228,9 @@ async def handle_worker(self, comm: Comm, worker: str) -> None: finally: if worker in self.stream_comms: worker_comm.abort() - await self.remove_worker(worker, stimulus_id=f"handle-worker-cleanup-{time()}") + await self.remove_worker( + worker, stimulus_id=f"handle-worker-cleanup-{time()}" + ) def add_plugin( self, @@ -6083,7 +6291,9 @@ def remove_plugin(self, name: str | None = None) -> None: try: del self.plugins[name] except KeyError: - raise ValueError(f"Could not find plugin {name!r} among the current scheduler plugins") + raise ValueError( + f"Could not find plugin {name!r} among the current scheduler plugins" + ) async def register_scheduler_plugin( self, @@ -6145,7 +6355,9 @@ def client_send(self, client: str, msg: dict) -> None: c.send(msg) except CommClosedError: if self.status == Status.running: - logger.critical("Closed comm %r while trying to write %s", c, msg, exc_info=True) + logger.critical( + "Closed comm %r while trying to write %s", c, msg, exc_info=True + ) def send_all(self, client_msgs: Msgs, worker_msgs: Msgs) -> None: """Send messages to client and workers""" @@ -6223,10 +6435,14 @@ async def scatter( n = len(workers) if broadcast is True else broadcast await self.replicate(keys=keys, workers=workers, n=n) - self.log_event([client, "all"], {"action": "scatter", "client": client, "count": len(data)}) + self.log_event( + [client, "all"], {"action": "scatter", "client": client, "count": len(data)} + ) return keys - async def gather(self, keys: Collection[Key], serializers: list[str] | None = None) -> dict[Key, object]: + async def gather( + self, keys: Collection[Key], serializers: list[str] | None = None + ) -> dict[Key, object]: """Collect data from workers to the scheduler""" data = {} missing_keys = list(keys) @@ -6247,7 +6463,9 @@ async def gather(self, keys: Collection[Key], serializers: list[str] | None = No missing_keys, new_failed_keys, new_missing_workers, - ) = await gather_from_workers(who_has, rpc=self.rpc, serializers=serializers) + ) = await gather_from_workers( + who_has, rpc=self.rpc, serializers=serializers + ) data.update(new_data) failed_keys += new_failed_keys missing_workers.update(new_missing_workers) @@ -6257,7 +6475,10 @@ async def gather(self, keys: Collection[Key], serializers: list[str] | None = No if not failed_keys: return {"status": "OK", "data": data} - failed_states = {key: self.tasks[key].state if key in self.tasks else "forgotten" for key in failed_keys} + failed_states = { + key: self.tasks[key].state if key in self.tasks else "forgotten" + for key in failed_keys + } logger.error("Couldn't gather keys: %s", failed_states) return {"status": "error", "keys": list(failed_keys)} @@ -6367,16 +6588,24 @@ async def restart_workers( workers = list(set(workers).intersection(self.workers)) logger.info(f"Restarting {len(workers)} workers: {workers} ({stimulus_id=}") - nanny_workers = {addr: self.workers[addr].nanny for addr in workers if self.workers[addr].nanny} + nanny_workers = { + addr: self.workers[addr].nanny + for addr in workers + if self.workers[addr].nanny + } # Close non-Nanny workers. We have no way to restart them, so we just let them # go, and assume a deployment system is going to restart them for us. no_nanny_workers = [addr for addr in workers if addr not in nanny_workers] if no_nanny_workers: logger.warning( - f"Workers {no_nanny_workers} do not use a nanny and will be terminated without restarting them" + f"Workers {no_nanny_workers} do not use a nanny and will be terminated " + "without restarting them" ) await asyncio.gather( - *(self.remove_worker(address=addr, stimulus_id=stimulus_id) for addr in no_nanny_workers) + *( + self.remove_worker(address=addr, stimulus_id=stimulus_id) + for addr in no_nanny_workers + ) ) out: dict[str, Literal["OK", "removed", "timed out"]] out = {addr: "removed" for addr in no_nanny_workers} @@ -6386,7 +6615,9 @@ async def restart_workers( async with contextlib.AsyncExitStack() as stack: nannies = await asyncio.gather( *( - stack.enter_async_context(rpc(nanny_address, connection_args=self.connection_args)) + stack.enter_async_context( + rpc(nanny_address, connection_args=self.connection_args) + ) for nanny_address in nanny_workers.values() ) ) @@ -6422,8 +6653,16 @@ async def restart_workers( raise resp if bad_nannies: - logger.error(f"Workers {list(bad_nannies)} did not shut down within {timeout}s; force closing") - await asyncio.gather(*(self.remove_worker(addr, stimulus_id=stimulus_id) for addr in bad_nannies)) + logger.error( + f"Workers {list(bad_nannies)} did not shut down within {timeout}s; " + "force closing" + ) + await asyncio.gather( + *( + self.remove_worker(addr, stimulus_id=stimulus_id) + for addr in bad_nannies + ) + ) if on_error == "raise": raise TimeoutError( f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not " @@ -6432,10 +6671,15 @@ async def restart_workers( if client: self.log_event(client, {"action": "restart-workers", "workers": workers}) - self.log_event("all", {"action": "restart-workers", "workers": workers, "client": client}) + self.log_event( + "all", {"action": "restart-workers", "workers": workers, "client": client} + ) if not wait_for_workers: - logger.info(f"Workers restart finished (did not wait for new workers) ({stimulus_id=}") + logger.info( + "Workers restart finished (did not wait for new workers) " + f"({stimulus_id=}" + ) return out # NOTE: if new (unrelated) workers join while we're waiting, we may return @@ -6504,7 +6748,9 @@ async def broadcast( ERROR = object() - reuse_broadcast_comm = dask.config.get("distributed.scheduler.reuse-broadcast-comm", False) + reuse_broadcast_comm = dask.config.get( + "distributed.scheduler.reuse-broadcast-comm", False + ) close = not reuse_broadcast_comm async def send_message(addr: str) -> Any: @@ -6512,7 +6758,9 @@ async def send_message(addr: str) -> Any: comm = await self.rpc.connect(addr) comm.name = "Scheduler Broadcast" try: - resp = await send_recv(comm, close=close, serializers=serializers, **msg) + resp = await send_recv( + comm, close=close, serializers=serializers, **msg + ) finally: self.rpc.reuse(addr, comm) return resp @@ -6528,7 +6776,8 @@ async def send_message(addr: str) -> Any: return ERROR else: raise ValueError( - f"on_error must be 'raise', 'return', 'return_pickle', or 'ignore'; got {on_error!r}" + "on_error must be 'raise', 'return', 'return_pickle', " + f"or 'ignore'; got {on_error!r}" ) results = await All([send_message(address) for address in addresses]) @@ -6544,7 +6793,9 @@ async def proxy( d = await self.broadcast(msg=msg, workers=[worker], serializers=serializers) return d[worker] - async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[str]]) -> set: + async def gather_on_worker( + self, worker_address: str, who_has: dict[Key, list[str]] + ) -> set: """Peer-to-peer copy of keys from multiple workers to a single worker Parameters @@ -6560,12 +6811,15 @@ async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[st set of keys that failed to be copied """ try: - result = await retry_operation(self.rpc(addr=worker_address).gather, who_has=who_has) + result = await retry_operation( + self.rpc(addr=worker_address).gather, who_has=who_has + ) except OSError as e: # This can happen e.g. if the worker is going through controlled shutdown; # it doesn't necessarily mean that it went unexpectedly missing logger.warning( - f"Communication with worker {worker_address} failed during replication: {e.__class__.__name__}: {e}" + f"Communication with worker {worker_address} failed during " + f"replication: {e.__class__.__name__}: {e}" ) return set(who_has) @@ -6580,7 +6834,9 @@ async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[st elif result["status"] == "partial-fail": keys_failed = set(result["keys"]) keys_ok = who_has.keys() - keys_failed - logger.warning(f"Worker {worker_address} failed to acquire keys: {result['keys']}") + logger.warning( + f"Worker {worker_address} failed to acquire keys: {result['keys']}" + ) else: # pragma: nocover raise ValueError(f"Unexpected message from {worker_address}: {result}") @@ -6593,7 +6849,9 @@ async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[st return keys_failed - async def delete_worker_data(self, worker_address: str, keys: Collection[Key], stimulus_id: str) -> None: + async def delete_worker_data( + self, worker_address: str, keys: Collection[Key], stimulus_id: str + ) -> None: """Delete data from a worker and update the corresponding worker/task states Parameters @@ -6613,7 +6871,8 @@ async def delete_worker_data(self, worker_address: str, keys: Collection[Key], s # This can happen e.g. if the worker is going through controlled shutdown; # it doesn't necessarily mean that it went unexpectedly missing logger.warning( - f"Communication with worker {worker_address} failed during replication: {e.__class__.__name__}: {e}" + f"Communication with worker {worker_address} failed during " + f"replication: {e.__class__.__name__}: {e}" ) return @@ -6719,7 +6978,9 @@ async def rebalance( keys = set(keys) # unless already a set-like if not keys: return {"status": "OK"} - missing_data = [k for k in keys if k not in self.tasks or not self.tasks[k].who_has] + missing_data = [ + k for k in keys if k not in self.tasks or not self.tasks[k].who_has + ] if missing_data: return {"status": "partial-fail", "keys": missing_data} @@ -6794,7 +7055,9 @@ def _rebalance_find_msgs( # unmanaged memory that appeared over the last 30 seconds # (distributed.worker.memory.recent-to-old-time). # This lets us ignore temporary spikes caused by task heap usage. - memory_by_worker = [(ws, getattr(ws.memory, self.MEMORY_REBALANCE_MEASURE)) for ws in workers] + memory_by_worker = [ + (ws, getattr(ws.memory, self.MEMORY_REBALANCE_MEASURE)) for ws in workers + ] mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker) for ws, ws_memory in memory_by_worker: @@ -6807,12 +7070,18 @@ def _rebalance_find_msgs( sender_min = 0.0 recipient_max = math.inf - if ws._has_what and ws_memory >= mean_memory + half_gap and ws_memory >= sender_min: + if ( + ws._has_what + and ws_memory >= mean_memory + half_gap + and ws_memory >= sender_min + ): # This may send the worker below sender_min (by design) snd_bytes_max = mean_memory - ws_memory # negative snd_bytes_min = snd_bytes_max + half_gap # negative # See definition of senders above - senders.append((snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what))) + senders.append( + (snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what)) + ) elif ws_memory < mean_memory - half_gap and ws_memory < recipient_max: # This may send the worker above recipient_max (by design) rec_bytes_max = ws_memory - mean_memory # negative @@ -6930,7 +7199,9 @@ async def _rebalance_move_data( FIXME this method is not robust when the cluster is not idle. """ # {recipient address: {key: [sender address, ...]}} - to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict(lambda: defaultdict(list)) + to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict( + lambda: defaultdict(list) + ) for snd_ws, rec_ws, ts in msgs: to_recipients[rec_ws.address][ts.key].append(snd_ws.address) failed_keys_by_recipient = dict( @@ -6952,7 +7223,9 @@ async def _rebalance_move_data( to_senders[snd_ws.address].append(ts.key) # Note: this never raises exceptions - await asyncio.gather(*(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items())) + await asyncio.gather( + *(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items()) + ) for r, v in to_recipients.items(): self.log_event(r, {"action": "rebalance", "who_has": v}) @@ -7031,13 +7304,17 @@ async def replicate( assert ts.who_has is not None del_candidates = tuple(ts.who_has & workers) if len(del_candidates) > n: - for ws in random.sample(del_candidates, len(del_candidates) - n): + for ws in random.sample( + del_candidates, len(del_candidates) - n + ): del_worker_tasks[ws].add(ts) # Note: this never raises exceptions await asyncio.gather( *[ - self.delete_worker_data(ws.address, [t.key for t in tasks], stimulus_id) + self.delete_worker_data( + ws.address, [t.key for t in tasks], stimulus_id + ) for ws, tasks in del_worker_tasks.items() ] ) @@ -7062,7 +7339,9 @@ async def replicate( assert count > 0 for ws in random.sample(tuple(workers - ts.who_has), count): - gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has] + gathers[ws.address][ts.key] = [ + wws.address for wws in ts.who_has + ] await asyncio.gather( *( @@ -7318,7 +7597,8 @@ async def retire_workers( raise TypeError("names and workers are mutually exclusive") if (names is not None or workers is not None) and kwargs: raise TypeError( - f"Parameters for workers_to_close() are mutually exclusive with names and workers: {kwargs}" + "Parameters for workers_to_close() are mutually exclusive with " + f"names and workers: {kwargs}" ) stimulus_id = stimulus_id or f"retire-workers-{time()}" @@ -7338,16 +7618,24 @@ async def retire_workers( stimulus_id, workers, ) - wss = {self.workers[address] for address in workers if address in self.workers} + wss = { + self.workers[address] + for address in workers + if address in self.workers + } else: - wss = {self.workers[address] for address in self.workers_to_close(**kwargs)} + wss = { + self.workers[address] for address in self.workers_to_close(**kwargs) + } if not wss: return {} stop_amm = False amm: ActiveMemoryManagerExtension | None = self.extensions.get("amm") if not amm or not amm.running: - amm = ActiveMemoryManagerExtension(self, policies=set(), register=False, start=True, interval=2.0) + amm = ActiveMemoryManagerExtension( + self, policies=set(), register=False, start=True, interval=2.0 + ) stop_amm = True try: @@ -7359,7 +7647,9 @@ async def retire_workers( # Change Worker.status to closing_gracefully. Immediately set # the same on the scheduler to prevent race conditions. prev_status = ws.status - self.handle_worker_status_change(Status.closing_gracefully, ws, stimulus_id) + self.handle_worker_status_change( + Status.closing_gracefully, ws, stimulus_id + ) # FIXME: We should send a message to the nanny first; # eventually workers won't be able to close their own nannies. self.stream_comms[ws.address].send( @@ -7450,10 +7740,14 @@ async def _track_retire_worker( ) return ws.address, "no-recipients", ws.identity() - logger.debug(f"All unique keys on worker {ws.address!r} have been replicated elsewhere") + logger.debug( + f"All unique keys on worker {ws.address!r} have been replicated elsewhere" + ) if remove: - await self.remove_worker(ws.address, expected=True, close=close, stimulus_id=stimulus_id) + await self.remove_worker( + ws.address, expected=True, close=close, stimulus_id=stimulus_id + ) elif close: self.close_worker(ws.address) @@ -7595,7 +7889,9 @@ async def feed( if teardown: teardown(self, state) # type: ignore - def log_worker_event(self, worker: str, topic: str | Collection[str], msg: Any) -> None: + def log_worker_event( + self, worker: str, topic: str | Collection[str], msg: Any + ) -> None: if isinstance(msg, dict) and worker != topic: msg["worker"] = worker self.log_event(topic, msg) @@ -7608,25 +7904,46 @@ def subscribe_worker_status(self, comm: Comm) -> dict[str, Any]: del v["last_seen"] return ident - def get_processing(self, workers: Iterable[str] | None = None) -> dict[str, list[Key]]: + def get_processing( + self, workers: Iterable[str] | None = None + ) -> dict[str, list[Key]]: if workers is not None: workers = set(map(self.coerce_address, workers)) return {w: [ts.key for ts in self.workers[w].processing] for w in workers} else: - return {w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()} + return { + w: [ts.key for ts in ws.processing] for w, ws in self.workers.items() + } def get_who_has(self, keys: Iterable[Key] | None = None) -> dict[Key, list[str]]: if keys is not None: return { - key: ([ws.address for ws in self.tasks[key].who_has or ()] if key in self.tasks else []) for key in keys + key: ( + [ws.address for ws in self.tasks[key].who_has or ()] + if key in self.tasks + else [] + ) + for key in keys } else: - return {key: [ws.address for ws in ts.who_has or ()] for key, ts in self.tasks.items()} + return { + key: [ws.address for ws in ts.who_has or ()] + for key, ts in self.tasks.items() + } - def get_has_what(self, workers: Iterable[str] | None = None) -> dict[str, list[Key]]: + def get_has_what( + self, workers: Iterable[str] | None = None + ) -> dict[str, list[Key]]: if workers is not None: workers = map(self.coerce_address, workers) - return {w: ([ts.key for ts in self.workers[w].has_what] if w in self.workers else []) for w in workers} + return { + w: ( + [ts.key for ts in self.workers[w].has_what] + if w in self.workers + else [] + ) + for w in workers + } else: return {w: [ts.key for ts in ws.has_what] for w, ws in self.workers.items()} @@ -7637,9 +7954,13 @@ def get_ncores(self, workers: Iterable[str] | None = None) -> dict[str, int]: else: return {w: ws.nthreads for w, ws in self.workers.items()} - def get_ncores_running(self, workers: Iterable[str] | None = None) -> dict[str, int]: + def get_ncores_running( + self, workers: Iterable[str] | None = None + ) -> dict[str, int]: ncores = self.get_ncores(workers=workers) - return {w: n for w, n in ncores.items() if self.workers[w].status == Status.running} + return { + w: n for w, n in ncores.items() if self.workers[w].status == Status.running + } async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, Any]: workers: dict[str, list[Key] | None] @@ -7666,7 +7987,9 @@ async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, A if not workers: return {} - results = await asyncio.gather(*(self.rpc(w).call_stack(keys=v) for w, v in workers.items())) + results = await asyncio.gather( + *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) + ) response = {w: r for w, r in zip(workers, results) if r} return response @@ -7707,7 +8030,9 @@ async def benchmark_hardware(self) -> dict[str, dict[str, float]]: # implementing logic based on IP addresses would not necessarily help. # Randomize the connections to even out the mean measures. random.shuffle(workers) - futures = [self.rpc(a).benchmark_network(address=b) for a, b in partition(2, workers)] + futures = [ + self.rpc(a).benchmark_network(address=b) for a, b in partition(2, workers) + ] responses = await asyncio.gather(*futures) for d in responses: @@ -7716,12 +8041,17 @@ async def benchmark_hardware(self) -> dict[str, dict[str, float]]: result = {} for mode in out: - result[mode] = {size: sum(durations) / len(durations) for size, durations in out[mode].items()} + result[mode] = { + size: sum(durations) / len(durations) + for size, durations in out[mode].items() + } return result @log_errors - def get_nbytes(self, keys: Iterable[Key] | None = None, summary: bool = True) -> dict[Key, int]: + def get_nbytes( + self, keys: Iterable[Key] | None = None, summary: bool = True + ) -> dict[Key, int]: if keys is not None: result = {k: self.tasks[k].nbytes for k in keys} else: @@ -7807,7 +8137,9 @@ def get_task_prefix_states(self) -> dict[str, dict[str, int]]: return state def get_task_status(self, keys: Iterable[Key]) -> dict[Key, TaskStateState | None]: - return {key: (self.tasks[key].state if key in self.tasks else None) for key in keys} + return { + key: (self.tasks[key].state if key in self.tasks else None) for key in keys + } def get_task_stream( self, @@ -7830,11 +8162,14 @@ def start_task_metadata(self, name: str) -> None: def stop_task_metadata(self, name: str | None = None) -> dict: plugins = [ - p for p in list(self.plugins.values()) if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name + p + for p in list(self.plugins.values()) + if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name ] if len(plugins) != 1: raise ValueError( - f"Expected to find exactly one CollectTaskMetaDataPlugin with name {name} but found {len(plugins)}." + "Expected to find exactly one CollectTaskMetaDataPlugin " + f"with name {name} but found {len(plugins)}." ) plugin = plugins[0] @@ -7859,10 +8194,14 @@ async def register_worker_plugin( self.worker_plugins[name] = plugin - responses = await self.broadcast(msg=dict(op="plugin-add", plugin=plugin, name=name)) + responses = await self.broadcast( + msg=dict(op="plugin-add", plugin=plugin, name=name) + ) return responses - async def unregister_worker_plugin(self, comm: None, name: str) -> dict[str, ErrorMessage | OKMessage]: + async def unregister_worker_plugin( + self, comm: None, name: str + ) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.worker_plugins.pop(name) @@ -7894,21 +8233,27 @@ async def register_nanny_plugin( async with self._starting_nannies_cond: if self._starting_nannies: logger.info("Waiting for Nannies to start %s", self._starting_nannies) - await self._starting_nannies_cond.wait_for(lambda: not self._starting_nannies) + await self._starting_nannies_cond.wait_for( + lambda: not self._starting_nannies + ) responses = await self.broadcast( msg=dict(op="plugin_add", plugin=plugin, name=name), nanny=True, ) return responses - async def unregister_nanny_plugin(self, comm: None, name: str) -> dict[str, ErrorMessage | OKMessage]: + async def unregister_nanny_plugin( + self, comm: None, name: str + ) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.nanny_plugins.pop(name) except KeyError: raise ValueError(f"The nanny plugin {name} does not exist") - responses = await self.broadcast(msg=dict(op="plugin_remove", name=name), nanny=True) + responses = await self.broadcast( + msg=dict(op="plugin_remove", name=name), nanny=True + ) return responses def transition( @@ -7933,7 +8278,9 @@ def transition( -------- Scheduler.transitions: transitive version of this function """ - recommendations, client_msgs, worker_msgs = self._transition(key, finish, stimulus_id, **kwargs) + recommendations, client_msgs, worker_msgs = self._transition( + key, finish, stimulus_id, **kwargs + ) self.send_all(client_msgs, worker_msgs) return recommendations @@ -7956,7 +8303,9 @@ async def get_story(self, keys_or_stimuli: Iterable[Key | str]) -> list[Transiti """ return self.story(*keys_or_stimuli) - def _reschedule(self, key: Key, worker: str | None = None, *, stimulus_id: str) -> None: + def _reschedule( + self, key: Key, worker: str | None = None, *, stimulus_id: str + ) -> None: """Reschedule a task. This function should only be used when the task has already been released in @@ -7968,7 +8317,8 @@ def _reschedule(self, key: Key, worker: str | None = None, *, stimulus_id: str) ts = self.tasks[key] except KeyError: logger.warning( - f"Attempting to reschedule task {key!r}, which was not found on the scheduler. Aborting reschedule." + f"Attempting to reschedule task {key!r}, which was not " + "found on the scheduler. Aborting reschedule." ) return if ts.state != "processing": @@ -7983,7 +8333,9 @@ def _reschedule(self, key: Key, worker: str | None = None, *, stimulus_id: str) # Utility functions # ##################### - def add_resources(self, worker: str, resources: dict | None = None) -> Literal["OK"]: + def add_resources( + self, worker: str, resources: dict | None = None + ) -> Literal["OK"]: ws = self.workers[worker] if resources: ws.resources.update(resources) @@ -8065,7 +8417,10 @@ async def get_profile( ) results = await asyncio.gather( - *(self.rpc(w).profile(start=start, stop=stop, key=key, server=server) for w in workers), + *( + self.rpc(w).profile(start=start, stop=stop, key=key, server=server) + for w in workers + ), return_exceptions=True, ) @@ -8085,7 +8440,9 @@ async def get_profile_metadata( stop: float | None = None, profile_cycle_interval: str | float | None = None, ) -> dict[str, Any]: - dt = profile_cycle_interval or dask.config.get("distributed.worker.profile.cycle") + dt = profile_cycle_interval or dask.config.get( + "distributed.worker.profile.cycle" + ) dt = parse_timedelta(dt, default="ms") if workers is None: @@ -8108,7 +8465,9 @@ async def get_profile_metadata( ) ] - keys: dict[Key, list[list]] = {k: [] for v in results for t, d in v["keys"] for k in d} + keys: dict[Key, list[list]] = { + k: [] for v in results for t, d in v["keys"] for k in d + } groups1 = [v["keys"] for v in results] groups2 = list(merge_sorted(*groups1, key=first)) @@ -8125,7 +8484,9 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} - async def performance_report(self, start: float, last_count: int, code: str = "", mode: str | None = None) -> str: + async def performance_report( + self, start: float, last_count: int, code: str = "", mode: str | None = None + ) -> str: stop = time() # Profiles compute_d, scheduler_d, workers_d = await asyncio.gather( @@ -8142,7 +8503,9 @@ def profile_to_figure(state: object) -> object: figure, source = profile.plot_figure(data, sizing_mode="stretch_both") return figure - compute, scheduler, workers = map(profile_to_figure, (compute_d, scheduler_d, workers_d)) + compute, scheduler, workers = map( + profile_to_figure, (compute_d, scheduler_d, workers_d) + ) del compute_d, scheduler_d, workers_d # Task stream @@ -8226,10 +8589,16 @@ def profile_to_figure(state: object) -> object: html = TabPanel(child=html, title="Summary") compute = TabPanel(child=compute, title="Worker Profile (compute)") workers = TabPanel(child=workers, title="Worker Profile (administrative)") - scheduler = TabPanel(child=scheduler, title="Scheduler Profile (administrative)") + scheduler = TabPanel( + child=scheduler, title="Scheduler Profile (administrative)" + ) task_stream = TabPanel(child=task_stream, title="Task Stream") - bandwidth_workers = TabPanel(child=bandwidth_workers.root, title="Bandwidth (Workers)") - bandwidth_types = TabPanel(child=bandwidth_types.root, title="Bandwidth (Types)") + bandwidth_workers = TabPanel( + child=bandwidth_workers.root, title="Bandwidth (Workers)" + ) + bandwidth_types = TabPanel( + child=bandwidth_types.root, title="Bandwidth (Types)" + ) system = TabPanel(child=sysmon.root, title="System") logs = TabPanel(child=logs.root, title="Scheduler Logs") @@ -8253,7 +8622,9 @@ def profile_to_figure(state: object) -> object: with tmpfile(extension=".html") as fn: output_file(filename=fn, title="Dask Performance Report", mode=mode) - template_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates") + template_directory = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates" + ) template_environment = get_env() template_environment.loader.searchpath.append(template_directory) template = template_environment.get_template("performance_report.html") @@ -8264,8 +8635,12 @@ def profile_to_figure(state: object) -> object: return data - async def get_worker_logs(self, n: int | None = None, workers: list | None = None, nanny: bool = False) -> dict: - results = await self.broadcast(msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny) + async def get_worker_logs( + self, n: int | None = None, workers: list | None = None, nanny: bool = False + ) -> dict: + results = await self.broadcast( + msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny + ) return results def log_event(self, topic: str | Collection[str], msg: Any) -> None: @@ -8302,11 +8677,16 @@ def get_events( ) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]: return self._broker.get_events(topic) - async def get_worker_monitor_info(self, recent: bool = False, starts: dict | None = None) -> dict: + async def get_worker_monitor_info( + self, recent: bool = False, starts: dict | None = None + ) -> dict: if starts is None: starts = {} results = await asyncio.gather( - *(self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) for w in self.workers) + *( + self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) + for w in self.workers + ) ) return dict(zip(self.workers, results)) @@ -8359,7 +8739,11 @@ def check_idle(self) -> float | None: self.idle_since = None return None - if self.queued or self.unrunnable or any(ws.processing for ws in self.workers.values()): + if ( + self.queued + or self.unrunnable + or any(ws.processing for ws in self.workers.values()) + ): self.idle_since = None return None @@ -8368,7 +8752,9 @@ def check_idle(self) -> float | None: return self.idle_since if self.jupyter: - last_activity = self._jupyter_server_application.web_app.last_activity().timestamp() + last_activity = ( + self._jupyter_server_application.web_app.last_activity().timestamp() + ) if last_activity > self.idle_since: self.idle_since = last_activity return self.idle_since @@ -8380,11 +8766,16 @@ def check_idle(self) -> float | None: "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - self._ongoing_background_tasks.call_soon(self.close, reason="idle-timeout-exceeded") + self._ongoing_background_tasks.call_soon( + self.close, reason="idle-timeout-exceeded" + ) return self.idle_since def _check_no_workers(self) -> None: - if self.status in (Status.closing, Status.closed) or self.no_workers_timeout is None: + if ( + self.status in (Status.closing, Status.closed) + or self.no_workers_timeout is None + ): return now = monotonic() @@ -8394,9 +8785,15 @@ def _check_no_workers(self) -> None: self._refresh_no_workers_since(now) - affected = self._check_unrunnable_task_timeouts(now, recommendations=recommendations, stimulus_id=stimulus_id) + affected = self._check_unrunnable_task_timeouts( + now, recommendations=recommendations, stimulus_id=stimulus_id + ) - affected.update(self._check_queued_task_timeouts(now, recommendations=recommendations, stimulus_id=stimulus_id)) + affected.update( + self._check_queued_task_timeouts( + now, recommendations=recommendations, stimulus_id=stimulus_id + ) + ) self.transitions(recommendations, stimulus_id=stimulus_id) if affected: self.log_event( @@ -8404,7 +8801,9 @@ def _check_no_workers(self) -> None: {"action": "no-workers-timeout-exceeded", "keys": affected}, ) - def _check_unrunnable_task_timeouts(self, timestamp: float, recommendations: Recs, stimulus_id: str) -> set[Key]: + def _check_unrunnable_task_timeouts( + self, timestamp: float, recommendations: Recs, stimulus_id: str + ) -> set[Key]: assert self.no_workers_timeout unsatisfied = [] no_workers = [] @@ -8413,7 +8812,10 @@ def _check_unrunnable_task_timeouts(self, timestamp: float, recommendations: Rec # unrunnable is insertion-ordered, which means that unrunnable_since will # be monotonically increasing in this loop. break - if self._no_workers_since is None or self._no_workers_since >= unrunnable_since: + if ( + self._no_workers_since is None + or self._no_workers_since >= unrunnable_since + ): unsatisfied.append(ts) else: no_workers.append(ts) @@ -8439,13 +8841,18 @@ def _check_unrunnable_task_timeouts(self, timestamp: float, recommendations: Rec ) recommendations.update(r) logger.error( - "Task %s marked as failed because it timed out waiting for its restrictions to become satisfied.", + "Task %s marked as failed because it timed out waiting " + "for its restrictions to become satisfied.", ts.key, ) - self._fail_tasks_after_no_workers_timeout(no_workers, recommendations, stimulus_id) + self._fail_tasks_after_no_workers_timeout( + no_workers, recommendations, stimulus_id + ) return {ts.key for ts in concat([unsatisfied, no_workers])} - def _check_queued_task_timeouts(self, timestamp: float, recommendations: Recs, stimulus_id: str) -> set[Key]: + def _check_queued_task_timeouts( + self, timestamp: float, recommendations: Recs, stimulus_id: str + ) -> set[Key]: assert self.no_workers_timeout if self._no_workers_since is None: @@ -8454,7 +8861,9 @@ def _check_queued_task_timeouts(self, timestamp: float, recommendations: Recs, s if timestamp <= self._no_workers_since + self.no_workers_timeout: return set() affected = list(self.queued) - self._fail_tasks_after_no_workers_timeout(affected, recommendations, stimulus_id) + self._fail_tasks_after_no_workers_timeout( + affected, recommendations, stimulus_id + ) return {ts.key for ts in affected} def _fail_tasks_after_no_workers_timeout( @@ -8478,7 +8887,8 @@ def _fail_tasks_after_no_workers_timeout( ) recommendations.update(r) logger.error( - "Task %s marked as failed because it timed out waiting without any running workers.", + "Task %s marked as failed because it timed out waiting " + "without any running workers.", ts.key, ) @@ -8553,7 +8963,9 @@ def adaptive_target(self, target_duration: float | None = None) -> int: to_close = self.workers_to_close() return len(self.workers) - len(to_close) - def request_acquire_replicas(self, addr: str, keys: Iterable[Key], *, stimulus_id: str) -> None: + def request_acquire_replicas( + self, addr: str, keys: Iterable[Key], *, stimulus_id: str + ) -> None: """Asynchronously ask a worker to acquire a replica of the listed keys from other workers. This is a fire-and-forget operation which offers no feedback for success or failure, and is intended for housekeeping and not for computation. @@ -8575,7 +8987,9 @@ def request_acquire_replicas(self, addr: str, keys: Iterable[Key], *, stimulus_i }, ) - def request_remove_replicas(self, addr: str, keys: list[Key], *, stimulus_id: str) -> None: + def request_remove_replicas( + self, addr: str, keys: list[Key], *, stimulus_id: str + ) -> None: """Asynchronously ask a worker to discard its replica of the listed keys. This must never be used to destroy the last replica of a key. This is a fire-and-forget operation, intended for housekeeping and not for computation. @@ -8759,7 +9173,13 @@ def validate_task_state(ts: TaskState) -> None: if ts.run_spec: # was computed assert ts.type assert isinstance(ts.type, str) - assert not any([ts in dts.waiting_on for dts in ts.dependents if dts.waiting_on is not None]) + assert not any( + [ + ts in dts.waiting_on + for dts in ts.dependents + if dts.waiting_on is not None + ] + ) for ws in ts.who_has: assert ts in ws.has_what, ( "not in who_has' has_what", @@ -8862,7 +9282,9 @@ def heartbeat_interval(n: int) -> float: def _task_slots_available(ws: WorkerState, saturation_factor: float) -> int: """Number of tasks that can be sent to this worker without oversaturating it""" assert not math.isinf(saturation_factor) - return max(math.ceil(saturation_factor * ws.nthreads), 1) - (len(ws.processing) - len(ws.long_running)) + return max(math.ceil(saturation_factor * ws.nthreads), 1) - ( + len(ws.processing) - len(ws.long_running) + ) def _worker_full(ws: WorkerState, saturation_factor: float) -> bool: @@ -8906,7 +9328,9 @@ def __init__( resource_restrictions: dict[str, float], timeout: float, ): - super().__init__(task, host_restrictions, worker_restrictions, resource_restrictions, timeout) + super().__init__( + task, host_restrictions, worker_restrictions, resource_restrictions, timeout + ) @property def task(self) -> Key: From 7a1cef8b09a66b9a1bf86e6e4ce4ca6566433082 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 19:05:37 -0800 Subject: [PATCH 05/23] Update condition.py,test_condition.py --- distributed/condition.py | 3 +- distributed/tests/test_condition.py | 150 +++++++++++++--------------- 2 files changed, 73 insertions(+), 80 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index ad31630815..4d250b7f2a 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -6,8 +6,7 @@ from collections import defaultdict from contextlib import suppress -from distributed.utils import log_errors, wait_for, TimeoutError -from distributed.utils import SyncMethodMixin +from distributed.utils import SyncMethodMixin, log_errors from distributed.worker import get_client logger = logging.getLogger(__name__) diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index ea8e051796..102e629fbf 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -1,16 +1,17 @@ import asyncio + import pytest -from distributed import Condition, Client, wait -from distributed.utils_test import gen_cluster, inc +from distributed import Condition from distributed.metrics import time +from distributed.utils_test import gen_cluster @gen_cluster(client=True) -async def test_condition_acqui re_release(c, s, a, b): +async def test_condition_acquire_release(c, s, a, b): """Test basic lock acquire/release""" condition = Condition("test-lock") - + assert not condition.locked() await condition.acquire() assert condition.locked() @@ -22,7 +23,7 @@ async def test_condition_acqui re_release(c, s, a, b): async def test_condition_context_manager(c, s, a, b): """Test context manager interface""" condition = Condition("test-context") - + assert not condition.locked() async with condition: assert condition.locked() @@ -34,19 +35,19 @@ async def test_condition_wait_notify(c, s, a, b): """Test basic wait/notify""" condition = Condition("test-notify") results = [] - + async def waiter(): async with condition: results.append("waiting") await condition.wait() results.append("notified") - + async def notifier(): await asyncio.sleep(0.2) async with condition: results.append("notifying") condition.notify() - + await asyncio.gather(waiter(), notifier()) assert results == ["waiting", "notifying", "notified"] @@ -56,20 +57,18 @@ async def test_condition_notify_all(c, s, a, b): """Test notify_all wakes all waiters""" condition = Condition("test-notify-all") results = [] - + async def waiter(i): async with condition: await condition.wait() results.append(i) - + async def notifier(): await asyncio.sleep(0.2) async with condition: condition.notify_all() - - await asyncio.gather( - waiter(1), waiter(2), waiter(3), notifier() - ) + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) assert sorted(results) == [1, 2, 3] @@ -78,12 +77,12 @@ async def test_condition_notify_n(c, s, a, b): """Test notify with specific count""" condition = Condition("test-notify-n") results = [] - + async def waiter(i): async with condition: await condition.wait() results.append(i) - + async def notifier(): await asyncio.sleep(0.2) async with condition: @@ -91,10 +90,8 @@ async def notifier(): await asyncio.sleep(0.2) async with condition: condition.notify() # Wake remaining waiter - - await asyncio.gather( - waiter(1), waiter(2), waiter(3), notifier() - ) + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) assert sorted(results) == [1, 2, 3] @@ -102,12 +99,12 @@ async def notifier(): async def test_condition_wait_timeout(c, s, a, b): """Test wait with timeout""" condition = Condition("test-timeout") - + start = time() async with condition: result = await condition.wait(timeout=0.5) elapsed = time() - start - + assert result is False assert 0.4 < elapsed < 0.7 @@ -117,21 +114,21 @@ async def test_condition_wait_timeout_then_notify(c, s, a, b): """Test that timeout doesn't prevent subsequent notifications""" condition = Condition("test-timeout-notify") results = [] - + async def waiter(): async with condition: result = await condition.wait(timeout=0.2) results.append(f"timeout: {result}") - + async with condition: result = await condition.wait() results.append(f"notified: {result}") - + async def notifier(): await asyncio.sleep(0.5) async with condition: condition.notify() - + await asyncio.gather(waiter(), notifier()) assert results == ["timeout: False", "notified: True"] @@ -140,13 +137,13 @@ async def notifier(): async def test_condition_error_without_lock(c, s, a, b): """Test errors when calling wait/notify without holding lock""" condition = Condition("test-error") - + with pytest.raises(RuntimeError, match="without holding the lock"): await condition.wait() - + with pytest.raises(RuntimeError, match="Cannot notify"): await condition.notify() - + with pytest.raises(RuntimeError, match="Cannot notify"): await condition.notify_all() @@ -155,7 +152,7 @@ async def test_condition_error_without_lock(c, s, a, b): async def test_condition_error_release_without_acquire(c, s, a, b): """Test error when releasing without acquiring""" condition = Condition("test-release-error") - + with pytest.raises(RuntimeError, match="Cannot release"): await condition.release() @@ -165,14 +162,14 @@ async def test_condition_producer_consumer(c, s, a, b): """Test classic producer-consumer pattern""" condition = Condition("prod-cons") queue = [] - + async def producer(): for i in range(5): await asyncio.sleep(0.1) async with condition: queue.append(i) condition.notify() - + async def consumer(): results = [] for _ in range(5): @@ -181,13 +178,13 @@ async def consumer(): await condition.wait() results.append(queue.pop(0)) return results - + prod_task = asyncio.create_task(producer()) cons_task = asyncio.create_task(consumer()) - + await prod_task results = await cons_task - + assert results == [0, 1, 2, 3, 4] @@ -196,14 +193,14 @@ async def test_condition_multiple_producers_consumers(c, s, a, b): """Test multiple producers and consumers""" condition = Condition("multi-prod-cons") queue = [] - + async def producer(start): for i in range(start, start + 3): await asyncio.sleep(0.05) async with condition: queue.append(i) condition.notify() - + async def consumer(): results = [] for _ in range(3): @@ -212,12 +209,9 @@ async def consumer(): await condition.wait() results.append(queue.pop(0)) return results - - results = await asyncio.gather( - producer(0), producer(10), - consumer(), consumer() - ) - + + results = await asyncio.gather(producer(0), producer(10), consumer(), consumer()) + # Last two results are from consumers consumed = results[2] + results[3] assert sorted(consumed) == [0, 1, 2, 10, 11, 12] @@ -226,39 +220,43 @@ async def consumer(): @gen_cluster(client=True) async def test_condition_from_worker(c, s, a, b): """Test condition accessed from worker tasks""" + def wait_on_condition(name): + from distributed import Condition - import asyncio - + async def _wait(): condition = Condition(name) async with condition: await condition.wait() return "worker_notified" - + from distributed.worker import get_worker + worker = get_worker() return worker.loop.run_until_complete(_wait()) - + def notify_condition(name): - from distributed import Condition import asyncio - + + from distributed import Condition + async def _notify(): await asyncio.sleep(0.2) condition = Condition(name) async with condition: condition.notify() return "notified" - + from distributed.worker import get_worker + worker = get_worker() return worker.loop.run_until_complete(_notify()) - + name = "worker-condition" f1 = c.submit(wait_on_condition, name, workers=[a.address]) f2 = c.submit(notify_condition, name, workers=[b.address]) - + results = await c.gather([f1, f2]) assert results == ["worker_notified", "notified"] @@ -269,21 +267,21 @@ async def test_condition_same_name_different_instances(c, s, a, b): name = "shared-condition" cond1 = Condition(name) cond2 = Condition(name) - + results = [] - + async def waiter(): async with cond1: results.append("waiting") await cond1.wait() results.append("notified") - + async def notifier(): await asyncio.sleep(0.2) async with cond2: results.append("notifying") cond2.notify() - + await asyncio.gather(waiter(), notifier()) assert results == ["waiting", "notifying", "notified"] @@ -293,11 +291,11 @@ async def test_condition_unique_names_independent(c, s, a, b): """Test conditions with different names are independent""" cond1 = Condition("cond-1") cond2 = Condition("cond-2") - + async with cond1: assert cond1.locked() assert not cond2.locked() - + async with cond2: assert not cond1.locked() assert cond2.locked() @@ -307,15 +305,15 @@ async def test_condition_unique_names_independent(c, s, a, b): async def test_condition_cleanup(c, s, a, b): """Test that condition state is cleaned up after use""" condition = Condition("cleanup-test") - + # Check initial state assert "cleanup-test" not in s.extensions["conditions"]._lock_holders assert "cleanup-test" not in s.extensions["conditions"]._waiters - + # Use condition async with condition: condition.notify() - + # State should be cleaned up await asyncio.sleep(0.1) assert "cleanup-test" not in s.extensions["conditions"]._lock_holders @@ -327,7 +325,7 @@ async def test_condition_barrier_pattern(c, s, a, b): condition = Condition("barrier") arrived = [] n_workers = 3 - + async def worker(i): async with condition: arrived.append(i) @@ -336,11 +334,9 @@ async def worker(i): else: condition.notify_all() return f"worker-{i}-done" - - results = await asyncio.gather( - worker(0), worker(1), worker(2) - ) - + + results = await asyncio.gather(worker(0), worker(1), worker(2)) + assert sorted(results) == ["worker-0-done", "worker-1-done", "worker-2-done"] assert len(arrived) == 3 @@ -349,12 +345,12 @@ def test_condition_sync_interface(client): """Test synchronous interface via SyncMethodMixin""" condition = Condition("sync-test") results = [] - + def worker(): with condition: results.append("locked") results.append("released") - + worker() assert results == ["locked", "released"] @@ -364,12 +360,12 @@ async def test_condition_multiple_notify_calls(c, s, a, b): """Test multiple notify calls in sequence""" condition = Condition("multi-notify") results = [] - + async def waiter(i): async with condition: await condition.wait() results.append(i) - + async def notifier(): await asyncio.sleep(0.2) async with condition: @@ -380,10 +376,8 @@ async def notifier(): await asyncio.sleep(0.1) async with condition: condition.notify() - - await asyncio.gather( - waiter(1), waiter(2), waiter(3), notifier() - ) + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) assert sorted(results) == [1, 2, 3] @@ -392,20 +386,20 @@ async def test_condition_predicate_loop(c, s, a, b): """Test typical predicate-based wait loop pattern""" condition = Condition("predicate") state = {"value": 0, "target": 5} - + async def waiter(): async with condition: while state["value"] < state["target"]: await condition.wait() return state["value"] - + async def updater(): for i in range(1, 6): await asyncio.sleep(0.1) async with condition: state["value"] = i condition.notify_all() - + result, _ = await asyncio.gather(waiter(), updater()) assert result == 5 From fda7dce00582296f9d94fcdbe3b0042fb78a19d6 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 19:09:19 -0800 Subject: [PATCH 06/23] Update scheduler.py --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f2ad5080a8..07951263f4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -97,6 +97,7 @@ ) from distributed.comm.addressing import addresses_from_user_args from distributed.compatibility import PeriodicCallback +from distributed.condition import ConditionExtension from distributed.core import ( ErrorMessage, OKMessage, @@ -144,7 +145,6 @@ scatter_to_workers, ) from distributed.variable import VariableExtension -from distributed.condition import ConditionExtension if TYPE_CHECKING: from typing import TypeAlias, TypeVar From e04d5da7de2a18061fc62d752db8a25d3776d7ba Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 19:23:33 -0800 Subject: [PATCH 07/23] Update condition.py --- distributed/condition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/condition.py b/distributed/condition.py index 4d250b7f2a..24d9a8a022 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -158,7 +158,9 @@ async def wait(self, timeout=None): raise RuntimeError("Cannot wait on un-acquired condition") scheduler = self._get_scheduler_rpc() - result = await scheduler.condition_wait(name=self.name, id=self.id, timeout=timeout) + result = await scheduler.condition_wait( + name=self.name, id=self.id, timeout=timeout + ) return result async def notify(self, n=1): From f777c64313d750e047d60f679627d202ee841c2b Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 14 Dec 2025 13:28:30 -0800 Subject: [PATCH 08/23] Update condition.py --- distributed/condition.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 24d9a8a022..530cc169d4 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -66,14 +66,12 @@ async def wait(self, name=None, id=None, timeout=None): return True except asyncio.TimeoutError: return False + except asyncio.CancelledError: + raise finally: self._waiters[name].discard(id) - # Cleanup if no waiters if not self._waiters[name]: - with suppress(KeyError): - del self._waiters[name] - with suppress(KeyError): - del self._conditions[name] + del self._waiters[name] @log_errors def notify(self, name=None, n=1): @@ -158,9 +156,7 @@ async def wait(self, timeout=None): raise RuntimeError("Cannot wait on un-acquired condition") scheduler = self._get_scheduler_rpc() - result = await scheduler.condition_wait( - name=self.name, id=self.id, timeout=timeout - ) + result = await scheduler.condition_wait(name=self.name, id=self.id, timeout=timeout) return result async def notify(self, n=1): From 585b6770886474de0835a2da53f82ce8b16c670f Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 14 Dec 2025 13:33:00 -0800 Subject: [PATCH 09/23] Update condition.py --- distributed/condition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/condition.py b/distributed/condition.py index 530cc169d4..2ccf9a7699 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -156,7 +156,9 @@ async def wait(self, timeout=None): raise RuntimeError("Cannot wait on un-acquired condition") scheduler = self._get_scheduler_rpc() - result = await scheduler.condition_wait(name=self.name, id=self.id, timeout=timeout) + result = await scheduler.condition_wait( + name=self.name, id=self.id, timeout=timeout + ) return result async def notify(self, n=1): From 332c4d9fc945c5f49413c9be5becfd938dbc32ab Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 14 Dec 2025 13:35:20 -0800 Subject: [PATCH 10/23] Update condition.py --- distributed/condition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/condition.py b/distributed/condition.py index 2ccf9a7699..798056d653 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -4,7 +4,6 @@ import logging import uuid from collections import defaultdict -from contextlib import suppress from distributed.utils import SyncMethodMixin, log_errors from distributed.worker import get_client From 59badc62ad34d5656005ebba545b45bbbc3b5fb6 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 14 Dec 2025 14:31:46 -0800 Subject: [PATCH 11/23] Update condition.py --- distributed/condition.py | 198 ++++++++++++++++++++++----------------- 1 file changed, 110 insertions(+), 88 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 798056d653..6aed6ab4ee 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -3,9 +3,10 @@ import asyncio import logging import uuid -from collections import defaultdict -from distributed.utils import SyncMethodMixin, log_errors +from dask.utils import parse_timedelta + +from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for from distributed.worker import get_client logger = logging.getLogger(__name__) @@ -16,10 +17,9 @@ class ConditionExtension: def __init__(self, scheduler): self.scheduler = scheduler - # {condition_name: asyncio.Condition} - self._conditions = {} - # {condition_name: set of waiter_ids} - self._waiters = defaultdict(set) + self._locks = {} # {name: asyncio.Lock} + self._lock_holders = {} # {name: client_id} + self._waiters = {} # {name: {waiter_id: asyncio.Event}} self.scheduler.handlers.update( { @@ -30,152 +30,171 @@ def __init__(self, scheduler): } ) - def _get_condition(self, name): - if name not in self._conditions: - self._conditions[name] = asyncio.Condition() - return self._conditions[name] + def _get_lock(self, name): + if name not in self._locks: + self._locks[name] = asyncio.Lock() + return self._locks[name] @log_errors async def acquire(self, name=None, id=None): - """Acquire the underlying lock""" - condition = self._get_condition(name) - await condition.acquire() + lock = self._get_lock(name) + await lock.acquire() + self._lock_holders[name] = id return True @log_errors async def release(self, name=None, id=None): - """Release the underlying lock""" - if name not in self._conditions: + if self._lock_holders.get(name) != id: return False - condition = self._conditions[name] - condition.release() + + lock = self._locks[name] + lock.release() + del self._lock_holders[name] + + # Cleanup if no waiters + if name not in self._waiters or not self._waiters[name]: + del self._locks[name] + return True @log_errors async def wait(self, name=None, id=None, timeout=None): - """Wait on condition""" - condition = self._get_condition(name) - self._waiters[name].add(id) + # Verify lock is held by this client + if self._lock_holders.get(name) != id: + raise RuntimeError("wait() called without holding the lock") + + lock = self._locks[name] + + # Create event for this waiter + if name not in self._waiters: + self._waiters[name] = {} + event = asyncio.Event() + self._waiters[name][id] = event + + # Release lock + lock.release() + del self._lock_holders[name] + + # Wait on event + future = event.wait() + if timeout is not None: + future = wait_for(future, timeout) try: - if timeout: - await asyncio.wait_for(condition.wait(), timeout=timeout) - else: - await condition.wait() - return True - except asyncio.TimeoutError: - return False - except asyncio.CancelledError: - raise + await future + result = True + except TimeoutError: + result = False finally: - self._waiters[name].discard(id) + # Cleanup waiter + self._waiters[name].pop(id, None) if not self._waiters[name]: del self._waiters[name] + # Reacquire lock + await lock.acquire() + self._lock_holders[name] = id + + return result + @log_errors def notify(self, name=None, n=1): - """Notify n waiters""" - if name not in self._conditions: - return 0 - condition = self._conditions[name] - condition.notify(n=n) - return min(n, len(self._waiters.get(name, []))) + if self._lock_holders.get(name) is None: + raise RuntimeError("notify() called without holding the lock") + + waiters = self._waiters.get(name, {}) + count = 0 + for event in list(waiters.values())[:n]: + event.set() + count += 1 + return count @log_errors def notify_all(self, name=None): - """Notify all waiters""" - if name not in self._conditions: - return 0 - condition = self._conditions[name] - count = len(self._waiters.get(name, [])) - condition.notify_all() - return count + if self._lock_holders.get(name) is None: + raise RuntimeError("notify_all() called without holding the lock") + + waiters = self._waiters.get(name, {}) + for event in waiters.values(): + event.set() + return len(waiters) class Condition(SyncMethodMixin): """Distributed Condition Variable - Mimics asyncio.Condition API. Allows coordination between - distributed workers using wait/notify pattern. + Parameters + ---------- + name: str, optional + Name of the condition. Same name = shared state. + client: Client, optional + Client for scheduler communication. Examples -------- - >>> from distributed import Condition >>> condition = Condition('my-condition') >>> async with condition: - ... await condition.wait() # Wait for notification - - >>> # In another worker/client - >>> condition = Condition('my-condition') - >>> async with condition: - ... condition.notify() # Wake one waiter + ... await condition.wait() """ - def __init__(self, name=None, scheduler_rpc=None, loop=None): - self._scheduler = scheduler_rpc - self._loop = loop + def __init__(self, name=None, client=None): + self._client = client self.name = name or f"condition-{uuid.uuid4().hex}" self.id = uuid.uuid4().hex self._locked = False - def _get_scheduler_rpc(self): - if self._scheduler: - return self._scheduler - try: - client = get_client() - return client.scheduler - except ValueError: - from distributed.worker import get_worker + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client - worker = get_worker() - return worker.scheduler + def _verify_running(self): + if not self.client: + raise RuntimeError(f"{type(self)} object not properly initialized.") async def acquire(self): - """Acquire underlying lock""" - scheduler = self._get_scheduler_rpc() - result = await scheduler.condition_acquire(name=self.name, id=self.id) + self._verify_running() + result = await self.client.scheduler.condition_acquire( + name=self.name, id=self.id + ) self._locked = result return result async def release(self): - """Release underlying lock""" if not self._locked: raise RuntimeError("Cannot release un-acquired lock") - scheduler = self._get_scheduler_rpc() - await scheduler.condition_release(name=self.name, id=self.id) + self._verify_running() + await self.client.scheduler.condition_release(name=self.name, id=self.id) self._locked = False async def wait(self, timeout=None): - """Wait until notified - - Must be called while lock is held. Releases lock and waits - for notify(), then reacquires lock before returning. - """ if not self._locked: - raise RuntimeError("Cannot wait on un-acquired condition") + raise RuntimeError("wait() called without holding the lock") - scheduler = self._get_scheduler_rpc() - result = await scheduler.condition_wait( + self._verify_running() + timeout = parse_timedelta(timeout) + result = await self.client.scheduler.condition_wait( name=self.name, id=self.id, timeout=timeout ) return result async def notify(self, n=1): - """Wake up one or more waiters""" if not self._locked: - raise RuntimeError("Cannot notify on un-acquired condition") - scheduler = self._get_scheduler_rpc() - return await scheduler.condition_notify(name=self.name, n=n) + raise RuntimeError("Cannot notify without holding the lock") + self._verify_running() + return await self.client.scheduler.condition_notify(name=self.name, n=n) async def notify_all(self): - """Wake up all waiters""" if not self._locked: - raise RuntimeError("Cannot notify on un-acquired condition") - scheduler = self._get_scheduler_rpc() - return await scheduler.condition_notify_all(name=self.name) + raise RuntimeError("Cannot notify without holding the lock") + self._verify_running() + return await self.client.scheduler.condition_notify_all(name=self.name) def locked(self): - """Return True if lock is held""" return self._locked async def __aenter__(self): @@ -193,3 +212,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): def __repr__(self): return f"" + + def __reduce__(self): + return (Condition, (self.name,)) From a48d28f2b2e6796d3fccd9db7d690878889b26da Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 14 Dec 2025 15:02:07 -0800 Subject: [PATCH 12/23] Update condition.py,test_condition.py --- distributed/condition.py | 80 ++++++++++++++++++++++++++--- distributed/tests/test_condition.py | 45 +++++----------- 2 files changed, 84 insertions(+), 41 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 6aed6ab4ee..1a46a52341 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -27,6 +27,7 @@ def __init__(self, scheduler): "condition_notify": self.notify, "condition_acquire": self.acquire, "condition_release": self.release, + "condition_notify_all": self.notify_all, } ) @@ -37,6 +38,7 @@ def _get_lock(self, name): @log_errors async def acquire(self, name=None, id=None): + """Acquire the underlying lock""" lock = self._get_lock(name) await lock.acquire() self._lock_holders[name] = id @@ -44,6 +46,7 @@ async def acquire(self, name=None, id=None): @log_errors async def release(self, name=None, id=None): + """Release the underlying lock""" if self._lock_holders.get(name) != id: return False @@ -59,6 +62,7 @@ async def release(self, name=None, id=None): @log_errors async def wait(self, name=None, id=None, timeout=None): + """Wait on condition""" # Verify lock is held by this client if self._lock_holders.get(name) != id: raise RuntimeError("wait() called without holding the lock") @@ -99,6 +103,7 @@ async def wait(self, name=None, id=None, timeout=None): @log_errors def notify(self, name=None, n=1): + """Notify n waiters""" if self._lock_holders.get(name) is None: raise RuntimeError("notify() called without holding the lock") @@ -111,6 +116,7 @@ def notify(self, name=None, n=1): @log_errors def notify_all(self, name=None): + """Notify all waiters""" if self._lock_holders.get(name) is None: raise RuntimeError("notify_all() called without holding the lock") @@ -123,18 +129,27 @@ def notify_all(self, name=None): class Condition(SyncMethodMixin): """Distributed Condition Variable + Mimics asyncio.Condition API. Allows coordination between + distributed workers using wait/notify pattern. + Parameters ---------- - name: str, optional + name : str, optional Name of the condition. Same name = shared state. - client: Client, optional + client : Client, optional Client for scheduler communication. Examples -------- + >>> from distributed import Condition + >>> condition = Condition('my-condition') + >>> async with condition: + ... await condition.wait() # Wait for notification + + >>> # In another worker/client >>> condition = Condition('my-condition') >>> async with condition: - ... await condition.wait() + ... condition.notify() # Wake one waiter """ def __init__(self, name=None, client=None): @@ -152,11 +167,20 @@ def client(self): pass return self._client + @property + def loop(self): + return self.client.loop if self.client else None + def _verify_running(self): if not self.client: - raise RuntimeError(f"{type(self)} object not properly initialized.") + raise RuntimeError( + f"{type(self)} object not properly initialized. This can happen" + " if the object is being deserialized outside of the context of" + " a Client or Worker." + ) async def acquire(self): + """Acquire underlying lock""" self._verify_running() result = await self.client.scheduler.condition_acquire( name=self.name, id=self.id @@ -165,6 +189,7 @@ async def acquire(self): return result async def release(self): + """Release underlying lock""" if not self._locked: raise RuntimeError("Cannot release un-acquired lock") self._verify_running() @@ -172,6 +197,21 @@ async def release(self): self._locked = False async def wait(self, timeout=None): + """Wait until notified + + Must be called while lock is held. Releases lock and waits + for notify(), then reacquires lock before returning. + + Parameters + ---------- + timeout : number or string or timedelta, optional + Seconds to wait on the condition in the scheduler. + + Returns + ------- + bool + True if notified, False if timeout occurred + """ if not self._locked: raise RuntimeError("wait() called without holding the lock") @@ -182,19 +222,43 @@ async def wait(self, timeout=None): ) return result - async def notify(self, n=1): + def notify(self, n=1): + """Wake up one or more waiters + + Parameters + ---------- + n : int, optional + Number of waiters to wake. Default is 1. + + Returns + ------- + int + Number of waiters notified + """ if not self._locked: raise RuntimeError("Cannot notify without holding the lock") self._verify_running() - return await self.client.scheduler.condition_notify(name=self.name, n=n) + return self.client.sync( + self.client.scheduler.condition_notify, name=self.name, n=n + ) + + def notify_all(self): + """Wake up all waiters - async def notify_all(self): + Returns + ------- + int + Number of waiters notified + """ if not self._locked: raise RuntimeError("Cannot notify without holding the lock") self._verify_running() - return await self.client.scheduler.condition_notify_all(name=self.name) + return self.client.sync( + self.client.scheduler.condition_notify_all, name=self.name + ) def locked(self): + """Return True if lock is held""" return self._locked async def __aenter__(self): diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index 102e629fbf..d1ebde26bd 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -11,7 +11,6 @@ async def test_condition_acquire_release(c, s, a, b): """Test basic lock acquire/release""" condition = Condition("test-lock") - assert not condition.locked() await condition.acquire() assert condition.locked() @@ -23,7 +22,6 @@ async def test_condition_acquire_release(c, s, a, b): async def test_condition_context_manager(c, s, a, b): """Test context manager interface""" condition = Condition("test-context") - assert not condition.locked() async with condition: assert condition.locked() @@ -119,7 +117,6 @@ async def waiter(): async with condition: result = await condition.wait(timeout=0.2) results.append(f"timeout: {result}") - async with condition: result = await condition.wait() results.append(f"notified: {result}") @@ -142,10 +139,10 @@ async def test_condition_error_without_lock(c, s, a, b): await condition.wait() with pytest.raises(RuntimeError, match="Cannot notify"): - await condition.notify() + condition.notify() with pytest.raises(RuntimeError, match="Cannot notify"): - await condition.notify_all() + condition.notify_all() @gen_cluster(client=True) @@ -184,7 +181,6 @@ async def consumer(): await prod_task results = await cons_task - assert results == [0, 1, 2, 3, 4] @@ -211,7 +207,6 @@ async def consumer(): return results results = await asyncio.gather(producer(0), producer(10), consumer(), consumer()) - # Last two results are from consumers consumed = results[2] + results[3] assert sorted(consumed) == [0, 1, 2, 10, 11, 12] @@ -222,41 +217,27 @@ async def test_condition_from_worker(c, s, a, b): """Test condition accessed from worker tasks""" def wait_on_condition(name): - from distributed import Condition - async def _wait(): - condition = Condition(name) - async with condition: - await condition.wait() - return "worker_notified" - - from distributed.worker import get_worker - - worker = get_worker() - return worker.loop.run_until_complete(_wait()) + condition = Condition(name) + with condition: + condition.wait() + return "worker_notified" def notify_condition(name): - import asyncio + import time from distributed import Condition - async def _notify(): - await asyncio.sleep(0.2) - condition = Condition(name) - async with condition: - condition.notify() - return "notified" - - from distributed.worker import get_worker - - worker = get_worker() - return worker.loop.run_until_complete(_notify()) + time.sleep(0.2) + condition = Condition(name) + with condition: + condition.notify() + return "notified" name = "worker-condition" f1 = c.submit(wait_on_condition, name, workers=[a.address]) f2 = c.submit(notify_condition, name, workers=[b.address]) - results = await c.gather([f1, f2]) assert results == ["worker_notified", "notified"] @@ -267,7 +248,6 @@ async def test_condition_same_name_different_instances(c, s, a, b): name = "shared-condition" cond1 = Condition(name) cond2 = Condition(name) - results = [] async def waiter(): @@ -336,7 +316,6 @@ async def worker(i): return f"worker-{i}-done" results = await asyncio.gather(worker(0), worker(1), worker(2)) - assert sorted(results) == ["worker-0-done", "worker-1-done", "worker-2-done"] assert len(arrived) == 3 From d8f98d56382ca4272166580e29379bb1edb60ecf Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 14 Dec 2025 16:25:08 -0800 Subject: [PATCH 13/23] Update condition.py --- distributed/condition.py | 141 ++++++++++++--------------------------- 1 file changed, 41 insertions(+), 100 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 1a46a52341..4e730c3c4b 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -3,9 +3,11 @@ import asyncio import logging import uuid +from collections import defaultdict from dask.utils import parse_timedelta +from distributed.semaphore import Semaphore from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for from distributed.worker import get_client @@ -13,72 +15,36 @@ class ConditionExtension: - """Scheduler extension for managing distributed Conditions""" + """Scheduler extension for managing Condition variable notifications + + This extension only handles wait/notify coordination. + The underlying lock is a Semaphore managed by SemaphoreExtension. + """ def __init__(self, scheduler): self.scheduler = scheduler - self._locks = {} # {name: asyncio.Lock} - self._lock_holders = {} # {name: client_id} - self._waiters = {} # {name: {waiter_id: asyncio.Event}} + # {condition_name: {waiter_id: asyncio.Event}} + self._waiters = defaultdict(dict) self.scheduler.handlers.update( { "condition_wait": self.wait, "condition_notify": self.notify, - "condition_acquire": self.acquire, - "condition_release": self.release, "condition_notify_all": self.notify_all, } ) - def _get_lock(self, name): - if name not in self._locks: - self._locks[name] = asyncio.Lock() - return self._locks[name] - - @log_errors - async def acquire(self, name=None, id=None): - """Acquire the underlying lock""" - lock = self._get_lock(name) - await lock.acquire() - self._lock_holders[name] = id - return True - - @log_errors - async def release(self, name=None, id=None): - """Release the underlying lock""" - if self._lock_holders.get(name) != id: - return False - - lock = self._locks[name] - lock.release() - del self._lock_holders[name] - - # Cleanup if no waiters - if name not in self._waiters or not self._waiters[name]: - del self._locks[name] - - return True - @log_errors async def wait(self, name=None, id=None, timeout=None): - """Wait on condition""" - # Verify lock is held by this client - if self._lock_holders.get(name) != id: - raise RuntimeError("wait() called without holding the lock") - - lock = self._locks[name] + """Wait to be notified + Caller must already hold the lock (Semaphore lease). + This only manages the wait/notify Events. + """ # Create event for this waiter - if name not in self._waiters: - self._waiters[name] = {} event = asyncio.Event() self._waiters[name][id] = event - # Release lock - lock.release() - del self._lock_holders[name] - # Wait on event future = event.wait() if timeout is not None: @@ -95,18 +61,11 @@ async def wait(self, name=None, id=None, timeout=None): if not self._waiters[name]: del self._waiters[name] - # Reacquire lock - await lock.acquire() - self._lock_holders[name] = id - return result @log_errors def notify(self, name=None, n=1): """Notify n waiters""" - if self._lock_holders.get(name) is None: - raise RuntimeError("notify() called without holding the lock") - waiters = self._waiters.get(name, {}) count = 0 for event in list(waiters.values())[:n]: @@ -117,9 +76,6 @@ def notify(self, name=None, n=1): @log_errors def notify_all(self, name=None): """Notify all waiters""" - if self._lock_holders.get(name) is None: - raise RuntimeError("notify_all() called without holding the lock") - waiters = self._waiters.get(name, {}) for event in waiters.values(): event.set() @@ -129,8 +85,7 @@ def notify_all(self, name=None): class Condition(SyncMethodMixin): """Distributed Condition Variable - Mimics asyncio.Condition API. Allows coordination between - distributed workers using wait/notify pattern. + Combines a Semaphore (lock) with wait/notify coordination. Parameters ---------- @@ -144,19 +99,20 @@ class Condition(SyncMethodMixin): >>> from distributed import Condition >>> condition = Condition('my-condition') >>> async with condition: - ... await condition.wait() # Wait for notification + ... await condition.wait() >>> # In another worker/client >>> condition = Condition('my-condition') >>> async with condition: - ... condition.notify() # Wake one waiter + ... condition.notify() """ def __init__(self, name=None, client=None): - self._client = client self.name = name or f"condition-{uuid.uuid4().hex}" self.id = uuid.uuid4().hex - self._locked = False + # Use Semaphore(max_leases=1) as the underlying lock + self._lock = Semaphore(max_leases=1, name=f"{self.name}-lock") + self._client = client @property def client(self): @@ -169,7 +125,7 @@ def client(self): @property def loop(self): - return self.client.loop if self.client else None + return self._lock.loop def _verify_running(self): if not self.client: @@ -181,20 +137,12 @@ def _verify_running(self): async def acquire(self): """Acquire underlying lock""" - self._verify_running() - result = await self.client.scheduler.condition_acquire( - name=self.name, id=self.id - ) - self._locked = result + result = await self._lock.acquire() return result async def release(self): """Release underlying lock""" - if not self._locked: - raise RuntimeError("Cannot release un-acquired lock") - self._verify_running() - await self.client.scheduler.condition_release(name=self.name, id=self.id) - self._locked = False + await self._lock.release() async def wait(self, timeout=None): """Wait until notified @@ -212,30 +160,29 @@ async def wait(self, timeout=None): bool True if notified, False if timeout occurred """ - if not self._locked: + if not self._lock.locked(): raise RuntimeError("wait() called without holding the lock") self._verify_running() timeout = parse_timedelta(timeout) - result = await self.client.scheduler.condition_wait( - name=self.name, id=self.id, timeout=timeout - ) - return result - def notify(self, n=1): - """Wake up one or more waiters + # Release lock + await self._lock.release() - Parameters - ---------- - n : int, optional - Number of waiters to wake. Default is 1. + # Wait for notification + try: + result = await self.client.scheduler.condition_wait( + name=self.name, id=self.id, timeout=timeout + ) + finally: + # Reacquire lock + await self._lock.acquire() - Returns - ------- - int - Number of waiters notified - """ - if not self._locked: + return result + + def notify(self, n=1): + """Wake up one or more waiters""" + if not self._lock.locked(): raise RuntimeError("Cannot notify without holding the lock") self._verify_running() return self.client.sync( @@ -243,14 +190,8 @@ def notify(self, n=1): ) def notify_all(self): - """Wake up all waiters - - Returns - ------- - int - Number of waiters notified - """ - if not self._locked: + """Wake up all waiters""" + if not self._lock.locked(): raise RuntimeError("Cannot notify without holding the lock") self._verify_running() return self.client.sync( @@ -259,7 +200,7 @@ def notify_all(self): def locked(self): """Return True if lock is held""" - return self._locked + return self._lock.locked() async def __aenter__(self): await self.acquire() From d2e2af962d929329078c5cff972205b5778957a0 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sat, 20 Dec 2025 23:32:43 -0800 Subject: [PATCH 14/23] Update condition.py,test_condition.py --- distributed/condition.py | 95 ++++++++++++++++++----------- distributed/tests/test_condition.py | 9 ++- 2 files changed, 62 insertions(+), 42 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 4e730c3c4b..213c4837a2 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -7,7 +7,7 @@ from dask.utils import parse_timedelta -from distributed.semaphore import Semaphore +from distributed.lock import Lock from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for from distributed.worker import get_client @@ -17,8 +17,8 @@ class ConditionExtension: """Scheduler extension for managing Condition variable notifications - This extension only handles wait/notify coordination. - The underlying lock is a Semaphore managed by SemaphoreExtension. + Coordinates wait/notify between distributed clients. + The lock itself is managed by LockExtension. """ def __init__(self, scheduler): @@ -36,16 +36,13 @@ def __init__(self, scheduler): @log_errors async def wait(self, name=None, id=None, timeout=None): - """Wait to be notified + """Register waiter and block until notified - Caller must already hold the lock (Semaphore lease). - This only manages the wait/notify Events. + Caller must have released the lock before calling this. """ - # Create event for this waiter event = asyncio.Event() self._waiters[name][id] = event - # Wait on event future = event.wait() if timeout is not None: future = wait_for(future, timeout) @@ -56,7 +53,6 @@ async def wait(self, name=None, id=None, timeout=None): except TimeoutError: result = False finally: - # Cleanup waiter self._waiters[name].pop(id, None) if not self._waiters[name]: del self._waiters[name] @@ -65,7 +61,7 @@ async def wait(self, name=None, id=None, timeout=None): @log_errors def notify(self, name=None, n=1): - """Notify n waiters""" + """Wake up n waiters""" waiters = self._waiters.get(name, {}) count = 0 for event in list(waiters.values())[:n]: @@ -75,7 +71,7 @@ def notify(self, name=None, n=1): @log_errors def notify_all(self, name=None): - """Notify all waiters""" + """Wake up all waiters""" waiters = self._waiters.get(name, {}) for event in waiters.values(): event.set() @@ -85,33 +81,36 @@ def notify_all(self, name=None): class Condition(SyncMethodMixin): """Distributed Condition Variable - Combines a Semaphore (lock) with wait/notify coordination. + Combines a Lock with wait/notify coordination across the cluster. Parameters ---------- name : str, optional - Name of the condition. Same name = shared state. + Name of the condition. Conditions with the same name share state. client : Client, optional Client for scheduler communication. Examples -------- - >>> from distributed import Condition - >>> condition = Condition('my-condition') + Producer-consumer pattern: + + >>> condition = Condition('data-ready') + >>> # Consumer >>> async with condition: - ... await condition.wait() + ... while not data_available(): + ... await condition.wait() + ... process_data() - >>> # In another worker/client - >>> condition = Condition('my-condition') + >>> # Producer >>> async with condition: - ... condition.notify() + ... produce_data() + ... condition.notify_all() """ def __init__(self, name=None, client=None): self.name = name or f"condition-{uuid.uuid4().hex}" self.id = uuid.uuid4().hex - # Use Semaphore(max_leases=1) as the underlying lock - self._lock = Semaphore(max_leases=1, name=f"{self.name}-lock") + self._lock = Lock(name=f"{self.name}-lock") self._client = client @property @@ -136,29 +135,33 @@ def _verify_running(self): ) async def acquire(self): - """Acquire underlying lock""" - result = await self._lock.acquire() - return result + """Acquire the underlying lock""" + return await self._lock.acquire() async def release(self): - """Release underlying lock""" + """Release the underlying lock""" await self._lock.release() async def wait(self, timeout=None): """Wait until notified - Must be called while lock is held. Releases lock and waits - for notify(), then reacquires lock before returning. + Must be called while lock is held. Atomically releases lock, + waits for notify(), then reacquires lock before returning. Parameters ---------- timeout : number or string or timedelta, optional - Seconds to wait on the condition in the scheduler. + Maximum time to wait for notification. Returns ------- bool True if notified, False if timeout occurred + + Raises + ------ + RuntimeError + If called without holding the lock """ if not self._lock.locked(): raise RuntimeError("wait() called without holding the lock") @@ -166,40 +169,58 @@ async def wait(self, timeout=None): self._verify_running() timeout = parse_timedelta(timeout) - # Release lock + # Atomically: release lock, wait for notify, reacquire lock await self._lock.release() - - # Wait for notification try: result = await self.client.scheduler.condition_wait( name=self.name, id=self.id, timeout=timeout ) finally: - # Reacquire lock await self._lock.acquire() return result def notify(self, n=1): - """Wake up one or more waiters""" + """Wake up one or more waiters + + Must be called while holding the lock. + + Parameters + ---------- + n : int, optional + Number of waiters to wake. Default is 1. + + Returns + ------- + int + Number of waiters actually notified + """ if not self._lock.locked(): - raise RuntimeError("Cannot notify without holding the lock") + raise RuntimeError("notify() called without holding the lock") self._verify_running() return self.client.sync( self.client.scheduler.condition_notify, name=self.name, n=n ) def notify_all(self): - """Wake up all waiters""" + """Wake up all waiters + + Must be called while holding the lock. + + Returns + ------- + int + Number of waiters notified + """ if not self._lock.locked(): - raise RuntimeError("Cannot notify without holding the lock") + raise RuntimeError("notify_all() called without holding the lock") self._verify_running() return self.client.sync( self.client.scheduler.condition_notify_all, name=self.name ) def locked(self): - """Return True if lock is held""" + """Return True if the lock is currently held""" return self._lock.locked() async def __aenter__(self): diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index d1ebde26bd..67b03cbca8 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -150,7 +150,7 @@ async def test_condition_error_release_without_acquire(c, s, a, b): """Test error when releasing without acquiring""" condition = Condition("test-release-error") - with pytest.raises(RuntimeError, match="Cannot release"): + with pytest.raises(RuntimeError, match="Released too often"): await condition.release() @@ -286,17 +286,16 @@ async def test_condition_cleanup(c, s, a, b): """Test that condition state is cleaned up after use""" condition = Condition("cleanup-test") - # Check initial state - assert "cleanup-test" not in s.extensions["conditions"]._lock_holders + # Check initial state - only check waiters since locks are managed by LockExtension assert "cleanup-test" not in s.extensions["conditions"]._waiters # Use condition async with condition: condition.notify() - # State should be cleaned up + # Waiter state should be cleaned up await asyncio.sleep(0.1) - assert "cleanup-test" not in s.extensions["conditions"]._lock_holders + assert "cleanup-test" not in s.extensions["conditions"]._waiters @gen_cluster(client=True) From ea75275b6e7462b72a2eb76157bada7736095961 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 21 Dec 2025 00:05:56 -0800 Subject: [PATCH 15/23] Update --- continuous_integration/scripts/test_report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/continuous_integration/scripts/test_report.py b/continuous_integration/scripts/test_report.py index 1f749a78fd..1f339ed7ef 100644 --- a/continuous_integration/scripts/test_report.py +++ b/continuous_integration/scripts/test_report.py @@ -489,7 +489,7 @@ def main(argv: list[str] | None = None) -> None: total.groupby([total.file, total.test]) .filter(lambda g: (g.status == "x").sum() >= args.nfails) .reset_index() - .assign(test=lambda df: df.file + "." + df.test) # type: ignore + .assign(test=lambda df: df.file + "." + df.test) .groupby("test") ) overall = {name: grouped.get_group(name) for name in grouped.groups} From f2773f20e1c1fce58766ecb4eafd2c244de6fe6b Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 21 Dec 2025 11:48:11 -0800 Subject: [PATCH 16/23] Update condition.py,test_condition.py --- distributed/condition.py | 242 ++++++++++++++-------------- distributed/tests/test_condition.py | 10 +- 2 files changed, 127 insertions(+), 125 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 213c4837a2..1731d0b668 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -3,11 +3,10 @@ import asyncio import logging import uuid -from collections import defaultdict +from collections import defaultdict, deque from dask.utils import parse_timedelta -from distributed.lock import Lock from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for from distributed.worker import get_client @@ -15,19 +14,21 @@ class ConditionExtension: - """Scheduler extension for managing Condition variable notifications - - Coordinates wait/notify between distributed clients. - The lock itself is managed by LockExtension. - """ + """Scheduler extension managing Condition lock and notifications""" def __init__(self, scheduler): self.scheduler = scheduler - # {condition_name: {waiter_id: asyncio.Event}} - self._waiters = defaultdict(dict) + # {condition_name: client_id} - who holds each lock + self._lock_holders = {} + # {condition_name: deque of (client_id, future)} - waiting to acquire + self._acquire_waiters = defaultdict(deque) + # {condition_name: {waiter_id: (client_id, Event)}} - waiting for notify + self._notify_waiters = defaultdict(dict) self.scheduler.handlers.update( { + "condition_acquire": self.acquire, + "condition_release": self.release, "condition_wait": self.wait, "condition_notify": self.notify, "condition_notify_all": self.notify_all, @@ -35,14 +36,58 @@ def __init__(self, scheduler): ) @log_errors - async def wait(self, name=None, id=None, timeout=None): - """Register waiter and block until notified + async def acquire(self, name=None, client_id=None): + """Acquire lock - blocks until available""" + if name not in self._lock_holders: + # Lock is free + self._lock_holders[name] = client_id + return True + + if self._lock_holders[name] == client_id: + # Already hold it (shouldn't happen in normal use) + return True + + # Lock is held by someone else - wait our turn + future = asyncio.Future() + self._acquire_waiters[name].append((client_id, future)) + await future + return True + + @log_errors + async def release(self, name=None, client_id=None): + """Release lock""" + if name not in self._lock_holders: + raise RuntimeError("Released too often") + + if self._lock_holders[name] != client_id: + raise RuntimeError("Cannot release lock held by another client") - Caller must have released the lock before calling this. - """ + del self._lock_holders[name] + + # Wake next waiter if any + waiters = self._acquire_waiters.get(name, deque()) + while waiters: + next_client_id, future = waiters.popleft() + if not future.done(): + self._lock_holders[name] = next_client_id + future.set_result(True) + break + + @log_errors + async def wait(self, name=None, waiter_id=None, client_id=None, timeout=None): + """Release lock, wait for notify, reacquire lock""" + # Verify caller holds lock + if self._lock_holders.get(name) != client_id: + raise RuntimeError("wait() called without holding the lock") + + # Release lock (waking next acquire waiter if any) + await self.release(name=name, client_id=client_id) + + # Register as notify waiter event = asyncio.Event() - self._waiters[name][id] = event + self._notify_waiters[name][waiter_id] = (client_id, event) + # Wait for notification future = event.wait() if timeout is not None: future = wait_for(future, timeout) @@ -53,65 +98,52 @@ async def wait(self, name=None, id=None, timeout=None): except TimeoutError: result = False finally: - self._waiters[name].pop(id, None) - if not self._waiters[name]: - del self._waiters[name] + # Cleanup waiter + self._notify_waiters[name].pop(waiter_id, None) + if not self._notify_waiters[name]: + del self._notify_waiters[name] + + # Reacquire lock - blocks until available + await self.acquire(name=name, client_id=client_id) return result @log_errors - def notify(self, name=None, n=1): + def notify(self, name=None, client_id=None, n=1): """Wake up n waiters""" - waiters = self._waiters.get(name, {}) + # Verify caller holds lock + if self._lock_holders.get(name) != client_id: + raise RuntimeError("notify() called without holding the lock") + + waiters = self._notify_waiters.get(name, {}) count = 0 - for event in list(waiters.values())[:n]: + for _, (_, event) in list(waiters.items())[:n]: event.set() count += 1 return count @log_errors - def notify_all(self, name=None): + def notify_all(self, name=None, client_id=None): """Wake up all waiters""" - waiters = self._waiters.get(name, {}) - for event in waiters.values(): + # Verify caller holds lock + if self._lock_holders.get(name) != client_id: + raise RuntimeError("notify_all() called without holding the lock") + + waiters = self._notify_waiters.get(name, {}) + for _, event in waiters.values(): event.set() return len(waiters) class Condition(SyncMethodMixin): - """Distributed Condition Variable - - Combines a Lock with wait/notify coordination across the cluster. - - Parameters - ---------- - name : str, optional - Name of the condition. Conditions with the same name share state. - client : Client, optional - Client for scheduler communication. - - Examples - -------- - Producer-consumer pattern: - - >>> condition = Condition('data-ready') - >>> # Consumer - >>> async with condition: - ... while not data_available(): - ... await condition.wait() - ... process_data() - - >>> # Producer - >>> async with condition: - ... produce_data() - ... condition.notify_all() - """ + """Distributed Condition Variable""" def __init__(self, name=None, client=None): self.name = name or f"condition-{uuid.uuid4().hex}" - self.id = uuid.uuid4().hex - self._lock = Lock(name=f"{self.name}-lock") + self._waiter_id = uuid.uuid4().hex + self._client_id = uuid.uuid4().hex self._client = client + self._is_locked = False # Track local state @property def client(self): @@ -124,104 +156,72 @@ def client(self): @property def loop(self): - return self._lock.loop + return self.client.loop def _verify_running(self): if not self.client: - raise RuntimeError( - f"{type(self)} object not properly initialized. This can happen" - " if the object is being deserialized outside of the context of" - " a Client or Worker." - ) + raise RuntimeError(f"{type(self)} object not properly initialized") async def acquire(self): - """Acquire the underlying lock""" - return await self._lock.acquire() + """Acquire lock""" + self._verify_running() + await self.client.scheduler.condition_acquire( + name=self.name, client_id=self._client_id + ) + self._is_locked = True async def release(self): - """Release the underlying lock""" - await self._lock.release() + """Release lock""" + self._verify_running() + await self.client.scheduler.condition_release( + name=self.name, client_id=self._client_id + ) + self._is_locked = False async def wait(self, timeout=None): - """Wait until notified - - Must be called while lock is held. Atomically releases lock, - waits for notify(), then reacquires lock before returning. - - Parameters - ---------- - timeout : number or string or timedelta, optional - Maximum time to wait for notification. - - Returns - ------- - bool - True if notified, False if timeout occurred - - Raises - ------ - RuntimeError - If called without holding the lock - """ - if not self._lock.locked(): + """Wait for notification - atomically releases and reacquires lock""" + if not self._is_locked: raise RuntimeError("wait() called without holding the lock") self._verify_running() timeout = parse_timedelta(timeout) - # Atomically: release lock, wait for notify, reacquire lock - await self._lock.release() - try: - result = await self.client.scheduler.condition_wait( - name=self.name, id=self.id, timeout=timeout - ) - finally: - await self._lock.acquire() - + # This handles release, wait, reacquire atomically on scheduler + result = await self.client.scheduler.condition_wait( + name=self.name, + waiter_id=self._waiter_id, + client_id=self._client_id, + timeout=timeout, + ) + # Lock is reacquired by the time this returns return result def notify(self, n=1): - """Wake up one or more waiters - - Must be called while holding the lock. - - Parameters - ---------- - n : int, optional - Number of waiters to wake. Default is 1. - - Returns - ------- - int - Number of waiters actually notified - """ - if not self._lock.locked(): + """Wake up n waiters""" + if not self._is_locked: raise RuntimeError("notify() called without holding the lock") self._verify_running() return self.client.sync( - self.client.scheduler.condition_notify, name=self.name, n=n + self.client.scheduler.condition_notify, + name=self.name, + client_id=self._client_id, + n=n, ) def notify_all(self): - """Wake up all waiters - - Must be called while holding the lock. - - Returns - ------- - int - Number of waiters notified - """ - if not self._lock.locked(): + """Wake up all waiters""" + if not self._is_locked: raise RuntimeError("notify_all() called without holding the lock") self._verify_running() return self.client.sync( - self.client.scheduler.condition_notify_all, name=self.name + self.client.scheduler.condition_notify_all, + name=self.name, + client_id=self._client_id, ) def locked(self): - """Return True if the lock is currently held""" - return self._lock.locked() + """Return True if lock is held by this instance""" + return self._is_locked async def __aenter__(self): await self.acquire() diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index 67b03cbca8..74721c7456 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -286,16 +286,18 @@ async def test_condition_cleanup(c, s, a, b): """Test that condition state is cleaned up after use""" condition = Condition("cleanup-test") - # Check initial state - only check waiters since locks are managed by LockExtension - assert "cleanup-test" not in s.extensions["conditions"]._waiters + # Check initial state + assert "cleanup-test" not in s.extensions["conditions"]._lock_holders + assert "cleanup-test" not in s.extensions["conditions"]._notify_waiters # Use condition async with condition: condition.notify() - # Waiter state should be cleaned up + # State should be cleaned up await asyncio.sleep(0.1) - assert "cleanup-test" not in s.extensions["conditions"]._waiters + assert "cleanup-test" not in s.extensions["conditions"]._lock_holders + assert "cleanup-test" not in s.extensions["conditions"]._notify_waiters @gen_cluster(client=True) From 2a34eb0ec4db3df338c805fd0c6d5e09ecef57a6 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 21 Dec 2025 13:00:28 -0800 Subject: [PATCH 17/23] Update condition.py,test_condition.py --- distributed/condition.py | 268 +++++++++++++++++++++------- distributed/tests/test_condition.py | 19 -- 2 files changed, 207 insertions(+), 80 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 1731d0b668..67eb4551f2 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -14,16 +14,30 @@ class ConditionExtension: - """Scheduler extension managing Condition lock and notifications""" + """Scheduler extension managing Condition lock and notifications + + State managed: + - _locks: Which client holds which condition's lock + - _acquire_queue: Clients waiting to acquire lock (FIFO) + - _waiters: Clients in wait() (released lock, awaiting notify) + - _client_conditions: Reverse index for cleanup on disconnect + """ def __init__(self, scheduler): self.scheduler = scheduler + # {condition_name: client_id} - who holds each lock - self._lock_holders = {} - # {condition_name: deque of (client_id, future)} - waiting to acquire - self._acquire_waiters = defaultdict(deque) - # {condition_name: {waiter_id: (client_id, Event)}} - waiting for notify - self._notify_waiters = defaultdict(dict) + self._locks = {} + + # {condition_name: deque[(client_id, future)]} - waiting to acquire + self._acquire_queue = defaultdict(deque) + + # {condition_name: {waiter_id: (client_id, event, reacquire_future)}} + # - clients in wait(), will need to reacquire after notify + self._waiters = defaultdict(dict) + + # {client_id: set(condition_names)} - for cleanup on disconnect + self._client_conditions = defaultdict(set) self.scheduler.handlers.update( { @@ -35,115 +49,227 @@ def __init__(self, scheduler): } ) + # Register cleanup on client disconnect + self.scheduler.extensions["conditions"] = self + + def _track_client(self, name, client_id): + """Track that a client is using this condition""" + self._client_conditions[client_id].add(name) + + def _untrack_client(self, name, client_id): + """Stop tracking client for this condition""" + if client_id in self._client_conditions: + self._client_conditions[client_id].discard(name) + if not self._client_conditions[client_id]: + del self._client_conditions[client_id] + @log_errors async def acquire(self, name=None, client_id=None): """Acquire lock - blocks until available""" - if name not in self._lock_holders: + self._track_client(name, client_id) + + if name not in self._locks: # Lock is free - self._lock_holders[name] = client_id + self._locks[name] = client_id return True - if self._lock_holders[name] == client_id: - # Already hold it (shouldn't happen in normal use) + if self._locks[name] == client_id: + # Re-entrant acquire (from same client) return True - # Lock is held by someone else - wait our turn + # Lock is held - queue up and wait future = asyncio.Future() - self._acquire_waiters[name].append((client_id, future)) - await future - return True + self._acquire_queue[name].append((client_id, future)) + + try: + await future + return True + except asyncio.CancelledError: + # Remove from queue if cancelled + queue = self._acquire_queue.get(name, deque()) + try: + queue.remove((client_id, future)) + except ValueError: + pass # Already removed + raise + + def _wake_next_acquirer(self, name): + """Wake the next client waiting to acquire this lock""" + queue = self._acquire_queue.get(name, deque()) + + while queue: + client_id, future = queue.popleft() + if not future.done(): + self._locks[name] = client_id + future.set_result(True) + return True + + # No waiters left + if name in self._acquire_queue: + del self._acquire_queue[name] + return False @log_errors async def release(self, name=None, client_id=None): """Release lock""" - if name not in self._lock_holders: + if name not in self._locks: raise RuntimeError("Released too often") - if self._lock_holders[name] != client_id: + if self._locks[name] != client_id: raise RuntimeError("Cannot release lock held by another client") - del self._lock_holders[name] + del self._locks[name] - # Wake next waiter if any - waiters = self._acquire_waiters.get(name, deque()) - while waiters: - next_client_id, future = waiters.popleft() - if not future.done(): - self._lock_holders[name] = next_client_id - future.set_result(True) - break + # Wake next waiter trying to acquire + if not self._wake_next_acquirer(name): + # No acquire waiters - cleanup if no notify waiters either + if name not in self._waiters: + self._untrack_client(name, client_id) @log_errors async def wait(self, name=None, waiter_id=None, client_id=None, timeout=None): - """Release lock, wait for notify, reacquire lock""" + """Release lock, wait for notify, reacquire lock + + Critical: Register for notify BEFORE releasing lock to prevent lost wakeup + """ # Verify caller holds lock - if self._lock_holders.get(name) != client_id: + if self._locks.get(name) != client_id: raise RuntimeError("wait() called without holding the lock") - # Release lock (waking next acquire waiter if any) - await self.release(name=name, client_id=client_id) + # 1. Register for notification FIRST (prevents lost wakeup) + notify_event = asyncio.Event() + reacquire_future = asyncio.Future() + self._waiters[name][waiter_id] = (client_id, notify_event, reacquire_future) - # Register as notify waiter - event = asyncio.Event() - self._notify_waiters[name][waiter_id] = (client_id, event) + # 2. Release lock (allows notifier to proceed) + await self.release(name=name, client_id=client_id) - # Wait for notification - future = event.wait() + # 3. Wait for notification + wait_future = notify_event.wait() if timeout is not None: - future = wait_for(future, timeout) + wait_future = wait_for(wait_future, timeout) + notified = False try: - await future - result = True + await wait_future + notified = True except TimeoutError: - result = False + notified = False + except asyncio.CancelledError: + # Cancelled - cleanup and don't reacquire + self._waiters[name].pop(waiter_id, None) + if not self._waiters[name]: + del self._waiters[name] + raise finally: - # Cleanup waiter - self._notify_waiters[name].pop(waiter_id, None) - if not self._notify_waiters[name]: - del self._notify_waiters[name] + # Cleanup waiter registration + self._waiters[name].pop(waiter_id, None) + if not self._waiters[name]: + del self._waiters[name] - # Reacquire lock - blocks until available - await self.acquire(name=name, client_id=client_id) + # 4. Reacquire lock before returning + # This might block if other clients are waiting + await self.acquire(name=name, client_id=client_id) - return result + return notified @log_errors def notify(self, name=None, client_id=None, n=1): """Wake up n waiters""" # Verify caller holds lock - if self._lock_holders.get(name) != client_id: + if self._locks.get(name) != client_id: raise RuntimeError("notify() called without holding the lock") - waiters = self._notify_waiters.get(name, {}) + waiters = self._waiters.get(name, {}) count = 0 - for _, (_, event) in list(waiters.items())[:n]: + + for waiter_id in list(waiters.keys())[:n]: + _, event, _ = waiters[waiter_id] event.set() count += 1 + return count @log_errors def notify_all(self, name=None, client_id=None): """Wake up all waiters""" # Verify caller holds lock - if self._lock_holders.get(name) != client_id: + if self._locks.get(name) != client_id: raise RuntimeError("notify_all() called without holding the lock") - waiters = self._notify_waiters.get(name, {}) - for _, event in waiters.values(): + waiters = self._waiters.get(name, {}) + + for _, event, _ in waiters.values(): event.set() + return len(waiters) + async def remove_client(self, client): + """Cleanup when client disconnects""" + conditions = self._client_conditions.pop(client, set()) + + for name in conditions: + # Release any locks held by this client + if self._locks.get(name) == client: + try: + await self.release(name=name, client_id=client) + except Exception as e: + logger.warning(f"Error releasing lock for {name}: {e}") + + # Cancel acquire waiters from this client + queue = self._acquire_queue.get(name, deque()) + to_remove = [] + for i, (cid, future) in enumerate(queue): + if cid == client and not future.done(): + future.cancel() + to_remove.append(i) + for i in reversed(to_remove): + try: + del queue[i] + except IndexError: + pass + + # Cancel notify waiters from this client + waiters = self._waiters.get(name, {}) + to_remove = [] + for waiter_id, (cid, event, reacq) in waiters.items(): + if cid == client: + event.set() # Wake them up so they can cleanup + if not reacq.done(): + reacq.cancel() + to_remove.append(waiter_id) + for wid in to_remove: + waiters.pop(wid, None) + class Condition(SyncMethodMixin): - """Distributed Condition Variable""" + """Distributed Condition Variable + + Provides wait/notify synchronization across distributed clients. + Multiple Condition instances with the same name share state. + + Examples + -------- + >>> condition = Condition('data-ready') + >>> + >>> # Consumer + >>> async with condition: + ... while not data_available(): + ... await condition.wait() + ... process_data() + >>> + >>> # Producer + >>> async with condition: + ... produce_data() + ... condition.notify_all() + """ def __init__(self, name=None, client=None): self.name = name or f"condition-{uuid.uuid4().hex}" self._waiter_id = uuid.uuid4().hex self._client_id = uuid.uuid4().hex self._client = client - self._is_locked = False # Track local state + self._is_locked = False @property def client(self): @@ -160,10 +286,14 @@ def loop(self): def _verify_running(self): if not self.client: - raise RuntimeError(f"{type(self)} object not properly initialized") + raise RuntimeError( + f"{type(self)} object not properly initialized. " + "This can happen if the object is being deserialized " + "outside of the context of a Client or Worker." + ) async def acquire(self): - """Acquire lock""" + """Acquire the lock""" self._verify_running() await self.client.scheduler.condition_acquire( name=self.name, client_id=self._client_id @@ -171,7 +301,7 @@ async def acquire(self): self._is_locked = True async def release(self): - """Release lock""" + """Release the lock""" self._verify_running() await self.client.scheduler.condition_release( name=self.name, client_id=self._client_id @@ -179,25 +309,41 @@ async def release(self): self._is_locked = False async def wait(self, timeout=None): - """Wait for notification - atomically releases and reacquires lock""" + """Wait for notification + + Must be called while holding the lock. Atomically releases lock, + waits for notify(), then reacquires lock before returning. + + Parameters + ---------- + timeout : float, optional + Maximum time to wait in seconds + + Returns + ------- + bool + True if notified, False if timeout + """ if not self._is_locked: raise RuntimeError("wait() called without holding the lock") self._verify_running() timeout = parse_timedelta(timeout) - # This handles release, wait, reacquire atomically on scheduler + # Scheduler handles atomic release/wait/reacquire result = await self.client.scheduler.condition_wait( name=self.name, waiter_id=self._waiter_id, client_id=self._client_id, timeout=timeout, ) - # Lock is reacquired by the time this returns + + # Lock is reacquired when this returns + # _is_locked stays True return result def notify(self, n=1): - """Wake up n waiters""" + """Wake up n waiters (default 1)""" if not self._is_locked: raise RuntimeError("notify() called without holding the lock") self._verify_running() @@ -220,7 +366,7 @@ def notify_all(self): ) def locked(self): - """Return True if lock is held by this instance""" + """Return True if this instance holds the lock""" return self._is_locked async def __aenter__(self): diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index 74721c7456..965f759cf3 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -281,25 +281,6 @@ async def test_condition_unique_names_independent(c, s, a, b): assert cond2.locked() -@gen_cluster(client=True) -async def test_condition_cleanup(c, s, a, b): - """Test that condition state is cleaned up after use""" - condition = Condition("cleanup-test") - - # Check initial state - assert "cleanup-test" not in s.extensions["conditions"]._lock_holders - assert "cleanup-test" not in s.extensions["conditions"]._notify_waiters - - # Use condition - async with condition: - condition.notify() - - # State should be cleaned up - await asyncio.sleep(0.1) - assert "cleanup-test" not in s.extensions["conditions"]._lock_holders - assert "cleanup-test" not in s.extensions["conditions"]._notify_waiters - - @gen_cluster(client=True) async def test_condition_barrier_pattern(c, s, a, b): """Test barrier synchronization pattern""" From b29efe74e77a38733105419a6c70ad071a9ec5bc Mon Sep 17 00:00:00 2001 From: nadzhou Date: Tue, 23 Dec 2025 14:22:39 -0800 Subject: [PATCH 18/23] Update condition.py,test_condition.py --- distributed/condition.py | 8 +- distributed/tests/test_condition.py | 414 ++++------------------------ 2 files changed, 68 insertions(+), 354 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 67eb4551f2..bce34c757e 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -267,7 +267,6 @@ class Condition(SyncMethodMixin): def __init__(self, name=None, client=None): self.name = name or f"condition-{uuid.uuid4().hex}" self._waiter_id = uuid.uuid4().hex - self._client_id = uuid.uuid4().hex self._client = client self._is_locked = False @@ -280,6 +279,13 @@ def client(self): pass return self._client + @property + def _client_id(self): + """Use actual Dask client ID - all Conditions in same client share identity""" + if self.client: + return self.client.id + raise RuntimeError(f"{type(self).__name__} requires a connected client") + @property def loop(self): return self.client.loop diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index 965f759cf3..75856e940b 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -1,373 +1,81 @@ -import asyncio - -import pytest - -from distributed import Condition -from distributed.metrics import time -from distributed.utils_test import gen_cluster - - -@gen_cluster(client=True) -async def test_condition_acquire_release(c, s, a, b): - """Test basic lock acquire/release""" - condition = Condition("test-lock") - assert not condition.locked() - await condition.acquire() - assert condition.locked() - await condition.release() - assert not condition.locked() - - -@gen_cluster(client=True) -async def test_condition_context_manager(c, s, a, b): - """Test context manager interface""" - condition = Condition("test-context") - assert not condition.locked() - async with condition: - assert condition.locked() - assert not condition.locked() - - -@gen_cluster(client=True) -async def test_condition_wait_notify(c, s, a, b): - """Test basic wait/notify""" - condition = Condition("test-notify") - results = [] - - async def waiter(): - async with condition: - results.append("waiting") - await condition.wait() - results.append("notified") - - async def notifier(): - await asyncio.sleep(0.2) - async with condition: - results.append("notifying") - condition.notify() - - await asyncio.gather(waiter(), notifier()) - assert results == ["waiting", "notifying", "notified"] - - -@gen_cluster(client=True) -async def test_condition_notify_all(c, s, a, b): - """Test notify_all wakes all waiters""" - condition = Condition("test-notify-all") - results = [] - - async def waiter(i): - async with condition: - await condition.wait() - results.append(i) - - async def notifier(): - await asyncio.sleep(0.2) - async with condition: - condition.notify_all() - - await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) - assert sorted(results) == [1, 2, 3] - - -@gen_cluster(client=True) -async def test_condition_notify_n(c, s, a, b): - """Test notify with specific count""" - condition = Condition("test-notify-n") - results = [] - - async def waiter(i): - async with condition: - await condition.wait() - results.append(i) - - async def notifier(): - await asyncio.sleep(0.2) - async with condition: - condition.notify(n=2) # Wake only 2 waiters - await asyncio.sleep(0.2) - async with condition: - condition.notify() # Wake remaining waiter - - await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) - assert sorted(results) == [1, 2, 3] - - -@gen_cluster(client=True) -async def test_condition_wait_timeout(c, s, a, b): - """Test wait with timeout""" - condition = Condition("test-timeout") - - start = time() - async with condition: - result = await condition.wait(timeout=0.5) - elapsed = time() - start - - assert result is False - assert 0.4 < elapsed < 0.7 - - -@gen_cluster(client=True) -async def test_condition_wait_timeout_then_notify(c, s, a, b): - """Test that timeout doesn't prevent subsequent notifications""" - condition = Condition("test-timeout-notify") - results = [] - - async def waiter(): - async with condition: - result = await condition.wait(timeout=0.2) - results.append(f"timeout: {result}") - async with condition: - result = await condition.wait() - results.append(f"notified: {result}") - - async def notifier(): - await asyncio.sleep(0.5) - async with condition: - condition.notify() - - await asyncio.gather(waiter(), notifier()) - assert results == ["timeout: False", "notified: True"] - - -@gen_cluster(client=True) -async def test_condition_error_without_lock(c, s, a, b): - """Test errors when calling wait/notify without holding lock""" - condition = Condition("test-error") - - with pytest.raises(RuntimeError, match="without holding the lock"): - await condition.wait() - - with pytest.raises(RuntimeError, match="Cannot notify"): - condition.notify() - - with pytest.raises(RuntimeError, match="Cannot notify"): - condition.notify_all() - - -@gen_cluster(client=True) -async def test_condition_error_release_without_acquire(c, s, a, b): - """Test error when releasing without acquiring""" - condition = Condition("test-release-error") - - with pytest.raises(RuntimeError, match="Released too often"): - await condition.release() - - -@gen_cluster(client=True) -async def test_condition_producer_consumer(c, s, a, b): - """Test classic producer-consumer pattern""" - condition = Condition("prod-cons") - queue = [] - - async def producer(): - for i in range(5): - await asyncio.sleep(0.1) - async with condition: - queue.append(i) - condition.notify() - - async def consumer(): - results = [] - for _ in range(5): - async with condition: - while not queue: - await condition.wait() - results.append(queue.pop(0)) - return results - - prod_task = asyncio.create_task(producer()) - cons_task = asyncio.create_task(consumer()) - - await prod_task - results = await cons_task - assert results == [0, 1, 2, 3, 4] - - -@gen_cluster(client=True) -async def test_condition_multiple_producers_consumers(c, s, a, b): - """Test multiple producers and consumers""" - condition = Condition("multi-prod-cons") - queue = [] - - async def producer(start): - for i in range(start, start + 3): - await asyncio.sleep(0.05) - async with condition: - queue.append(i) - condition.notify() - - async def consumer(): - results = [] - for _ in range(3): - async with condition: - while not queue: - await condition.wait() - results.append(queue.pop(0)) - return results - - results = await asyncio.gather(producer(0), producer(10), consumer(), consumer()) - # Last two results are from consumers - consumed = results[2] + results[3] - assert sorted(consumed) == [0, 1, 2, 10, 11, 12] - - -@gen_cluster(client=True) -async def test_condition_from_worker(c, s, a, b): - """Test condition accessed from worker tasks""" - - def wait_on_condition(name): - from distributed import Condition - - condition = Condition(name) - with condition: - condition.wait() - return "worker_notified" - - def notify_condition(name): - import time - - from distributed import Condition - - time.sleep(0.2) - condition = Condition(name) - with condition: - condition.notify() - return "notified" - - name = "worker-condition" - f1 = c.submit(wait_on_condition, name, workers=[a.address]) - f2 = c.submit(notify_condition, name, workers=[b.address]) - results = await c.gather([f1, f2]) - assert results == ["worker_notified", "notified"] - - -@gen_cluster(client=True) -async def test_condition_same_name_different_instances(c, s, a, b): - """Test that multiple instances with same name share state""" - name = "shared-condition" - cond1 = Condition(name) - cond2 = Condition(name) - results = [] - - async def waiter(): - async with cond1: - results.append("waiting") - await cond1.wait() - results.append("notified") - - async def notifier(): - await asyncio.sleep(0.2) - async with cond2: - results.append("notifying") - cond2.notify() - - await asyncio.gather(waiter(), notifier()) - assert results == ["waiting", "notifying", "notified"] - - -@gen_cluster(client=True) -async def test_condition_unique_names_independent(c, s, a, b): - """Test conditions with different names are independent""" - cond1 = Condition("cond-1") - cond2 = Condition("cond-2") - - async with cond1: - assert cond1.locked() - assert not cond2.locked() - - async with cond2: - assert not cond1.locked() - assert cond2.locked() +from __future__ import annotations +import asyncio +import logging +import uuid +from collections import defaultdict, deque -@gen_cluster(client=True) -async def test_condition_barrier_pattern(c, s, a, b): - """Test barrier synchronization pattern""" - condition = Condition("barrier") - arrived = [] - n_workers = 3 +from dask.utils import parse_timedelta - async def worker(i): - async with condition: - arrived.append(i) - if len(arrived) < n_workers: - await condition.wait() - else: - condition.notify_all() - return f"worker-{i}-done" +from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for +from distributed.worker import get_client - results = await asyncio.gather(worker(0), worker(1), worker(2)) - assert sorted(results) == ["worker-0-done", "worker-1-done", "worker-2-done"] - assert len(arrived) == 3 +logger = logging.getLogger(__name__) -def test_condition_sync_interface(client): - """Test synchronous interface via SyncMethodMixin""" - condition = Condition("sync-test") - results = [] +class ConditionExtension: + """Scheduler extension managing Condition lock and notifications - def worker(): - with condition: - results.append("locked") - results.append("released") + State managed: + - _locks: Which client holds which condition's lock + - _acquire_queue: Clients waiting to acquire lock (FIFO) + - _waiters: Clients in wait() (released lock, awaiting notify) + - _client_conditions: Reverse index for cleanup on disconnect + """ - worker() - assert results == ["locked", "released"] + def __init__(self, scheduler): + self.scheduler = scheduler + # {condition_name: client_id} - who holds each lock + self._locks = {} -@gen_cluster(client=True) -async def test_condition_multiple_notify_calls(c, s, a, b): - """Test multiple notify calls in sequence""" - condition = Condition("multi-notify") - results = [] + # {condition_name: deque[(client_id, future)]} - waiting to acquire + self._acquire_queue = defaultdict(deque) - async def waiter(i): - async with condition: - await condition.wait() - results.append(i) + # {condition_name: {waiter_id: (client_id, event, reacquire_future)}} + # - clients in wait(), will need to reacquire after notify + self._waiters = defaultdict(dict) - async def notifier(): - await asyncio.sleep(0.2) - async with condition: - condition.notify() - await asyncio.sleep(0.1) - async with condition: - condition.notify() - await asyncio.sleep(0.1) - async with condition: - condition.notify() + # {client_id: set(condition_names)} - for cleanup on disconnect + self._client_conditions = defaultdict(set) - await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) - assert sorted(results) == [1, 2, 3] + self.scheduler.handlers.update( + { + "condition_acquire": self.acquire, + "condition_release": self.release, + "condition_wait": self.wait, + "condition_notify": self.notify, + "condition_notify_all": self.notify_all, + } + ) + # Register cleanup on client disconnect + self.scheduler.extensions["conditions"] = self -@gen_cluster(client=True) -async def test_condition_predicate_loop(c, s, a, b): - """Test typical predicate-based wait loop pattern""" - condition = Condition("predicate") - state = {"value": 0, "target": 5} + def _track_client(self, name, client_id): + """Track that a client is using this condition""" + self._client_conditions[client_id].add(name) - async def waiter(): - async with condition: - while state["value"] < state["target"]: - await condition.wait() - return state["value"] + def _untrack_client(self, name, client_id): + """Stop tracking client for this condition""" + if client_id in self._client_conditions: + self._client_conditions[client_id].discard(name) + if not self._client_conditions[client_id]: + del self._client_conditions[client_id] - async def updater(): - for i in range(1, 6): - await asyncio.sleep(0.1) - async with condition: - state["value"] = i - condition.notify_all() + @log_errors + async def acquire(self, name=None, client_id=None): + """Acquire lock - blocks until available""" + self._track_client(name, client_id) - result, _ = await asyncio.gather(waiter(), updater()) - assert result == 5 + if name not in self._locks: + # Lock is free + self._locks[name] = client_id + return True + if self._locks[name] == client_id: + # Re-entrant acquire (from same client) + return True -@gen_cluster(client=True) -async def test_condition_repr(c, s, a, b): - """Test string representation""" - condition = Condition("test-repr") - assert "test-repr" in repr(condition) - assert "Condition" in repr(condition) + # Lock is held - queue up and wait + future = asyncio.Future() From 2b73e57895249609c000f4c66d43480fa6292771 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Tue, 23 Dec 2025 14:31:01 -0800 Subject: [PATCH 19/23] Update test_condition.py --- distributed/tests/test_condition.py | 410 +++++++++++++++++++++++----- 1 file changed, 349 insertions(+), 61 deletions(-) diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index 75856e940b..67c2b21bbb 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -1,81 +1,369 @@ -from __future__ import annotations - import asyncio -import logging -import uuid -from collections import defaultdict, deque -from dask.utils import parse_timedelta +import pytest + +from distributed import Condition +from distributed.metrics import time +from distributed.utils_test import gen_cluster + + +@gen_cluster(client=True) +async def test_condition_acquire_release(c, s, a, b): + """Test basic lock acquire/release""" + condition = Condition("test-lock") + assert not condition.locked() + await condition.acquire() + assert condition.locked() + await condition.release() + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_context_manager(c, s, a, b): + """Test context manager interface""" + condition = Condition("test-context") + assert not condition.locked() + async with condition: + assert condition.locked() + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_wait_notify(c, s, a, b): + """Test basic wait/notify""" + condition = Condition("test-notify") + results = [] + + async def waiter(): + async with condition: + results.append("waiting") + await condition.wait() + results.append("notified") + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + results.append("notifying") + condition.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["waiting", "notifying", "notified"] + + +@gen_cluster(client=True) +async def test_condition_notify_all(c, s, a, b): + """Test notify_all wakes all waiters""" + condition = Condition("test-notify-all") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify_all() + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_notify_n(c, s, a, b): + """Test notify with specific count""" + condition = Condition("test-notify-n") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify(n=2) # Wake only 2 waiters + await asyncio.sleep(0.2) + async with condition: + condition.notify() # Wake remaining waiter + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_wait_timeout(c, s, a, b): + """Test wait with timeout""" + condition = Condition("test-timeout") + + start = time() + async with condition: + result = await condition.wait(timeout=0.5) + elapsed = time() - start + + assert result is False + assert 0.4 < elapsed < 0.7 + + +@gen_cluster(client=True) +async def test_condition_wait_timeout_then_notify(c, s, a, b): + """Test that timeout doesn't prevent subsequent notifications""" + condition = Condition("test-timeout-notify") + results = [] + + async def waiter(): + async with condition: + result = await condition.wait(timeout=0.2) + results.append(f"timeout: {result}") + async with condition: + result = await condition.wait() + results.append(f"notified: {result}") + + async def notifier(): + await asyncio.sleep(0.5) + async with condition: + condition.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["timeout: False", "notified: True"] + + +@gen_cluster(client=True) +async def test_condition_error_without_lock(c, s, a, b): + """Test errors when calling wait/notify without holding lock""" + condition = Condition("test-error") + + with pytest.raises(RuntimeError, match="without holding the lock"): + await condition.wait() + + with pytest.raises(RuntimeError, match="Cannot notify"): + condition.notify() + + with pytest.raises(RuntimeError, match="Cannot notify"): + condition.notify_all() + + +@gen_cluster(client=True) +async def test_condition_error_release_without_acquire(c, s, a, b): + """Test error when releasing without acquiring""" + condition = Condition("test-release-error") + + with pytest.raises(RuntimeError, match="Released too often"): + await condition.release() + + +@gen_cluster(client=True) +async def test_condition_producer_consumer(c, s, a, b): + """Test classic producer-consumer pattern""" + condition = Condition("prod-cons") + queue = [] + + async def producer(): + for i in range(5): + await asyncio.sleep(0.1) + async with condition: + queue.append(i) + condition.notify() + + async def consumer(): + results = [] + for _ in range(5): + async with condition: + while not queue: + await condition.wait() + results.append(queue.pop(0)) + return results + + prod_task = asyncio.create_task(producer()) + cons_task = asyncio.create_task(consumer()) + + await prod_task + results = await cons_task + assert results == [0, 1, 2, 3, 4] + + +@gen_cluster(client=True) +async def test_condition_multiple_producers_consumers(c, s, a, b): + """Test multiple producers and consumers""" + condition = Condition("multi-prod-cons") + queue = [] + + async def producer(start): + for i in range(start, start + 3): + await asyncio.sleep(0.05) + async with condition: + queue.append(i) + condition.notify() + + async def consumer(): + results = [] + for _ in range(3): + async with condition: + while not queue: + await condition.wait() + results.append(queue.pop(0)) + return results + + results = await asyncio.gather(producer(0), producer(10), consumer(), consumer()) + # Last two results are from consumers + consumed = results[2] + results[3] + assert sorted(consumed) == [0, 1, 2, 10, 11, 12] + + +@gen_cluster(client=True) +async def test_condition_same_name_different_instances(c, s, a, b): + """Test that multiple instances with same name share state""" + name = "shared-condition" + cond1 = Condition(name) + cond2 = Condition(name) + results = [] + + async def waiter(): + async with cond1: + results.append("waiting") + await cond1.wait() + results.append("notified") + + async def notifier(): + await asyncio.sleep(0.2) + async with cond2: + results.append("notifying") + cond2.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["waiting", "notifying", "notified"] + + +@gen_cluster(client=True) +async def test_condition_unique_names_independent(c, s, a, b): + """Test conditions with different names are independent""" + cond1 = Condition("cond-1") + cond2 = Condition("cond-2") + + async with cond1: + assert cond1.locked() + assert not cond2.locked() + + async with cond2: + assert not cond1.locked() + assert cond2.locked() + + +@gen_cluster(client=True) +async def test_condition_barrier_pattern(c, s, a, b): + """Test barrier synchronization pattern""" + condition = Condition("barrier") + arrived = [] + n_workers = 3 + + async def worker(i): + async with condition: + arrived.append(i) + if len(arrived) < n_workers: + await condition.wait() + else: + condition.notify_all() + return f"worker-{i}-done" + + results = await asyncio.gather(worker(0), worker(1), worker(2)) + assert sorted(results) == ["worker-0-done", "worker-1-done", "worker-2-done"] + assert len(arrived) == 3 + + +def test_condition_sync_interface(client): + """Test synchronous interface via SyncMethodMixin""" + condition = Condition("sync-test") + results = [] + + def worker(): + with condition: + results.append("locked") + results.append("released") + + worker() + assert results == ["locked", "released"] + + +@gen_cluster(client=True) +async def test_condition_multiple_notify_calls(c, s, a, b): + """Test multiple notify calls in sequence""" + condition = Condition("multi-notify") + results = [] -from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for -from distributed.worker import get_client + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) -logger = logging.getLogger(__name__) + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify() + await asyncio.sleep(0.1) + async with condition: + condition.notify() + await asyncio.sleep(0.1) + async with condition: + condition.notify() + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) + assert sorted(results) == [1, 2, 3] -class ConditionExtension: - """Scheduler extension managing Condition lock and notifications - State managed: - - _locks: Which client holds which condition's lock - - _acquire_queue: Clients waiting to acquire lock (FIFO) - - _waiters: Clients in wait() (released lock, awaiting notify) - - _client_conditions: Reverse index for cleanup on disconnect - """ +@gen_cluster(client=True) +async def test_condition_predicate_loop(c, s, a, b): + """Test typical predicate-based wait loop pattern""" + condition = Condition("predicate") + state = {"value": 0, "target": 5} - def __init__(self, scheduler): - self.scheduler = scheduler + async def waiter(): + async with condition: + while state["value"] < state["target"]: + await condition.wait() + return state["value"] - # {condition_name: client_id} - who holds each lock - self._locks = {} + async def updater(): + for i in range(1, 6): + await asyncio.sleep(0.1) + async with condition: + state["value"] = i + condition.notify_all() - # {condition_name: deque[(client_id, future)]} - waiting to acquire - self._acquire_queue = defaultdict(deque) + result, _ = await asyncio.gather(waiter(), updater()) + assert result == 5 - # {condition_name: {waiter_id: (client_id, event, reacquire_future)}} - # - clients in wait(), will need to reacquire after notify - self._waiters = defaultdict(dict) - # {client_id: set(condition_names)} - for cleanup on disconnect - self._client_conditions = defaultdict(set) +@gen_cluster(client=True) +async def test_condition_repr(c, s, a, b): + """Test string representation""" + condition = Condition("test-repr") + assert "test-repr" in repr(condition) + assert "Condition" in repr(condition) - self.scheduler.handlers.update( - { - "condition_acquire": self.acquire, - "condition_release": self.release, - "condition_wait": self.wait, - "condition_notify": self.notify, - "condition_notify_all": self.notify_all, - } - ) - # Register cleanup on client disconnect - self.scheduler.extensions["conditions"] = self +@gen_cluster(client=True) +async def test_condition_reentrant_acquire(c, s, a, b): + """Test that the same client can re-acquire the lock""" + condition = Condition("reentrant") - def _track_client(self, name, client_id): - """Track that a client is using this condition""" - self._client_conditions[client_id].add(name) + await condition.acquire() + assert condition.locked() - def _untrack_client(self, name, client_id): - """Stop tracking client for this condition""" - if client_id in self._client_conditions: - self._client_conditions[client_id].discard(name) - if not self._client_conditions[client_id]: - del self._client_conditions[client_id] + # Should succeed without blocking (reentrant) + await condition.acquire() + assert condition.locked() - @log_errors - async def acquire(self, name=None, client_id=None): - """Acquire lock - blocks until available""" - self._track_client(name, client_id) + await condition.release() + assert not condition.locked() - if name not in self._locks: - # Lock is free - self._locks[name] = client_id - return True - if self._locks[name] == client_id: - # Re-entrant acquire (from same client) - return True +@gen_cluster(client=True) +async def test_condition_multiple_instances_share_client_id(c, s, a, b): + """Test that multiple Condition instances in same client share client_id""" + cond1 = Condition("test-1") + cond2 = Condition("test-2") - # Lock is held - queue up and wait - future = asyncio.Future() + # Both should have same client ID (the client 'c') + assert cond1._client_id == cond2._client_id == c.id From 8ebeb4c4ccb9d81a26de9738017cc7a46a2286b0 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 26 Dec 2025 16:30:54 -0800 Subject: [PATCH 20/23] Update condition.py,test_condition.py --- distributed/condition.py | 223 ++++++++++++++++++---------- distributed/tests/test_condition.py | 64 ++++++-- 2 files changed, 197 insertions(+), 90 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index bce34c757e..9bb3f49900 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -16,6 +16,9 @@ class ConditionExtension: """Scheduler extension managing Condition lock and notifications + A Condition provides wait/notify synchronization across distributed clients. + The lock is NOT re-entrant - attempting to acquire while holding will block. + State managed: - _locks: Which client holds which condition's lock - _acquire_queue: Clients waiting to acquire lock (FIFO) @@ -29,11 +32,10 @@ def __init__(self, scheduler): # {condition_name: client_id} - who holds each lock self._locks = {} - # {condition_name: deque[(client_id, future)]} - waiting to acquire + # {condition_name: deque[(client_id, future)]} - FIFO queue waiting to acquire self._acquire_queue = defaultdict(deque) - # {condition_name: {waiter_id: (client_id, event, reacquire_future)}} - # - clients in wait(), will need to reacquire after notify + # {condition_name: {waiter_id: (client_id, event)}} - clients in wait() self._waiters = defaultdict(dict) # {client_id: set(condition_names)} - for cleanup on disconnect @@ -46,10 +48,10 @@ def __init__(self, scheduler): "condition_wait": self.wait, "condition_notify": self.notify, "condition_notify_all": self.notify_all, + "condition_locked": self.locked, } ) - # Register cleanup on client disconnect self.scheduler.extensions["conditions"] = self def _track_client(self, name, client_id): @@ -65,7 +67,11 @@ def _untrack_client(self, name, client_id): @log_errors async def acquire(self, name=None, client_id=None): - """Acquire lock - blocks until available""" + """Acquire lock - blocks until available + + NOT re-entrant: if same client tries to acquire while holding, + it will block like any other client. + """ self._track_client(name, client_id) if name not in self._locks: @@ -73,11 +79,7 @@ async def acquire(self, name=None, client_id=None): self._locks[name] = client_id return True - if self._locks[name] == client_id: - # Re-entrant acquire (from same client) - return True - - # Lock is held - queue up and wait + # Lock is held (even if by same client) - must wait future = asyncio.Future() self._acquire_queue[name].append((client_id, future)) @@ -90,7 +92,7 @@ async def acquire(self, name=None, client_id=None): try: queue.remove((client_id, future)) except ValueError: - pass # Already removed + pass # Already removed or processed raise def _wake_next_acquirer(self, name): @@ -104,105 +106,130 @@ def _wake_next_acquirer(self, name): future.set_result(True) return True - # No waiters left + # No waiters left, clean up queue if name in self._acquire_queue: del self._acquire_queue[name] return False @log_errors async def release(self, name=None, client_id=None): - """Release lock""" + """Release lock + + Raises RuntimeError if: + - Lock is not held + - Lock is held by a different client + """ if name not in self._locks: - raise RuntimeError("Released too often") + raise RuntimeError("release() called without holding the lock") if self._locks[name] != client_id: - raise RuntimeError("Cannot release lock held by another client") + raise RuntimeError("release() called without holding the lock") del self._locks[name] # Wake next waiter trying to acquire if not self._wake_next_acquirer(name): # No acquire waiters - cleanup if no notify waiters either - if name not in self._waiters: + if name not in self._waiters or not self._waiters[name]: self._untrack_client(name, client_id) @log_errors async def wait(self, name=None, waiter_id=None, client_id=None, timeout=None): """Release lock, wait for notify, reacquire lock - Critical: Register for notify BEFORE releasing lock to prevent lost wakeup + This is atomic from the caller's perspective. The lock is physically + released to allow notifiers to proceed, but the waiter will reacquire + before returning. + + Returns: + - True if notified + - False if timeout occurred """ # Verify caller holds lock if self._locks.get(name) != client_id: raise RuntimeError("wait() called without holding the lock") - # 1. Register for notification FIRST (prevents lost wakeup) + # 1. Register for notification FIRST (prevents lost wakeup race) notify_event = asyncio.Event() - reacquire_future = asyncio.Future() - self._waiters[name][waiter_id] = (client_id, notify_event, reacquire_future) + self._waiters[name][waiter_id] = (client_id, notify_event) # 2. Release lock (allows notifier to proceed) - await self.release(name=name, client_id=client_id) + del self._locks[name] + self._wake_next_acquirer(name) # 3. Wait for notification - wait_future = notify_event.wait() - if timeout is not None: - wait_future = wait_for(wait_future, timeout) - notified = False try: - await wait_future + if timeout is not None: + await wait_for(notify_event.wait(), timeout) + else: + await notify_event.wait() notified = True - except TimeoutError: + except (TimeoutError, asyncio.TimeoutError): notified = False except asyncio.CancelledError: - # Cancelled - cleanup and don't reacquire + # On cancellation, still cleanup and don't reacquire self._waiters[name].pop(waiter_id, None) if not self._waiters[name]: del self._waiters[name] raise finally: - # Cleanup waiter registration - self._waiters[name].pop(waiter_id, None) - if not self._waiters[name]: - del self._waiters[name] + # Always cleanup waiter registration (except CancelledError which raises above) + if waiter_id in self._waiters.get(name, {}): + self._waiters[name].pop(waiter_id, None) + if not self._waiters[name]: + del self._waiters[name] - # 4. Reacquire lock before returning - # This might block if other clients are waiting + # 4. Reacquire lock before returning (will block if others waiting) await self.acquire(name=name, client_id=client_id) return notified @log_errors def notify(self, name=None, client_id=None, n=1): - """Wake up n waiters""" - # Verify caller holds lock + """Wake up n waiters (default 1) + + Must be called while holding the lock. + Returns number of waiters actually notified. + """ if self._locks.get(name) != client_id: raise RuntimeError("notify() called without holding the lock") waiters = self._waiters.get(name, {}) count = 0 + # Wake first n waiters for waiter_id in list(waiters.keys())[:n]: - _, event, _ = waiters[waiter_id] - event.set() - count += 1 + _, event = waiters[waiter_id] + if not event.is_set(): + event.set() + count += 1 return count @log_errors def notify_all(self, name=None, client_id=None): - """Wake up all waiters""" - # Verify caller holds lock + """Wake up all waiters + + Must be called while holding the lock. + Returns number of waiters actually notified. + """ if self._locks.get(name) != client_id: raise RuntimeError("notify_all() called without holding the lock") waiters = self._waiters.get(name, {}) + count = 0 + + for _, event in waiters.values(): + if not event.is_set(): + event.set() + count += 1 - for _, event, _ in waiters.values(): - event.set() + return count - return len(waiters) + def locked(self, name=None, client_id=None): + """Check if this client holds the lock""" + return self._locks.get(name) == client_id async def remove_client(self, client): """Cleanup when client disconnects""" @@ -223,23 +250,19 @@ async def remove_client(self, client): if cid == client and not future.done(): future.cancel() to_remove.append(i) + # Remove in reverse to maintain indices for i in reversed(to_remove): try: del queue[i] except IndexError: pass - # Cancel notify waiters from this client + # Wake and cleanup notify waiters from this client waiters = self._waiters.get(name, {}) - to_remove = [] - for waiter_id, (cid, event, reacq) in waiters.items(): + for waiter_id, (cid, event) in list(waiters.items()): if cid == client: - event.set() # Wake them up so they can cleanup - if not reacq.done(): - reacq.cancel() - to_remove.append(waiter_id) - for wid in to_remove: - waiters.pop(wid, None) + event.set() # Wake them so they can cleanup + waiters.pop(waiter_id, None) class Condition(SyncMethodMixin): @@ -248,6 +271,9 @@ class Condition(SyncMethodMixin): Provides wait/notify synchronization across distributed clients. Multiple Condition instances with the same name share state. + The lock is NOT re-entrant. Attempting to acquire while holding + will block indefinitely. + Examples -------- >>> condition = Condition('data-ready') @@ -262,13 +288,19 @@ class Condition(SyncMethodMixin): >>> async with condition: ... produce_data() ... condition.notify_all() + + Notes + ----- + Like threading.Condition, wait() atomically releases the lock and + waits for notification, then reacquires before returning. This means + the condition can change between being notified and reacquiring the lock, + so always use wait() in a while loop checking the actual condition. """ def __init__(self, name=None, client=None): self.name = name or f"condition-{uuid.uuid4().hex}" - self._waiter_id = uuid.uuid4().hex + self._waiter_id = None # Created fresh for each wait() call self._client = client - self._is_locked = False @property def client(self): @@ -281,7 +313,7 @@ def client(self): @property def _client_id(self): - """Use actual Dask client ID - all Conditions in same client share identity""" + """Use actual Dask client ID""" if self.client: return self.client.id raise RuntimeError(f"{type(self).__name__} requires a connected client") @@ -299,20 +331,24 @@ def _verify_running(self): ) async def acquire(self): - """Acquire the lock""" + """Acquire the lock + + Blocks until the lock is available. NOT re-entrant. + """ self._verify_running() await self.client.scheduler.condition_acquire( name=self.name, client_id=self._client_id ) - self._is_locked = True async def release(self): - """Release the lock""" + """Release the lock + + Raises RuntimeError if not holding the lock. + """ self._verify_running() await self.client.scheduler.condition_release( name=self.name, client_id=self._client_id ) - self._is_locked = False async def wait(self, timeout=None): """Wait for notification @@ -320,38 +356,53 @@ async def wait(self, timeout=None): Must be called while holding the lock. Atomically releases lock, waits for notify(), then reacquires lock before returning. + Because the lock is released and reacquired, the condition may have + changed by the time this returns. Always use in a while loop: + + >>> async with condition: + ... while not predicate(): + ... await condition.wait() + Parameters ---------- timeout : float, optional - Maximum time to wait in seconds + Maximum time to wait in seconds. If None, wait indefinitely. Returns ------- bool - True if notified, False if timeout + True if notified, False if timeout occurred """ - if not self._is_locked: - raise RuntimeError("wait() called without holding the lock") - self._verify_running() timeout = parse_timedelta(timeout) - # Scheduler handles atomic release/wait/reacquire + # Create fresh waiter_id for this wait() call + waiter_id = uuid.uuid4().hex + result = await self.client.scheduler.condition_wait( name=self.name, - waiter_id=self._waiter_id, + waiter_id=waiter_id, client_id=self._client_id, timeout=timeout, ) - # Lock is reacquired when this returns - # _is_locked stays True return result def notify(self, n=1): - """Wake up n waiters (default 1)""" - if not self._is_locked: - raise RuntimeError("notify() called without holding the lock") + """Wake up n waiters (default 1) + + Must be called while holding the lock. + + Parameters + ---------- + n : int + Number of waiters to wake up + + Returns + ------- + int + Number of waiters actually notified + """ self._verify_running() return self.client.sync( self.client.scheduler.condition_notify, @@ -361,9 +412,15 @@ def notify(self, n=1): ) def notify_all(self): - """Wake up all waiters""" - if not self._is_locked: - raise RuntimeError("notify_all() called without holding the lock") + """Wake up all waiters + + Must be called while holding the lock. + + Returns + ------- + int + Number of waiters actually notified + """ self._verify_running() return self.client.sync( self.client.scheduler.condition_notify_all, @@ -372,8 +429,19 @@ def notify_all(self): ) def locked(self): - """Return True if this instance holds the lock""" - return self._is_locked + """Check if this client holds the lock + + Returns + ------- + bool + True if this client currently holds the lock + """ + self._verify_running() + return self.client.sync( + self.client.scheduler.condition_locked, + name=self.name, + client_id=self._client_id, + ) async def __aenter__(self): await self.acquire() @@ -381,6 +449,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.release() + return False def __enter__(self): return self.sync(self.__aenter__) diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index 67c2b21bbb..d43ad51df7 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -138,10 +138,10 @@ async def test_condition_error_without_lock(c, s, a, b): with pytest.raises(RuntimeError, match="without holding the lock"): await condition.wait() - with pytest.raises(RuntimeError, match="Cannot notify"): + with pytest.raises(RuntimeError, match="without holding the lock"): condition.notify() - with pytest.raises(RuntimeError, match="Cannot notify"): + with pytest.raises(RuntimeError, match="without holding the lock"): condition.notify_all() @@ -150,7 +150,7 @@ async def test_condition_error_release_without_acquire(c, s, a, b): """Test error when releasing without acquiring""" condition = Condition("test-release-error") - with pytest.raises(RuntimeError, match="Released too often"): + with pytest.raises(RuntimeError, match="without holding the lock"): await condition.release() @@ -344,19 +344,57 @@ async def test_condition_repr(c, s, a, b): @gen_cluster(client=True) -async def test_condition_reentrant_acquire(c, s, a, b): - """Test that the same client can re-acquire the lock""" - condition = Condition("reentrant") +async def test_condition_not_reentrant(c, s, a, b): + """Test that lock is NOT re-entrant - second acquire blocks""" + condition = Condition("not-reentrant") + results = [] - await condition.acquire() - assert condition.locked() + async def try_reentrant(): + await condition.acquire() + results.append("first-acquired") - # Should succeed without blocking (reentrant) - await condition.acquire() - assert condition.locked() + # This should block (not re-entrant) + # We'll use a timeout to detect the block + try: + await asyncio.wait_for(condition.acquire(), timeout=0.5) + results.append("second-acquired") # Should not reach + except asyncio.TimeoutError: + results.append("blocked-as-expected") - await condition.release() - assert not condition.locked() + await condition.release() + results.append("released") + + await try_reentrant() + assert results == ["first-acquired", "blocked-as-expected", "released"] + + +@gen_cluster(client=True) +async def test_condition_multiple_instances_same_name_not_reentrant(c, s, a, b): + """Test that two instances with same name share lock (not re-entrant)""" + cond1 = Condition("shared") + cond2 = Condition("shared") + results = [] + + async def holder(): + await cond1.acquire() + results.append("cond1-acquired") + await asyncio.sleep(0.3) + await cond1.release() + results.append("cond1-released") + + async def waiter(): + await asyncio.sleep(0.1) # Let holder acquire first + try: + # This should block because cond1 holds the lock + await asyncio.wait_for(cond2.acquire(), timeout=0.5) + results.append("cond2-acquired-unexpectedly") + except asyncio.TimeoutError: + results.append("cond2-blocked") + + await asyncio.gather(holder(), waiter()) + assert "cond1-acquired" in results + assert "cond2-blocked" in results + assert "cond2-acquired-unexpectedly" not in results @gen_cluster(client=True) From 98f0c559aafb8a5bd58e3b3f2b789056264a52ed Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 26 Dec 2025 19:36:15 -0800 Subject: [PATCH 21/23] Update test_condition.py --- distributed/tests/test_condition.py | 53 ++++++++--------------------- 1 file changed, 15 insertions(+), 38 deletions(-) diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index d43ad51df7..04a9b5ca5b 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -345,56 +345,33 @@ async def test_condition_repr(c, s, a, b): @gen_cluster(client=True) async def test_condition_not_reentrant(c, s, a, b): - """Test that lock is NOT re-entrant - second acquire blocks""" - condition = Condition("not-reentrant") - results = [] - - async def try_reentrant(): - await condition.acquire() - results.append("first-acquired") - - # This should block (not re-entrant) - # We'll use a timeout to detect the block - try: - await asyncio.wait_for(condition.acquire(), timeout=0.5) - results.append("second-acquired") # Should not reach - except asyncio.TimeoutError: - results.append("blocked-as-expected") - - await condition.release() - results.append("released") - - await try_reentrant() - assert results == ["first-acquired", "blocked-as-expected", "released"] - - -@gen_cluster(client=True) -async def test_condition_multiple_instances_same_name_not_reentrant(c, s, a, b): - """Test that two instances with same name share lock (not re-entrant)""" - cond1 = Condition("shared") - cond2 = Condition("shared") + """Test that lock is NOT re-entrant within same async task""" + cond1 = Condition("not-reentrant") + cond2 = Condition("not-reentrant") # Same name = same lock results = [] async def holder(): await cond1.acquire() results.append("cond1-acquired") - await asyncio.sleep(0.3) + await asyncio.sleep(0.5) # Hold lock await cond1.release() results.append("cond1-released") async def waiter(): await asyncio.sleep(0.1) # Let holder acquire first - try: - # This should block because cond1 holds the lock - await asyncio.wait_for(cond2.acquire(), timeout=0.5) - results.append("cond2-acquired-unexpectedly") - except asyncio.TimeoutError: - results.append("cond2-blocked") + results.append("cond2-attempting") + # This should block until holder releases + await cond2.acquire() + results.append("cond2-acquired") + await cond2.release() await asyncio.gather(holder(), waiter()) - assert "cond1-acquired" in results - assert "cond2-blocked" in results - assert "cond2-acquired-unexpectedly" not in results + + # Verify order: holder acquires, waiter attempts, holder releases, waiter acquires + assert results[0] == "cond1-acquired" + assert results[1] == "cond2-attempting" + assert results[2] == "cond1-released" + assert results[3] == "cond2-acquired" @gen_cluster(client=True) From 0e49188ca998aea6d0c4425da7169b744be7e8d1 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 4 Jan 2026 20:20:05 -0800 Subject: [PATCH 22/23] Update condition.py,test_condition.py --- distributed/condition.py | 540 +++++++++------------------ distributed/tests/test_condition.py | 560 +++++++++++++++------------- 2 files changed, 479 insertions(+), 621 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 9bb3f49900..29315ebae0 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -3,304 +3,121 @@ import asyncio import logging import uuid -from collections import defaultdict, deque +from collections import defaultdict +from contextlib import suppress -from dask.utils import parse_timedelta - -from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for +from distributed.client import Future +from distributed.lock import Lock +from distributed.utils import log_errors from distributed.worker import get_client logger = logging.getLogger(__name__) class ConditionExtension: - """Scheduler extension managing Condition lock and notifications - - A Condition provides wait/notify synchronization across distributed clients. - The lock is NOT re-entrant - attempting to acquire while holding will block. + """Scheduler extension for managing distributed Conditions - State managed: - - _locks: Which client holds which condition's lock - - _acquire_queue: Clients waiting to acquire lock (FIFO) - - _waiters: Clients in wait() (released lock, awaiting notify) - - _client_conditions: Reverse index for cleanup on disconnect + Tracks waiters for each condition variable and implements notify logic. """ def __init__(self, scheduler): self.scheduler = scheduler - - # {condition_name: client_id} - who holds each lock - self._locks = {} - - # {condition_name: deque[(client_id, future)]} - FIFO queue waiting to acquire - self._acquire_queue = defaultdict(deque) - - # {condition_name: {waiter_id: (client_id, event)}} - clients in wait() - self._waiters = defaultdict(dict) - - # {client_id: set(condition_names)} - for cleanup on disconnect - self._client_conditions = defaultdict(set) + # name -> {client_id -> asyncio.Event} + self.waiters = defaultdict(dict) self.scheduler.handlers.update( { - "condition_acquire": self.acquire, - "condition_release": self.release, "condition_wait": self.wait, "condition_notify": self.notify, "condition_notify_all": self.notify_all, - "condition_locked": self.locked, } ) - self.scheduler.extensions["conditions"] = self - - def _track_client(self, name, client_id): - """Track that a client is using this condition""" - self._client_conditions[client_id].add(name) - - def _untrack_client(self, name, client_id): - """Stop tracking client for this condition""" - if client_id in self._client_conditions: - self._client_conditions[client_id].discard(name) - if not self._client_conditions[client_id]: - del self._client_conditions[client_id] - @log_errors - async def acquire(self, name=None, client_id=None): - """Acquire lock - blocks until available - - NOT re-entrant: if same client tries to acquire while holding, - it will block like any other client. - """ - self._track_client(name, client_id) - - if name not in self._locks: - # Lock is free - self._locks[name] = client_id - return True - - # Lock is held (even if by same client) - must wait - future = asyncio.Future() - self._acquire_queue[name].append((client_id, future)) - - try: - await future - return True - except asyncio.CancelledError: - # Remove from queue if cancelled - queue = self._acquire_queue.get(name, deque()) - try: - queue.remove((client_id, future)) - except ValueError: - pass # Already removed or processed - raise - - def _wake_next_acquirer(self, name): - """Wake the next client waiting to acquire this lock""" - queue = self._acquire_queue.get(name, deque()) - - while queue: - client_id, future = queue.popleft() - if not future.done(): - self._locks[name] = client_id - future.set_result(True) - return True - - # No waiters left, clean up queue - if name in self._acquire_queue: - del self._acquire_queue[name] - return False - - @log_errors - async def release(self, name=None, client_id=None): - """Release lock - - Raises RuntimeError if: - - Lock is not held - - Lock is held by a different client - """ - if name not in self._locks: - raise RuntimeError("release() called without holding the lock") - - if self._locks[name] != client_id: - raise RuntimeError("release() called without holding the lock") - - del self._locks[name] - - # Wake next waiter trying to acquire - if not self._wake_next_acquirer(name): - # No acquire waiters - cleanup if no notify waiters either - if name not in self._waiters or not self._waiters[name]: - self._untrack_client(name, client_id) - - @log_errors - async def wait(self, name=None, waiter_id=None, client_id=None, timeout=None): - """Release lock, wait for notify, reacquire lock - - This is atomic from the caller's perspective. The lock is physically - released to allow notifiers to proceed, but the waiter will reacquire - before returning. - - Returns: - - True if notified - - False if timeout occurred - """ - # Verify caller holds lock - if self._locks.get(name) != client_id: - raise RuntimeError("wait() called without holding the lock") - - # 1. Register for notification FIRST (prevents lost wakeup race) - notify_event = asyncio.Event() - self._waiters[name][waiter_id] = (client_id, notify_event) - - # 2. Release lock (allows notifier to proceed) - del self._locks[name] - self._wake_next_acquirer(name) + async def wait(self, name=None, client=None): + """Register a waiter and block until notified""" + # Create event for this specific waiter + event = asyncio.Event() + self.waiters[name][client] = event - # 3. Wait for notification - notified = False try: - if timeout is not None: - await wait_for(notify_event.wait(), timeout) - else: - await notify_event.wait() - notified = True - except (TimeoutError, asyncio.TimeoutError): - notified = False - except asyncio.CancelledError: - # On cancellation, still cleanup and don't reacquire - self._waiters[name].pop(waiter_id, None) - if not self._waiters[name]: - del self._waiters[name] - raise + # Block until notified + await event.wait() finally: - # Always cleanup waiter registration (except CancelledError which raises above) - if waiter_id in self._waiters.get(name, {}): - self._waiters[name].pop(waiter_id, None) - if not self._waiters[name]: - del self._waiters[name] - - # 4. Reacquire lock before returning (will block if others waiting) - await self.acquire(name=name, client_id=client_id) - - return notified - - @log_errors - def notify(self, name=None, client_id=None, n=1): - """Wake up n waiters (default 1) - - Must be called while holding the lock. - Returns number of waiters actually notified. - """ - if self._locks.get(name) != client_id: - raise RuntimeError("notify() called without holding the lock") - - waiters = self._waiters.get(name, {}) - count = 0 - - # Wake first n waiters - for waiter_id in list(waiters.keys())[:n]: - _, event = waiters[waiter_id] - if not event.is_set(): + # Cleanup after waking or cancellation + with suppress(KeyError): + del self.waiters[name][client] + if not self.waiters[name]: + del self.waiters[name] + + async def notify(self, name=None, n=1): + """Wake up n waiters""" + if name not in self.waiters: + return + + # Wake up to n waiters + notified = 0 + for client_id in list(self.waiters[name].keys()): + if notified >= n: + break + event = self.waiters[name].get(client_id) + if event and not event.is_set(): event.set() - count += 1 - - return count - - @log_errors - def notify_all(self, name=None, client_id=None): - """Wake up all waiters - - Must be called while holding the lock. - Returns number of waiters actually notified. - """ - if self._locks.get(name) != client_id: - raise RuntimeError("notify_all() called without holding the lock") + notified += 1 - waiters = self._waiters.get(name, {}) - count = 0 + async def notify_all(self, name=None): + """Wake up all waiters""" + if name not in self.waiters: + return - for _, event in waiters.values(): + for event in self.waiters[name].values(): if not event.is_set(): event.set() - count += 1 - return count - def locked(self, name=None, client_id=None): - """Check if this client holds the lock""" - return self._locks.get(name) == client_id - - async def remove_client(self, client): - """Cleanup when client disconnects""" - conditions = self._client_conditions.pop(client, set()) - - for name in conditions: - # Release any locks held by this client - if self._locks.get(name) == client: - try: - await self.release(name=name, client_id=client) - except Exception as e: - logger.warning(f"Error releasing lock for {name}: {e}") - - # Cancel acquire waiters from this client - queue = self._acquire_queue.get(name, deque()) - to_remove = [] - for i, (cid, future) in enumerate(queue): - if cid == client and not future.done(): - future.cancel() - to_remove.append(i) - # Remove in reverse to maintain indices - for i in reversed(to_remove): - try: - del queue[i] - except IndexError: - pass - - # Wake and cleanup notify waiters from this client - waiters = self._waiters.get(name, {}) - for waiter_id, (cid, event) in list(waiters.items()): - if cid == client: - event.set() # Wake them so they can cleanup - waiters.pop(waiter_id, None) - - -class Condition(SyncMethodMixin): +class Condition: """Distributed Condition Variable - Provides wait/notify synchronization across distributed clients. - Multiple Condition instances with the same name share state. + A distributed version of asyncio.Condition. Allows one or more clients + to wait until notified by another client. + + Like asyncio.Condition, this must be used with a lock. The lock is + released before waiting and reacquired afterwards. - The lock is NOT re-entrant. Attempting to acquire while holding - will block indefinitely. + Parameters + ---------- + name : str, optional + Name of the condition. If not provided, a random name is generated. + client : Client, optional + Client instance. If not provided, uses the default client. + lock : Lock, optional + Lock to use with this condition. If not provided, creates a new Lock. Examples -------- - >>> condition = Condition('data-ready') - >>> - >>> # Consumer - >>> async with condition: - ... while not data_available(): - ... await condition.wait() - ... process_data() - >>> - >>> # Producer - >>> async with condition: - ... produce_data() - ... condition.notify_all() - - Notes - ----- - Like threading.Condition, wait() atomically releases the lock and - waits for notification, then reacquires before returning. This means - the condition can change between being notified and reacquiring the lock, - so always use wait() in a while loop checking the actual condition. + >>> from distributed import Client, Condition + >>> client = Client() # doctest: +SKIP + >>> condition = Condition() # doctest: +SKIP + + >>> async with condition: # doctest: +SKIP + ... # Wait for some condition + ... await condition.wait() + + >>> async with condition: # doctest: +SKIP + ... # Modify shared state + ... condition.notify() # Wake one waiter """ - def __init__(self, name=None, client=None): - self.name = name or f"condition-{uuid.uuid4().hex}" - self._waiter_id = None # Created fresh for each wait() call + def __init__(self, name=None, client=None, lock=None): self._client = client + self.name = name or "condition-" + uuid.uuid4().hex + + if lock is None: + lock = Lock(client=client) + elif not isinstance(lock, Lock): + raise TypeError(f"lock must be a Lock, not {type(lock)}") + + self._lock = lock @property def client(self): @@ -311,57 +128,43 @@ def client(self): pass return self._client - @property - def _client_id(self): - """Use actual Dask client ID""" - if self.client: - return self.client.id - raise RuntimeError(f"{type(self).__name__} requires a connected client") - - @property - def loop(self): - return self.client.loop - def _verify_running(self): if not self.client: raise RuntimeError( - f"{type(self)} object not properly initialized. " - "This can happen if the object is being deserialized " - "outside of the context of a Client or Worker." + f"{type(self)} object not properly initialized. Ensure it's created within a Client context." ) - async def acquire(self): - """Acquire the lock + async def __aenter__(self): + await self.acquire() + return self - Blocks until the lock is available. NOT re-entrant. - """ + async def __aexit__(self, exc_type, exc, tb): + await self.release() + + def __repr__(self): + return f"" + + async def acquire(self, timeout=None): + """Acquire the underlying lock""" self._verify_running() - await self.client.scheduler.condition_acquire( - name=self.name, client_id=self._client_id - ) + return await self._lock.acquire(timeout=timeout) async def release(self): - """Release the lock - - Raises RuntimeError if not holding the lock. - """ + """Release the underlying lock""" self._verify_running() - await self.client.scheduler.condition_release( - name=self.name, client_id=self._client_id - ) + return await self._lock.release() - async def wait(self, timeout=None): - """Wait for notification + def locked(self): + """Return True if lock is held""" + return self._lock.locked() - Must be called while holding the lock. Atomically releases lock, - waits for notify(), then reacquires lock before returning. + async def wait(self, timeout=None): + """Wait until notified. - Because the lock is released and reacquired, the condition may have - changed by the time this returns. Always use in a while loop: + This method releases the underlying lock, waits until notified, + then reacquires the lock before returning. - >>> async with condition: - ... while not predicate(): - ... await condition.wait() + Must be called with the lock held. Parameters ---------- @@ -371,94 +174,97 @@ async def wait(self, timeout=None): Returns ------- bool - True if notified, False if timeout occurred + True if woken by notify, False if timeout occurred + + Raises + ------ + RuntimeError + If called without holding the lock """ self._verify_running() - timeout = parse_timedelta(timeout) - # Create fresh waiter_id for this wait() call - waiter_id = uuid.uuid4().hex - - result = await self.client.scheduler.condition_wait( - name=self.name, - waiter_id=waiter_id, - client_id=self._client_id, - timeout=timeout, - ) - - return result + # Release lock before waiting (will error if not held) + await self.release() - def notify(self, n=1): - """Wake up n waiters (default 1) + # Track if we need to reacquire (for proper cleanup on cancellation) + reacquired = False - Must be called while holding the lock. + try: + # Wait for notification from scheduler + if timeout is not None: + try: + await asyncio.wait_for( + self.client.scheduler.condition_wait( + name=self.name, client=self.client.id + ), + timeout=timeout, + ) + return True + except asyncio.TimeoutError: + return False + else: + await self.client.scheduler.condition_wait( + name=self.name, client=self.client.id + ) + return True + finally: + # CRITICAL: Always reacquire lock, even on cancellation + # This maintains the invariant that wait() returns with lock held + try: + await self.acquire() + reacquired = True + except asyncio.CancelledError: + # If reacquisition is cancelled, we're in a bad state + # Try again without allowing cancellation + if not reacquired: + with suppress(Exception): + await asyncio.shield(self.acquire()) + raise + + async def wait_for(self, predicate, timeout=None): + """Wait until a predicate becomes true. Parameters ---------- - n : int - Number of waiters to wake up + predicate : callable + Function that returns True when the condition is met + timeout : float, optional + Maximum time to wait in seconds Returns ------- - int - Number of waiters actually notified + bool + The predicate result (should be True unless timeout) """ - self._verify_running() - return self.client.sync( - self.client.scheduler.condition_notify, - name=self.name, - client_id=self._client_id, - n=n, - ) - - def notify_all(self): - """Wake up all waiters + result = predicate() + while not result: + if timeout is not None: + # Consume timeout across multiple waits + import time + + start = time.time() + if not await self.wait(timeout=timeout): + return predicate() + timeout -= time.time() - start + if timeout <= 0: + return predicate() + else: + await self.wait() + result = predicate() + return result - Must be called while holding the lock. + async def notify(self, n=1): + """Wake up n waiters (default: 1) - Returns - ------- - int - Number of waiters actually notified + Parameters + ---------- + n : int + Number of waiters to wake up """ self._verify_running() - return self.client.sync( - self.client.scheduler.condition_notify_all, - name=self.name, - client_id=self._client_id, - ) - - def locked(self): - """Check if this client holds the lock + await self.client.scheduler.condition_notify(name=self.name, n=n) - Returns - ------- - bool - True if this client currently holds the lock - """ + async def notify_all(self): + """Wake up all waiters""" self._verify_running() - return self.client.sync( - self.client.scheduler.condition_locked, - name=self.name, - client_id=self._client_id, - ) - - async def __aenter__(self): - await self.acquire() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.release() - return False - - def __enter__(self): - return self.sync(self.__aenter__) - - def __exit__(self, exc_type, exc_val, exc_tb): - self.sync(self.__aexit__, exc_type, exc_val, exc_tb) - - def __repr__(self): - return f"" - - def __reduce__(self): - return (Condition, (self.name,)) + await self.client.scheduler.condition_notify_all(name=self.name) diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index 04a9b5ca5b..94c6f41406 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -1,37 +1,18 @@ -import asyncio +from __future__ import annotations +import asyncio import pytest -from distributed import Condition +from distributed import Condition, Lock, Client +from distributed.utils_test import gen_cluster, inc from distributed.metrics import time -from distributed.utils_test import gen_cluster @gen_cluster(client=True) -async def test_condition_acquire_release(c, s, a, b): - """Test basic lock acquire/release""" - condition = Condition("test-lock") - assert not condition.locked() - await condition.acquire() - assert condition.locked() - await condition.release() - assert not condition.locked() - +async def test_condition_basic(c, s, a, b): + """Basic condition wait and notify""" + condition = Condition() -@gen_cluster(client=True) -async def test_condition_context_manager(c, s, a, b): - """Test context manager interface""" - condition = Condition("test-context") - assert not condition.locked() - async with condition: - assert condition.locked() - assert not condition.locked() - - -@gen_cluster(client=True) -async def test_condition_wait_notify(c, s, a, b): - """Test basic wait/notify""" - condition = Condition("test-notify") results = [] async def waiter(): @@ -40,345 +21,416 @@ async def waiter(): await condition.wait() results.append("notified") - async def notifier(): - await asyncio.sleep(0.2) - async with condition: - results.append("notifying") - condition.notify() + # Start waiter task + waiter_task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) # Let waiter acquire lock and start waiting + + assert results == ["waiting"] + + # Notify the waiter + async with condition: + await condition.notify() - await asyncio.gather(waiter(), notifier()) - assert results == ["waiting", "notifying", "notified"] + await waiter_task + assert results == ["waiting", "notified"] @gen_cluster(client=True) -async def test_condition_notify_all(c, s, a, b): - """Test notify_all wakes all waiters""" - condition = Condition("test-notify-all") +async def test_condition_notify_one(c, s, a, b): + """notify() wakes only one waiter""" + condition = Condition() + results = [] - async def waiter(i): + async def waiter(n): async with condition: await condition.wait() - results.append(i) + results.append(n) - async def notifier(): - await asyncio.sleep(0.2) - async with condition: - condition.notify_all() + # Start multiple waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(3)] + await asyncio.sleep(0.1) - await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) - assert sorted(results) == [1, 2, 3] + # Notify once - only one should wake + async with condition: + await condition.notify() + + await asyncio.sleep(0.1) + assert len(results) == 1 + + # Notify again + async with condition: + await condition.notify() + + await asyncio.sleep(0.1) + assert len(results) == 2 + + # Cleanup remaining + async with condition: + await condition.notify_all() + + await asyncio.gather(*tasks) + assert len(results) == 3 @gen_cluster(client=True) async def test_condition_notify_n(c, s, a, b): - """Test notify with specific count""" - condition = Condition("test-notify-n") + """notify(n) wakes exactly n waiters""" + condition = Condition() + results = [] - async def waiter(i): + async def waiter(n): async with condition: await condition.wait() - results.append(i) + results.append(n) - async def notifier(): - await asyncio.sleep(0.2) - async with condition: - condition.notify(n=2) # Wake only 2 waiters - await asyncio.sleep(0.2) + # Start 5 waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(5)] + await asyncio.sleep(0.1) + + # Notify 2 + async with condition: + await condition.notify(2) + + await asyncio.sleep(0.1) + assert len(results) == 2 + + # Notify 3 more + async with condition: + await condition.notify(3) + + await asyncio.gather(*tasks) + assert len(results) == 5 + + +@gen_cluster(client=True) +async def test_condition_notify_all(c, s, a, b): + """notify_all() wakes all waiters""" + condition = Condition() + + results = [] + + async def waiter(n): async with condition: - condition.notify() # Wake remaining waiter + await condition.wait() + results.append(n) - await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) - assert sorted(results) == [1, 2, 3] + # Start multiple waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(10)] + await asyncio.sleep(0.1) + + # Wake all at once + async with condition: + await condition.notify_all() + + await asyncio.gather(*tasks) + assert len(results) == 10 @gen_cluster(client=True) async def test_condition_wait_timeout(c, s, a, b): - """Test wait with timeout""" - condition = Condition("test-timeout") + """wait() with timeout returns False if not notified""" + condition = Condition() - start = time() async with condition: - result = await condition.wait(timeout=0.5) - elapsed = time() - start + start = time() + result = await condition.wait(timeout=0.2) + elapsed = time() - start assert result is False - assert 0.4 < elapsed < 0.7 + assert 0.15 < elapsed < 0.5 @gen_cluster(client=True) async def test_condition_wait_timeout_then_notify(c, s, a, b): - """Test that timeout doesn't prevent subsequent notifications""" - condition = Condition("test-timeout-notify") - results = [] - - async def waiter(): - async with condition: - result = await condition.wait(timeout=0.2) - results.append(f"timeout: {result}") - async with condition: - result = await condition.wait() - results.append(f"notified: {result}") + """wait() with timeout returns True if notified before timeout""" + condition = Condition() async def notifier(): - await asyncio.sleep(0.5) + await asyncio.sleep(0.1) async with condition: - condition.notify() + await condition.notify() + + notifier_task = asyncio.create_task(notifier()) + + async with condition: + result = await condition.wait(timeout=1.0) + + await notifier_task + assert result is True + + +@gen_cluster(client=True) +async def test_condition_wait_for(c, s, a, b): + """wait_for() waits until predicate is true""" + condition = Condition() + state = {"value": 0} + + async def incrementer(): + for i in range(5): + await asyncio.sleep(0.05) + async with condition: + state["value"] += 1 + await condition.notify_all() + + inc_task = asyncio.create_task(incrementer()) + + async with condition: + result = await condition.wait_for(lambda: state["value"] >= 3) + + await inc_task + assert result is True + assert state["value"] >= 3 + + +@gen_cluster(client=True) +async def test_condition_wait_for_timeout(c, s, a, b): + """wait_for() returns predicate result on timeout""" + condition = Condition() + + async with condition: + result = await condition.wait_for(lambda: False, timeout=0.2) + + assert result is False + + +@gen_cluster(client=True) +async def test_condition_context_manager(c, s, a, b): + """Condition works as async context manager""" + condition = Condition() - await asyncio.gather(waiter(), notifier()) - assert results == ["timeout: False", "notified: True"] + assert not condition.locked() + + async with condition: + assert condition.locked() + + assert not condition.locked() @gen_cluster(client=True) -async def test_condition_error_without_lock(c, s, a, b): - """Test errors when calling wait/notify without holding lock""" - condition = Condition("test-error") +async def test_condition_with_explicit_lock(c, s, a, b): + """Condition can use an explicit Lock""" + lock = Lock() + condition = Condition(lock=lock) - with pytest.raises(RuntimeError, match="without holding the lock"): - await condition.wait() + async with lock: + assert condition.locked() + + assert not condition.locked() - with pytest.raises(RuntimeError, match="without holding the lock"): - condition.notify() - with pytest.raises(RuntimeError, match="without holding the lock"): - condition.notify_all() +@gen_cluster(client=True) +async def test_condition_multiple_notify_calls(c, s, a, b): + """Multiple notify calls work correctly""" + condition = Condition() + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + # Start 3 waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(3)] + await asyncio.sleep(0.1) + + # Notify one at a time + for _ in range(3): + async with condition: + await condition.notify() + await asyncio.sleep(0.05) + + await asyncio.gather(*tasks) + assert set(results) == {0, 1, 2} @gen_cluster(client=True) -async def test_condition_error_release_without_acquire(c, s, a, b): - """Test error when releasing without acquiring""" - condition = Condition("test-release-error") +async def test_condition_notify_without_waiters(c, s, a, b): + """notify() with no waiters is a no-op""" + condition = Condition() - with pytest.raises(RuntimeError, match="without holding the lock"): - await condition.release() + async with condition: + await condition.notify() + await condition.notify_all() + await condition.notify(5) + + # Should not raise or hang @gen_cluster(client=True) async def test_condition_producer_consumer(c, s, a, b): - """Test classic producer-consumer pattern""" - condition = Condition("prod-cons") + """Producer-consumer pattern with Condition""" + condition = Condition() queue = [] async def producer(): for i in range(5): - await asyncio.sleep(0.1) + await asyncio.sleep(0.05) async with condition: queue.append(i) - condition.notify() + await condition.notify() async def consumer(): - results = [] + items = [] for _ in range(5): async with condition: - while not queue: - await condition.wait() - results.append(queue.pop(0)) - return results + await condition.wait_for(lambda: len(queue) > 0) + items.append(queue.pop(0)) + return items prod_task = asyncio.create_task(producer()) cons_task = asyncio.create_task(consumer()) + result = await cons_task await prod_task - results = await cons_task - assert results == [0, 1, 2, 3, 4] + + assert result == [0, 1, 2, 3, 4] @gen_cluster(client=True) -async def test_condition_multiple_producers_consumers(c, s, a, b): - """Test multiple producers and consumers""" - condition = Condition("multi-prod-cons") - queue = [] +async def test_condition_same_name(c, s, a, b): + """Multiple Condition instances with same name share state""" + cond1 = Condition(name="shared") + cond2 = Condition(name="shared") - async def producer(start): - for i in range(start, start + 3): - await asyncio.sleep(0.05) - async with condition: - queue.append(i) - condition.notify() + result = [] - async def consumer(): - results = [] - for _ in range(3): - async with condition: - while not queue: - await condition.wait() - results.append(queue.pop(0)) - return results + async def waiter(): + async with cond1: + await cond1.wait() + result.append("done") + + waiter_task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + + # Notify via different instance + async with cond2: + await cond2.notify() - results = await asyncio.gather(producer(0), producer(10), consumer(), consumer()) - # Last two results are from consumers - consumed = results[2] + results[3] - assert sorted(consumed) == [0, 1, 2, 10, 11, 12] + await waiter_task + assert result == ["done"] @gen_cluster(client=True) -async def test_condition_same_name_different_instances(c, s, a, b): - """Test that multiple instances with same name share state""" - name = "shared-condition" - cond1 = Condition(name) - cond2 = Condition(name) - results = [] +async def test_condition_repr(c, s, a, b): + """Condition has readable repr""" + condition = Condition(name="test-cond") + assert "test-cond" in repr(condition) - async def waiter(): - async with cond1: - results.append("waiting") - await cond1.wait() - results.append("notified") - async def notifier(): - await asyncio.sleep(0.2) - async with cond2: - results.append("notifying") - cond2.notify() +@gen_cluster(client=True) +async def test_condition_waiter_cancelled(c, s, a, b): + """Cancelled waiter properly cleans up""" + condition = Condition() - await asyncio.gather(waiter(), notifier()) - assert results == ["waiting", "notifying", "notified"] + async def waiter(): + async with condition: + await condition.wait() + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) -@gen_cluster(client=True) -async def test_condition_unique_names_independent(c, s, a, b): - """Test conditions with different names are independent""" - cond1 = Condition("cond-1") - cond2 = Condition("cond-2") + # Cancel the waiter + task.cancel() - async with cond1: - assert cond1.locked() - assert not cond2.locked() + with pytest.raises(asyncio.CancelledError): + await task - async with cond2: - assert not cond1.locked() - assert cond2.locked() + # Should not deadlock - lock should be released + async with condition: + await condition.notify() # No-op since no waiters -@gen_cluster(client=True) -async def test_condition_barrier_pattern(c, s, a, b): - """Test barrier synchronization pattern""" - condition = Condition("barrier") - arrived = [] - n_workers = 3 +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_condition_across_workers(c, s, a): + """Condition works across different worker tasks""" + condition = Condition(name="cross-worker") - async def worker(i): - async with condition: - arrived.append(i) - if len(arrived) < n_workers: - await condition.wait() - else: - condition.notify_all() - return f"worker-{i}-done" + def worker_wait(): + import asyncio + from distributed import Condition - results = await asyncio.gather(worker(0), worker(1), worker(2)) - assert sorted(results) == ["worker-0-done", "worker-1-done", "worker-2-done"] - assert len(arrived) == 3 + async def wait_task(): + cond = Condition(name="cross-worker") + async with cond: + await cond.wait() + return "notified" + return asyncio.run(wait_task()) -def test_condition_sync_interface(client): - """Test synchronous interface via SyncMethodMixin""" - condition = Condition("sync-test") - results = [] + # Submit waiter to worker + future = c.submit(worker_wait, pure=False) + await asyncio.sleep(0.2) - def worker(): - with condition: - results.append("locked") - results.append("released") + # Notify from client + async with condition: + await condition.notify() - worker() - assert results == ["locked", "released"] + result = await future + assert result == "notified" @gen_cluster(client=True) -async def test_condition_multiple_notify_calls(c, s, a, b): - """Test multiple notify calls in sequence""" - condition = Condition("multi-notify") - results = [] +async def test_condition_locked_status(c, s, a, b): + """locked() returns correct status""" + condition = Condition() - async def waiter(i): - async with condition: - await condition.wait() - results.append(i) + assert not condition.locked() - async def notifier(): - await asyncio.sleep(0.2) - async with condition: - condition.notify() - await asyncio.sleep(0.1) - async with condition: - condition.notify() - await asyncio.sleep(0.1) - async with condition: - condition.notify() + await condition.acquire() + assert condition.locked() - await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) - assert sorted(results) == [1, 2, 3] + await condition.release() + assert not condition.locked() @gen_cluster(client=True) -async def test_condition_predicate_loop(c, s, a, b): - """Test typical predicate-based wait loop pattern""" - condition = Condition("predicate") - state = {"value": 0, "target": 5} +async def test_condition_reacquire_after_wait(c, s, a, b): + """wait() reacquires lock after being notified""" + condition = Condition() + lock_states = [] async def waiter(): async with condition: - while state["value"] < state["target"]: - await condition.wait() - return state["value"] + lock_states.append(("before_wait", condition.locked())) + await condition.wait() + lock_states.append(("after_wait", condition.locked())) - async def updater(): - for i in range(1, 6): - await asyncio.sleep(0.1) - async with condition: - state["value"] = i - condition.notify_all() + waiter_task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) - result, _ = await asyncio.gather(waiter(), updater()) - assert result == 5 + async with condition: + lock_states.append(("notifier", condition.locked())) + await condition.notify() + await waiter_task -@gen_cluster(client=True) -async def test_condition_repr(c, s, a, b): - """Test string representation""" - condition = Condition("test-repr") - assert "test-repr" in repr(condition) - assert "Condition" in repr(condition) + # Waiter should hold lock before and after wait + assert lock_states[0] == ("before_wait", True) + assert lock_states[1] == ("notifier", True) + assert lock_states[2] == ("after_wait", True) @gen_cluster(client=True) -async def test_condition_not_reentrant(c, s, a, b): - """Test that lock is NOT re-entrant within same async task""" - cond1 = Condition("not-reentrant") - cond2 = Condition("not-reentrant") # Same name = same lock +async def test_condition_many_waiters(c, s, a, b): + """Handle many waiters efficiently""" + condition = Condition() results = [] - async def holder(): - await cond1.acquire() - results.append("cond1-acquired") - await asyncio.sleep(0.5) # Hold lock - await cond1.release() - results.append("cond1-released") - - async def waiter(): - await asyncio.sleep(0.1) # Let holder acquire first - results.append("cond2-attempting") - # This should block until holder releases - await cond2.acquire() - results.append("cond2-acquired") - await cond2.release() - - await asyncio.gather(holder(), waiter()) - - # Verify order: holder acquires, waiter attempts, holder releases, waiter acquires - assert results[0] == "cond1-acquired" - assert results[1] == "cond2-attempting" - assert results[2] == "cond1-released" - assert results[3] == "cond2-acquired" + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + # Start many waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(50)] + await asyncio.sleep(0.2) -@gen_cluster(client=True) -async def test_condition_multiple_instances_share_client_id(c, s, a, b): - """Test that multiple Condition instances in same client share client_id""" - cond1 = Condition("test-1") - cond2 = Condition("test-2") + # Wake them all + async with condition: + await condition.notify_all() - # Both should have same client ID (the client 'c') - assert cond1._client_id == cond2._client_id == c.id + await asyncio.gather(*tasks) + assert len(results) == 50 + assert set(results) == set(range(50)) From 055fb8ba9d5058d084e8c6d1e23bf993159f5238 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Sun, 4 Jan 2026 20:37:42 -0800 Subject: [PATCH 23/23] Update condition.py,test_condition.py --- distributed/condition.py | 1 - distributed/tests/test_condition.py | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index 29315ebae0..7df6020a17 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -6,7 +6,6 @@ from collections import defaultdict from contextlib import suppress -from distributed.client import Future from distributed.lock import Lock from distributed.utils import log_errors from distributed.worker import get_client diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index 94c6f41406..3e57d47a2d 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -1,11 +1,12 @@ from __future__ import annotations import asyncio + import pytest -from distributed import Condition, Lock, Client -from distributed.utils_test import gen_cluster, inc +from distributed import Condition, Lock from distributed.metrics import time +from distributed.utils_test import gen_cluster @gen_cluster(client=True) @@ -168,7 +169,7 @@ async def test_condition_wait_for(c, s, a, b): state = {"value": 0} async def incrementer(): - for i in range(5): + for _i in range(5): await asyncio.sleep(0.05) async with condition: state["value"] += 1 @@ -349,6 +350,7 @@ async def test_condition_across_workers(c, s, a): def worker_wait(): import asyncio + from distributed import Condition async def wait_task():