22
33import asyncio
44import concurrent .futures
5+ import contextlib
56import contextvars
67import functools
78import inspect
89import threading
9- import warnings
1010import weakref
1111from collections import deque
1212from concurrent .futures .thread import ThreadPoolExecutor
@@ -523,6 +523,22 @@ def done(task: asyncio.Future[_T]) -> None:
523523 return result
524524
525525
526+ _pool = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
527+
528+
529+ @contextlib .asynccontextmanager
530+ async def async_lock (lock : threading .RLock ) -> AsyncGenerator [None , None ]:
531+ locked = lock .acquire (False )
532+ while not locked :
533+ await asyncio .sleep (0 )
534+ locked = lock .acquire (False )
535+ try :
536+ yield
537+ finally :
538+ if locked :
539+ lock .release ()
540+
541+
526542class Event :
527543 """Thread safe version of an async Event"""
528544
@@ -547,39 +563,47 @@ def set(self) -> None:
547563 if not self ._value :
548564 self ._value = True
549565
550- for fut in self ._waiters :
566+ while self ._waiters :
567+ fut = self ._waiters .popleft ()
568+
551569 if not fut .done ():
552570 if fut ._loop == asyncio .get_running_loop ():
553571 if not fut .done ():
554572 fut .set_result (True )
555573 else :
556574
557- def s (w : asyncio .Future [Any ]) -> None :
558- if not w .done ():
559- w .set_result (True )
575+ def s (w : asyncio .Future [Any ], ev : threading .Event ) -> None :
576+ try :
577+ if not w .done ():
578+ w .set_result (True )
579+ finally :
580+ ev .set ()
560581
561582 if not fut .done ():
562- fut ._loop .call_soon_threadsafe (s , fut )
583+ done = threading .Event ()
584+
585+ fut ._loop .call_soon_threadsafe (s , fut , done )
586+
587+ if not done .wait (120 ):
588+ raise RuntimeError ("Callback timeout" )
563589
564590 def clear (self ) -> None :
565591 with self ._lock :
566592 self ._value = False
567593
568594 async def wait (self , timeout : Optional [float ] = None ) -> bool :
569- if self ._value :
570- return True
571-
572- fut = create_sub_future ()
573595 with self ._lock :
596+ if self ._value :
597+ return True
598+
599+ fut = create_sub_future ()
574600 self ._waiters .append (fut )
601+
575602 try :
576603 await asyncio .wait_for (fut , timeout )
577604 return True
578605 except asyncio .TimeoutError :
579606 return False
580- finally :
581- with self ._lock :
582- self ._waiters .remove (fut )
583607
584608
585609class Semaphore :
@@ -600,8 +624,8 @@ def __repr__(self) -> str:
600624 extra = f"{ extra } , waiters:{ len (self ._waiters )} "
601625 return f"<{ res [1 :- 1 ]} [{ extra } ]>"
602626
603- def _wake_up_next (self ) -> None :
604- with self ._lock :
627+ async def _wake_up_next (self ) -> None :
628+ async with async_lock ( self ._lock ) :
605629 while self ._waiters :
606630 waiter = self ._waiters .popleft ()
607631
@@ -612,14 +636,23 @@ def _wake_up_next(self) -> None:
612636 else :
613637 if waiter ._loop .is_running ():
614638
615- def s (w : asyncio .Future [Any ]) -> None :
616- if w ._loop .is_running () and not w .done ():
617- w .set_result (True )
639+ def s (w : asyncio .Future [Any ], ev : threading .Event ) -> None :
640+ try :
641+ if w ._loop .is_running () and not w .done ():
642+ w .set_result (True )
643+ finally :
644+ ev .set ()
618645
619646 if not waiter .done ():
620- waiter ._loop .call_soon_threadsafe (s , waiter )
647+ done = threading .Event ()
648+
649+ waiter ._loop .call_soon_threadsafe (s , waiter , done )
650+
651+ if not done .wait (120 ):
652+ raise RuntimeError ("Callback timeout" )
653+
621654 else :
622- warnings . warn ("Loop is not running." )
655+ raise RuntimeError ("Loop is not running." )
623656
624657 def locked (self ) -> bool :
625658 with self ._lock :
@@ -628,7 +661,7 @@ def locked(self) -> bool:
628661 async def acquire (self , timeout : Optional [float ] = None ) -> bool :
629662 while self ._value <= 0 :
630663 fut = create_sub_future ()
631- with self ._lock :
664+ async with async_lock ( self ._lock ) :
632665 self ._waiters .append (fut )
633666 try :
634667 await asyncio .wait_for (fut , timeout )
@@ -638,26 +671,26 @@ async def acquire(self, timeout: Optional[float] = None) -> bool:
638671 if not fut .done ():
639672 fut .cancel ()
640673 if self ._value > 0 and not fut .cancelled ():
641- self ._wake_up_next ()
674+ await self ._wake_up_next ()
642675
643676 raise
644677
645- with self ._lock :
678+ async with async_lock ( self ._lock ) :
646679 self ._value -= 1
647680
648681 return True
649682
650- def release (self ) -> None :
683+ async def release (self ) -> None :
651684 self ._value += 1
652- self ._wake_up_next ()
685+ await self ._wake_up_next ()
653686
654687 async def __aenter__ (self ) -> None :
655688 await self .acquire ()
656689
657690 async def __aexit__ (
658691 self , exc_type : Optional [Type [BaseException ]], exc_val : Optional [BaseException ], exc_tb : Optional [TracebackType ]
659692 ) -> None :
660- self .release ()
693+ await self .release ()
661694
662695
663696class BoundedSemaphore (Semaphore ):
@@ -667,10 +700,10 @@ def __init__(self, value: int = 1) -> None:
667700 self ._bound_value = value
668701 super ().__init__ (value )
669702
670- def release (self ) -> None :
703+ async def release (self ) -> None :
671704 if self ._value >= self ._bound_value :
672705 raise ValueError ("BoundedSemaphore released too many times" )
673- super ().release ()
706+ await super ().release ()
674707
675708
676709class Lock :
@@ -683,8 +716,8 @@ def __repr__(self) -> str:
683716 async def acquire (self , timeout : Optional [float ] = None ) -> bool :
684717 return await self ._block .acquire (timeout )
685718
686- def release (self ) -> None :
687- self ._block .release ()
719+ async def release (self ) -> None :
720+ await self ._block .release ()
688721
689722 @property
690723 def locked (self ) -> bool :
@@ -696,7 +729,7 @@ async def __aenter__(self) -> None:
696729 async def __aexit__ (
697730 self , exc_type : Optional [Type [BaseException ]], exc_val : Optional [BaseException ], exc_tb : Optional [TracebackType ]
698731 ) -> None :
699- self .release ()
732+ await self .release ()
700733
701734
702735class OldLock :
@@ -772,7 +805,7 @@ async def release(self) -> None:
772805 if self ._locked :
773806 self ._locked = False
774807 else :
775- warnings . warn (f"Lock is not acquired ({ len (self ._waiters ) if self ._waiters else 0 } waiters)." )
808+ raise RuntimeError (f"Lock is not acquired ({ len (self ._waiters ) if self ._waiters else 0 } waiters)." )
776809
777810 await self ._wake_up_next ()
778811
0 commit comments