From defd6aa1ebb47fc3ca5559ac54bf11f295f3e394 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Wed, 7 Jan 2026 13:15:35 +0530 Subject: [PATCH] feat: add requestID info in error exceptions --- google/cloud/spanner_v1/__init__.py | 3 + google/cloud/spanner_v1/_helpers.py | 72 +++++++++++++- google/cloud/spanner_v1/batch.py | 22 +++-- google/cloud/spanner_v1/database.py | 90 +++++++++++++++--- google/cloud/spanner_v1/exceptions.py | 42 ++++++++ google/cloud/spanner_v1/pool.py | 36 +++---- google/cloud/spanner_v1/request_id_header.py | 13 +++ google/cloud/spanner_v1/session.py | 95 ++++++++++--------- google/cloud/spanner_v1/snapshot.py | 72 +++++++++----- google/cloud/spanner_v1/transaction.py | 76 ++++++++------- noxfile.py | 1 + .../test_aborted_transaction.py | 39 +++++--- .../test_dbapi_isolation_level.py | 1 + tests/system/test_observability_options.py | 31 +++--- tests/unit/test_batch.py | 30 +++++- tests/unit/test_database.py | 21 +++- tests/unit/test_database_session_manager.py | 10 +- tests/unit/test_exceptions.py | 65 +++++++++++++ tests/unit/test_pool.py | 15 +++ tests/unit/test_session.py | 74 ++++++++++++--- tests/unit/test_snapshot.py | 39 +++++++- tests/unit/test_spanner.py | 27 ++++++ tests/unit/test_transaction.py | 15 +++ 23 files changed, 697 insertions(+), 192 deletions(-) create mode 100644 google/cloud/spanner_v1/exceptions.py create mode 100644 tests/unit/test_exceptions.py diff --git a/google/cloud/spanner_v1/__init__.py b/google/cloud/spanner_v1/__init__.py index 48b11d9342..4f77269bb2 100644 --- a/google/cloud/spanner_v1/__init__.py +++ b/google/cloud/spanner_v1/__init__.py @@ -65,6 +65,7 @@ from .types.type import TypeCode from .data_types import JsonObject, Interval from .transaction import BatchTransactionId, DefaultTransactionOptions +from .exceptions import wrap_with_request_id from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1.client import Client @@ -88,6 +89,8 @@ # google.cloud.spanner_v1 "__version__", "param_types", + # google.cloud.spanner_v1.exceptions + "wrap_with_request_id", # google.cloud.spanner_v1.client "Client", # google.cloud.spanner_v1.keyset diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 8a200fe812..c5e11dd2bb 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -22,6 +22,7 @@ import threading import logging import uuid +from contextlib import contextmanager from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value @@ -34,8 +35,12 @@ from google.cloud.spanner_v1.types import ExecuteSqlRequest from google.cloud.spanner_v1.types import TransactionOptions from google.cloud.spanner_v1.data_types import JsonObject, Interval -from google.cloud.spanner_v1.request_id_header import with_request_id +from google.cloud.spanner_v1.request_id_header import ( + with_request_id, + with_request_id_metadata_only, +) from google.cloud.spanner_v1.types import TypeCode +from google.cloud.spanner_v1.exceptions import wrap_with_request_id from google.rpc.error_details_pb2 import RetryInfo @@ -612,11 +617,14 @@ def _retry( try: return func() except Exception as exc: - if ( + is_allowed = ( allowed_exceptions is None or exc.__class__ in allowed_exceptions - ) and retries < retry_count: + ) + + if is_allowed and retries < retry_count: if ( allowed_exceptions is not None + and exc.__class__ in allowed_exceptions and allowed_exceptions[exc.__class__] is not None ): allowed_exceptions[exc.__class__](exc) @@ -767,9 +775,67 @@ def reset(self): def _metadata_with_request_id(*args, **kwargs): + """Return metadata with request ID header. + + This function returns only the metadata list (not a tuple), + maintaining backward compatibility with existing code. + + Args: + *args: Arguments to pass to with_request_id + **kwargs: Keyword arguments to pass to with_request_id + + Returns: + list: gRPC metadata with request ID header + """ + return with_request_id_metadata_only(*args, **kwargs) + + +def _metadata_with_request_id_and_req_id(*args, **kwargs): + """Return both metadata and request ID string. + + This is used when we need to augment errors with the request ID. + + Args: + *args: Arguments to pass to with_request_id + **kwargs: Keyword arguments to pass to with_request_id + + Returns: + tuple: (metadata, request_id) + """ return with_request_id(*args, **kwargs) +def _augment_error_with_request_id(error, request_id=None): + """Augment an error with request ID information. + + Args: + error: The error to augment (typically GoogleAPICallError) + request_id (str): The request ID to include + + Returns: + The augmented error with request ID information + """ + return wrap_with_request_id(error, request_id) + + +@contextmanager +def _augment_errors_with_request_id(request_id): + """Context manager to augment exceptions with request ID. + + Args: + request_id (str): The request ID to include in exceptions + + Yields: + None + """ + try: + yield + except Exception as exc: + augmented = _augment_error_with_request_id(exc, request_id) + # Use exception chaining to preserve the original exception + raise augmented from exc + + def _merge_Transaction_Options( defaultTransactionOptions: TransactionOptions, mergeTransactionOptions: TransactionOptions, diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 0792e600dc..e70d214783 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -252,20 +252,22 @@ def wrapped_method(): max_commit_delay=max_commit_delay, request_options=request_options, ) + # This code is retried due to ABORTED, hence nth_request + # should be increased. attempt can only be increased if + # we encounter UNAVAILABLE or INTERNAL. + call_metadata, error_augmenter = database.with_error_augmentation( + getattr(database, "_next_nth_request", 0), + 1, + metadata, + span, + ) commit_method = functools.partial( api.commit, request=commit_request, - metadata=database.metadata_with_request_id( - # This code is retried due to ABORTED, hence nth_request - # should be increased. attempt can only be increased if - # we encounter UNAVAILABLE or INTERNAL. - getattr(database, "_next_nth_request", 0), - 1, - metadata, - span, - ), + metadata=call_metadata, ) - return commit_method() + with error_augmenter: + return commit_method() response = _retry_on_aborted_exception( wrapped_method, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 33c442602c..4977a4abb9 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -25,7 +25,6 @@ import google.auth.credentials from google.api_core.retry import Retry -from google.api_core.retry import if_exception_type from google.cloud.exceptions import NotFound from google.api_core.exceptions import Aborted from google.api_core import gapic_v1 @@ -55,6 +54,8 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _metadata_with_request_id, + _augment_errors_with_request_id, + _metadata_with_request_id_and_req_id, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -496,6 +497,66 @@ def metadata_with_request_id( span, ) + def metadata_and_request_id( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + """Return metadata and request ID string. + + This method returns both the gRPC metadata with request ID header + and the request ID string itself, which can be used to augment errors. + + Args: + nth_request: The request sequence number + nth_attempt: The attempt number (for retries) + prior_metadata: Prior metadata to include + span: Optional span for tracing + + Returns: + tuple: (metadata_list, request_id_string) + """ + if span is None: + span = get_current_span() + + return _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + """Context manager for gRPC calls with error augmentation. + + This context manager provides both metadata with request ID and + automatically augments any exceptions with the request ID. + + Args: + nth_request: The request sequence number + nth_attempt: The attempt number (for retries) + prior_metadata: Prior metadata to include + span: Optional span for tracing + + Yields: + tuple: (metadata_list, context_manager) + """ + if span is None: + span = get_current_span() + + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + + return metadata, _augment_errors_with_request_id(request_id) + def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented @@ -783,16 +844,18 @@ def execute_pdml(): try: add_span_event(span, "Starting BeginTransaction") - txn = api.begin_transaction( - session=session.name, - options=txn_options, - metadata=self.metadata_with_request_id( - self._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = self.with_error_augmentation( + self._next_nth_request, + 1, + metadata, + span, ) + with error_augmenter: + txn = api.begin_transaction( + session=session.name, + options=txn_options, + metadata=call_metadata, + ) txn_selector = TransactionSelector(id=txn.id) @@ -2060,5 +2123,10 @@ def _retry_on_aborted(func, retry_config): :type retry_config: Retry :param retry_config: retry object with the settings to be used """ - retry = retry_config.with_predicate(if_exception_type(Aborted)) + + def _is_aborted(exc): + """Check if exception is Aborted.""" + return isinstance(exc, Aborted) + + retry = retry_config.with_predicate(_is_aborted) return retry(func) diff --git a/google/cloud/spanner_v1/exceptions.py b/google/cloud/spanner_v1/exceptions.py new file mode 100644 index 0000000000..361079b4f2 --- /dev/null +++ b/google/cloud/spanner_v1/exceptions.py @@ -0,0 +1,42 @@ +# Copyright 2026 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cloud Spanner exception utilities with request ID support.""" + +from google.api_core.exceptions import GoogleAPICallError + + +def wrap_with_request_id(error, request_id=None): + """Add request ID information to a GoogleAPICallError. + + This function adds request_id as an attribute to the exception, + preserving the original exception type for exception handling compatibility. + The request_id is also appended to the error message so it appears in logs. + + Args: + error: The error to augment. If not a GoogleAPICallError, returns as-is + request_id (str): The request ID to include + + Returns: + The original error with request_id attribute added and message updated + (if GoogleAPICallError and request_id is provided), otherwise returns + the original error unchanged. + """ + if isinstance(error, GoogleAPICallError) and request_id: + # Add request_id as an attribute for programmatic access + error.request_id = request_id + # Modify the message to include request_id so it appears in logs + if hasattr(error, "message") and error.message: + error.message = f"{error.message}, request_id = {request_id}" + return error diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index a75c13cb7a..348a01e940 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -259,15 +259,17 @@ def bind(self, database): f"Creating {request.session_count} sessions", span_event_attributes, ) - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + database._next_nth_request, + 1, + metadata, + span, ) + with error_augmenter: + resp = api.batch_create_sessions( + request=request, + metadata=call_metadata, + ) add_span_event( span, @@ -570,15 +572,17 @@ def bind(self, database): ) as span, MetricsCapture(): returned_session_count = 0 while returned_session_count < self.size: - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + database._next_nth_request, + 1, + metadata, + span, ) + with error_augmenter: + resp = api.batch_create_sessions( + request=request, + metadata=call_metadata, + ) add_span_event( span, diff --git a/google/cloud/spanner_v1/request_id_header.py b/google/cloud/spanner_v1/request_id_header.py index 95c25b94f7..fb84e56100 100644 --- a/google/cloud/spanner_v1/request_id_header.py +++ b/google/cloud/spanner_v1/request_id_header.py @@ -43,6 +43,19 @@ def with_request_id( all_metadata = (other_metadata or []).copy() all_metadata.append((REQ_ID_HEADER_KEY, req_id)) + if span: + span.set_attribute(X_GOOG_SPANNER_REQUEST_ID_SPAN_ATTR, req_id) + + return all_metadata, req_id + + +def with_request_id_metadata_only( + client_id, channel_id, nth_request, attempt, other_metadata=[], span=None +): + req_id = build_request_id(client_id, channel_id, nth_request, attempt) + all_metadata = (other_metadata or []).copy() + all_metadata.append((REQ_ID_HEADER_KEY, req_id)) + if span: span.set_attribute(X_GOOG_SPANNER_REQUEST_ID_SPAN_ATTR, req_id) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 4c29014e15..e7bc913c27 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -25,13 +25,13 @@ from google.api_core.gapic_v1 import method from google.cloud.spanner_v1._helpers import _delay_until_retry from google.cloud.spanner_v1._helpers import _get_retry_delay - -from google.cloud.spanner_v1 import ExecuteSqlRequest -from google.cloud.spanner_v1 import CreateSessionRequest from google.cloud.spanner_v1._helpers import ( _metadata_with_prefix, _metadata_with_leader_aware_routing, ) + +from google.cloud.spanner_v1 import ExecuteSqlRequest +from google.cloud.spanner_v1 import CreateSessionRequest from google.cloud.spanner_v1._opentelemetry_tracing import ( add_span_event, get_current_span, @@ -185,6 +185,7 @@ def create(self): if self._is_multiplexed else "CloudSpanner.CreateSession" ) + nth_request = database._next_nth_request with trace_call( span_name, self, @@ -192,15 +193,14 @@ def create(self): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - session_pb = api.create_session( - request=create_session_request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, 1, metadata, span ) + with error_augmenter: + session_pb = api.create_session( + request=create_session_request, + metadata=call_metadata, + ) self._session_id = session_pb.name.split("/")[-1] def exists(self): @@ -235,26 +235,26 @@ def exists(self): ) observability_options = getattr(self._database, "observability_options", None) + nth_request = database._next_nth_request with trace_call( "CloudSpanner.GetSession", self, observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - try: - api.get_session( - name=self.name, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), - ) - span.set_attribute("session_found", True) - except NotFound: - span.set_attribute("session_found", False) - return False + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, 1, metadata, span + ) + with error_augmenter: + try: + api.get_session( + name=self.name, + metadata=call_metadata, + ) + span.set_attribute("session_found", True) + except NotFound: + span.set_attribute("session_found", False) + return False return True @@ -288,6 +288,7 @@ def delete(self): api = database.spanner_api metadata = _metadata_with_prefix(database.name) observability_options = getattr(self._database, "observability_options", None) + nth_request = database._next_nth_request with trace_call( "CloudSpanner.DeleteSession", self, @@ -298,15 +299,14 @@ def delete(self): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - api.delete_session( - name=self.name, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, 1, metadata, span ) + with error_augmenter: + api.delete_session( + name=self.name, + metadata=call_metadata, + ) def ping(self): """Ping the session to keep it alive by executing "SELECT 1". @@ -318,18 +318,19 @@ def ping(self): database = self._database api = database.spanner_api + metadata = _metadata_with_prefix(database.name) + nth_request = database._next_nth_request with trace_call("CloudSpanner.Session.ping", self) as span: - request = ExecuteSqlRequest(session=self.name, sql="SELECT 1") - api.execute_sql( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - _metadata_with_prefix(database.name), - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, 1, metadata, span ) + with error_augmenter: + request = ExecuteSqlRequest(session=self.name, sql="SELECT 1") + api.execute_sql( + request=request, + metadata=call_metadata, + ) def snapshot(self, **kw): """Create a snapshot to perform a set of reads with shared staleness. @@ -585,7 +586,10 @@ def run_in_transaction(self, func, *args, **kw): attributes, ) _delay_until_retry( - exc, deadline, attempts, default_retry_delay=default_retry_delay + exc, + deadline, + attempts, + default_retry_delay=default_retry_delay, ) continue @@ -628,7 +632,10 @@ def run_in_transaction(self, func, *args, **kw): attributes, ) _delay_until_retry( - exc, deadline, attempts, default_retry_delay=default_retry_delay + exc, + deadline, + attempts, + default_retry_delay=default_retry_delay, ) except GoogleAPICallError: diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 89cbc9fe88..af34f3ac2f 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -47,6 +47,7 @@ _check_rst_stream_error, _SessionWrapper, AtomicCounter, + _augment_error_with_request_id, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call, add_span_event from google.cloud.spanner_v1.streamed import StreamedResultSet @@ -103,6 +104,7 @@ def _restart_on_unavailable( iterator = None attempt = 1 nth_request = getattr(request_id_manager, "_next_nth_request", 0) + current_request_id = None while True: try: @@ -115,14 +117,18 @@ def _restart_on_unavailable( observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): + ( + call_metadata, + current_request_id, + ) = request_id_manager.metadata_and_request_id( + nth_request, + attempt, + metadata, + span, + ) iterator = method( request=request, - metadata=request_id_manager.metadata_with_request_id( - nth_request, - attempt, - metadata, - span, - ), + metadata=call_metadata, ) # Add items from iterator to buffer. @@ -158,14 +164,18 @@ def _restart_on_unavailable( transaction_selector = transaction._build_transaction_selector_pb() request.transaction = transaction_selector attempt += 1 + ( + call_metadata, + current_request_id, + ) = request_id_manager.metadata_and_request_id( + nth_request, + attempt, + metadata, + span, + ) iterator = method( request=request, - metadata=request_id_manager.metadata_with_request_id( - nth_request, - attempt, - metadata, - span, - ), + metadata=call_metadata, ) continue @@ -175,7 +185,7 @@ def _restart_on_unavailable( for resumable_message in _STREAM_RESUMPTION_INTERNAL_ERROR_MESSAGES ) if not resumable_error: - raise + raise _augment_error_with_request_id(exc, current_request_id) del item_buffer[:] with trace_call( trace_name, @@ -189,17 +199,25 @@ def _restart_on_unavailable( transaction_selector = transaction._build_transaction_selector_pb() attempt += 1 request.transaction = transaction_selector + ( + call_metadata, + current_request_id, + ) = request_id_manager.metadata_and_request_id( + nth_request, + attempt, + metadata, + span, + ) iterator = method( request=request, - metadata=request_id_manager.metadata_with_request_id( - nth_request, - attempt, - metadata, - span, - ), + metadata=call_metadata, ) continue + except Exception as exc: + # Augment any other exception with the request ID + raise _augment_error_with_request_id(exc, current_request_id) + if len(item_buffer) == 0: break @@ -961,17 +979,19 @@ def wrapped_method(): begin_transaction_request = BeginTransactionRequest( **begin_request_kwargs ) + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, + attempt.increment(), + metadata, + span, + ) begin_transaction_method = functools.partial( api.begin_transaction, request=begin_transaction_request, - metadata=database.metadata_with_request_id( - nth_request, - attempt.increment(), - metadata, - span, - ), + metadata=call_metadata, ) - return begin_transaction_method() + with error_augmenter: + return begin_transaction_method() def before_next_retry(nth_retry, delay_in_seconds): add_span_event( diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index de8b421840..413ac0af1f 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -185,18 +185,20 @@ def rollback(self) -> None: def wrapped_method(*args, **kwargs): attempt.increment() + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, + attempt.value, + metadata, + span, + ) rollback_method = functools.partial( api.rollback, session=session.name, transaction_id=self._transaction_id, - metadata=database.metadata_with_request_id( - nth_request, - attempt.value, - metadata, - span, - ), + metadata=call_metadata, ) - return rollback_method(*args, **kwargs) + with error_augmenter: + return rollback_method(*args, **kwargs) _retry( wrapped_method, @@ -298,17 +300,19 @@ def wrapped_method(*args, **kwargs): if is_multiplexed and self._precommit_token is not None: commit_request_args["precommit_token"] = self._precommit_token + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, + attempt.value, + metadata, + span, + ) commit_method = functools.partial( api.commit, request=CommitRequest(**commit_request_args), - metadata=database.metadata_with_request_id( - nth_request, - attempt.value, - metadata, - span, - ), + metadata=call_metadata, ) - return commit_method(*args, **kwargs) + with error_augmenter: + return commit_method(*args, **kwargs) commit_retry_event_name = "Transaction Commit Attempt Failed. Retrying" @@ -335,18 +339,20 @@ def before_next_retry(nth_retry, delay_in_seconds): if commit_response_pb._pb.HasField("precommit_token"): add_span_event(span, commit_retry_event_name) nth_request = database._next_nth_request - commit_response_pb = api.commit( - request=CommitRequest( - precommit_token=commit_response_pb.precommit_token, - **common_commit_request_args, - ), - metadata=database.metadata_with_request_id( - nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, + 1, + metadata, + span, ) + with error_augmenter: + commit_response_pb = api.commit( + request=CommitRequest( + precommit_token=commit_response_pb.precommit_token, + **common_commit_request_args, + ), + metadata=call_metadata, + ) add_span_event(span, "Commit Done") @@ -510,16 +516,18 @@ def execute_update( def wrapped_method(*args, **kwargs): attempt.increment() + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, attempt.value, metadata + ) execute_sql_method = functools.partial( api.execute_sql, request=execute_sql_request, - metadata=database.metadata_with_request_id( - nth_request, attempt.value, metadata - ), + metadata=call_metadata, retry=retry, timeout=timeout, ) - return execute_sql_method(*args, **kwargs) + with error_augmenter: + return execute_sql_method(*args, **kwargs) result_set_pb: ResultSet = self._execute_request( wrapped_method, @@ -658,16 +666,18 @@ def batch_update( def wrapped_method(*args, **kwargs): attempt.increment() + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, attempt.value, metadata + ) execute_batch_dml_method = functools.partial( api.execute_batch_dml, request=execute_batch_dml_request, - metadata=database.metadata_with_request_id( - nth_request, attempt.value, metadata - ), + metadata=call_metadata, retry=retry, timeout=timeout, ) - return execute_batch_dml_method(*args, **kwargs) + with error_augmenter: + return execute_batch_dml_method(*args, **kwargs) response_pb: ExecuteBatchDmlResponse = self._execute_request( wrapped_method, diff --git a/noxfile.py b/noxfile.py index e85fba3c54..2cd172c587 100644 --- a/noxfile.py +++ b/noxfile.py @@ -558,6 +558,7 @@ def prerelease_deps(session, protobuf_implementation, database_dialect): # dependency of google-auth "cffi", "cryptography", + "cachetools", ] for dep in prerel_deps: diff --git a/tests/mockserver_tests/test_aborted_transaction.py b/tests/mockserver_tests/test_aborted_transaction.py index a1f9f1ba1e..7963538c59 100644 --- a/tests/mockserver_tests/test_aborted_transaction.py +++ b/tests/mockserver_tests/test_aborted_transaction.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import random - from google.cloud.spanner_v1 import ( BeginTransactionRequest, CommitRequest, @@ -33,8 +31,19 @@ from test_utils import retry from google.cloud.spanner_v1.database_sessions_manager import TransactionType + +def _is_aborted_error(exc): + """Check if exception is Aborted.""" + return isinstance(exc, exceptions.Aborted) + + +# Retry on Aborted exceptions retry_maybe_aborted_txn = retry.RetryErrors( - exceptions.Aborted, max_tries=5, delay=0, backoff=1 + exceptions.Aborted, + error_predicate=_is_aborted_error, + max_tries=5, + delay=0, + backoff=1, ) @@ -119,17 +128,21 @@ def test_batch_commit_aborted(self): TransactionType.READ_WRITE, ) - @retry_maybe_aborted_txn def test_retry_helper(self): - # Randomly add an Aborted error for the Commit method on the mock server. - if random.random() < 0.5: - add_error(SpannerServicer.Commit.__name__, aborted_status()) - session = self.database.session() - session.create() - transaction = session.transaction() - transaction.begin() - transaction.insert("my_table", ["col1, col2"], [{"col1": 1, "col2": "One"}]) - transaction.commit() + # Add an Aborted error for the Commit method on the mock server. + # The error is popped after the first use, so the retry will succeed. + add_error(SpannerServicer.Commit.__name__, aborted_status()) + + @retry_maybe_aborted_txn + def do_commit(): + session = self.database.session() + session.create() + transaction = session.transaction() + transaction.begin() + transaction.insert("my_table", ["col1, col2"], [{"col1": 1, "col2": "One"}]) + transaction.commit() + + do_commit() def _insert_mutations(transaction: Transaction): diff --git a/tests/mockserver_tests/test_dbapi_isolation_level.py b/tests/mockserver_tests/test_dbapi_isolation_level.py index 679740969a..e912914b19 100644 --- a/tests/mockserver_tests/test_dbapi_isolation_level.py +++ b/tests/mockserver_tests/test_dbapi_isolation_level.py @@ -146,5 +146,6 @@ def test_begin_isolation_level(self): def test_begin_invalid_isolation_level(self): connection = Connection(self.instance, self.database) with connection.cursor() as cursor: + # The Unknown exception has request_id attribute added with self.assertRaises(Unknown): cursor.execute("begin isolation level does_not_exist") diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index 8ebcffcb7f..48a8c8b2ed 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -530,20 +530,23 @@ def test_database_partitioned_error(): if multiplexed_enabled else "CloudSpanner.CreateSession" ) - want_statuses = [ - ( - "CloudSpanner.Database.execute_partitioned_pdml", - codes.ERROR, - "InvalidArgument: 400 Table not found: NonExistent [at 1:8]\nUPDATE NonExistent SET name = 'foo' WHERE id > 1\n ^", - ), - (expected_session_span_name, codes.OK, None), - ( - "CloudSpanner.ExecuteStreamingSql", - codes.ERROR, - "InvalidArgument: 400 Table not found: NonExistent [at 1:8]\nUPDATE NonExistent SET name = 'foo' WHERE id > 1\n ^", - ), - ] - assert got_statuses == want_statuses + expected_error_prefix = "InvalidArgument: 400 Table not found: NonExistent [at 1:8]\nUPDATE NonExistent SET name = 'foo' WHERE id > 1\n ^" + + # Check the statuses - error messages may include request_id suffix + assert len(got_statuses) == 3 + + # First status: execute_partitioned_pdml with error + assert got_statuses[0][0] == "CloudSpanner.Database.execute_partitioned_pdml" + assert got_statuses[0][1] == codes.ERROR + assert got_statuses[0][2].startswith(expected_error_prefix) + + # Second status: session creation OK + assert got_statuses[1] == (expected_session_span_name, codes.OK, None) + + # Third status: ExecuteStreamingSql with error + assert got_statuses[2][0] == "CloudSpanner.ExecuteStreamingSql" + assert got_statuses[2][1] == codes.ERROR + assert got_statuses[2][2].startswith(expected_error_prefix) def _make_credentials(): diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index e8297030eb..d75792fe07 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -41,6 +41,8 @@ from google.cloud.spanner_v1._helpers import ( AtomicCounter, _metadata_with_request_id, + _augment_errors_with_request_id, + _metadata_with_request_id_and_req_id, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID @@ -213,9 +215,13 @@ def test_commit_grpc_error(self, mock_region): batch = self._make_one(session) batch.delete(TABLE_NAME, keyset=keyset) - with self.assertRaises(Unknown): + # Exception has request_id attribute added + with self.assertRaises(Unknown) as context: batch.commit() + # Verify the exception has request_id attribute + self.assertTrue(hasattr(context.exception, "request_id")) + req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" self.assertSpanAttributes( "CloudSpanner.Batch.commit", @@ -281,7 +287,7 @@ def test_commit_ok(self, mock_region): def test_aborted_exception_on_commit_with_retries(self): # Test case to verify that an Aborted exception is raised when # batch.commit() is called and the transaction is aborted internally. - + # The exception has request_id attribute added. database = _Database() # Setup the spanner API which throws Aborted exception when calling commit API. api = database.spanner_api = _FauxSpannerAPI(_aborted_error=True) @@ -294,12 +300,13 @@ def test_aborted_exception_on_commit_with_retries(self): batch = self._make_one(session) batch.insert(TABLE_NAME, COLUMNS, VALUES) - # Assertion: Ensure that calling batch.commit() raises the Aborted exception + # Assertion: Ensure that calling batch.commit() raises Aborted with self.assertRaises(Aborted) as context: batch.commit(timeout_secs=0.1, default_retry_delay=0) - # Verify additional details about the exception - self.assertEqual(str(context.exception), "409 Transaction was aborted") + # Verify exception includes request_id attribute + self.assertIn("409 Transaction was aborted", str(context.exception)) + self.assertTrue(hasattr(context.exception, "request_id")) self.assertGreater( api.commit.call_count, 1, "commit should be called more than once" ) @@ -821,6 +828,19 @@ def metadata_with_request_id( span, ) + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + @property def _channel_id(self): return 1 diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 92001fb52c..929f0c0010 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -34,6 +34,8 @@ from google.cloud.spanner_v1._helpers import ( AtomicCounter, _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID from google.cloud.spanner_v1.session import Session @@ -2265,12 +2267,16 @@ def test_context_mgr_w_aborted_commit_status(self): pool.put(session) checkout = self._make_one(database, timeout_secs=0.1, default_retry_delay=0) - with self.assertRaises(Aborted): + # Exception has request_id attribute added + with self.assertRaises(Aborted) as context: with checkout as batch: self.assertIsNone(pool._session) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) + # Verify the exception has request_id attribute + self.assertTrue(hasattr(context.exception, "request_id")) + self.assertIs(pool._session, session) expected_txn_options = TransactionOptions(read_write={}) @@ -3635,6 +3641,19 @@ def metadata_with_request_id( def _channel_id(self): return 1 + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + class _Pool(object): _bound = None diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index c6156b5e8c..6c90cd62ab 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -208,16 +208,22 @@ def test_exception_bad_request(self): api = manager._database.spanner_api api.create_session.side_effect = BadRequest("") - with self.assertRaises(BadRequest): + # Exception has request_id attribute added + with self.assertRaises(BadRequest) as cm: manager.get_session(TransactionType.READ_ONLY) + # Verify the exception has request_id attribute + self.assertTrue(hasattr(cm.exception, "request_id")) def test_exception_failed_precondition(self): manager = self._manager api = manager._database.spanner_api api.create_session.side_effect = FailedPrecondition("") - with self.assertRaises(FailedPrecondition): + # Exception has request_id attribute added + with self.assertRaises(FailedPrecondition) as cm: manager.get_session(TransactionType.READ_ONLY) + # Verify the exception has request_id attribute + self.assertTrue(hasattr(cm.exception, "request_id")) def test__use_multiplexed_read_only(self): transaction_type = TransactionType.READ_ONLY diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py new file mode 100644 index 0000000000..802928153b --- /dev/null +++ b/tests/unit/test_exceptions.py @@ -0,0 +1,65 @@ +# Copyright 2026 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Spanner exception handling with request IDs.""" + +import unittest + +from google.api_core.exceptions import Aborted +from google.cloud.spanner_v1.exceptions import wrap_with_request_id + + +class TestWrapWithRequestId(unittest.TestCase): + """Test wrap_with_request_id function.""" + + def test_wrap_with_request_id_with_google_api_error(self): + """Test adding request_id to GoogleAPICallError preserves original type.""" + error = Aborted("Transaction aborted") + request_id = "1.12345.1.0.1.1" + + result = wrap_with_request_id(error, request_id) + + # Should return the same error object (not wrapped) + self.assertIs(result, error) + # Should still be the original exception type + self.assertIsInstance(result, Aborted) + # Should have request_id attribute + self.assertEqual(result.request_id, request_id) + # String representation should include request_id + self.assertIn(request_id, str(result)) + self.assertIn("Transaction aborted", str(result)) + + def test_wrap_with_request_id_without_request_id(self): + """Test that without request_id, error is returned unchanged.""" + error = Aborted("Transaction aborted") + + result = wrap_with_request_id(error) + + self.assertIs(result, error) + self.assertFalse(hasattr(result, "request_id")) + + def test_wrap_with_request_id_with_non_google_api_error(self): + """Test that non-GoogleAPICallError is returned unchanged.""" + error = Exception("Some other error") + request_id = "1.12345.1.0.1.1" + + result = wrap_with_request_id(error, request_id) + + # Non-GoogleAPICallError should be returned unchanged + self.assertIs(result, error) + self.assertFalse(hasattr(result, "request_id")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index ec03e4350b..a44bb8c5a1 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -21,6 +21,8 @@ import mock from google.cloud.spanner_v1._helpers import ( _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, AtomicCounter, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID @@ -1450,6 +1452,19 @@ def metadata_with_request_id( def _channel_id(self): return 1 + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + class _Queue(object): _size = 1 diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8026c50c24..6266274e19 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -92,7 +92,11 @@ def inject_into_mock_database(mockdb): def metadata_with_request_id( nth_request, nth_attempt, prior_metadata=[], span=None ): - nth_req = nth_request.fget(mockdb) + # Handle both cases: nth_request as an integer or as a property descriptor + if isinstance(nth_request, int): + nth_req = nth_request + else: + nth_req = nth_request.fget(mockdb) return _metadata_with_request_id( nth_client_id, channel_id, @@ -104,11 +108,45 @@ def metadata_with_request_id( setattr(mockdb, "metadata_with_request_id", metadata_with_request_id) - @property - def _next_nth_request(self): - return self._nth_request.increment() + # Create a property-like object using type() to make it work with mock + type(mockdb)._next_nth_request = property( + lambda self: self._nth_request.increment() + ) + + # Use a closure to capture nth_client_id and channel_id + def make_with_error_augmentation(db_nth_client_id, db_channel_id): + def with_error_augmentation( + nth_request, nth_attempt, prior_metadata=[], span=None + ): + """Context manager for gRPC calls with error augmentation.""" + from google.cloud.spanner_v1._helpers import ( + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, + ) + + if span is None: + from google.cloud.spanner_v1._opentelemetry_tracing import ( + get_current_span, + ) + + span = get_current_span() + + metadata, request_id = _metadata_with_request_id_and_req_id( + db_nth_client_id, + db_channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) - setattr(mockdb, "_next_nth_request", _next_nth_request) + return metadata, _augment_errors_with_request_id(request_id) + + return with_error_augmentation + + mockdb.with_error_augmentation = make_with_error_augmentation( + nth_client_id, channel_id + ) return mockdb @@ -443,8 +481,11 @@ def test_create_error(self, mock_region): database.spanner_api = gax_api session = self._make_one(database) - with self.assertRaises(Unknown): + # Exception has request_id attribute added + with self.assertRaises(Unknown) as cm: session.create() + # Verify the exception has request_id attribute + self.assertTrue(hasattr(cm.exception, "request_id")) req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" self.assertSpanAttributes( @@ -547,8 +588,11 @@ def test_exists_error(self, mock_region): session = self._make_one(database) session._session_id = self.SESSION_ID - with self.assertRaises(Unknown): + # Exception has request_id attribute added + with self.assertRaises(Unknown) as cm: session.exists() + # Verify the exception has request_id attribute + self.assertTrue(hasattr(cm.exception, "request_id")) req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" gax_api.get_session.assert_called_once_with( @@ -1292,8 +1336,10 @@ def unit_of_work(txn, *args, **kw): called_with.append((txn, args, kw)) txn.insert(TABLE_NAME, COLUMNS, VALUES) - with self.assertRaises(Unknown): + # Exception has request_id attribute added + with self.assertRaises(Unknown) as context: session.run_in_transaction(unit_of_work) + self.assertTrue(hasattr(context.exception, "request_id")) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] @@ -1661,8 +1707,10 @@ def _time(_results=[1, 1.5]): with mock.patch("time.time", _time): with mock.patch("time.sleep") as sleep_mock: - with self.assertRaises(Aborted): + # Exception has request_id attribute added + with self.assertRaises(Aborted) as context: session.run_in_transaction(unit_of_work, "abc", timeout_secs=1) + self.assertTrue(hasattr(context.exception, "request_id")) sleep_mock.assert_not_called() @@ -1729,8 +1777,10 @@ def _time(_results=[1, 2, 4, 8]): with mock.patch("time.time", _time), mock.patch( "google.cloud.spanner_v1._helpers.random.random", return_value=0 ), mock.patch("time.sleep") as sleep_mock: - with self.assertRaises(Aborted): + # Exception has request_id attribute added + with self.assertRaises(Aborted) as context: session.run_in_transaction(unit_of_work, timeout_secs=8) + self.assertTrue(hasattr(context.exception, "request_id")) # unpacking call args into list call_args = [call_[0][0] for call_ in sleep_mock.call_args_list] @@ -1928,8 +1978,10 @@ def unit_of_work(txn, *args, **kw): txn.insert(TABLE_NAME, COLUMNS, VALUES) return 42 - with self.assertRaises(Unknown): + # Exception has request_id attribute added + with self.assertRaises(Unknown) as context: session.run_in_transaction(unit_of_work, "abc", some_arg="def") + self.assertTrue(hasattr(context.exception, "request_id")) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 974cc8e75e..62c1c0c40b 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -44,6 +44,8 @@ ) from google.cloud.spanner_v1._helpers import ( _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, AtomicCounter, ) from google.cloud.spanner_v1.param_types import INT64 @@ -297,8 +299,10 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): session = _Session(database) derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) - with self.assertRaises(InternalServerError): + # Exception has request_id attribute added + with self.assertRaises(InternalServerError) as context: list(resumable) + self.assertTrue(hasattr(context.exception, "request_id")) restart.assert_called_once_with( request=request, metadata=[ @@ -371,8 +375,10 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self): session = _Session(database) derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) - with self.assertRaises(InternalServerError): + # Exception has request_id attribute added + with self.assertRaises(InternalServerError) as context: list(resumable) + self.assertTrue(hasattr(context.exception, "request_id")) restart.assert_called_once_with( request=request, metadata=[ @@ -546,8 +552,10 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): session = _Session(database) derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) - with self.assertRaises(InternalServerError): + # Exception has request_id attribute added + with self.assertRaises(InternalServerError) as context: list(resumable) + self.assertTrue(hasattr(context.exception, "request_id")) restart.assert_called_once_with( request=request, metadata=[ @@ -2168,6 +2176,31 @@ def metadata_with_request_id( span, ) + def metadata_and_request_id( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + return _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + @property def _channel_id(self): return 1 diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index d1de23d2d0..ecd7d4fd86 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -42,6 +42,8 @@ _make_value_pb, _merge_query_options, _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID import mock @@ -1319,10 +1321,35 @@ def metadata_with_request_id( span, ) + def metadata_and_request_id( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + return _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + @property def _channel_id(self): return 1 + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + class _Session(object): _transaction = None diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 510251656e..405521509f 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -35,6 +35,8 @@ from google.cloud.spanner_v1._helpers import ( AtomicCounter, _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, ) from google.cloud.spanner_v1.batch import _make_write_pb from google.cloud.spanner_v1.database import Database @@ -1420,6 +1422,19 @@ def metadata_with_request_id( span, ) + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + @property def _channel_id(self): return 1