diff --git a/continuous_integration/scripts/test_report.py b/continuous_integration/scripts/test_report.py index 1f749a78fd..1f339ed7ef 100644 --- a/continuous_integration/scripts/test_report.py +++ b/continuous_integration/scripts/test_report.py @@ -489,7 +489,7 @@ def main(argv: list[str] | None = None) -> None: total.groupby([total.file, total.test]) .filter(lambda g: (g.status == "x").sum() >= args.nfails) .reset_index() - .assign(test=lambda df: df.file + "." + df.test) # type: ignore + .assign(test=lambda df: df.file + "." + df.test) .groupby("test") ) overall = {name: grouped.get_group(name) for name in grouped.groups} diff --git a/distributed/__init__.py b/distributed/__init__.py index 3f075b977c..091c14e0eb 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -146,3 +146,4 @@ "widgets", "worker_client", ] +from distributed.condition import Condition diff --git a/distributed/condition.py b/distributed/condition.py new file mode 100644 index 0000000000..9bb3f49900 --- /dev/null +++ b/distributed/condition.py @@ -0,0 +1,464 @@ +from __future__ import annotations + +import asyncio +import logging +import uuid +from collections import defaultdict, deque + +from dask.utils import parse_timedelta + +from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for +from distributed.worker import get_client + +logger = logging.getLogger(__name__) + + +class ConditionExtension: + """Scheduler extension managing Condition lock and notifications + + A Condition provides wait/notify synchronization across distributed clients. + The lock is NOT re-entrant - attempting to acquire while holding will block. + + State managed: + - _locks: Which client holds which condition's lock + - _acquire_queue: Clients waiting to acquire lock (FIFO) + - _waiters: Clients in wait() (released lock, awaiting notify) + - _client_conditions: Reverse index for cleanup on disconnect + """ + + def __init__(self, scheduler): + self.scheduler = scheduler + + # {condition_name: client_id} - who holds each lock + self._locks = {} + + # {condition_name: deque[(client_id, future)]} - FIFO queue waiting to acquire + self._acquire_queue = defaultdict(deque) + + # {condition_name: {waiter_id: (client_id, event)}} - clients in wait() + self._waiters = defaultdict(dict) + + # {client_id: set(condition_names)} - for cleanup on disconnect + self._client_conditions = defaultdict(set) + + self.scheduler.handlers.update( + { + "condition_acquire": self.acquire, + "condition_release": self.release, + "condition_wait": self.wait, + "condition_notify": self.notify, + "condition_notify_all": self.notify_all, + "condition_locked": self.locked, + } + ) + + self.scheduler.extensions["conditions"] = self + + def _track_client(self, name, client_id): + """Track that a client is using this condition""" + self._client_conditions[client_id].add(name) + + def _untrack_client(self, name, client_id): + """Stop tracking client for this condition""" + if client_id in self._client_conditions: + self._client_conditions[client_id].discard(name) + if not self._client_conditions[client_id]: + del self._client_conditions[client_id] + + @log_errors + async def acquire(self, name=None, client_id=None): + """Acquire lock - blocks until available + + NOT re-entrant: if same client tries to acquire while holding, + it will block like any other client. + """ + self._track_client(name, client_id) + + if name not in self._locks: + # Lock is free + self._locks[name] = client_id + return True + + # Lock is held (even if by same client) - must wait + future = asyncio.Future() + self._acquire_queue[name].append((client_id, future)) + + try: + await future + return True + except asyncio.CancelledError: + # Remove from queue if cancelled + queue = self._acquire_queue.get(name, deque()) + try: + queue.remove((client_id, future)) + except ValueError: + pass # Already removed or processed + raise + + def _wake_next_acquirer(self, name): + """Wake the next client waiting to acquire this lock""" + queue = self._acquire_queue.get(name, deque()) + + while queue: + client_id, future = queue.popleft() + if not future.done(): + self._locks[name] = client_id + future.set_result(True) + return True + + # No waiters left, clean up queue + if name in self._acquire_queue: + del self._acquire_queue[name] + return False + + @log_errors + async def release(self, name=None, client_id=None): + """Release lock + + Raises RuntimeError if: + - Lock is not held + - Lock is held by a different client + """ + if name not in self._locks: + raise RuntimeError("release() called without holding the lock") + + if self._locks[name] != client_id: + raise RuntimeError("release() called without holding the lock") + + del self._locks[name] + + # Wake next waiter trying to acquire + if not self._wake_next_acquirer(name): + # No acquire waiters - cleanup if no notify waiters either + if name not in self._waiters or not self._waiters[name]: + self._untrack_client(name, client_id) + + @log_errors + async def wait(self, name=None, waiter_id=None, client_id=None, timeout=None): + """Release lock, wait for notify, reacquire lock + + This is atomic from the caller's perspective. The lock is physically + released to allow notifiers to proceed, but the waiter will reacquire + before returning. + + Returns: + - True if notified + - False if timeout occurred + """ + # Verify caller holds lock + if self._locks.get(name) != client_id: + raise RuntimeError("wait() called without holding the lock") + + # 1. Register for notification FIRST (prevents lost wakeup race) + notify_event = asyncio.Event() + self._waiters[name][waiter_id] = (client_id, notify_event) + + # 2. Release lock (allows notifier to proceed) + del self._locks[name] + self._wake_next_acquirer(name) + + # 3. Wait for notification + notified = False + try: + if timeout is not None: + await wait_for(notify_event.wait(), timeout) + else: + await notify_event.wait() + notified = True + except (TimeoutError, asyncio.TimeoutError): + notified = False + except asyncio.CancelledError: + # On cancellation, still cleanup and don't reacquire + self._waiters[name].pop(waiter_id, None) + if not self._waiters[name]: + del self._waiters[name] + raise + finally: + # Always cleanup waiter registration (except CancelledError which raises above) + if waiter_id in self._waiters.get(name, {}): + self._waiters[name].pop(waiter_id, None) + if not self._waiters[name]: + del self._waiters[name] + + # 4. Reacquire lock before returning (will block if others waiting) + await self.acquire(name=name, client_id=client_id) + + return notified + + @log_errors + def notify(self, name=None, client_id=None, n=1): + """Wake up n waiters (default 1) + + Must be called while holding the lock. + Returns number of waiters actually notified. + """ + if self._locks.get(name) != client_id: + raise RuntimeError("notify() called without holding the lock") + + waiters = self._waiters.get(name, {}) + count = 0 + + # Wake first n waiters + for waiter_id in list(waiters.keys())[:n]: + _, event = waiters[waiter_id] + if not event.is_set(): + event.set() + count += 1 + + return count + + @log_errors + def notify_all(self, name=None, client_id=None): + """Wake up all waiters + + Must be called while holding the lock. + Returns number of waiters actually notified. + """ + if self._locks.get(name) != client_id: + raise RuntimeError("notify_all() called without holding the lock") + + waiters = self._waiters.get(name, {}) + count = 0 + + for _, event in waiters.values(): + if not event.is_set(): + event.set() + count += 1 + + return count + + def locked(self, name=None, client_id=None): + """Check if this client holds the lock""" + return self._locks.get(name) == client_id + + async def remove_client(self, client): + """Cleanup when client disconnects""" + conditions = self._client_conditions.pop(client, set()) + + for name in conditions: + # Release any locks held by this client + if self._locks.get(name) == client: + try: + await self.release(name=name, client_id=client) + except Exception as e: + logger.warning(f"Error releasing lock for {name}: {e}") + + # Cancel acquire waiters from this client + queue = self._acquire_queue.get(name, deque()) + to_remove = [] + for i, (cid, future) in enumerate(queue): + if cid == client and not future.done(): + future.cancel() + to_remove.append(i) + # Remove in reverse to maintain indices + for i in reversed(to_remove): + try: + del queue[i] + except IndexError: + pass + + # Wake and cleanup notify waiters from this client + waiters = self._waiters.get(name, {}) + for waiter_id, (cid, event) in list(waiters.items()): + if cid == client: + event.set() # Wake them so they can cleanup + waiters.pop(waiter_id, None) + + +class Condition(SyncMethodMixin): + """Distributed Condition Variable + + Provides wait/notify synchronization across distributed clients. + Multiple Condition instances with the same name share state. + + The lock is NOT re-entrant. Attempting to acquire while holding + will block indefinitely. + + Examples + -------- + >>> condition = Condition('data-ready') + >>> + >>> # Consumer + >>> async with condition: + ... while not data_available(): + ... await condition.wait() + ... process_data() + >>> + >>> # Producer + >>> async with condition: + ... produce_data() + ... condition.notify_all() + + Notes + ----- + Like threading.Condition, wait() atomically releases the lock and + waits for notification, then reacquires before returning. This means + the condition can change between being notified and reacquiring the lock, + so always use wait() in a while loop checking the actual condition. + """ + + def __init__(self, name=None, client=None): + self.name = name or f"condition-{uuid.uuid4().hex}" + self._waiter_id = None # Created fresh for each wait() call + self._client = client + + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + + @property + def _client_id(self): + """Use actual Dask client ID""" + if self.client: + return self.client.id + raise RuntimeError(f"{type(self).__name__} requires a connected client") + + @property + def loop(self): + return self.client.loop + + def _verify_running(self): + if not self.client: + raise RuntimeError( + f"{type(self)} object not properly initialized. " + "This can happen if the object is being deserialized " + "outside of the context of a Client or Worker." + ) + + async def acquire(self): + """Acquire the lock + + Blocks until the lock is available. NOT re-entrant. + """ + self._verify_running() + await self.client.scheduler.condition_acquire( + name=self.name, client_id=self._client_id + ) + + async def release(self): + """Release the lock + + Raises RuntimeError if not holding the lock. + """ + self._verify_running() + await self.client.scheduler.condition_release( + name=self.name, client_id=self._client_id + ) + + async def wait(self, timeout=None): + """Wait for notification + + Must be called while holding the lock. Atomically releases lock, + waits for notify(), then reacquires lock before returning. + + Because the lock is released and reacquired, the condition may have + changed by the time this returns. Always use in a while loop: + + >>> async with condition: + ... while not predicate(): + ... await condition.wait() + + Parameters + ---------- + timeout : float, optional + Maximum time to wait in seconds. If None, wait indefinitely. + + Returns + ------- + bool + True if notified, False if timeout occurred + """ + self._verify_running() + timeout = parse_timedelta(timeout) + + # Create fresh waiter_id for this wait() call + waiter_id = uuid.uuid4().hex + + result = await self.client.scheduler.condition_wait( + name=self.name, + waiter_id=waiter_id, + client_id=self._client_id, + timeout=timeout, + ) + + return result + + def notify(self, n=1): + """Wake up n waiters (default 1) + + Must be called while holding the lock. + + Parameters + ---------- + n : int + Number of waiters to wake up + + Returns + ------- + int + Number of waiters actually notified + """ + self._verify_running() + return self.client.sync( + self.client.scheduler.condition_notify, + name=self.name, + client_id=self._client_id, + n=n, + ) + + def notify_all(self): + """Wake up all waiters + + Must be called while holding the lock. + + Returns + ------- + int + Number of waiters actually notified + """ + self._verify_running() + return self.client.sync( + self.client.scheduler.condition_notify_all, + name=self.name, + client_id=self._client_id, + ) + + def locked(self): + """Check if this client holds the lock + + Returns + ------- + bool + True if this client currently holds the lock + """ + self._verify_running() + return self.client.sync( + self.client.scheduler.condition_locked, + name=self.name, + client_id=self._client_id, + ) + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.release() + return False + + def __enter__(self): + return self.sync(self.__aenter__) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.sync(self.__aexit__, exc_type, exc_val, exc_tb) + + def __repr__(self): + return f"" + + def __reduce__(self): + return (Condition, (self.name,)) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ea5775aea6..07951263f4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -97,6 +97,7 @@ ) from distributed.comm.addressing import addresses_from_user_args from distributed.compatibility import PeriodicCallback +from distributed.condition import ConditionExtension from distributed.core import ( ErrorMessage, OKMessage, @@ -195,6 +196,7 @@ "semaphores": SemaphoreExtension, "events": EventExtension, "amm": ActiveMemoryManagerExtension, + "conditions": ConditionExtension, "memory_sampler": MemorySamplerExtension, "shuffle": ShuffleSchedulerPlugin, "spans": SpansSchedulerExtension, diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py new file mode 100644 index 0000000000..04a9b5ca5b --- /dev/null +++ b/distributed/tests/test_condition.py @@ -0,0 +1,384 @@ +import asyncio + +import pytest + +from distributed import Condition +from distributed.metrics import time +from distributed.utils_test import gen_cluster + + +@gen_cluster(client=True) +async def test_condition_acquire_release(c, s, a, b): + """Test basic lock acquire/release""" + condition = Condition("test-lock") + assert not condition.locked() + await condition.acquire() + assert condition.locked() + await condition.release() + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_context_manager(c, s, a, b): + """Test context manager interface""" + condition = Condition("test-context") + assert not condition.locked() + async with condition: + assert condition.locked() + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_wait_notify(c, s, a, b): + """Test basic wait/notify""" + condition = Condition("test-notify") + results = [] + + async def waiter(): + async with condition: + results.append("waiting") + await condition.wait() + results.append("notified") + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + results.append("notifying") + condition.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["waiting", "notifying", "notified"] + + +@gen_cluster(client=True) +async def test_condition_notify_all(c, s, a, b): + """Test notify_all wakes all waiters""" + condition = Condition("test-notify-all") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify_all() + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_notify_n(c, s, a, b): + """Test notify with specific count""" + condition = Condition("test-notify-n") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify(n=2) # Wake only 2 waiters + await asyncio.sleep(0.2) + async with condition: + condition.notify() # Wake remaining waiter + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_wait_timeout(c, s, a, b): + """Test wait with timeout""" + condition = Condition("test-timeout") + + start = time() + async with condition: + result = await condition.wait(timeout=0.5) + elapsed = time() - start + + assert result is False + assert 0.4 < elapsed < 0.7 + + +@gen_cluster(client=True) +async def test_condition_wait_timeout_then_notify(c, s, a, b): + """Test that timeout doesn't prevent subsequent notifications""" + condition = Condition("test-timeout-notify") + results = [] + + async def waiter(): + async with condition: + result = await condition.wait(timeout=0.2) + results.append(f"timeout: {result}") + async with condition: + result = await condition.wait() + results.append(f"notified: {result}") + + async def notifier(): + await asyncio.sleep(0.5) + async with condition: + condition.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["timeout: False", "notified: True"] + + +@gen_cluster(client=True) +async def test_condition_error_without_lock(c, s, a, b): + """Test errors when calling wait/notify without holding lock""" + condition = Condition("test-error") + + with pytest.raises(RuntimeError, match="without holding the lock"): + await condition.wait() + + with pytest.raises(RuntimeError, match="without holding the lock"): + condition.notify() + + with pytest.raises(RuntimeError, match="without holding the lock"): + condition.notify_all() + + +@gen_cluster(client=True) +async def test_condition_error_release_without_acquire(c, s, a, b): + """Test error when releasing without acquiring""" + condition = Condition("test-release-error") + + with pytest.raises(RuntimeError, match="without holding the lock"): + await condition.release() + + +@gen_cluster(client=True) +async def test_condition_producer_consumer(c, s, a, b): + """Test classic producer-consumer pattern""" + condition = Condition("prod-cons") + queue = [] + + async def producer(): + for i in range(5): + await asyncio.sleep(0.1) + async with condition: + queue.append(i) + condition.notify() + + async def consumer(): + results = [] + for _ in range(5): + async with condition: + while not queue: + await condition.wait() + results.append(queue.pop(0)) + return results + + prod_task = asyncio.create_task(producer()) + cons_task = asyncio.create_task(consumer()) + + await prod_task + results = await cons_task + assert results == [0, 1, 2, 3, 4] + + +@gen_cluster(client=True) +async def test_condition_multiple_producers_consumers(c, s, a, b): + """Test multiple producers and consumers""" + condition = Condition("multi-prod-cons") + queue = [] + + async def producer(start): + for i in range(start, start + 3): + await asyncio.sleep(0.05) + async with condition: + queue.append(i) + condition.notify() + + async def consumer(): + results = [] + for _ in range(3): + async with condition: + while not queue: + await condition.wait() + results.append(queue.pop(0)) + return results + + results = await asyncio.gather(producer(0), producer(10), consumer(), consumer()) + # Last two results are from consumers + consumed = results[2] + results[3] + assert sorted(consumed) == [0, 1, 2, 10, 11, 12] + + +@gen_cluster(client=True) +async def test_condition_same_name_different_instances(c, s, a, b): + """Test that multiple instances with same name share state""" + name = "shared-condition" + cond1 = Condition(name) + cond2 = Condition(name) + results = [] + + async def waiter(): + async with cond1: + results.append("waiting") + await cond1.wait() + results.append("notified") + + async def notifier(): + await asyncio.sleep(0.2) + async with cond2: + results.append("notifying") + cond2.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["waiting", "notifying", "notified"] + + +@gen_cluster(client=True) +async def test_condition_unique_names_independent(c, s, a, b): + """Test conditions with different names are independent""" + cond1 = Condition("cond-1") + cond2 = Condition("cond-2") + + async with cond1: + assert cond1.locked() + assert not cond2.locked() + + async with cond2: + assert not cond1.locked() + assert cond2.locked() + + +@gen_cluster(client=True) +async def test_condition_barrier_pattern(c, s, a, b): + """Test barrier synchronization pattern""" + condition = Condition("barrier") + arrived = [] + n_workers = 3 + + async def worker(i): + async with condition: + arrived.append(i) + if len(arrived) < n_workers: + await condition.wait() + else: + condition.notify_all() + return f"worker-{i}-done" + + results = await asyncio.gather(worker(0), worker(1), worker(2)) + assert sorted(results) == ["worker-0-done", "worker-1-done", "worker-2-done"] + assert len(arrived) == 3 + + +def test_condition_sync_interface(client): + """Test synchronous interface via SyncMethodMixin""" + condition = Condition("sync-test") + results = [] + + def worker(): + with condition: + results.append("locked") + results.append("released") + + worker() + assert results == ["locked", "released"] + + +@gen_cluster(client=True) +async def test_condition_multiple_notify_calls(c, s, a, b): + """Test multiple notify calls in sequence""" + condition = Condition("multi-notify") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify() + await asyncio.sleep(0.1) + async with condition: + condition.notify() + await asyncio.sleep(0.1) + async with condition: + condition.notify() + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_predicate_loop(c, s, a, b): + """Test typical predicate-based wait loop pattern""" + condition = Condition("predicate") + state = {"value": 0, "target": 5} + + async def waiter(): + async with condition: + while state["value"] < state["target"]: + await condition.wait() + return state["value"] + + async def updater(): + for i in range(1, 6): + await asyncio.sleep(0.1) + async with condition: + state["value"] = i + condition.notify_all() + + result, _ = await asyncio.gather(waiter(), updater()) + assert result == 5 + + +@gen_cluster(client=True) +async def test_condition_repr(c, s, a, b): + """Test string representation""" + condition = Condition("test-repr") + assert "test-repr" in repr(condition) + assert "Condition" in repr(condition) + + +@gen_cluster(client=True) +async def test_condition_not_reentrant(c, s, a, b): + """Test that lock is NOT re-entrant within same async task""" + cond1 = Condition("not-reentrant") + cond2 = Condition("not-reentrant") # Same name = same lock + results = [] + + async def holder(): + await cond1.acquire() + results.append("cond1-acquired") + await asyncio.sleep(0.5) # Hold lock + await cond1.release() + results.append("cond1-released") + + async def waiter(): + await asyncio.sleep(0.1) # Let holder acquire first + results.append("cond2-attempting") + # This should block until holder releases + await cond2.acquire() + results.append("cond2-acquired") + await cond2.release() + + await asyncio.gather(holder(), waiter()) + + # Verify order: holder acquires, waiter attempts, holder releases, waiter acquires + assert results[0] == "cond1-acquired" + assert results[1] == "cond2-attempting" + assert results[2] == "cond1-released" + assert results[3] == "cond2-acquired" + + +@gen_cluster(client=True) +async def test_condition_multiple_instances_share_client_id(c, s, a, b): + """Test that multiple Condition instances in same client share client_id""" + cond1 = Condition("test-1") + cond2 = Condition("test-2") + + # Both should have same client ID (the client 'c') + assert cond1._client_id == cond2._client_id == c.id