diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index d472bff13f..3860b9d335 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -7,11 +7,6 @@ # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations from typing import TYPE_CHECKING @@ -22,6 +17,7 @@ from .auth_schemes import AuthSchemeType from .auth_schemes import OpenIdConnectWithConfig from .auth_tool import AuthConfig +from .credential_manager import CredentialManager from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger if TYPE_CHECKING: @@ -48,10 +44,14 @@ async def exchange_auth_token( self, ) -> AuthCredential: exchanger = OAuth2CredentialExchanger() - exchange_result = await exchanger.exchange( - self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme - ) - return exchange_result.credential + + # Restore secret if needed + credential = self.auth_config.exchanged_auth_credential + + with CredentialManager.restore_client_secret(credential): + res = await exchanger.exchange(credential, self.auth_config.auth_scheme) + return res.credential + async def parse_and_store_auth_response(self, state: State) -> None: @@ -183,21 +183,25 @@ def generate_auth_uri( ) scopes = list(scopes.keys()) - client = OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, - scope=" ".join(scopes), - redirect_uri=auth_credential.oauth2.redirect_uri, - ) - params = { - "access_type": "offline", - "prompt": "consent", - } - if auth_credential.oauth2.audience: - params["audience"] = auth_credential.oauth2.audience - uri, state = client.create_authorization_url( - url=authorization_endpoint, **params - ) + client_id = auth_credential.oauth2.client_id + + with CredentialManager.restore_client_secret(auth_credential): + client_secret = auth_credential.oauth2.client_secret + client = OAuth2Session( + client_id, + client_secret, + scope=" ".join(scopes), + redirect_uri=auth_credential.oauth2.redirect_uri, + ) + params = { + "access_type": "offline", + "prompt": "consent", + } + if auth_credential.oauth2.audience: + params["audience"] = auth_credential.oauth2.audience + uri, state = client.create_authorization_url( + url=authorization_endpoint, **params + ) exchanged_auth_credential = auth_credential.model_copy(deep=True) exchanged_auth_credential.oauth2.auth_uri = uri diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py index 0316e5258e..e77c43f6af 100644 --- a/src/google/adk/auth/auth_tool.py +++ b/src/google/adk/auth/auth_tool.py @@ -78,9 +78,16 @@ def get_credential_key(self): ) auth_credential = self.raw_auth_credential - if auth_credential and auth_credential.model_extra: + if auth_credential and ( + auth_credential.model_extra or auth_credential.oauth2 + ): auth_credential = auth_credential.model_copy(deep=True) - auth_credential.model_extra.clear() + if auth_credential.model_extra: + auth_credential.model_extra.clear() + # Normalize secret to ensure stable key regardless of redaction + if auth_credential.oauth2: + auth_credential.oauth2.client_secret = None + credential_name = ( f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}" if auth_credential diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index 2497c7b6b3..1de9963acb 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -14,13 +14,13 @@ from __future__ import annotations +import contextlib import logging from typing import Optional from fastapi.openapi.models import OAuth2 from ..agents.callback_context import CallbackContext -from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger from ..utils.feature_decorator import experimental from .auth_credential import AuthCredential from .auth_credential import AuthCredentialTypes @@ -76,11 +76,23 @@ class CredentialManager: ``` """ + # A map to store client secrets in memory. Key is client_id, value is client_secret + _CLIENT_SECRETS: dict[str, str] = {} + def __init__( self, auth_config: AuthConfig, ): - self._auth_config = auth_config + # We deep copy the auth_config to avoid modifying the original object passed + # by the user. This allows for safe redaction of sensitive information without + # causing side effects. + + self._auth_config = auth_config.model_copy(deep=True) + + # Secure the client secret + self._secure_client_secret(self._auth_config.raw_auth_credential) + self._secure_client_secret(self._auth_config.exchanged_auth_credential) + self._exchanger_registry = CredentialExchangerRegistry() self._refresher_registry = CredentialRefresherRegistry() self._discovery_manager = OAuth2DiscoveryManager() @@ -98,6 +110,8 @@ def __init__( ) # TODO: Move ServiceAccountCredentialExchanger to the auth module + from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger + self._exchanger_registry.register( AuthCredentialTypes.SERVICE_ACCOUNT, ServiceAccountCredentialExchanger(), @@ -111,6 +125,36 @@ def __init__( AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher ) + def _secure_client_secret(self, credential: Optional[AuthCredential]): + """Extracts client secret to memory and redacts it from the credential.""" + if ( + credential + and credential.oauth2 + and credential.oauth2.client_id + and credential.oauth2.client_secret + and credential.oauth2.client_secret != "" + ): + logger.info( + f"Securing client secret for client_id: {credential.oauth2.client_id}" + ) + # Store in memory map + CredentialManager._CLIENT_SECRETS[credential.oauth2.client_id] = ( + credential.oauth2.client_secret + ) + # Redact from config + credential.oauth2.client_secret = "" + else: + if credential and credential.oauth2: + logger.debug( + f"Not securing secret for client_id {credential.oauth2.client_id}:" + f" secret is {credential.oauth2.client_secret}" + ) + + @staticmethod + def get_client_secret(client_id: str) -> Optional[str]: + """Retrieves the client secret for a given client_id.""" + return CredentialManager._CLIENT_SECRETS.get(client_id) + def register_credential_exchanger( self, credential_type: AuthCredentialTypes, @@ -125,6 +169,9 @@ def register_credential_exchanger( self._exchanger_registry.register(credential_type, exchanger_instance) async def request_credential(self, callback_context: CallbackContext) -> None: + # We send the auth_config (which is already redacted in __init__) to the client + # Note: we need to ensure we don't send any stale exchanged credentials if they are not valid + # But usually CredentialManager manages that. callback_context.request_credential(self._auth_config) async def get_auth_credential( @@ -206,6 +253,40 @@ async def _load_from_auth_response( """Load credential from auth response in callback context.""" return callback_context.get_auth_response(self._auth_config) + @staticmethod + @contextlib.contextmanager + def restore_client_secret(credential: AuthCredential, secret: str = None): + """Context manager to temporarily restore client secret in a credential. + + Args: + credential: The credential to restore secret for. + secret: Optional secret to use. If not provided, looks up by client_id. + """ + if not credential or not credential.oauth2: + yield + return + + restored = False + if secret: + credential.oauth2.client_secret = secret + restored = True + elif ( + credential.oauth2.client_id + and credential.oauth2.client_secret == "" + ): + stored_secret = CredentialManager.get_client_secret( + credential.oauth2.client_id + ) + if stored_secret: + credential.oauth2.client_secret = stored_secret + restored = True + + try: + yield + finally: + if restored: + credential.oauth2.client_secret = "" + async def _exchange_credential( self, credential: AuthCredential ) -> tuple[AuthCredential, bool]: @@ -214,6 +295,8 @@ async def _exchange_credential( if not exchanger: return credential, False + from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger + if isinstance(exchanger, ServiceAccountCredentialExchanger): return ( exchanger.exchange_credential( @@ -221,11 +304,12 @@ async def _exchange_credential( ), True, ) - - exchange_result = await exchanger.exchange( - credential, self._auth_config.auth_scheme - ) - return exchange_result.credential, exchange_result.was_exchanged + else: + with self.restore_client_secret(credential): + exchanged_credential = await exchanger.exchange( + credential, self._auth_config.auth_scheme + ) + return exchanged_credential, True async def _refresh_credential( self, credential: AuthCredential diff --git a/tests/unittests/auth/test_auth_handler_secrets.py b/tests/unittests/auth/test_auth_handler_secrets.py new file mode 100644 index 0000000000..11d8eb9d39 --- /dev/null +++ b/tests/unittests/auth/test_auth_handler_secrets.py @@ -0,0 +1,131 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_handler import AuthHandler +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +from google.adk.auth.exchanger.base_credential_exchanger import ExchangeResult +import pytest + + +class TestAuthHandlerSecrets: + + @pytest.fixture(autouse=True) + def clear_credential_manager_secrets(self): + """Clear CredentialManager secrets buffer before/after each test.""" + CredentialManager._CLIENT_SECRETS = {} + yield + CredentialManager._CLIENT_SECRETS = {} + + @pytest.mark.asyncio + async def test_exchange_auth_token_restores_and_reredacts_secret(self): + client_id = "test_client_id" + secret = "super_secret_value" + + # Setup secure storage + CredentialManager._CLIENT_SECRETS[client_id] = secret + + # Create credential with redacted secret + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(client_id=client_id, client_secret=""), + ) + + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = credential + auth_config.auth_scheme = Mock() + + handler = AuthHandler(auth_config) + + # Mock exchanger + mock_exchanger = AsyncMock() + + # Check secret inside exchange + def check_secret(cred, scheme): + assert cred.oauth2.client_secret == secret + return ExchangeResult(cred, True) + + mock_exchanger.exchange.side_effect = check_secret + + with patch( + "google.adk.auth.auth_handler.OAuth2CredentialExchanger", + return_value=mock_exchanger, + ): + await handler.exchange_auth_token() + + # Verify secret is re-redacted + assert credential.oauth2.client_secret == "" + + def test_generate_auth_uri_uses_restored_secret(self): + client_id = "test_client_id" + secret = "super_secret_value" + + # Setup secure storage + CredentialManager._CLIENT_SECRETS[client_id] = secret + + # Create credential with redacted secret + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=client_id, + client_secret="", + redirect_uri="http://localhost/callback", + ), + ) + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = credential + auth_config.auth_scheme = Mock() + # Mock flows for scopes + auth_config.auth_scheme.flows.implicit = None + auth_config.auth_scheme.flows.clientCredentials = None + auth_config.auth_scheme.flows.password = None + auth_config.auth_scheme.flows.authorizationCode.scopes = {"scope": "desc"} + auth_config.auth_scheme.flows.authorizationCode.authorizationUrl = ( + "http://auth" + ) + + handler = AuthHandler(auth_config) + + # Mock OAuth2Session + with ( + patch("google.adk.auth.auth_handler.OAuth2Session") as mock_session_cls, + patch("google.adk.auth.auth_handler.AUTHLIB_AVAILABLE", True), + ): + + mock_session = Mock() + mock_session.create_authorization_url.return_value = ( + "http://auth?param=1", + "state", + ) + mock_session_cls.return_value = mock_session + + handler.generate_auth_uri() + + # Verify session was created with the REAL secret, not redacted one + mock_session_cls.assert_called_with( + client_id, + secret, + scope="scope", + redirect_uri="http://localhost/callback", + ) diff --git a/tests/unittests/auth/test_auth_tool_key_stability.py b/tests/unittests/auth/test_auth_tool_key_stability.py new file mode 100644 index 0000000000..827ffbb98e --- /dev/null +++ b/tests/unittests/auth/test_auth_tool_key_stability.py @@ -0,0 +1,73 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TestAuthToolKeyStability(unittest.TestCase): + + def test_key_stability_with_different_secrets(self): + from google.adk.auth.auth_schemes import AuthSchemeType + from google.adk.auth.auth_schemes import OAuth2 + + # Consistent scheme for both + auth_scheme = OAuth2(type=AuthSchemeType.oauth2, flows={}) + + # Config 1: Real secret + auth_credential_1 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client_id", client_secret="real_secret", auth_uri="uri" + ), + ) + config1 = AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential_1 + ) + + # Config 2: Redacted secret + auth_credential_2 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client_id", client_secret="", auth_uri="uri" + ), + ) + config2 = AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential_2 + ) + + # Keys should be identical + key1 = config1.credential_key + key2 = config2.credential_key + + self.assertEqual(key1, key2, f"Keys should match! {key1} vs {key2}") diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py index ab021d1eaa..53ff1cd2cc 100644 --- a/tests/unittests/auth/test_credential_manager.py +++ b/tests/unittests/auth/test_credential_manager.py @@ -31,24 +31,35 @@ from google.adk.auth.auth_schemes import ExtendedOAuth2 from google.adk.auth.auth_tool import AuthConfig from google.adk.auth.credential_manager import CredentialManager -from google.adk.auth.credential_manager import ServiceAccountCredentialExchanger from google.adk.auth.oauth2_discovery import AuthorizationServerMetadata +from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger import pytest +def create_auth_config_mock(): + """Creates a mock AuthConfig that returns itself on model_copy.""" + # We remove spec=AuthConfig because accessing Pydantic fields on a spec-ed mock + # can fail if they are not seen as class attributes or if we need dynamic attributes. + m = Mock() + m.spec = AuthConfig # Optional: if we want isinstance to work, but Mock(spec=X) enforces attributes. + # Let's just use a plain Mock and configure what we need. + m.model_copy.side_effect = lambda **kwargs: m + return m + + class TestCredentialManager: """Test suite for CredentialManager.""" def test_init(self): """Test CredentialManager initialization.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() manager = CredentialManager(auth_config) assert manager._auth_config == auth_config @pytest.mark.asyncio async def test_request_credential(self): """Test request_credential method.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() callback_context = Mock() callback_context.request_credential = Mock() @@ -61,7 +72,7 @@ async def test_request_credential(self): async def test_load_auth_credentials_success(self): """Test load_auth_credential with successful flow.""" # Create mocks - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None @@ -104,7 +115,7 @@ async def test_load_auth_credentials_success(self): @pytest.mark.asyncio async def test_load_auth_credentials_no_credential(self): """Test load_auth_credential when no credential is available.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None # Add auth_scheme for the _is_client_credentials_flow method @@ -134,8 +145,9 @@ async def test_load_auth_credentials_no_credential(self): @pytest.mark.asyncio async def test_load_existing_credential_already_exchanged(self): """Test _load_existing_credential when credential is already exchanged.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() mock_credential = Mock(spec=AuthCredential) + mock_credential.oauth2 = Mock() auth_config.exchanged_auth_credential = mock_credential callback_context = Mock() @@ -150,7 +162,7 @@ async def test_load_existing_credential_already_exchanged(self): @pytest.mark.asyncio async def test_load_existing_credential_with_credential_service(self): """Test _load_existing_credential with credential service.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.exchanged_auth_credential = None mock_credential = Mock(spec=AuthCredential) @@ -172,7 +184,7 @@ async def test_load_existing_credential_with_credential_service(self): @pytest.mark.asyncio async def test_load_from_credential_service_with_service(self): """Test _load_from_credential_service from callback context when credential service is available.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() mock_credential = Mock(spec=AuthCredential) @@ -196,7 +208,7 @@ async def test_load_from_credential_service_with_service(self): @pytest.mark.asyncio async def test_load_from_credential_service_no_service(self): """Test _load_from_credential_service when no credential service is available.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() # Mock invocation context with no credential service invocation_context = Mock() @@ -213,7 +225,7 @@ async def test_load_from_credential_service_no_service(self): @pytest.mark.asyncio async def test_save_credential_with_service(self): """Test _save_credential with credential service.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() mock_credential = Mock(spec=AuthCredential) # Mock credential service @@ -236,7 +248,7 @@ async def test_save_credential_with_service(self): @pytest.mark.asyncio async def test_save_credential_no_service(self): """Test _save_credential when no credential service is available.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.exchanged_auth_credential = None mock_credential = Mock(spec=AuthCredential) @@ -260,9 +272,10 @@ async def test_refresh_credential_oauth2(self): mock_oauth2_auth = Mock(spec=OAuth2Auth) mock_credential = Mock(spec=AuthCredential) + mock_credential.oauth2 = Mock() mock_credential.auth_type = AuthCredentialTypes.OAUTH2 - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = Mock() # Mock refresher @@ -297,7 +310,7 @@ async def test_refresh_credential_no_refresher(self): mock_credential = Mock(spec=AuthCredential) mock_credential.auth_type = AuthCredentialTypes.API_KEY - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() manager = CredentialManager(auth_config) @@ -316,9 +329,10 @@ async def test_refresh_credential_no_refresher(self): async def test_is_credential_ready_api_key(self): """Test _is_credential_ready with API key credential.""" mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.oauth2 = Mock() mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = mock_raw_credential manager = CredentialManager(auth_config) @@ -330,9 +344,10 @@ async def test_is_credential_ready_api_key(self): async def test_is_credential_ready_oauth2(self): """Test _is_credential_ready with OAuth2 credential (needs processing).""" mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.oauth2 = Mock() mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = mock_raw_credential manager = CredentialManager(auth_config) @@ -346,7 +361,7 @@ async def test_validate_credential_no_raw_credential_oauth2(self): auth_scheme = Mock() auth_scheme.type_ = AuthSchemeType.oauth2 - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.auth_scheme = auth_scheme @@ -361,7 +376,7 @@ async def test_validate_credential_no_raw_credential_openid(self): auth_scheme = Mock() auth_scheme.type_ = AuthSchemeType.openIdConnect - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.auth_scheme = auth_scheme @@ -376,7 +391,7 @@ async def test_validate_credential_no_raw_credential_other_scheme(self): auth_scheme = Mock() auth_scheme.type_ = AuthSchemeType.apiKey - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.auth_scheme = auth_scheme @@ -392,7 +407,7 @@ async def test_validate_credential_oauth2_missing_oauth2_field(self): mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 mock_raw_credential.oauth2 = None - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = mock_raw_credential auth_config.auth_scheme = Mock() @@ -407,10 +422,10 @@ async def test_validate_credential_oauth2_missing_scheme_info( ): """Test _validate_credential with OAuth2 missing scheme info.""" mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.oauth2 = Mock() mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 - mock_raw_credential.oauth2 = Mock(spec=OAuth2Auth) - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = mock_raw_credential auth_config.auth_scheme = extended_oauth2_scheme @@ -428,7 +443,7 @@ async def test_exchange_credentials_service_account( self, service_account_credential, oauth2_auth_scheme ): """Test _exchange_credential with service account credential.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = oauth2_auth_scheme exchanged_credential = Mock(spec=AuthCredential) @@ -457,7 +472,7 @@ async def test_exchange_credential_no_exchanger(self): mock_credential = Mock(spec=AuthCredential) mock_credential.auth_type = AuthCredentialTypes.API_KEY - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() manager = CredentialManager(auth_config) @@ -513,7 +528,7 @@ async def test_populate_auth_scheme_success( self, auth_server_metadata, extended_oauth2_scheme ): """Test _populate_auth_scheme successfully populates missing info.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = extended_oauth2_scheme manager = CredentialManager(auth_config) @@ -536,7 +551,7 @@ async def test_populate_auth_scheme_success( @pytest.mark.asyncio async def test_populate_auth_scheme_fail(self, extended_oauth2_scheme): """Test _populate_auth_scheme when auto-discovery fails.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = extended_oauth2_scheme manager = CredentialManager(auth_config) @@ -555,7 +570,7 @@ async def test_populate_auth_scheme_fail(self, extended_oauth2_scheme): @pytest.mark.asyncio async def test_populate_auth_scheme_noop(self, implicit_oauth2_scheme): """Test _populate_auth_scheme when auth scheme info not missing.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = implicit_oauth2_scheme manager = CredentialManager(auth_config) @@ -578,7 +593,7 @@ def test_is_client_credentials_flow_oauth2_with_client_credentials(self): ) ) - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = auth_scheme auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None @@ -603,7 +618,7 @@ def test_is_client_credentials_flow_oauth2_without_client_credentials(self): ) ) - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = auth_scheme auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None @@ -623,7 +638,7 @@ def test_is_client_credentials_flow_oidc_with_client_credentials(self): grant_types_supported=["authorization_code", "client_credentials"], ) - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = auth_scheme auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None @@ -643,7 +658,7 @@ def test_is_client_credentials_flow_oidc_without_client_credentials(self): grant_types_supported=["authorization_code"], ) - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = auth_scheme auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None @@ -657,7 +672,7 @@ def test_is_client_credentials_flow_other_scheme(self): # Create a non-OAuth2/OIDC scheme auth_scheme = Mock() - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = auth_scheme auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None diff --git a/tests/unittests/auth/test_credential_manager_secrets.py b/tests/unittests/auth/test_credential_manager_secrets.py new file mode 100644 index 0000000000..ea5d210818 --- /dev/null +++ b/tests/unittests/auth/test_credential_manager_secrets.py @@ -0,0 +1,184 @@ +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +import pytest + + +@pytest.fixture(autouse=True) +def clear_credential_manager_secrets(): + """Clear CredentialManager secrets buffer before/after each test.""" + CredentialManager._CLIENT_SECRETS = {} + yield + CredentialManager._CLIENT_SECRETS = {} + + +@pytest.mark.asyncio +async def test_credential_manager_redacts_secrets_in_raw_credential(): + """Test that CredentialManager redacts client_secret from raw_auth_credential upon initialization.""" + + # Setup + client_id = "test_client_id" + client_secret = "test_client_secret" + + oauth_auth = OAuth2Auth(client_id=client_id, client_secret=client_secret) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth_auth + ) + + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + ) + ) + ) + + auth_config = AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential + ) + + # Act + manager = CredentialManager(auth_config) + + # Assert + # 1. Check if secret is in memory map + assert client_id in manager._CLIENT_SECRETS + assert manager._CLIENT_SECRETS[client_id] == client_secret + + # 2. Check if secret is redacted in the manager's config + assert ( + manager._auth_config.raw_auth_credential.oauth2.client_secret + == "" + ) + + # 3. Check original config is NOT modified (AuthConfig copy behavior) + # Since we used model_copy(deep=True), calling on Pydantic model copies it. + assert auth_config.raw_auth_credential.oauth2.client_secret == client_secret + + +@pytest.mark.asyncio +async def test_credential_manager_redacts_secrets_in_exchanged_credential(): + """Test that CredentialManager redacts client_secret from exchanged_auth_credential if present.""" + + # Setup + client_id = "test_client_id_exchanged" + client_secret = "test_client_secret_exchanged" + + oauth_auth = OAuth2Auth( + client_id=client_id, + client_secret=client_secret, + access_token="some_token", + ) + + exchanged_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth_auth + ) + + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + ) + ) + ) + + auth_config = AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=None, + exchanged_auth_credential=exchanged_credential, + ) + + # Act + manager = CredentialManager(auth_config) + + # Assert + assert client_id in manager._CLIENT_SECRETS + assert manager._CLIENT_SECRETS[client_id] == client_secret + + assert ( + manager._auth_config.exchanged_auth_credential.oauth2.client_secret + == "" + ) + + +@pytest.mark.asyncio +async def test_exchange_credential_restores_secret(): + """Test that _exchange_credential restores the secret before calling exchanger.""" + + # Setup + client_id = "test_client_id_exchange" + client_secret = "test_client_secret_exchange" + + oauth_auth = OAuth2Auth(client_id=client_id, client_secret=client_secret) + + raw_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth_auth + ) + + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + ) + ) + ) + + auth_config = AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=raw_credential + ) + + manager = CredentialManager(auth_config) + + # Secret should be redacted now + assert ( + manager._auth_config.raw_auth_credential.oauth2.client_secret + == "" + ) + + # Prepare a credential to be exchanged (e.g. from client response, has no secret or redacted) + credential_to_exchange = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=client_id, + client_secret="", # or None + auth_code="some_code", + ), + ) + + # Mock exchanger + mock_exchanger = AsyncMock() + + # We use side_effect to verify the secret at the moment of call, because the object is mutated later + def check_secret(cred, scheme): + assert cred.oauth2.client_secret == client_secret + return credential_to_exchange + + mock_exchanger.exchange.side_effect = check_secret + + with patch.object( + manager._exchanger_registry, "get_exchanger", return_value=mock_exchanger + ): + # Act + result_credential, exchanged = await manager._exchange_credential( + credential_to_exchange + ) + + # Assert + # Verification happened in side_effect + assert mock_exchanger.exchange.called + + # Check that the result credential (modified in place or returned) has secret REDACTED again + assert result_credential.oauth2.client_secret == "" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 1284e73bce..bdfaa7444e 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -84,10 +84,13 @@ def test_init_with_auth(self): # Create real auth scheme instances instead of mocks from fastapi.openapi.models import OAuth2 + test_client_secret = "test_secret" auth_scheme = OAuth2(flows={}) auth_credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + oauth2=OAuth2Auth( + client_id="test_id", client_secret=test_client_secret + ), ) tool = MCPTool( @@ -100,6 +103,15 @@ def test_init_with_auth(self): # The auth config is stored in the parent class _credentials_manager assert tool._credentials_manager is not None assert tool._credentials_manager._auth_config.auth_scheme == auth_scheme + assert ( + tool._credentials_manager._auth_config.raw_auth_credential.oauth2.client_secret + == "" + ) + + # Restore the client secret and validate it's the same credential in the end. + tool._credentials_manager._auth_config.raw_auth_credential.oauth2.client_secret = ( + test_client_secret + ) assert ( tool._credentials_manager._auth_config.raw_auth_credential == auth_credential