diff --git a/releasenotes/notes/asynchronous_defined_retry_callbacks-254eecdc313c52f8.yaml b/releasenotes/notes/asynchronous_defined_retry_callbacks-254eecdc313c52f8.yaml new file mode 100644 index 00000000..908ad38e --- /dev/null +++ b/releasenotes/notes/asynchronous_defined_retry_callbacks-254eecdc313c52f8.yaml @@ -0,0 +1,16 @@ +--- +prelude: > + Example use cases: + - if we get logged out from the server + - if authentication tokens expire + We want to be able to automatically refresh our session by calling a specific + function. This function can be asynchronously defined. +features: + - | + Asynchronous defined retry callbacks: + - retry + - retry_error_callback +issues: + - | + #249 + diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py index 979b6544..6991621d 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/_asyncio.py @@ -15,10 +15,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +try: + from inspect import iscoroutinefunction +except ImportError: + iscoroutinefunction = None + import sys from asyncio import sleep -from tenacity import AttemptManager +import six + +from tenacity import AttemptManager, RetryAction, TryAgain from tenacity import BaseRetrying from tenacity import DoAttempt from tenacity import DoSleep @@ -30,12 +37,58 @@ def __init__(self, sleep=sleep, **kwargs): super(AsyncRetrying, self).__init__(**kwargs) self.sleep = sleep + async def iter(self, retry_state): # noqa + fut = retry_state.outcome + if fut is None: + if self.before is not None: + self.before(retry_state) + return DoAttempt() + + is_explicit_retry = retry_state.outcome.failed and isinstance( + retry_state.outcome.exception(), TryAgain + ) + if iscoroutinefunction(self.retry): + should_retry = await self.retry(retry_state=retry_state) + else: + should_retry = self.retry(retry_state=retry_state) + if not (is_explicit_retry or should_retry): + return fut.result() + + if self.after is not None: + self.after(retry_state=retry_state) + + self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start + if self.stop(retry_state=retry_state): + if self.retry_error_callback: + if iscoroutinefunction(self.retry_error_callback): + return await self.retry_error_callback(retry_state=retry_state) + else: + return self.retry_error_callback(retry_state=retry_state) + retry_exc = self.retry_error_cls(fut) + if self.reraise: + raise retry_exc.reraise() + six.raise_from(retry_exc, fut.exception()) + + if self.wait: + iteration_sleep = self.wait(retry_state=retry_state) + else: + iteration_sleep = 0.0 + retry_state.next_action = RetryAction(iteration_sleep) + retry_state.idle_for += iteration_sleep + self.statistics["idle_for"] += iteration_sleep + self.statistics["attempt_number"] += 1 + + if self.before_sleep is not None: + self.before_sleep(retry_state=retry_state) + + return DoSleep(iteration_sleep) + async def __call__(self, fn, *args, **kwargs): self.begin(fn) retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) while True: - do = self.iter(retry_state=retry_state) + do = await self.iter(retry_state=retry_state) if isinstance(do, DoAttempt): try: result = await fn(*args, **kwargs) @@ -56,7 +109,7 @@ def __aiter__(self): async def __anext__(self): while True: - do = self.iter(retry_state=self._retry_state) + do = await self.iter(retry_state=self._retry_state) if do is None: raise StopAsyncIteration elif isinstance(do, DoAttempt): @@ -69,6 +122,7 @@ async def __anext__(self): def wraps(self, fn): fn = super().wraps(fn) + # Ensure wrapper is recognized as a coroutine function. async def async_wrapped(*args, **kwargs): diff --git a/tenacity/tests/test_asyncio.py b/tenacity/tests/test_asyncio.py index 2057fd2d..f954ec1d 100644 --- a/tenacity/tests/test_asyncio.py +++ b/tenacity/tests/test_asyncio.py @@ -124,6 +124,49 @@ def after(retry_state): assert len(set(things)) == 1 assert list(attempt_nos2) == [1, 2, 3] + @asynctest + async def test_async_retry(self): + attempts = [] + + async def async_retry(retry_state): + if retry_state.outcome.failed: + attempts.append((retry_state.outcome, retry_state.attempt_number)) + return True + else: + attempts.append((retry_state.outcome, retry_state.attempt_number)) + return False + + thing = NoIOErrorAfterCount(2) + + await _retryable_coroutine.retry_with(retry=async_retry)(thing) + + things, attempt_numbers = zip(*attempts) + assert len(attempts) == 3 + + for thing in things[:-1]: + with pytest.raises(IOError): + thing.result() + + assert things[-1].result() is True + + @asynctest + async def test_async_callback_error_retry(self): + async def async_return_text(retry_state): + await asyncio.sleep(0.00001) + + return "Calling %s keeps raising errors after %s attempts" % ( + retry_state.fn.__name__, + retry_state.attempt_number, + ) + + thing = NoIOErrorAfterCount(3) + + result = await _retryable_coroutine_with_2_attempts.retry_with( + retry_error_callback=async_return_text + )(thing) + message = "Calling _retryable_coroutine_with_2_attempts keeps raising errors after 2 attempts" + assert result == message + class TestContextManager(unittest.TestCase): @asynctest