Skip to content

Commit 88f05fa

Browse files
sleepyStickShaneHarveyblink1073NoahStappdependabot[bot]
authored
PYTHON-5518: withTransaction API retries too frequently (#2600)
Co-authored-by: Shane Harvey <shnhrv@gmail.com> Co-authored-by: Steven Silvester <steven.silvester@ieee.org> Co-authored-by: Noah Stapp <noah.stapp@mongodb.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
1 parent d767759 commit 88f05fa

File tree

4 files changed

+168
-16
lines changed

4 files changed

+168
-16
lines changed

pymongo/asynchronous/client_session.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,15 +472,20 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
472472
# This limit is non-configurable and was chosen to be twice the 60 second
473473
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
474474
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
475-
_BACKOFF_MAX = 1
476-
_BACKOFF_INITIAL = 0.050 # 50ms initial backoff
475+
_BACKOFF_MAX = 0.500 # 500ms max backoff
476+
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff
477477

478478

479479
def _within_time_limit(start_time: float) -> bool:
480480
"""Are we within the with_transaction retry limit?"""
481481
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
482482

483483

484+
def _would_exceed_time_limit(start_time: float, backoff: float) -> bool:
485+
"""Is the backoff within the with_transaction retry limit?"""
486+
return time.monotonic() + backoff - start_time >= _WITH_TRANSACTION_RETRY_TIME_LIMIT
487+
488+
484489
_T = TypeVar("_T")
485490

486491
if TYPE_CHECKING:
@@ -708,10 +713,14 @@ async def callback(session, custom_arg, custom_kwarg=None):
708713
"""
709714
start_time = time.monotonic()
710715
retry = 0
716+
last_error: Optional[BaseException] = None
711717
while True:
712718
if retry: # Implement exponential backoff on retry.
713719
jitter = random.random() # noqa: S311
714-
backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX)
720+
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
721+
if _would_exceed_time_limit(start_time, backoff):
722+
assert last_error is not None
723+
raise last_error
715724
await asyncio.sleep(backoff)
716725
retry += 1
717726
await self.start_transaction(
@@ -721,6 +730,7 @@ async def callback(session, custom_arg, custom_kwarg=None):
721730
ret = await callback(self)
722731
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
723732
except BaseException as exc:
733+
last_error = exc
724734
if self.in_transaction:
725735
await self.abort_transaction()
726736
if (

pymongo/synchronous/client_session.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,15 +470,20 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
470470
# This limit is non-configurable and was chosen to be twice the 60 second
471471
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
472472
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
473-
_BACKOFF_MAX = 1
474-
_BACKOFF_INITIAL = 0.050 # 50ms initial backoff
473+
_BACKOFF_MAX = 0.500 # 500ms max backoff
474+
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff
475475

476476

477477
def _within_time_limit(start_time: float) -> bool:
478478
"""Are we within the with_transaction retry limit?"""
479479
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
480480

481481

482+
def _would_exceed_time_limit(start_time: float, backoff: float) -> bool:
483+
"""Is the backoff within the with_transaction retry limit?"""
484+
return time.monotonic() + backoff - start_time >= _WITH_TRANSACTION_RETRY_TIME_LIMIT
485+
486+
482487
_T = TypeVar("_T")
483488

484489
if TYPE_CHECKING:
@@ -706,17 +711,22 @@ def callback(session, custom_arg, custom_kwarg=None):
706711
"""
707712
start_time = time.monotonic()
708713
retry = 0
714+
last_error: Optional[BaseException] = None
709715
while True:
710716
if retry: # Implement exponential backoff on retry.
711717
jitter = random.random() # noqa: S311
712-
backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX)
718+
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
719+
if _would_exceed_time_limit(start_time, backoff):
720+
assert last_error is not None
721+
raise last_error
713722
time.sleep(backoff)
714723
retry += 1
715724
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
716725
try:
717726
ret = callback(self)
718727
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
719728
except BaseException as exc:
729+
last_error = exc
720730
if self.in_transaction:
721731
self.abort_transaction()
722732
if (

test/asynchronous/test_transactions.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from __future__ import annotations
1717

1818
import asyncio
19+
import random
1920
import sys
21+
import time
2022
from io import BytesIO
2123
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
2224

@@ -441,7 +443,7 @@ async def set_fail_point(self, command_args):
441443
await self.configure_fail_point(client, command_args)
442444

443445
@async_client_context.require_transactions
444-
async def test_callback_raises_custom_error(self):
446+
async def test_1_callback_raises_custom_error(self):
445447
class _MyException(Exception):
446448
pass
447449

@@ -453,7 +455,7 @@ async def raise_error(_):
453455
await s.with_transaction(raise_error)
454456

455457
@async_client_context.require_transactions
456-
async def test_callback_returns_value(self):
458+
async def test_2_callback_returns_value(self):
457459
async def callback(_):
458460
return "Foo"
459461

@@ -481,7 +483,7 @@ def callback(_):
481483
self.assertEqual(await s.with_transaction(callback), "Foo")
482484

483485
@async_client_context.require_transactions
484-
async def test_callback_not_retried_after_timeout(self):
486+
async def test_3_1_callback_not_retried_after_timeout(self):
485487
listener = OvertCommandListener()
486488
client = await self.async_rs_client(event_listeners=[listener])
487489
coll = client[self.db.name].test
@@ -509,7 +511,7 @@ async def callback(session):
509511

510512
@async_client_context.require_test_commands
511513
@async_client_context.require_transactions
512-
async def test_callback_not_retried_after_commit_timeout(self):
514+
async def test_3_2_callback_not_retried_after_commit_timeout(self):
513515
listener = OvertCommandListener()
514516
client = await self.async_rs_client(event_listeners=[listener])
515517
coll = client[self.db.name].test
@@ -543,7 +545,7 @@ async def callback(session):
543545

544546
@async_client_context.require_test_commands
545547
@async_client_context.require_transactions
546-
async def test_commit_not_retried_after_timeout(self):
548+
async def test_3_3_commit_not_retried_after_timeout(self):
547549
listener = OvertCommandListener()
548550
client = await self.async_rs_client(event_listeners=[listener])
549551
coll = client[self.db.name].test
@@ -613,6 +615,72 @@ async def callback(session):
613615
await s.with_transaction(callback)
614616
self.assertFalse(s.in_transaction)
615617

618+
@async_client_context.require_test_commands
619+
@async_client_context.require_transactions
620+
async def test_4_retry_backoff_is_enforced(self):
621+
client = async_client_context.client
622+
coll = client[self.db.name].test
623+
# patch random to make it deterministic -- once to effectively have
624+
# no backoff and the second time with "max" backoff (always waiting the longest
625+
# possible time)
626+
_original_random_random = random.random
627+
628+
def always_one():
629+
return 1
630+
631+
def always_zero():
632+
return 0
633+
634+
random.random = always_zero
635+
# set fail point to trigger transaction failure and trigger backoff
636+
await self.set_fail_point(
637+
{
638+
"configureFailPoint": "failCommand",
639+
"mode": {"times": 13},
640+
"data": {
641+
"failCommands": ["commitTransaction"],
642+
"errorCode": 251,
643+
},
644+
}
645+
)
646+
self.addAsyncCleanup(
647+
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
648+
)
649+
650+
async def callback(session):
651+
await coll.insert_one({}, session=session)
652+
653+
start = time.monotonic()
654+
async with self.client.start_session() as s:
655+
await s.with_transaction(callback)
656+
end = time.monotonic()
657+
no_backoff_time = end - start
658+
659+
random.random = always_one
660+
# set fail point to trigger transaction failure and trigger backoff
661+
await self.set_fail_point(
662+
{
663+
"configureFailPoint": "failCommand",
664+
"mode": {
665+
"times": 13
666+
}, # sufficiently high enough such that the time effect of backoff is noticeable
667+
"data": {
668+
"failCommands": ["commitTransaction"],
669+
"errorCode": 251,
670+
},
671+
}
672+
)
673+
self.addAsyncCleanup(
674+
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
675+
)
676+
start = time.monotonic()
677+
async with self.client.start_session() as s:
678+
await s.with_transaction(callback)
679+
end = time.monotonic()
680+
self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2
681+
682+
random.random = _original_random_random
683+
616684

617685
class TestOptionsInsideTransactionProse(AsyncTransactionsBase):
618686
@async_client_context.require_transactions

test/test_transactions.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from __future__ import annotations
1717

1818
import asyncio
19+
import random
1920
import sys
21+
import time
2022
from io import BytesIO
2123
from test.utils_spec_runner import SpecRunner
2224

@@ -433,7 +435,7 @@ def set_fail_point(self, command_args):
433435
self.configure_fail_point(client, command_args)
434436

435437
@client_context.require_transactions
436-
def test_callback_raises_custom_error(self):
438+
def test_1_callback_raises_custom_error(self):
437439
class _MyException(Exception):
438440
pass
439441

@@ -445,7 +447,7 @@ def raise_error(_):
445447
s.with_transaction(raise_error)
446448

447449
@client_context.require_transactions
448-
def test_callback_returns_value(self):
450+
def test_2_callback_returns_value(self):
449451
def callback(_):
450452
return "Foo"
451453

@@ -473,7 +475,7 @@ def callback(_):
473475
self.assertEqual(s.with_transaction(callback), "Foo")
474476

475477
@client_context.require_transactions
476-
def test_callback_not_retried_after_timeout(self):
478+
def test_3_1_callback_not_retried_after_timeout(self):
477479
listener = OvertCommandListener()
478480
client = self.rs_client(event_listeners=[listener])
479481
coll = client[self.db.name].test
@@ -501,7 +503,7 @@ def callback(session):
501503

502504
@client_context.require_test_commands
503505
@client_context.require_transactions
504-
def test_callback_not_retried_after_commit_timeout(self):
506+
def test_3_2_callback_not_retried_after_commit_timeout(self):
505507
listener = OvertCommandListener()
506508
client = self.rs_client(event_listeners=[listener])
507509
coll = client[self.db.name].test
@@ -533,7 +535,7 @@ def callback(session):
533535

534536
@client_context.require_test_commands
535537
@client_context.require_transactions
536-
def test_commit_not_retried_after_timeout(self):
538+
def test_3_3_commit_not_retried_after_timeout(self):
537539
listener = OvertCommandListener()
538540
client = self.rs_client(event_listeners=[listener])
539541
coll = client[self.db.name].test
@@ -601,6 +603,68 @@ def callback(session):
601603
s.with_transaction(callback)
602604
self.assertFalse(s.in_transaction)
603605

606+
@client_context.require_test_commands
607+
@client_context.require_transactions
608+
def test_4_retry_backoff_is_enforced(self):
609+
client = client_context.client
610+
coll = client[self.db.name].test
611+
# patch random to make it deterministic -- once to effectively have
612+
# no backoff and the second time with "max" backoff (always waiting the longest
613+
# possible time)
614+
_original_random_random = random.random
615+
616+
def always_one():
617+
return 1
618+
619+
def always_zero():
620+
return 0
621+
622+
random.random = always_zero
623+
# set fail point to trigger transaction failure and trigger backoff
624+
self.set_fail_point(
625+
{
626+
"configureFailPoint": "failCommand",
627+
"mode": {"times": 13},
628+
"data": {
629+
"failCommands": ["commitTransaction"],
630+
"errorCode": 251,
631+
},
632+
}
633+
)
634+
self.addCleanup(self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"})
635+
636+
def callback(session):
637+
coll.insert_one({}, session=session)
638+
639+
start = time.monotonic()
640+
with self.client.start_session() as s:
641+
s.with_transaction(callback)
642+
end = time.monotonic()
643+
no_backoff_time = end - start
644+
645+
random.random = always_one
646+
# set fail point to trigger transaction failure and trigger backoff
647+
self.set_fail_point(
648+
{
649+
"configureFailPoint": "failCommand",
650+
"mode": {
651+
"times": 13
652+
}, # sufficiently high enough such that the time effect of backoff is noticeable
653+
"data": {
654+
"failCommands": ["commitTransaction"],
655+
"errorCode": 251,
656+
},
657+
}
658+
)
659+
self.addCleanup(self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"})
660+
start = time.monotonic()
661+
with self.client.start_session() as s:
662+
s.with_transaction(callback)
663+
end = time.monotonic()
664+
self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2
665+
666+
random.random = _original_random_random
667+
604668

605669
class TestOptionsInsideTransactionProse(TransactionsBase):
606670
@client_context.require_transactions

0 commit comments

Comments
 (0)