diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 697de81b1c..27f26c35b2 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -472,8 +472,8 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # This limit is non-configurable and was chosen to be twice the 60 second # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 -_BACKOFF_MAX = 1 -_BACKOFF_INITIAL = 0.050 # 50ms initial backoff +_BACKOFF_MAX = 0.500 # 500ms max backoff +_BACKOFF_INITIAL = 0.005 # 5ms initial backoff def _within_time_limit(start_time: float) -> bool: @@ -481,6 +481,11 @@ def _within_time_limit(start_time: float) -> bool: return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT +def _would_exceed_time_limit(start_time: float, backoff: float) -> bool: + """Is the backoff within the with_transaction retry limit?""" + return time.monotonic() + backoff - start_time >= _WITH_TRANSACTION_RETRY_TIME_LIMIT + + _T = TypeVar("_T") if TYPE_CHECKING: @@ -708,10 +713,14 @@ async def callback(session, custom_arg, custom_kwarg=None): """ start_time = time.monotonic() retry = 0 + last_error: Optional[BaseException] = None while True: if retry: # Implement exponential backoff on retry. jitter = random.random() # noqa: S311 - backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX) + backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX) + if _would_exceed_time_limit(start_time, backoff): + assert last_error is not None + raise last_error await asyncio.sleep(backoff) retry += 1 await self.start_transaction( @@ -721,6 +730,7 @@ async def callback(session, custom_arg, custom_kwarg=None): ret = await callback(self) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as exc: + last_error = exc if self.in_transaction: await self.abort_transaction() if ( diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index d5a37eb108..28999bcd62 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -470,8 +470,8 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # This limit is non-configurable and was chosen to be twice the 60 second # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 -_BACKOFF_MAX = 1 -_BACKOFF_INITIAL = 0.050 # 50ms initial backoff +_BACKOFF_MAX = 0.500 # 500ms max backoff +_BACKOFF_INITIAL = 0.005 # 5ms initial backoff def _within_time_limit(start_time: float) -> bool: @@ -479,6 +479,11 @@ def _within_time_limit(start_time: float) -> bool: return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT +def _would_exceed_time_limit(start_time: float, backoff: float) -> bool: + """Is the backoff within the with_transaction retry limit?""" + return time.monotonic() + backoff - start_time >= _WITH_TRANSACTION_RETRY_TIME_LIMIT + + _T = TypeVar("_T") if TYPE_CHECKING: @@ -706,10 +711,14 @@ def callback(session, custom_arg, custom_kwarg=None): """ start_time = time.monotonic() retry = 0 + last_error: Optional[BaseException] = None while True: if retry: # Implement exponential backoff on retry. jitter = random.random() # noqa: S311 - backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX) + backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX) + if _would_exceed_time_limit(start_time, backoff): + assert last_error is not None + raise last_error time.sleep(backoff) retry += 1 self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) @@ -717,6 +726,7 @@ def callback(session, custom_arg, custom_kwarg=None): ret = callback(self) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as exc: + last_error = exc if self.in_transaction: self.abort_transaction() if ( diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 29c5d26423..18f9778463 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -16,7 +16,9 @@ from __future__ import annotations import asyncio +import random import sys +import time from io import BytesIO from test.asynchronous.utils_spec_runner import AsyncSpecRunner @@ -441,7 +443,7 @@ async def set_fail_point(self, command_args): await self.configure_fail_point(client, command_args) @async_client_context.require_transactions - async def test_callback_raises_custom_error(self): + async def test_1_callback_raises_custom_error(self): class _MyException(Exception): pass @@ -453,7 +455,7 @@ async def raise_error(_): await s.with_transaction(raise_error) @async_client_context.require_transactions - async def test_callback_returns_value(self): + async def test_2_callback_returns_value(self): async def callback(_): return "Foo" @@ -481,7 +483,7 @@ def callback(_): self.assertEqual(await s.with_transaction(callback), "Foo") @async_client_context.require_transactions - async def test_callback_not_retried_after_timeout(self): + async def test_3_1_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -509,7 +511,7 @@ async def callback(session): @async_client_context.require_test_commands @async_client_context.require_transactions - async def test_callback_not_retried_after_commit_timeout(self): + async def test_3_2_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -543,7 +545,7 @@ async def callback(session): @async_client_context.require_test_commands @async_client_context.require_transactions - async def test_commit_not_retried_after_timeout(self): + async def test_3_3_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -613,6 +615,72 @@ async def callback(session): await s.with_transaction(callback) self.assertFalse(s.in_transaction) + @async_client_context.require_test_commands + @async_client_context.require_transactions + async def test_4_retry_backoff_is_enforced(self): + client = async_client_context.client + coll = client[self.db.name].test + # patch random to make it deterministic -- once to effectively have + # no backoff and the second time with "max" backoff (always waiting the longest + # possible time) + _original_random_random = random.random + + def always_one(): + return 1 + + def always_zero(): + return 0 + + random.random = always_zero + # set fail point to trigger transaction failure and trigger backoff + await self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": {"times": 13}, + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, + }, + } + ) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} + ) + + async def callback(session): + await coll.insert_one({}, session=session) + + start = time.monotonic() + async with self.client.start_session() as s: + await s.with_transaction(callback) + end = time.monotonic() + no_backoff_time = end - start + + random.random = always_one + # set fail point to trigger transaction failure and trigger backoff + await self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": { + "times": 13 + }, # sufficiently high enough such that the time effect of backoff is noticeable + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, + }, + } + ) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} + ) + start = time.monotonic() + async with self.client.start_session() as s: + await s.with_transaction(callback) + end = time.monotonic() + self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2 + + random.random = _original_random_random + class TestOptionsInsideTransactionProse(AsyncTransactionsBase): @async_client_context.require_transactions diff --git a/test/test_transactions.py b/test/test_transactions.py index 37e1a249e0..94d70396fc 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -16,7 +16,9 @@ from __future__ import annotations import asyncio +import random import sys +import time from io import BytesIO from test.utils_spec_runner import SpecRunner @@ -433,7 +435,7 @@ def set_fail_point(self, command_args): self.configure_fail_point(client, command_args) @client_context.require_transactions - def test_callback_raises_custom_error(self): + def test_1_callback_raises_custom_error(self): class _MyException(Exception): pass @@ -445,7 +447,7 @@ def raise_error(_): s.with_transaction(raise_error) @client_context.require_transactions - def test_callback_returns_value(self): + def test_2_callback_returns_value(self): def callback(_): return "Foo" @@ -473,7 +475,7 @@ def callback(_): self.assertEqual(s.with_transaction(callback), "Foo") @client_context.require_transactions - def test_callback_not_retried_after_timeout(self): + def test_3_1_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -501,7 +503,7 @@ def callback(session): @client_context.require_test_commands @client_context.require_transactions - def test_callback_not_retried_after_commit_timeout(self): + def test_3_2_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -533,7 +535,7 @@ def callback(session): @client_context.require_test_commands @client_context.require_transactions - def test_commit_not_retried_after_timeout(self): + def test_3_3_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -601,6 +603,68 @@ def callback(session): s.with_transaction(callback) self.assertFalse(s.in_transaction) + @client_context.require_test_commands + @client_context.require_transactions + def test_4_retry_backoff_is_enforced(self): + client = client_context.client + coll = client[self.db.name].test + # patch random to make it deterministic -- once to effectively have + # no backoff and the second time with "max" backoff (always waiting the longest + # possible time) + _original_random_random = random.random + + def always_one(): + return 1 + + def always_zero(): + return 0 + + random.random = always_zero + # set fail point to trigger transaction failure and trigger backoff + self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": {"times": 13}, + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, + }, + } + ) + self.addCleanup(self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}) + + def callback(session): + coll.insert_one({}, session=session) + + start = time.monotonic() + with self.client.start_session() as s: + s.with_transaction(callback) + end = time.monotonic() + no_backoff_time = end - start + + random.random = always_one + # set fail point to trigger transaction failure and trigger backoff + self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": { + "times": 13 + }, # sufficiently high enough such that the time effect of backoff is noticeable + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, + }, + } + ) + self.addCleanup(self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}) + start = time.monotonic() + with self.client.start_session() as s: + s.with_transaction(callback) + end = time.monotonic() + self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2 + + random.random = _original_random_random + class TestOptionsInsideTransactionProse(TransactionsBase): @client_context.require_transactions