Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
features:
- Allow asynchronous callbacks for `before`, `after`, `retry_error_callback`, `wait`, and `before_sleep` parameters.
52 changes: 50 additions & 2 deletions tenacity/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import sys
import typing
from asyncio import sleep
from inspect import iscoroutinefunction

from tenacity import AttemptManager
from tenacity import BaseRetrying
from tenacity import DoAttempt
from tenacity import DoSleep
from tenacity import RetryAction
from tenacity import RetryCallState
from tenacity import TryAgain

WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable)
_RetValT = typing.TypeVar("_RetValT")
Expand All @@ -45,7 +48,7 @@ async def __call__( # type: ignore # Change signature from supertype

retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
while True:
do = self.iter(retry_state=retry_state)
do = await self.iter(retry_state=retry_state)
if isinstance(do, DoAttempt):
try:
result = await fn(*args, **kwargs)
Expand All @@ -66,7 +69,7 @@ def __aiter__(self) -> "AsyncRetrying":

async def __anext__(self) -> typing.Union[AttemptManager, typing.Any]:
while True:
do = self.iter(retry_state=self._retry_state)
do = await self.iter(retry_state=self._retry_state)
if do is None:
raise StopAsyncIteration
elif isinstance(do, DoAttempt):
Expand All @@ -90,3 +93,48 @@ async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
async_wrapped.retry_with = fn.retry_with

return async_wrapped

@staticmethod
async def handle_custom_function(
func: typing.Union[typing.Callable, typing.Awaitable], retry_state: RetryCallState
) -> typing.Any:
if iscoroutinefunction(func):
return await func(retry_state)
return func(retry_state)

async def iter(self, retry_state: "RetryCallState") -> typing.Union[DoAttempt, DoSleep, typing.Any]: # noqa
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jd I also wanted to use some base method to avoid re-defining the iter method again, however, one of them is sync def iter and other one is async async def iter. Any ideas how to make it DRY?

I wanted to use handle_custom_function to handle sync/async calls, however, if we make handle_custom_function sync function, then it might need to use asyncio.get_event_loop() to run awaitable functions (running async functions from sync function).

Copy link
Author

@mastizada mastizada Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jd any ideas? I would love to have a progress with this PR and start using async callbacks for retry_error_callback

fut = retry_state.outcome
if fut is None:
if self.before is not None:
await self.handle_custom_function(self.before, retry_state)
return DoAttempt()

is_explicit_retry = retry_state.outcome.failed and isinstance(retry_state.outcome.exception(), TryAgain)
if not (is_explicit_retry or self.retry(retry_state=retry_state)):
return fut.result()

if self.after is not None:
await self.handle_custom_function(self.after, retry_state)

self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
if self.stop(retry_state=retry_state):
if self.retry_error_callback:
return await self.handle_custom_function(self.retry_error_callback, retry_state)
retry_exc = self.retry_error_cls(fut)
if self.reraise:
raise retry_exc.reraise()
raise retry_exc from fut.exception()

if self.wait:
_sleep = await self.handle_custom_function(self.wait, retry_state=retry_state)
else:
_sleep = 0.0
retry_state.next_action = RetryAction(_sleep)
retry_state.idle_for += _sleep
self.statistics["idle_for"] += _sleep
self.statistics["attempt_number"] += 1

if self.before_sleep is not None:
await self.handle_custom_function(self.before_sleep, retry_state)

return DoSleep(_sleep)
131 changes: 128 additions & 3 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@

import asyncio
import inspect
import logging
import unittest
from functools import wraps

from tenacity import AsyncRetrying, RetryError
from tenacity import AsyncRetrying, Future, RetryCallState, RetryError
from tenacity import _asyncio as tasyncio
from tenacity import retry, stop_after_attempt
from tenacity import before_sleep_log, retry, retry_if_result, stop_after_attempt
from tenacity.wait import wait_fixed

from .test_tenacity import NoIOErrorAfterCount, current_time_ms
from .test_tenacity import CapturingHandler, NoIOErrorAfterCount, NoneReturnUntilAfterCount, current_time_ms


def asynctest(callable_):
Expand Down Expand Up @@ -86,6 +87,31 @@ def test_retry_attributes(self):
assert hasattr(_retryable_coroutine, "retry")
assert hasattr(_retryable_coroutine, "retry_with")

@asynctest
async def test_async_retry_error_callback_handler(self):
num_attempts = 3
self.attempt_counter = 0

async def _retry_error_callback_handler(retry_state: RetryCallState):
_retry_error_callback_handler.called_times += 1
return retry_state.outcome

_retry_error_callback_handler.called_times = 0

@retry(
stop=stop_after_attempt(num_attempts),
retry_error_callback=_retry_error_callback_handler,
)
async def _foobar():
self.attempt_counter += 1
raise Exception("This exception should not be raised")

result = await _foobar()

self.assertEqual(_retry_error_callback_handler.called_times, 1)
self.assertEqual(num_attempts, self.attempt_counter)
self.assertIsInstance(result, Future)

@asynctest
async def test_attempt_number_is_correct_for_interleaved_coroutines(self):

Expand Down Expand Up @@ -157,5 +183,104 @@ async def test_sleeps(self):
self.assertLess(t, 1.1)


class TestAsyncBeforeAfterAttempts(unittest.TestCase):
_attempt_number = 0

@asynctest
async def test_before_attempts(self):
TestAsyncBeforeAfterAttempts._attempt_number = 0

async def _before(retry_state):
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number

@retry(
wait=wait_fixed(1),
stop=stop_after_attempt(1),
before=_before,
)
async def _test_before():
pass

await _test_before()

self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 1)

@asynctest
async def test_after_attempts(self):
TestAsyncBeforeAfterAttempts._attempt_number = 0

async def _after(retry_state):
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number

@retry(
wait=wait_fixed(0.1),
stop=stop_after_attempt(3),
after=_after,
)
async def _test_after():
if TestAsyncBeforeAfterAttempts._attempt_number < 2:
raise Exception("testing after_attempts handler")
else:
pass

await _test_after()

self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 2)

@asynctest
async def test_before_sleep(self):
async def _before_sleep(retry_state):
self.assertGreater(retry_state.next_action.sleep, 0)
_before_sleep.attempt_number = retry_state.attempt_number

_before_sleep.attempt_number = 0

@retry(
wait=wait_fixed(0.01),
stop=stop_after_attempt(3),
before_sleep=_before_sleep,
)
async def _test_before_sleep():
if _before_sleep.attempt_number < 2:
raise Exception("testing before_sleep_attempts handler")

await _test_before_sleep()
self.assertEqual(_before_sleep.attempt_number, 2)

async def _test_before_sleep_log_returns(self, exc_info):
thing = NoneReturnUntilAfterCount(2)
logger = logging.getLogger(self.id())
logger.propagate = False
logger.setLevel(logging.INFO)
handler = CapturingHandler()
logger.addHandler(handler)
try:
_before_sleep = before_sleep_log(logger, logging.INFO, exc_info=exc_info)
_retry = retry_if_result(lambda result: result is None)
retrying = AsyncRetrying(
wait=wait_fixed(0.01),
stop=stop_after_attempt(3),
retry=_retry,
before_sleep=_before_sleep,
)
await retrying(_async_function, thing)
finally:
logger.removeHandler(handler)

etalon_re = r"^Retrying .* in 0\.01 seconds as it returned None\.$"
self.assertEqual(len(handler.records), 2)
fmt = logging.Formatter().format
self.assertRegex(fmt(handler.records[0]), etalon_re)
self.assertRegex(fmt(handler.records[1]), etalon_re)

@asynctest
async def test_before_sleep_log_returns_without_exc_info(self):
await self._test_before_sleep_log_returns(exc_info=False)

@asynctest
async def test_before_sleep_log_returns_with_exc_info(self):
await self._test_before_sleep_log_returns(exc_info=True)


if __name__ == "__main__":
unittest.main()