diff --git a/distributed/__init__.py b/distributed/__init__.py index 3f075b977c..091c14e0eb 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -146,3 +146,4 @@ "widgets", "worker_client", ] +from distributed.condition import Condition diff --git a/distributed/condition.py b/distributed/condition.py new file mode 100644 index 0000000000..7df6020a17 --- /dev/null +++ b/distributed/condition.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import asyncio +import logging +import uuid +from collections import defaultdict +from contextlib import suppress + +from distributed.lock import Lock +from distributed.utils import log_errors +from distributed.worker import get_client + +logger = logging.getLogger(__name__) + + +class ConditionExtension: + """Scheduler extension for managing distributed Conditions + + Tracks waiters for each condition variable and implements notify logic. + """ + + def __init__(self, scheduler): + self.scheduler = scheduler + # name -> {client_id -> asyncio.Event} + self.waiters = defaultdict(dict) + + self.scheduler.handlers.update( + { + "condition_wait": self.wait, + "condition_notify": self.notify, + "condition_notify_all": self.notify_all, + } + ) + + @log_errors + async def wait(self, name=None, client=None): + """Register a waiter and block until notified""" + # Create event for this specific waiter + event = asyncio.Event() + self.waiters[name][client] = event + + try: + # Block until notified + await event.wait() + finally: + # Cleanup after waking or cancellation + with suppress(KeyError): + del self.waiters[name][client] + if not self.waiters[name]: + del self.waiters[name] + + async def notify(self, name=None, n=1): + """Wake up n waiters""" + if name not in self.waiters: + return + + # Wake up to n waiters + notified = 0 + for client_id in list(self.waiters[name].keys()): + if notified >= n: + break + event = self.waiters[name].get(client_id) + if event and not event.is_set(): + event.set() + notified += 1 + + async def notify_all(self, name=None): + """Wake up all waiters""" + if name not in self.waiters: + return + + for event in self.waiters[name].values(): + if not event.is_set(): + event.set() + + +class Condition: + """Distributed Condition Variable + + A distributed version of asyncio.Condition. Allows one or more clients + to wait until notified by another client. + + Like asyncio.Condition, this must be used with a lock. The lock is + released before waiting and reacquired afterwards. + + Parameters + ---------- + name : str, optional + Name of the condition. If not provided, a random name is generated. + client : Client, optional + Client instance. If not provided, uses the default client. + lock : Lock, optional + Lock to use with this condition. If not provided, creates a new Lock. + + Examples + -------- + >>> from distributed import Client, Condition + >>> client = Client() # doctest: +SKIP + >>> condition = Condition() # doctest: +SKIP + + >>> async with condition: # doctest: +SKIP + ... # Wait for some condition + ... await condition.wait() + + >>> async with condition: # doctest: +SKIP + ... # Modify shared state + ... condition.notify() # Wake one waiter + """ + + def __init__(self, name=None, client=None, lock=None): + self._client = client + self.name = name or "condition-" + uuid.uuid4().hex + + if lock is None: + lock = Lock(client=client) + elif not isinstance(lock, Lock): + raise TypeError(f"lock must be a Lock, not {type(lock)}") + + self._lock = lock + + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + + def _verify_running(self): + if not self.client: + raise RuntimeError( + f"{type(self)} object not properly initialized. Ensure it's created within a Client context." + ) + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.release() + + def __repr__(self): + return f"" + + async def acquire(self, timeout=None): + """Acquire the underlying lock""" + self._verify_running() + return await self._lock.acquire(timeout=timeout) + + async def release(self): + """Release the underlying lock""" + self._verify_running() + return await self._lock.release() + + def locked(self): + """Return True if lock is held""" + return self._lock.locked() + + async def wait(self, timeout=None): + """Wait until notified. + + This method releases the underlying lock, waits until notified, + then reacquires the lock before returning. + + Must be called with the lock held. + + Parameters + ---------- + timeout : float, optional + Maximum time to wait in seconds. If None, wait indefinitely. + + Returns + ------- + bool + True if woken by notify, False if timeout occurred + + Raises + ------ + RuntimeError + If called without holding the lock + """ + self._verify_running() + + # Release lock before waiting (will error if not held) + await self.release() + + # Track if we need to reacquire (for proper cleanup on cancellation) + reacquired = False + + try: + # Wait for notification from scheduler + if timeout is not None: + try: + await asyncio.wait_for( + self.client.scheduler.condition_wait( + name=self.name, client=self.client.id + ), + timeout=timeout, + ) + return True + except asyncio.TimeoutError: + return False + else: + await self.client.scheduler.condition_wait( + name=self.name, client=self.client.id + ) + return True + finally: + # CRITICAL: Always reacquire lock, even on cancellation + # This maintains the invariant that wait() returns with lock held + try: + await self.acquire() + reacquired = True + except asyncio.CancelledError: + # If reacquisition is cancelled, we're in a bad state + # Try again without allowing cancellation + if not reacquired: + with suppress(Exception): + await asyncio.shield(self.acquire()) + raise + + async def wait_for(self, predicate, timeout=None): + """Wait until a predicate becomes true. + + Parameters + ---------- + predicate : callable + Function that returns True when the condition is met + timeout : float, optional + Maximum time to wait in seconds + + Returns + ------- + bool + The predicate result (should be True unless timeout) + """ + result = predicate() + while not result: + if timeout is not None: + # Consume timeout across multiple waits + import time + + start = time.time() + if not await self.wait(timeout=timeout): + return predicate() + timeout -= time.time() - start + if timeout <= 0: + return predicate() + else: + await self.wait() + result = predicate() + return result + + async def notify(self, n=1): + """Wake up n waiters (default: 1) + + Parameters + ---------- + n : int + Number of waiters to wake up + """ + self._verify_running() + await self.client.scheduler.condition_notify(name=self.name, n=n) + + async def notify_all(self): + """Wake up all waiters""" + self._verify_running() + await self.client.scheduler.condition_notify_all(name=self.name) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d73da7c71f..897cfbb32f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -97,6 +97,7 @@ ) from distributed.comm.addressing import addresses_from_user_args from distributed.compatibility import PeriodicCallback +from distributed.condition import ConditionExtension from distributed.core import ( ErrorMessage, OKMessage, @@ -195,6 +196,7 @@ "semaphores": SemaphoreExtension, "events": EventExtension, "amm": ActiveMemoryManagerExtension, + "conditions": ConditionExtension, "memory_sampler": MemorySamplerExtension, "shuffle": ShuffleSchedulerPlugin, "spans": SpansSchedulerExtension, diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py new file mode 100644 index 0000000000..3e57d47a2d --- /dev/null +++ b/distributed/tests/test_condition.py @@ -0,0 +1,438 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from distributed import Condition, Lock +from distributed.metrics import time +from distributed.utils_test import gen_cluster + + +@gen_cluster(client=True) +async def test_condition_basic(c, s, a, b): + """Basic condition wait and notify""" + condition = Condition() + + results = [] + + async def waiter(): + async with condition: + results.append("waiting") + await condition.wait() + results.append("notified") + + # Start waiter task + waiter_task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) # Let waiter acquire lock and start waiting + + assert results == ["waiting"] + + # Notify the waiter + async with condition: + await condition.notify() + + await waiter_task + assert results == ["waiting", "notified"] + + +@gen_cluster(client=True) +async def test_condition_notify_one(c, s, a, b): + """notify() wakes only one waiter""" + condition = Condition() + + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + # Start multiple waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(3)] + await asyncio.sleep(0.1) + + # Notify once - only one should wake + async with condition: + await condition.notify() + + await asyncio.sleep(0.1) + assert len(results) == 1 + + # Notify again + async with condition: + await condition.notify() + + await asyncio.sleep(0.1) + assert len(results) == 2 + + # Cleanup remaining + async with condition: + await condition.notify_all() + + await asyncio.gather(*tasks) + assert len(results) == 3 + + +@gen_cluster(client=True) +async def test_condition_notify_n(c, s, a, b): + """notify(n) wakes exactly n waiters""" + condition = Condition() + + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + # Start 5 waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(5)] + await asyncio.sleep(0.1) + + # Notify 2 + async with condition: + await condition.notify(2) + + await asyncio.sleep(0.1) + assert len(results) == 2 + + # Notify 3 more + async with condition: + await condition.notify(3) + + await asyncio.gather(*tasks) + assert len(results) == 5 + + +@gen_cluster(client=True) +async def test_condition_notify_all(c, s, a, b): + """notify_all() wakes all waiters""" + condition = Condition() + + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + # Start multiple waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(10)] + await asyncio.sleep(0.1) + + # Wake all at once + async with condition: + await condition.notify_all() + + await asyncio.gather(*tasks) + assert len(results) == 10 + + +@gen_cluster(client=True) +async def test_condition_wait_timeout(c, s, a, b): + """wait() with timeout returns False if not notified""" + condition = Condition() + + async with condition: + start = time() + result = await condition.wait(timeout=0.2) + elapsed = time() - start + + assert result is False + assert 0.15 < elapsed < 0.5 + + +@gen_cluster(client=True) +async def test_condition_wait_timeout_then_notify(c, s, a, b): + """wait() with timeout returns True if notified before timeout""" + condition = Condition() + + async def notifier(): + await asyncio.sleep(0.1) + async with condition: + await condition.notify() + + notifier_task = asyncio.create_task(notifier()) + + async with condition: + result = await condition.wait(timeout=1.0) + + await notifier_task + assert result is True + + +@gen_cluster(client=True) +async def test_condition_wait_for(c, s, a, b): + """wait_for() waits until predicate is true""" + condition = Condition() + state = {"value": 0} + + async def incrementer(): + for _i in range(5): + await asyncio.sleep(0.05) + async with condition: + state["value"] += 1 + await condition.notify_all() + + inc_task = asyncio.create_task(incrementer()) + + async with condition: + result = await condition.wait_for(lambda: state["value"] >= 3) + + await inc_task + assert result is True + assert state["value"] >= 3 + + +@gen_cluster(client=True) +async def test_condition_wait_for_timeout(c, s, a, b): + """wait_for() returns predicate result on timeout""" + condition = Condition() + + async with condition: + result = await condition.wait_for(lambda: False, timeout=0.2) + + assert result is False + + +@gen_cluster(client=True) +async def test_condition_context_manager(c, s, a, b): + """Condition works as async context manager""" + condition = Condition() + + assert not condition.locked() + + async with condition: + assert condition.locked() + + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_with_explicit_lock(c, s, a, b): + """Condition can use an explicit Lock""" + lock = Lock() + condition = Condition(lock=lock) + + async with lock: + assert condition.locked() + + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_multiple_notify_calls(c, s, a, b): + """Multiple notify calls work correctly""" + condition = Condition() + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + # Start 3 waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(3)] + await asyncio.sleep(0.1) + + # Notify one at a time + for _ in range(3): + async with condition: + await condition.notify() + await asyncio.sleep(0.05) + + await asyncio.gather(*tasks) + assert set(results) == {0, 1, 2} + + +@gen_cluster(client=True) +async def test_condition_notify_without_waiters(c, s, a, b): + """notify() with no waiters is a no-op""" + condition = Condition() + + async with condition: + await condition.notify() + await condition.notify_all() + await condition.notify(5) + + # Should not raise or hang + + +@gen_cluster(client=True) +async def test_condition_producer_consumer(c, s, a, b): + """Producer-consumer pattern with Condition""" + condition = Condition() + queue = [] + + async def producer(): + for i in range(5): + await asyncio.sleep(0.05) + async with condition: + queue.append(i) + await condition.notify() + + async def consumer(): + items = [] + for _ in range(5): + async with condition: + await condition.wait_for(lambda: len(queue) > 0) + items.append(queue.pop(0)) + return items + + prod_task = asyncio.create_task(producer()) + cons_task = asyncio.create_task(consumer()) + + result = await cons_task + await prod_task + + assert result == [0, 1, 2, 3, 4] + + +@gen_cluster(client=True) +async def test_condition_same_name(c, s, a, b): + """Multiple Condition instances with same name share state""" + cond1 = Condition(name="shared") + cond2 = Condition(name="shared") + + result = [] + + async def waiter(): + async with cond1: + await cond1.wait() + result.append("done") + + waiter_task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + + # Notify via different instance + async with cond2: + await cond2.notify() + + await waiter_task + assert result == ["done"] + + +@gen_cluster(client=True) +async def test_condition_repr(c, s, a, b): + """Condition has readable repr""" + condition = Condition(name="test-cond") + assert "test-cond" in repr(condition) + + +@gen_cluster(client=True) +async def test_condition_waiter_cancelled(c, s, a, b): + """Cancelled waiter properly cleans up""" + condition = Condition() + + async def waiter(): + async with condition: + await condition.wait() + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + + # Cancel the waiter + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + # Should not deadlock - lock should be released + async with condition: + await condition.notify() # No-op since no waiters + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_condition_across_workers(c, s, a): + """Condition works across different worker tasks""" + condition = Condition(name="cross-worker") + + def worker_wait(): + import asyncio + + from distributed import Condition + + async def wait_task(): + cond = Condition(name="cross-worker") + async with cond: + await cond.wait() + return "notified" + + return asyncio.run(wait_task()) + + # Submit waiter to worker + future = c.submit(worker_wait, pure=False) + await asyncio.sleep(0.2) + + # Notify from client + async with condition: + await condition.notify() + + result = await future + assert result == "notified" + + +@gen_cluster(client=True) +async def test_condition_locked_status(c, s, a, b): + """locked() returns correct status""" + condition = Condition() + + assert not condition.locked() + + await condition.acquire() + assert condition.locked() + + await condition.release() + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_reacquire_after_wait(c, s, a, b): + """wait() reacquires lock after being notified""" + condition = Condition() + lock_states = [] + + async def waiter(): + async with condition: + lock_states.append(("before_wait", condition.locked())) + await condition.wait() + lock_states.append(("after_wait", condition.locked())) + + waiter_task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + + async with condition: + lock_states.append(("notifier", condition.locked())) + await condition.notify() + + await waiter_task + + # Waiter should hold lock before and after wait + assert lock_states[0] == ("before_wait", True) + assert lock_states[1] == ("notifier", True) + assert lock_states[2] == ("after_wait", True) + + +@gen_cluster(client=True) +async def test_condition_many_waiters(c, s, a, b): + """Handle many waiters efficiently""" + condition = Condition() + results = [] + + async def waiter(n): + async with condition: + await condition.wait() + results.append(n) + + # Start many waiters + tasks = [asyncio.create_task(waiter(i)) for i in range(50)] + await asyncio.sleep(0.2) + + # Wake them all + async with condition: + await condition.notify_all() + + await asyncio.gather(*tasks) + assert len(results) == 50 + assert set(results) == set(range(50))