Skip to content

Commit 74793e8

Browse files
author
Anton
committed
fix: custom semaphore
1 parent 779f421 commit 74793e8

File tree

3 files changed

+143
-5
lines changed

3 files changed

+143
-5
lines changed

taskiq/receiver/receiver.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from taskiq.message import TaskiqMessage
1515
from taskiq.receiver.params_parser import parse_params
1616
from taskiq.result import TaskiqResult
17+
from taskiq.semaphore import DequeSemaphore
1718
from taskiq.state import TaskiqState
1819
from taskiq.utils import DequeQueue, maybe_awaitable
1920

@@ -58,15 +59,15 @@ def __init__( # noqa: WPS211
5859
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
5960
self.task_hints[task.task_name] = get_type_hints(task.original_func)
6061
self.dependency_graphs[task.task_name] = DependencyGraph(task.original_func)
61-
self.sem: "Optional[asyncio.Semaphore]" = None
62+
self.sem: "Optional[DequeSemaphore]" = None
6263
if max_async_tasks is not None and max_async_tasks > 0:
63-
self.sem = asyncio.Semaphore(max_async_tasks)
64+
self.sem = DequeSemaphore(max_async_tasks)
6465
else:
6566
logger.warning(
6667
"Setting unlimited number of async tasks "
6768
+ "can result in undefined behavior",
6869
)
69-
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
70+
self.sem_prefetch = DequeSemaphore(max_prefetch)
7071
self.queue: DequeQueue[bytes] = DequeQueue()
7172

7273
self.sem_idle: Optional[asyncio.Semaphore] = None
@@ -309,7 +310,7 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
309310
break
310311
if message is QUEUE_SKIP:
311312
# Decrease max_prefetch
312-
prefetch_dec = asyncio.create_task(self.sem_prefetch.acquire())
313+
prefetch_dec = asyncio.create_task(self.sem_prefetch.acquire_first())
313314
prefetch_dec.add_done_callback(tasks.discard)
314315
tasks.add(prefetch_dec)
315316

@@ -356,5 +357,5 @@ async def task_idler(self, wait: float) -> None:
356357
# Decrease max_prefetch in runner
357358
task = asyncio.create_task(self.queue.put_first(QUEUE_SKIP))
358359
# Decrease max_tasks
359-
await self.sem.acquire()
360+
await self.sem.acquire_first()
360361
await task

taskiq/semaphore.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import asyncio
2+
import collections
3+
from typing import Any, Deque
4+
5+
from typing_extensions import Literal
6+
7+
8+
class DequeSemaphore:
9+
"""Custom deque based semaphore."""
10+
11+
def __init__(self, value: int) -> None:
12+
self._value = value
13+
self._waiters: Deque[asyncio.Future[Any]] = collections.deque()
14+
15+
if self._value < 0:
16+
raise ValueError("Value should be >= 0")
17+
18+
async def __aenter__(self) -> Literal[True]:
19+
return await self.acquire()
20+
21+
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
22+
self.release()
23+
24+
def locked(self) -> bool:
25+
"""
26+
Returns True if semaphore cannot be acquired immediately.
27+
28+
:returns: true or false
29+
"""
30+
return self._value == 0
31+
32+
def release(self) -> None:
33+
"""Release a semaphore, incrementing the internal counter by one.
34+
35+
When it was zero on entry and another coroutine is waiting for it to
36+
become larger than zero again, wake up that coroutine.
37+
"""
38+
self._value += 1
39+
self._wakeup_next()
40+
41+
async def acquire(self, first: bool = False) -> Literal[True]: # noqa: C901
42+
"""
43+
Acquire a semaphore.
44+
45+
:param first: acquire ASAP
46+
:raises asyncio.exceptions.CancelledError: task cancelled
47+
:returns: true
48+
"""
49+
if not self.locked() and not self._waiters:
50+
# No need to wait as the semaphore is not locked
51+
# and no one is waiting
52+
self._value -= 1
53+
return True
54+
55+
# if there are waiters or the semaphore is locked
56+
fut: asyncio.Future[Any] = asyncio.Future()
57+
58+
if first:
59+
self._waiters.appendleft(fut)
60+
else:
61+
self._waiters.append(fut)
62+
63+
try:
64+
try: # noqa: WPS501, WPS505
65+
await fut
66+
finally:
67+
self._waiters.remove(fut)
68+
69+
except asyncio.exceptions.CancelledError:
70+
if not fut.cancelled():
71+
self._value += 1
72+
self._wakeup_next()
73+
raise
74+
75+
if not self.locked():
76+
# This is required for strict FIFO ordering
77+
# otherwise it can cause starvation on the waiting tasks
78+
# The next loop iteration will wake up the task and switch
79+
self._wakeup_next()
80+
81+
return True
82+
83+
async def acquire_first(self) -> Literal[True]:
84+
"""
85+
Acquire a semaphore ASAP.
86+
87+
:returns: true
88+
"""
89+
return await self.acquire(True)
90+
91+
def _wakeup_next(self) -> None:
92+
if not self._waiters:
93+
return
94+
95+
for fut in self._waiters:
96+
if not fut.done():
97+
self._value -= 1
98+
fut.set_result(True)
99+
return

tests/test_semaphore.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import asyncio
2+
3+
import anyio
4+
import pytest
5+
6+
from taskiq.semaphore import DequeSemaphore
7+
8+
9+
@pytest.mark.anyio
10+
async def test_semaphore_exception() -> None:
11+
with pytest.raises(ValueError):
12+
DequeSemaphore(-1)
13+
14+
15+
@pytest.mark.anyio
16+
async def test_semaphore() -> None:
17+
sem = DequeSemaphore(1)
18+
19+
async def c1() -> None:
20+
await sem.acquire()
21+
22+
async def c2() -> None:
23+
await sem.acquire()
24+
25+
async def c3() -> None:
26+
await sem.acquire()
27+
28+
t1 = asyncio.create_task(c1())
29+
t2 = asyncio.create_task(c2())
30+
t3 = asyncio.create_task(c3())
31+
await asyncio.sleep(0)
32+
33+
sem.release()
34+
sem.release()
35+
t2.cancel()
36+
37+
async with anyio.maybe_async_cm(anyio.move_on_after(1)):
38+
await asyncio.gather(t1, t2, t3, return_exceptions=True)

0 commit comments

Comments
 (0)