Skip to content

Commit 8782a69

Browse files
sully90seanzhougoogle
authored andcommitted
feat: add token_endpoint_auth_method support to OAuth2 credentials
Merge #2870 ## Summary Add `token_endpoint_auth_method` field to OAuth2Auth class to allow configuring OAuth2 token endpoint authentication methods. This enables users to specify how the client should authenticate with the authorization server's token endpoint. • Add `token_endpoint_auth_method` field to `OAuth2Auth` with default value `"client_secret_basic"` • Update `create_oauth2_session()` to pass the authentication method to `OAuth2Session` • Maintain backward compatibility with existing OAuth2 configurations ## Unit Tests Added unit test coverage with 3 new test methods: 1. `test_create_oauth2_session_with_token_endpoint_auth_method()` - Tests explicit auth method setting (`client_secret_post`) 2. `test_create_oauth2_session_with_default_token_endpoint_auth_method()` - Tests default behavior (`client_secret_basic`) 3. `test_create_oauth2_session_oauth2_scheme_with_token_endpoint_auth_method()` - Tests with OAuth2 scheme using `client_secret_jwt` **Test Results:** ✅ 16/16 OAuth2 credential utility tests passed ✅ 240/240 auth module tests passed (no regressions) ✅ Tests cover both GOOGLE_AI and VERTEX variants ✅ Pylint score: 9.41/10 ## Changes Made **src/google/adk/auth/auth_credential.py** - Added `token_endpoint_auth_method: Optional[str] = "client_secret_basic"` to `OAuth2Auth` class **src/google/adk/auth/oauth2_credential_util.py** - Updated `create_oauth2_session()` to pass `token_endpoint_auth_method` parameter to `OAuth2Session` **tests/unittests/auth/test_oauth2_credential_util.py** - Added 3 comprehensive test methods covering different authentication scenarios ## Backward Compatibility ✅ **Non-breaking change** - All existing OAuth2 configurations continue to work unchanged with the default `client_secret_basic` authentication method. ## Supported Authentication Methods - `client_secret_basic` (default) - Client credentials in Authorization header - `client_secret_post` - Client credentials in request body - `client_secret_jwt` - JWT with client secret - `private_key_jwt` - JWT with private key Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> COPYBARA_INTEGRATE_REVIEW=#2870 from sully90:feat/oauth2-token-endpoint-auth-method 04fe824 PiperOrigin-RevId: 843739984
1 parent 29c1115 commit 8782a69

File tree

3 files changed

+103
-8
lines changed

3 files changed

+103
-8
lines changed

src/google/adk/auth/auth_credential.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any
1919
from typing import Dict
2020
from typing import List
21+
from typing import Literal
2122
from typing import Optional
2223

2324
from pydantic import alias_generators
@@ -80,6 +81,14 @@ class OAuth2Auth(BaseModelWithConfig):
8081
expires_at: Optional[int] = None
8182
expires_in: Optional[int] = None
8283
audience: Optional[str] = None
84+
token_endpoint_auth_method: Optional[
85+
Literal[
86+
"client_secret_basic",
87+
"client_secret_post",
88+
"client_secret_jwt",
89+
"private_key_jwt",
90+
]
91+
] = "client_secret_basic"
8392

8493

8594
class ServiceAccountCredential(BaseModelWithConfig):

src/google/adk/auth/oauth2_credential_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def create_oauth2_session(
9191
scope=" ".join(scopes),
9292
redirect_uri=auth_credential.oauth2.redirect_uri,
9393
state=auth_credential.oauth2.state,
94+
token_endpoint_auth_method=auth_credential.oauth2.token_endpoint_auth_method,
9495
),
9596
token_endpoint,
9697
)

tests/unittests/auth/test_oauth2_credential_util.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import time
16+
from typing import Optional
1617
from unittest.mock import Mock
1718

1819
from authlib.oauth2.rfc6749 import OAuth2Token
@@ -25,6 +26,39 @@
2526
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
2627
from google.adk.auth.oauth2_credential_util import create_oauth2_session
2728
from google.adk.auth.oauth2_credential_util import update_credential_with_tokens
29+
import pytest
30+
31+
32+
@pytest.fixture
33+
def openid_connect_scheme() -> OpenIdConnectWithConfig:
34+
"""Fixture providing a standard OpenIdConnectWithConfig scheme."""
35+
return OpenIdConnectWithConfig(
36+
type_="openIdConnect",
37+
openId_connect_url="https://example.com/.well-known/openid_configuration",
38+
authorization_endpoint="https://example.com/auth",
39+
token_endpoint="https://example.com/token",
40+
scopes=["openid", "profile"],
41+
)
42+
43+
44+
def create_oauth2_auth_credential(
45+
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
46+
token_endpoint_auth_method: Optional[str] = None,
47+
):
48+
"""Helper function to create OAuth2Auth credential with optional token_endpoint_auth_method."""
49+
oauth2_auth = OAuth2Auth(
50+
client_id="test_client_id",
51+
client_secret="test_client_secret",
52+
redirect_uri="https://example.com/callback",
53+
state="test_state",
54+
)
55+
if token_endpoint_auth_method is not None:
56+
oauth2_auth.token_endpoint_auth_method = token_endpoint_auth_method
57+
58+
return AuthCredential(
59+
auth_type=auth_type,
60+
oauth2=oauth2_auth,
61+
)
2862

2963

3064
class TestOAuth2CredentialUtil:
@@ -41,14 +75,9 @@ def test_create_oauth2_session_openid_connect(self):
4175
token_endpoint="https://example.com/token",
4276
scopes=["openid", "profile"],
4377
)
44-
credential = AuthCredential(
45-
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
46-
oauth2=OAuth2Auth(
47-
client_id="test_client_id",
48-
client_secret="test_client_secret",
49-
redirect_uri="https://example.com/callback",
50-
state="test_state",
51-
),
78+
credential = create_oauth2_auth_credential(
79+
auth_type=AuthCredentialTypes.OAUTH2,
80+
token_endpoint_auth_method="client_secret_jwt",
5281
)
5382

5483
client, token_endpoint = create_oauth2_session(scheme, credential)
@@ -122,6 +151,62 @@ def test_create_oauth2_session_missing_credentials(self):
122151
assert client is None
123152
assert token_endpoint is None
124153

154+
@pytest.mark.parametrize(
155+
"token_endpoint_auth_method, expected_auth_method",
156+
[
157+
("client_secret_post", "client_secret_post"),
158+
(None, "client_secret_basic"),
159+
],
160+
)
161+
def test_create_oauth2_session_with_token_endpoint_auth_method(
162+
self,
163+
openid_connect_scheme,
164+
token_endpoint_auth_method,
165+
expected_auth_method,
166+
):
167+
"""Test create_oauth2_session with various token_endpoint_auth_method settings."""
168+
credential = create_oauth2_auth_credential(
169+
token_endpoint_auth_method=token_endpoint_auth_method
170+
)
171+
172+
client, token_endpoint = create_oauth2_session(
173+
openid_connect_scheme, credential
174+
)
175+
176+
assert client is not None
177+
assert token_endpoint == "https://example.com/token"
178+
assert client.client_id == "test_client_id"
179+
assert client.client_secret == "test_client_secret"
180+
assert client.token_endpoint_auth_method == expected_auth_method
181+
182+
def test_create_oauth2_session_oauth2_scheme_with_token_endpoint_auth_method(
183+
self,
184+
):
185+
"""Test create_oauth2_session with OAuth2 scheme and token_endpoint_auth_method."""
186+
flows = OAuthFlows(
187+
authorizationCode=OAuthFlowAuthorizationCode(
188+
authorizationUrl="https://example.com/auth",
189+
tokenUrl="https://example.com/token",
190+
scopes={"read": "Read access", "write": "Write access"},
191+
)
192+
)
193+
scheme = OAuth2(type_="oauth2", flows=flows)
194+
credential = AuthCredential(
195+
auth_type=AuthCredentialTypes.OAUTH2,
196+
oauth2=OAuth2Auth(
197+
client_id="test_client_id",
198+
client_secret="test_client_secret",
199+
redirect_uri="https://example.com/callback",
200+
token_endpoint_auth_method="client_secret_jwt",
201+
),
202+
)
203+
204+
client, token_endpoint = create_oauth2_session(scheme, credential)
205+
206+
assert client is not None
207+
assert token_endpoint == "https://example.com/token"
208+
assert client.token_endpoint_auth_method == "client_secret_jwt"
209+
125210
def test_update_credential_with_tokens(self):
126211
"""Test update_credential_with_tokens function."""
127212
credential = AuthCredential(

0 commit comments

Comments
 (0)