From f0de9e70e920c79760b58fae4f3c494e057069dc Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Fri, 21 Nov 2025 17:27:04 +0530 Subject: [PATCH 01/16] - Implemented SEP-990 feature for providing support for Enterprise Managed Auth support. - Written unit test cases for client and server implementation of the enterprise managed auth code. --- src/mcp/client/auth/extensions/__init__.py | 20 + .../extensions/enterprise_managed_auth.py | 415 ++++++++++++++ src/mcp/server/auth/extensions/__init__.py | 16 + .../extensions/enterprise_managed_auth.py | 277 ++++++++++ .../test_enterprise_managed_auth_client.py | 471 ++++++++++++++++ .../test_enterprise_managed_auth_server.py | 523 ++++++++++++++++++ 6 files changed, 1722 insertions(+) create mode 100644 src/mcp/client/auth/extensions/enterprise_managed_auth.py create mode 100644 src/mcp/server/auth/extensions/__init__.py create mode 100644 src/mcp/server/auth/extensions/enterprise_managed_auth.py create mode 100644 tests/client/auth/test_enterprise_managed_auth_client.py create mode 100644 tests/server/auth/test_enterprise_managed_auth_server.py diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py index e69de29bb2..7b3ece607d 100644 --- a/src/mcp/client/auth/extensions/__init__.py +++ b/src/mcp/client/auth/extensions/__init__.py @@ -0,0 +1,20 @@ +"""MCP Client Auth Extensions.""" + +from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + IDJAGClaims, + TokenExchangeParameters, + TokenExchangeResponse, + decode_id_jag, + validate_token_exchange_params, +) + +__all__ = [ + "EnterpriseAuthOAuthClientProvider", + "IDJAGClaims", + "TokenExchangeParameters", + "TokenExchangeResponse", + "decode_id_jag", + "validate_token_exchange_params", +] + diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py new file mode 100644 index 0000000000..d5fb9e56ef --- /dev/null +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -0,0 +1,415 @@ +""" +Enterprise Managed Authorization extension for MCP (SEP-990). + +Implements RFC 8693 Token Exchange and RFC 7523 JWT Bearer Grant for +enterprise SSO integration. +""" + +import logging +from typing import Any + +import httpx +from pydantic import BaseModel, Field + +from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage +from mcp.shared.auth import OAuthClientMetadata, OAuthToken + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Data Models +# ============================================================================ + + +class TokenExchangeParameters(BaseModel): + """Parameters for RFC 8693 Token Exchange request.""" + + requested_token_type: str = Field( + default="urn:ietf:params:oauth:token-type:id-jag", + description="Type of token being requested (ID-JAG)", + ) + + audience: str = Field( + ..., + description="Issuer URL of the MCP Server's authorization server", + ) + + resource: str = Field( + ..., + description="RFC 9728 Resource Identifier of the MCP Server", + ) + + scope: str | None = Field( + default=None, + description="Space-separated list of scopes being requested", + ) + + subject_token: str = Field( + ..., + description="ID Token or SAML assertion for the end user", + ) + + subject_token_type: str = Field( + ..., + description="Type of subject token (id_token or saml2)", + ) + + @classmethod + def from_id_token( + cls, + id_token: str, + mcp_server_auth_issuer: str, + mcp_server_resource_id: str, + scope: str | None = None, + ) -> "TokenExchangeParameters": + """Create parameters for OIDC ID Token exchange.""" + return cls( + subject_token=id_token, + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience=mcp_server_auth_issuer, + resource=mcp_server_resource_id, + scope=scope, + ) + + @classmethod + def from_saml_assertion( + cls, + saml_assertion: str, + mcp_server_auth_issuer: str, + mcp_server_resource_id: str, + scope: str | None = None, + ) -> "TokenExchangeParameters": + """Create parameters for SAML assertion exchange.""" + return cls( + subject_token=saml_assertion, + subject_token_type="urn:ietf:params:oauth:token-type:saml2", + audience=mcp_server_auth_issuer, + resource=mcp_server_resource_id, + scope=scope, + ) + + +class TokenExchangeResponse(BaseModel): + """Response from RFC 8693 Token Exchange.""" + + issued_token_type: str = Field( + ..., + description="Type of token issued (should be id-jag)", + ) + + access_token: str = Field( + ..., + description="The ID-JAG token (named access_token per RFC 8693)", + ) + + token_type: str = Field( + ..., + description="Token type (should be N_A for ID-JAG)", + ) + + scope: str | None = Field( + default=None, + description="Granted scopes", + ) + + expires_in: int | None = Field( + default=None, + description="Lifetime in seconds", + ) + + @property + def id_jag(self) -> str: + """Get the ID-JAG token.""" + return self.access_token + + +class IDJAGClaims(BaseModel): + """Claims structure for Identity Assertion JWT Authorization Grant.""" + + model_config = {"extra": "allow"} + + # JWT header + typ: str = Field( + ..., + description="JWT type - must be 'oauth-id-jag+jwt'", + ) + + # Required claims + jti: str = Field(..., description="Unique JWT ID") + iss: str = Field(..., description="IdP issuer URL") + sub: str = Field(..., description="Subject (user) identifier") + aud: str = Field(..., description="MCP Server's auth server issuer") + resource: str = Field(..., description="MCP Server resource identifier") + client_id: str = Field(..., description="MCP Client identifier") + exp: int = Field(..., description="Expiration timestamp") + iat: int = Field(..., description="Issued-at timestamp") + + # Optional claims + scope: str | None = Field(None, description="Space-separated scopes") + email: str | None = Field(None, description="User email") + + +class EnterpriseAuthOAuthClientProvider(OAuthClientProvider): + """ + OAuth client provider for Enterprise Managed Authorization (SEP-990). + + Implements: + - RFC 8693: Token Exchange (ID Token → ID-JAG) + - RFC 7523: JWT Bearer Grant (ID-JAG → Access Token) + """ + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + idp_token_endpoint: str, + token_exchange_params: TokenExchangeParameters, + redirect_handler: Any = None, + callback_handler: Any = None, + timeout: float = 300.0, + ) -> None: + """ + Initialize Enterprise Auth OAuth Client. + + Args: + server_url: MCP server URL + client_metadata: OAuth client metadata + storage: Token storage implementation + idp_token_endpoint: Enterprise IdP token endpoint URL + token_exchange_params: Token exchange parameters + redirect_handler: Optional redirect handler + callback_handler: Optional callback handler + timeout: Request timeout in seconds + """ + super().__init__( + server_url=server_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + timeout=timeout, + ) + self.idp_token_endpoint = idp_token_endpoint + self.token_exchange_params = token_exchange_params + self._id_jag: str | None = None + + async def exchange_token_for_id_jag( + self, + client: httpx.AsyncClient, + ) -> str: + """ + Exchange ID Token for ID-JAG using RFC 8693 Token Exchange. + + Args: + client: HTTP client for making requests + + Returns: + The ID-JAG token string + + Raises: + OAuthTokenError: If token exchange fails + """ + logger.info("Starting token exchange for ID-JAG") + + # Build token exchange request + token_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "requested_token_type": self.token_exchange_params.requested_token_type, + "audience": self.token_exchange_params.audience, + "resource": self.token_exchange_params.resource, + "subject_token": self.token_exchange_params.subject_token, + "subject_token_type": self.token_exchange_params.subject_token_type, + } + + if self.token_exchange_params.scope: + token_data["scope"] = self.token_exchange_params.scope + + # Add client authentication if needed + if self.context.client_info: + token_data["client_id"] = self.context.client_info.client_id + if self.context.client_info.client_secret: + token_data["client_secret"] = self.context.client_info.client_secret + + try: + response = await client.post( + self.idp_token_endpoint, + data=token_data, + timeout=self.context.timeout, + ) + + if response.status_code != 200: + error_data = ( + response.json() + if response.headers.get("content-type", "").startswith("application/json") + else {} + ) + error = error_data.get("error", "unknown_error") + error_description = error_data.get("error_description", "Token exchange failed") + raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}") + + # Parse response + token_response = TokenExchangeResponse.model_validate_json(response.content) + + # Validate response + if token_response.issued_token_type != "urn:ietf:params:oauth:token-type:id-jag": + raise OAuthTokenError(f"Unexpected token type: {token_response.issued_token_type}") + + if token_response.token_type != "N_A": + logger.warning(f"Expected token_type 'N_A', got '{token_response.token_type}'") + + logger.info("Successfully obtained ID-JAG") + self._id_jag = token_response.id_jag + return token_response.id_jag + + except httpx.HTTPError as e: + raise OAuthTokenError(f"HTTP error during token exchange: {e}") from e + + async def exchange_id_jag_for_access_token( + self, + client: httpx.AsyncClient, + id_jag: str, + ) -> OAuthToken: + """ + Exchange ID-JAG for access token using RFC 7523 JWT Bearer Grant. + + Args: + client: HTTP client for making requests + id_jag: The ID-JAG token + + Returns: + OAuth access token + + Raises: + OAuthTokenError: If JWT bearer grant fails + """ + logger.info("Exchanging ID-JAG for access token") + + # Discover token endpoint from MCP server if not already done + if not self.context.oauth_metadata or not self.context.oauth_metadata.token_endpoint: + raise OAuthFlowError("MCP server token endpoint not discovered") + + token_endpoint = str(self.context.oauth_metadata.token_endpoint) + + # Build JWT bearer grant request + token_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": id_jag, + } + + # Add client authentication + if self.context.client_info: + token_data["client_id"] = self.context.client_info.client_id + if self.context.client_info.client_secret: + token_data["client_secret"] = self.context.client_info.client_secret + + try: + response = await client.post( + token_endpoint, + data=token_data, + timeout=self.context.timeout, + ) + + if response.status_code != 200: + error_data = ( + response.json() + if response.headers.get("content-type", "").startswith("application/json") + else {} + ) + error = error_data.get("error", "unknown_error") + error_description = error_data.get("error_description", "JWT bearer grant failed") + raise OAuthTokenError(f"JWT bearer grant failed: {error} - {error_description}") + + # Parse OAuth token response + token = OAuthToken.model_validate_json(response.content) + + # Store tokens + self.context.current_tokens = token + self.context.update_token_expiry(token) + await self.context.storage.set_tokens(token) + + logger.info("Successfully obtained access token via ID-JAG") + return token + + except httpx.HTTPError as e: + raise OAuthTokenError(f"HTTP error during JWT bearer grant: {e}") from e + + async def _perform_authorization(self) -> httpx.Request: + """ + Perform enterprise authorization flow. + + Overrides parent method to use token exchange + JWT bearer grant + instead of standard authorization code flow. + """ + # Check if we already have valid tokens + if self.context.is_token_valid(): + # Return a dummy request - we don't need to make any request + return httpx.Request("GET", self.context.server_url) + + # For now, raise NotImplementedError as this requires integration + # with the full httpx auth flow + raise NotImplementedError( + "Full enterprise auth flow integration not yet implemented. " + "Use exchange_token_for_id_jag and exchange_id_jag_for_access_token directly." + ) + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def decode_id_jag(id_jag: str, verify: bool = False) -> IDJAGClaims: + """ + Decode an ID-JAG token without verification. + + Args: + id_jag: The ID-JAG token string + verify: Whether to verify signature (requires key) + + Returns: + Decoded ID-JAG claims + + Note: + For verification, use server-side validation instead. + """ + import jwt + + # Decode without verification for inspection + claims = jwt.decode(id_jag, options={"verify_signature": False}) + header = jwt.get_unverified_header(id_jag) + + # Add typ from header to claims + claims["typ"] = header.get("typ", "") + + return IDJAGClaims.model_validate(claims) + + +def validate_token_exchange_params( + params: TokenExchangeParameters, +) -> None: + """ + Validate token exchange parameters. + + Args: + params: Token exchange parameters to validate + + Raises: + ValueError: If parameters are invalid + """ + if not params.subject_token: + raise ValueError("subject_token is required") + + if not params.audience: + raise ValueError("audience is required") + + if not params.resource: + raise ValueError("resource is required") + + if params.subject_token_type not in [ + "urn:ietf:params:oauth:token-type:id_token", + "urn:ietf:params:oauth:token-type:saml2", + ]: + raise ValueError(f"Invalid subject_token_type: {params.subject_token_type}") + diff --git a/src/mcp/server/auth/extensions/__init__.py b/src/mcp/server/auth/extensions/__init__.py new file mode 100644 index 0000000000..2d75f43b9f --- /dev/null +++ b/src/mcp/server/auth/extensions/__init__.py @@ -0,0 +1,16 @@ +"""MCP Server Auth Extensions.""" + +from mcp.server.auth.extensions.enterprise_managed_auth import ( + IDJAGClaims, + IDJAGValidator, + JWTValidationConfig, + ReplayPreventionStore, +) + +__all__ = [ + "IDJAGClaims", + "IDJAGValidator", + "JWTValidationConfig", + "ReplayPreventionStore", +] + diff --git a/src/mcp/server/auth/extensions/enterprise_managed_auth.py b/src/mcp/server/auth/extensions/enterprise_managed_auth.py new file mode 100644 index 0000000000..084394a127 --- /dev/null +++ b/src/mcp/server/auth/extensions/enterprise_managed_auth.py @@ -0,0 +1,277 @@ +""" +Server-side Enterprise Managed Authorization (SEP-990). + +Implements JWT validation for ID-JAG tokens and JWT bearer grant handling. +""" + +import logging +import time +from typing import Any + +import jwt +from jwt import PyJWKClient +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Configuration Models +# ============================================================================ + + +class JWTValidationConfig(BaseModel): + """Configuration for JWT validation.""" + + trusted_idp_issuers: list[str] = Field( + ..., + description="List of trusted IdP issuer URLs", + ) + + server_auth_issuer: str = Field( + ..., + description="This server's authorization server issuer URL", + ) + + server_resource_id: str = Field( + ..., + description="This server's resource identifier", + ) + + jwks_uri: str | None = Field( + default=None, + description="JWKS URI for key verification (if single IdP)", + ) + + jwks_cache_ttl: int = Field( + default=3600, + description="JWKS cache TTL in seconds", + ) + + allowed_algorithms: list[str] = Field( + default=["RS256", "ES256"], + description="Allowed JWT signing algorithms", + ) + + replay_prevention_enabled: bool = Field( + default=True, + description="Enable JTI-based replay prevention", + ) + + replay_cache_ttl: int = Field( + default=3600, + description="Replay cache TTL in seconds", + ) + + clock_skew_seconds: int = Field( + default=60, + description="Allowed clock skew for exp/iat validation", + ) + + +class IDJAGClaims(BaseModel): + """Validated ID-JAG claims.""" + + model_config = {"extra": "allow"} + + # JWT header + typ: str + + # Required claims + jti: str + iss: str + sub: str + aud: str + resource: str + client_id: str + exp: int + iat: int + + # Optional claims + scope: str | None = None + email: str | None = None + + +# ============================================================================ +# Replay Prevention +# ============================================================================ + + +class ReplayPreventionStore: + """In-memory store for replay prevention (production should use Redis/similar).""" + + def __init__(self, ttl: int = 3600): + self._used_jtis: dict[str, float] = {} + self._ttl = ttl + + def mark_used(self, jti: str) -> None: + """Mark a JTI as used.""" + self._cleanup() + self._used_jtis[jti] = time.time() + + def is_used(self, jti: str) -> bool: + """Check if a JTI has been used.""" + self._cleanup() + return jti in self._used_jtis + + def _cleanup(self) -> None: + """Remove expired entries.""" + now = time.time() + self._used_jtis = { + jti: timestamp for jti, timestamp in self._used_jtis.items() if now - timestamp < self._ttl + } + + +# ============================================================================ +# JWT Validator +# ============================================================================ + + +class IDJAGValidator: + """Validator for ID-JAG tokens.""" + + def __init__(self, config: JWTValidationConfig): + self.config = config + self.replay_store = ReplayPreventionStore(ttl=config.replay_cache_ttl) + + # Initialize JWKS client if provided + self.jwks_client: PyJWKClient | None = None + if config.jwks_uri: + self.jwks_client = PyJWKClient( + config.jwks_uri, + cache_keys=True, + max_cached_keys=16, + cache_jwk_set=True, + lifespan=config.jwks_cache_ttl, + ) + + def validate_id_jag( + self, + id_jag: str, + expected_client_id: str, + ) -> IDJAGClaims: + """ + Validate an ID-JAG token. + + Args: + id_jag: The ID-JAG token to validate + expected_client_id: The client_id from client authentication + + Returns: + Validated ID-JAG claims + + Raises: + jwt.InvalidTokenError: If validation fails + ValueError: If claims are invalid + """ + # Step 1: Decode and get header + header = jwt.get_unverified_header(id_jag) + + # Validate typ header + if header.get("typ") != "oauth-id-jag+jwt": + raise ValueError(f"Invalid typ header: expected 'oauth-id-jag+jwt', got '{header.get('typ')}'") + + # Step 2: Get signing key + if self.jwks_client: + signing_key = self.jwks_client.get_signing_key_from_jwt(id_jag) + key = signing_key.key + else: + # For testing/development - decode without verification + logger.warning("No JWKS client configured - skipping signature verification") + key = None + + # Step 3: Decode and verify JWT + try: + claims = jwt.decode( + id_jag, + key, + algorithms=self.config.allowed_algorithms, + options={ + "verify_signature": key is not None, + "verify_exp": True, + "verify_iat": True, + }, + leeway=self.config.clock_skew_seconds, + ) + except jwt.ExpiredSignatureError: + raise ValueError("ID-JAG has expired") + except jwt.InvalidTokenError as e: + raise ValueError(f"Invalid ID-JAG: {e}") + + # Step 4: Validate issuer + if claims.get("iss") not in self.config.trusted_idp_issuers: + raise ValueError(f"Untrusted issuer: {claims.get('iss')}") + + # Step 5: Validate audience + if claims.get("aud") != self.config.server_auth_issuer: + raise ValueError( + f"Invalid audience: expected '{self.config.server_auth_issuer}', " + f"got '{claims.get('aud')}'" + ) + + # Step 6: Validate resource + if claims.get("resource") != self.config.server_resource_id: + raise ValueError( + f"Invalid resource: expected '{self.config.server_resource_id}', " + f"got '{claims.get('resource')}'" + ) + + # Step 7: Validate client_id + if claims.get("client_id") != expected_client_id: + raise ValueError( + f"client_id mismatch: expected '{expected_client_id}', " f"got '{claims.get('client_id')}'" + ) + + # Step 8: Check for replay (if enabled) + jti = claims.get("jti") + if not jti: + raise ValueError("Missing jti claim") + + if self.config.replay_prevention_enabled: + if self.replay_store.is_used(jti): + raise ValueError(f"Token replay detected: jti '{jti}' already used") + self.replay_store.mark_used(jti) + + # Step 9: Create validated claims object + claims["typ"] = header["typ"] + return IDJAGClaims.model_validate(claims) + + async def handle_jwt_bearer_grant( + self, + assertion: str, + client_id: str, + ) -> dict[str, Any]: + """ + Handle JWT bearer grant request. + + Args: + assertion: The ID-JAG assertion + client_id: Authenticated client ID + + Returns: + Token response dict + + Raises: + ValueError: If validation fails + """ + # Validate ID-JAG + claims = self.validate_id_jag(assertion, client_id) + + # TODO: Generate and return access token + # This is where you'd integrate with your token generation logic + logger.info( + "JWT bearer grant validated successfully", + extra={ + "client_id": client_id, + "sub": claims.sub, + "scope": claims.scope, + }, + ) + + return { + "token_type": "Bearer", + "access_token": "generated_access_token_here", + "expires_in": 3600, + "scope": claims.scope, + } + diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py new file mode 100644 index 0000000000..7472830877 --- /dev/null +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -0,0 +1,471 @@ +"""Tests for Enterprise Managed Authorization client-side implementation.""" + +import time +from unittest.mock import AsyncMock, Mock + +import httpx +import jwt +import pytest + +from mcp.client.auth import OAuthTokenError +from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + IDJAGClaims, + TokenExchangeParameters, + TokenExchangeResponse, + decode_id_jag, + validate_token_exchange_params, +) +from mcp.shared.auth import OAuthClientMetadata + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def sample_id_token(): + """Generate a sample ID token for testing.""" + payload = { + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "mcp-client-app", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + "email": "user@example.com", + } + return jwt.encode(payload, "secret", algorithm="HS256") + + +@pytest.fixture +def sample_id_jag(): + """Generate a sample ID-JAG token for testing.""" + payload = { + "jti": "unique-jwt-id-12345", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "mcp-client-app", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "scope": "read write", + } + token = jwt.encode(payload, "secret", algorithm="HS256") + + # Manually add typ to header + header = jwt.get_unverified_header(token) + header["typ"] = "oauth-id-jag+jwt" + + return jwt.encode(payload, "secret", algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"}) + + +@pytest.fixture +def mock_token_storage(): + """Create a mock token storage.""" + storage = Mock() + storage.get_tokens = AsyncMock(return_value=None) + storage.set_tokens = AsyncMock() + storage.get_client_info = AsyncMock(return_value=None) + storage.set_client_info = AsyncMock() + return storage + + +# ============================================================================ +# Tests for TokenExchangeParameters +# ============================================================================ + + +def test_token_exchange_params_from_id_token(): + """Test creating TokenExchangeParameters from ID token.""" + params = TokenExchangeParameters.from_id_token( + id_token="eyJhbGc...", + mcp_server_auth_issuer="https://auth.server.example/", + mcp_server_resource_id="https://server.example/", + scope="read write", + ) + + assert params.subject_token == "eyJhbGc..." + assert params.subject_token_type == "urn:ietf:params:oauth:token-type:id_token" + assert params.audience == "https://auth.server.example/" + assert params.resource == "https://server.example/" + assert params.scope == "read write" + assert params.requested_token_type == "urn:ietf:params:oauth:token-type:id-jag" + + +def test_token_exchange_params_from_saml_assertion(): + """Test creating TokenExchangeParameters from SAML assertion.""" + params = TokenExchangeParameters.from_saml_assertion( + saml_assertion="...", + mcp_server_auth_issuer="https://auth.server.example/", + mcp_server_resource_id="https://server.example/", + scope="read", + ) + + assert params.subject_token == "..." + assert params.subject_token_type == "urn:ietf:params:oauth:token-type:saml2" + assert params.audience == "https://auth.server.example/" + assert params.resource == "https://server.example/" + assert params.scope == "read" + + +def test_validate_token_exchange_params_valid(): + """Test validating valid token exchange parameters.""" + params = TokenExchangeParameters.from_id_token( + id_token="token", + mcp_server_auth_issuer="https://auth.example/", + mcp_server_resource_id="https://server.example/", + ) + + # Should not raise + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_invalid_token_type(): + """Test validation fails for invalid subject token type.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="invalid:type", + audience="https://auth.example/", + resource="https://server.example/", + ) + + with pytest.raises(ValueError, match="Invalid subject_token_type"): + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_missing_subject_token(): + """Test validation fails for missing subject token.""" + params = TokenExchangeParameters( + subject_token="", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="https://auth.example/", + resource="https://server.example/", + ) + + with pytest.raises(ValueError, match="subject_token is required"): + validate_token_exchange_params(params) + + +# ============================================================================ +# Tests for TokenExchangeResponse +# ============================================================================ + + +def test_token_exchange_response_parsing(): + """Test parsing token exchange response.""" + response_json = """{ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": "eyJhbGc...", + "token_type": "N_A", + "scope": "read write", + "expires_in": 300 + }""" + + response = TokenExchangeResponse.model_validate_json(response_json) + + assert response.issued_token_type == "urn:ietf:params:oauth:token-type:id-jag" + assert response.id_jag == "eyJhbGc..." + assert response.access_token == "eyJhbGc..." + assert response.token_type == "N_A" + assert response.scope == "read write" + assert response.expires_in == 300 + + +def test_token_exchange_response_id_jag_property(): + """Test id_jag property returns access_token.""" + response = TokenExchangeResponse( + issued_token_type="urn:ietf:params:oauth:token-type:id-jag", + access_token="the-id-jag-token", + token_type="N_A", + ) + + assert response.id_jag == "the-id-jag-token" + + +# ============================================================================ +# Tests for IDJAGClaims +# ============================================================================ + + +def test_decode_id_jag(sample_id_jag): + """Test decoding ID-JAG token.""" + claims = decode_id_jag(sample_id_jag) + + assert claims.iss == "https://idp.example.com" + assert claims.sub == "user123" + assert claims.aud == "https://auth.mcp-server.example/" + assert claims.resource == "https://mcp-server.example/" + assert claims.client_id == "mcp-client-app" + assert claims.scope == "read write" + + +def test_id_jag_claims_with_extra_fields(): + """Test IDJAGClaims allows extra fields.""" + claims_data = { + "typ": "oauth-id-jag+jwt", + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.server.example/", + "resource": "https://server.example/", + "client_id": "client123", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "scope": "read", + "email": "user@example.com", + "custom_claim": "custom_value", # Extra field + } + + claims = IDJAGClaims.model_validate(claims_data) + assert claims.email == "user@example.com" + # Extra field should be preserved + assert claims.model_extra.get("custom_claim") == "custom_value" + + +# ============================================================================ +# Tests for EnterpriseAuthOAuthClientProvider +# ============================================================================ + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_success(sample_id_token, sample_id_jag, mock_token_storage): + """Test successful token exchange for ID-JAG.""" + # Create provider + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify + assert id_jag == sample_id_jag + assert provider._id_jag == sample_id_jag + + # Verify request was made correctly + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert call_args[0][0] == "https://idp.example.com/oauth2/token" + assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange" + assert call_args[1]["data"]["requested_token_type"] == "urn:ietf:params:oauth:token-type:id-jag" + assert call_args[1]["data"]["audience"] == "https://auth.mcp-server.example/" + assert call_args[1]["data"]["resource"] == "https://mcp-server.example/" + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_error(sample_id_token, mock_token_storage): + """Test token exchange failure handling.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock error response + mock_response = httpx.Response( + status_code=400, + json={ + "error": "invalid_request", + "error_description": "Invalid subject token", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="Token exchange failed"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token, mock_token_storage): + """Test token exchange with unexpected token type.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock response with wrong token type + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "access_token": "some-token", + "token_type": "Bearer", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="Unexpected token type"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_id_jag_for_access_token_success(sample_id_jag, mock_token_storage): + """Test successful JWT bearer grant to get access token.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + from mcp.shared.auth import OAuthMetadata + from pydantic import HttpUrl + + provider.context.oauth_metadata = OAuthMetadata( + issuer=HttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify + assert token.access_token == "mcp-access-token-12345" + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + + # Verify tokens were stored + mock_token_storage.set_tokens.assert_called_once() + + # Verify request was made correctly + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert call_args[1]["data"]["assertion"] == sample_id_jag + + +@pytest.mark.anyio +async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag, mock_token_storage): + """Test JWT bearer grant fails without OAuth metadata.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # No OAuth metadata set + mock_client = Mock(spec=httpx.AsyncClient) + + # Should raise OAuthFlowError + from mcp.client.auth import OAuthFlowError + + with pytest.raises(OAuthFlowError, match="token endpoint not discovered"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_perform_authorization_not_implemented(mock_token_storage): + """Test that _perform_authorization raises NotImplementedError.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Should raise NotImplementedError + with pytest.raises(NotImplementedError, match="not yet implemented"): + await provider._perform_authorization() + diff --git a/tests/server/auth/test_enterprise_managed_auth_server.py b/tests/server/auth/test_enterprise_managed_auth_server.py new file mode 100644 index 0000000000..908d389387 --- /dev/null +++ b/tests/server/auth/test_enterprise_managed_auth_server.py @@ -0,0 +1,523 @@ +"""Tests for Enterprise Managed Authorization server-side implementation.""" + +import time +from unittest.mock import patch + +import jwt +import pytest + +from src.mcp.server.auth.extensions.enterprise_managed_auth import ( + IDJAGClaims, + IDJAGValidator, + JWTValidationConfig, + ReplayPreventionStore, +) + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def jwt_validation_config(): + """Create a basic JWT validation config.""" + return JWTValidationConfig( + trusted_idp_issuers=["https://idp.example.com"], + server_auth_issuer="https://auth.mcp-server.example/", + server_resource_id="https://mcp-server.example/", + replay_prevention_enabled=True, + ) + + +@pytest.fixture +def valid_id_jag_claims(): + """Create valid ID-JAG claims.""" + return { + "jti": "unique-jwt-id-12345", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "mcp-client-app", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "scope": "read write", + "email": "user@example.com", + } + + +@pytest.fixture +def create_id_jag(valid_id_jag_claims): + """Factory to create ID-JAG tokens.""" + def _create(claims=None, secret="test-secret"): + claims_data = valid_id_jag_claims.copy() + if claims: + claims_data.update(claims) + return jwt.encode( + claims_data, + secret, + algorithm="HS256", + headers={"typ": "oauth-id-jag+jwt"}, + ) + return _create + + +# ============================================================================ +# Tests for ReplayPreventionStore +# ============================================================================ + + +def test_replay_prevention_store_mark_and_check(): + """Test marking JTI as used and checking.""" + store = ReplayPreventionStore(ttl=3600) + + jti = "test-jti-123" + + # Initially not used + assert not store.is_used(jti) + + # Mark as used + store.mark_used(jti) + + # Now should be used + assert store.is_used(jti) + + +def test_replay_prevention_store_cleanup(): + """Test that expired JTIs are cleaned up.""" + store = ReplayPreventionStore(ttl=1) # 1 second TTL + + jti1 = "test-jti-1" + jti2 = "test-jti-2" + + # Mark first JTI + store.mark_used(jti1) + assert store.is_used(jti1) + + # Wait for expiry + time.sleep(1.1) + + # Mark second JTI (triggers cleanup) + store.mark_used(jti2) + + # First JTI should be cleaned up + assert not store.is_used(jti1) + + # Second JTI should still be there + assert store.is_used(jti2) + + +def test_replay_prevention_store_multiple_jtis(): + """Test storing multiple JTIs.""" + store = ReplayPreventionStore(ttl=3600) + + jtis = [f"jti-{i}" for i in range(10)] + + for jti in jtis: + store.mark_used(jti) + + for jti in jtis: + assert store.is_used(jti) + + +# ============================================================================ +# Tests for JWTValidationConfig +# ============================================================================ + + +def test_jwt_validation_config_defaults(): + """Test JWT validation config with default values.""" + config = JWTValidationConfig( + trusted_idp_issuers=["https://idp.example.com"], + server_auth_issuer="https://auth.server.example/", + server_resource_id="https://server.example/", + ) + + assert config.jwks_uri is None + assert config.jwks_cache_ttl == 3600 + assert config.allowed_algorithms == ["RS256", "ES256"] + assert config.replay_prevention_enabled is True + assert config.replay_cache_ttl == 3600 + assert config.clock_skew_seconds == 60 + + +def test_jwt_validation_config_custom_values(): + """Test JWT validation config with custom values.""" + config = JWTValidationConfig( + trusted_idp_issuers=["https://idp1.example.com", "https://idp2.example.com"], + server_auth_issuer="https://auth.server.example/", + server_resource_id="https://server.example/", + jwks_uri="https://idp.example.com/.well-known/jwks.json", + jwks_cache_ttl=7200, + allowed_algorithms=["RS256"], + replay_prevention_enabled=False, + replay_cache_ttl=1800, + clock_skew_seconds=120, + ) + + assert len(config.trusted_idp_issuers) == 2 + assert config.jwks_uri == "https://idp.example.com/.well-known/jwks.json" + assert config.jwks_cache_ttl == 7200 + assert config.allowed_algorithms == ["RS256"] + assert config.replay_prevention_enabled is False + assert config.replay_cache_ttl == 1800 + assert config.clock_skew_seconds == 120 + + +# ============================================================================ +# Tests for IDJAGClaims +# ============================================================================ + + +def test_id_jag_claims_required_fields(valid_id_jag_claims): + """Test IDJAGClaims with all required fields.""" + claims = IDJAGClaims.model_validate({**valid_id_jag_claims, "typ": "oauth-id-jag+jwt"}) + + assert claims.typ == "oauth-id-jag+jwt" + assert claims.jti == "unique-jwt-id-12345" + assert claims.iss == "https://idp.example.com" + assert claims.sub == "user123" + assert claims.aud == "https://auth.mcp-server.example/" + assert claims.resource == "https://mcp-server.example/" + assert claims.client_id == "mcp-client-app" + assert claims.scope == "read write" + assert claims.email == "user@example.com" + + +def test_id_jag_claims_optional_fields(): + """Test IDJAGClaims without optional fields.""" + claims_data = { + "typ": "oauth-id-jag+jwt", + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.server.example/", + "resource": "https://server.example/", + "client_id": "client123", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + } + + claims = IDJAGClaims.model_validate(claims_data) + assert claims.scope is None + assert claims.email is None + + +def test_id_jag_claims_extra_fields(): + """Test that IDJAGClaims allows extra fields.""" + claims_data = { + "typ": "oauth-id-jag+jwt", + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.server.example/", + "resource": "https://server.example/", + "client_id": "client123", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "custom_field": "custom_value", + "another_field": 123, + } + + claims = IDJAGClaims.model_validate(claims_data) + assert claims.model_extra.get("custom_field") == "custom_value" + assert claims.model_extra.get("another_field") == 123 + + +# ============================================================================ +# Tests for IDJAGValidator +# ============================================================================ + + +def test_id_jag_validator_initialization(jwt_validation_config): + """Test IDJAGValidator initialization.""" + validator = IDJAGValidator(jwt_validation_config) + + assert validator.config == jwt_validation_config + assert isinstance(validator.replay_store, ReplayPreventionStore) + assert validator.jwks_client is None # No JWKS URI provided + + +def test_id_jag_validator_with_jwks(): + """Test IDJAGValidator initialization with JWKS URI.""" + config = JWTValidationConfig( + trusted_idp_issuers=["https://idp.example.com"], + server_auth_issuer="https://auth.server.example/", + server_resource_id="https://server.example/", + jwks_uri="https://idp.example.com/.well-known/jwks.json", + ) + + validator = IDJAGValidator(config) + + assert validator.jwks_client is not None + + +def test_validate_id_jag_success(jwt_validation_config, create_id_jag): + """Test successful ID-JAG validation (without signature verification).""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + # Mock the JWT decode to skip signature verification + with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.return_value = { + "jti": "unique-jwt-id-12345", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "mcp-client-app", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "scope": "read write", + } + + claims = validator.validate_id_jag(id_jag, expected_client_id="mcp-client-app") + + assert claims.jti == "unique-jwt-id-12345" + assert claims.iss == "https://idp.example.com" + assert claims.sub == "user123" + assert claims.client_id == "mcp-client-app" + + +def test_validate_id_jag_invalid_typ_header(jwt_validation_config, create_id_jag): + """Test validation fails with invalid typ header.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = jwt.encode( + {"iss": "https://idp.example.com"}, + "secret", + algorithm="HS256", + headers={"typ": "JWT"}, # Wrong typ + ) + + with pytest.raises(ValueError, match="Invalid typ header"): + validator.validate_id_jag(id_jag, expected_client_id="client") + + +def test_validate_id_jag_expired(jwt_validation_config, create_id_jag): + """Test validation fails for expired token.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag(claims={"exp": int(time.time()) - 100}) # Expired + + with patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + + # jwt.decode will raise ExpiredSignatureError + with pytest.raises(ValueError, match="ID-JAG has expired"): + validator.validate_id_jag(id_jag, expected_client_id="client") + + +def test_validate_id_jag_untrusted_issuer(jwt_validation_config, create_id_jag): + """Test validation fails for untrusted issuer.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.return_value = { + "jti": "jti123", + "iss": "https://untrusted-idp.example.com", # Untrusted + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "client", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + } + + with pytest.raises(ValueError, match="Untrusted issuer"): + validator.validate_id_jag(id_jag, expected_client_id="client") + + +def test_validate_id_jag_invalid_audience(jwt_validation_config, create_id_jag): + """Test validation fails for invalid audience.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.return_value = { + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://wrong-server.example/", # Wrong audience + "resource": "https://mcp-server.example/", + "client_id": "client", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + } + + with pytest.raises(ValueError, match="Invalid audience"): + validator.validate_id_jag(id_jag, expected_client_id="client") + + +def test_validate_id_jag_invalid_resource(jwt_validation_config, create_id_jag): + """Test validation fails for invalid resource.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.return_value = { + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://wrong-server.example/", # Wrong resource + "client_id": "client", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + } + + with pytest.raises(ValueError, match="Invalid resource"): + validator.validate_id_jag(id_jag, expected_client_id="client") + + +def test_validate_id_jag_client_id_mismatch(jwt_validation_config, create_id_jag): + """Test validation fails for client_id mismatch.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.return_value = { + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "wrong-client", # Doesn't match expected + "exp": int(time.time()) + 300, + "iat": int(time.time()), + } + + with pytest.raises(ValueError, match="client_id mismatch"): + validator.validate_id_jag(id_jag, expected_client_id="expected-client") + + +def test_validate_id_jag_missing_jti(jwt_validation_config, create_id_jag): + """Test validation fails for missing jti.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.return_value = { + # Missing jti + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "client", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + } + + with pytest.raises(ValueError, match="Missing jti claim"): + validator.validate_id_jag(id_jag, expected_client_id="client") + + +def test_validate_id_jag_replay_detection(jwt_validation_config, create_id_jag): + """Test replay attack detection.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.return_value = { + "jti": "replay-jti-123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "client", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + } + + # First validation should succeed + claims = validator.validate_id_jag(id_jag, expected_client_id="client") + assert claims.jti == "replay-jti-123" + + # Second validation with same jti should fail + with pytest.raises(ValueError, match="Token replay detected"): + validator.validate_id_jag(id_jag, expected_client_id="client") + + +def test_validate_id_jag_replay_disabled(create_id_jag): + """Test that replay detection can be disabled.""" + config = JWTValidationConfig( + trusted_idp_issuers=["https://idp.example.com"], + server_auth_issuer="https://auth.mcp-server.example/", + server_resource_id="https://mcp-server.example/", + replay_prevention_enabled=False, # Disabled + ) + + validator = IDJAGValidator(config) + id_jag = create_id_jag() + + with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.return_value = { + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "client", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + } + + # Should succeed multiple times + validator.validate_id_jag(id_jag, expected_client_id="client") + validator.validate_id_jag(id_jag, expected_client_id="client") + + +@pytest.mark.anyio +async def test_handle_jwt_bearer_grant_success(jwt_validation_config, create_id_jag): + """Test successful JWT bearer grant handling.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.return_value = { + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "client123", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "scope": "read write", + } + + result = await validator.handle_jwt_bearer_grant( + assertion=id_jag, + client_id="client123", + ) + + assert result["token_type"] == "Bearer" + assert "access_token" in result + assert result["expires_in"] == 3600 + assert result["scope"] == "read write" + + +@pytest.mark.anyio +async def test_handle_jwt_bearer_grant_validation_failure(jwt_validation_config, create_id_jag): + """Test JWT bearer grant with validation failure.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + with patch("jwt.get_unverified_header") as mock_header: + mock_header.return_value = {"typ": "wrong-typ", "alg": "HS256"} + + with pytest.raises(ValueError, match="Invalid typ header"): + await validator.handle_jwt_bearer_grant( + assertion=id_jag, + client_id="client123", + ) + From 9759c7a5513d12cb569124833c0f05cf67d07786 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Wed, 26 Nov 2025 03:31:30 +0530 Subject: [PATCH 02/16] Added test cases for missing lines of code. --- .../test_enterprise_managed_auth_client.py | 407 +++++++++++++++++- .../test_enterprise_managed_auth_server.py | 58 +++ 2 files changed, 464 insertions(+), 1 deletion(-) diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py index 7472830877..0114a2e5f1 100644 --- a/tests/client/auth/test_enterprise_managed_auth_client.py +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -1,7 +1,7 @@ """Tests for Enterprise Managed Authorization client-side implementation.""" import time -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock, Mock, patch import httpx import jwt @@ -469,3 +469,408 @@ async def test_perform_authorization_not_implemented(mock_token_storage): with pytest.raises(NotImplementedError, match="not yet implemented"): await provider._perform_authorization() + +@pytest.mark.anyio +async def test_perform_authorization_with_valid_tokens(mock_token_storage): + """Test that _perform_authorization returns dummy request when tokens are valid.""" + from mcp.shared.auth import OAuthToken + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set valid tokens + provider.context.current_tokens = OAuthToken( + token_type="Bearer", + access_token="valid-token", + expires_in=3600, + ) + provider.context.token_expiry = time.time() + 3600 + + # Should return a dummy request + request = await provider._perform_authorization() + assert request.method == "GET" + assert str(request.url) == "https://mcp-server.example/" + + +@pytest.mark.anyio +async def test_exchange_token_with_client_authentication(sample_id_token, sample_id_jag, mock_token_storage): + """Test token exchange with client authentication.""" + from mcp.shared.auth import OAuthClientInformationFull + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with secret + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret="test-client-secret", + redirect_uris=["http://localhost:8080/callback"], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify client credentials were included + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert call_args[1]["data"]["client_secret"] == "test-client-secret" + + +@pytest.mark.anyio +async def test_exchange_token_http_error(sample_id_token, mock_token_storage): + """Test token exchange with HTTP error.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed")) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="HTTP error during token exchange"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_non_json_error_response(sample_id_token, mock_token_storage): + """Test token exchange with non-JSON error response.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock error response with non-JSON content + mock_response = httpx.Response( + status_code=500, + content=b"Internal Server Error", + headers={"content-type": "text/plain"}, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError with default error + with pytest.raises(OAuthTokenError, match="Token exchange failed: unknown_error"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_warning_for_non_na_token_type(sample_id_token, sample_id_jag, mock_token_storage): + """Test token exchange logs warning for non-N_A token type.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock response with different token_type + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "Bearer", # Not N_A + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should succeed but log warning + import logging + with patch.object(logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning") as mock_warning: + id_jag = await provider.exchange_token_for_id_jag(mock_client) + assert id_jag == sample_id_jag + mock_warning.assert_called_once() + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_authentication(sample_id_jag, mock_token_storage): + """Test JWT bearer grant with client authentication.""" + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + from pydantic import HttpUrl + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with secret + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret="test-client-secret", + redirect_uris=["http://localhost:8080/callback"], + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=HttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify client credentials were included + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert call_args[1]["data"]["client_secret"] == "test-client-secret" + + +@pytest.mark.anyio +async def test_exchange_id_jag_error_response(sample_id_jag, mock_token_storage): + """Test JWT bearer grant with error response.""" + from mcp.shared.auth import OAuthMetadata + from pydantic import HttpUrl + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=HttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock error response + mock_response = httpx.Response( + status_code=400, + json={ + "error": "invalid_grant", + "error_description": "Invalid assertion", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="JWT bearer grant failed"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_exchange_id_jag_non_json_error(sample_id_jag, mock_token_storage): + """Test JWT bearer grant with non-JSON error response.""" + from mcp.shared.auth import OAuthMetadata + from pydantic import HttpUrl + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=HttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock error response with non-JSON content + mock_response = httpx.Response( + status_code=503, + content=b"Service Unavailable", + headers={"content-type": "text/html"}, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError with default error + with pytest.raises(OAuthTokenError, match="JWT bearer grant failed: unknown_error"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_exchange_id_jag_http_error(sample_id_jag, mock_token_storage): + """Test JWT bearer grant with HTTP error.""" + from mcp.shared.auth import OAuthMetadata + from pydantic import HttpUrl + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=HttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(side_effect=httpx.ReadTimeout("Request timeout")) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="HTTP error during JWT bearer grant"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +def test_validate_token_exchange_params_missing_audience(): + """Test validation fails for missing audience.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="", + resource="https://server.example/", + ) + + with pytest.raises(ValueError, match="audience is required"): + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_missing_resource(): + """Test validation fails for missing resource.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="https://auth.example/", + resource="", + ) + + with pytest.raises(ValueError, match="resource is required"): + validate_token_exchange_params(params) + + diff --git a/tests/server/auth/test_enterprise_managed_auth_server.py b/tests/server/auth/test_enterprise_managed_auth_server.py index 908d389387..81a4ba6c93 100644 --- a/tests/server/auth/test_enterprise_managed_auth_server.py +++ b/tests/server/auth/test_enterprise_managed_auth_server.py @@ -521,3 +521,61 @@ async def test_handle_jwt_bearer_grant_validation_failure(jwt_validation_config, client_id="client123", ) + +def test_validate_id_jag_with_jwks_client(create_id_jag): + """Test ID-JAG validation with JWKS client for signature verification.""" + from unittest.mock import MagicMock + + config = JWTValidationConfig( + trusted_idp_issuers=["https://idp.example.com"], + server_auth_issuer="https://auth.mcp-server.example/", + server_resource_id="https://mcp-server.example/", + jwks_uri="https://idp.example.com/.well-known/jwks.json", + ) + + validator = IDJAGValidator(config) + id_jag = create_id_jag() + + # Mock the JWKS client + mock_signing_key = MagicMock() + mock_signing_key.key = "mock-key" + + with patch.object(validator.jwks_client, "get_signing_key_from_jwt", return_value=mock_signing_key), \ + patch("jwt.decode") as mock_decode, \ + patch("jwt.get_unverified_header") as mock_header: + + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "RS256"} + mock_decode.return_value = { + "jti": "jti-with-jwks", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "client123", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + } + + claims = validator.validate_id_jag(id_jag, expected_client_id="client123") + + # Verify JWKS client was called + validator.jwks_client.get_signing_key_from_jwt.assert_called_once_with(id_jag) + # Verify jwt.decode was called with the key + assert mock_decode.call_args[0][1] == "mock-key" + + +def test_validate_id_jag_invalid_token_error(jwt_validation_config, create_id_jag): + """Test validation handles InvalidTokenError.""" + validator = IDJAGValidator(jwt_validation_config) + id_jag = create_id_jag() + + with patch("jwt.get_unverified_header") as mock_header, \ + patch("jwt.decode") as mock_decode: + + mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} + mock_decode.side_effect = jwt.InvalidTokenError("Invalid token") + + with pytest.raises(ValueError, match="Invalid ID-JAG: Invalid token"): + validator.validate_id_jag(id_jag, expected_client_id="client123") + + From 7f80d329748dacbbdbb6b8ef638184c175540049 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Wed, 26 Nov 2025 04:15:32 +0530 Subject: [PATCH 03/16] - Added tests cases for few of the missing lines. src/mcp/client/auth/extensions/enterprise_managed_auth.py 232->235, 304->307. - Resolved pre-commit errors. --- .../test_enterprise_managed_auth_client.py | 144 ++++++++++++++++-- .../test_enterprise_managed_auth_server.py | 22 +-- 2 files changed, 147 insertions(+), 19 deletions(-) diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py index 0114a2e5f1..b3d05542f6 100644 --- a/tests/client/auth/test_enterprise_managed_auth_client.py +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -18,7 +18,6 @@ ) from mcp.shared.auth import OAuthClientMetadata - # ============================================================================ # Fixtures # ============================================================================ @@ -376,9 +375,10 @@ async def test_exchange_id_jag_for_access_token_success(sample_id_jag, mock_toke ) # Set up OAuth metadata - from mcp.shared.auth import OAuthMetadata from pydantic import HttpUrl + from mcp.shared.auth import OAuthMetadata + provider.context.oauth_metadata = OAuthMetadata( issuer=HttpUrl("https://auth.mcp-server.example/"), authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), @@ -559,6 +559,60 @@ async def test_exchange_token_with_client_authentication(sample_id_token, sample assert call_args[1]["data"]["client_secret"] == "test-client-secret" +@pytest.mark.anyio +async def test_exchange_token_with_client_id_only(sample_id_token, sample_id_jag, mock_token_storage): + """Test token exchange with client_id but no client_secret (covers branch 232->235).""" + from mcp.shared.auth import OAuthClientInformationFull + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info WITHOUT secret (client_secret=None) + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret=None, # No secret + redirect_uris=["http://localhost:8080/callback"], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify client_id was included but NOT client_secret + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert "client_secret" not in call_args[1]["data"] + + @pytest.mark.anyio async def test_exchange_token_http_error(sample_id_token, mock_token_storage): """Test token exchange with HTTP error.""" @@ -656,7 +710,10 @@ async def test_exchange_token_warning_for_non_na_token_type(sample_id_token, sam # Should succeed but log warning import logging - with patch.object(logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning") as mock_warning: + + with patch.object( + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" + ) as mock_warning: id_jag = await provider.exchange_token_for_id_jag(mock_client) assert id_jag == sample_id_jag mock_warning.assert_called_once() @@ -665,9 +722,10 @@ async def test_exchange_token_warning_for_non_na_token_type(sample_id_token, sam @pytest.mark.anyio async def test_exchange_id_jag_with_client_authentication(sample_id_jag, mock_token_storage): """Test JWT bearer grant with client authentication.""" - from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata from pydantic import HttpUrl + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + token_exchange_params = TokenExchangeParameters.from_id_token( id_token="dummy-token", mcp_server_auth_issuer="https://auth.mcp-server.example/", @@ -715,18 +773,86 @@ async def test_exchange_id_jag_with_client_authentication(sample_id_jag, mock_to # Perform JWT bearer grant token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + # Verify token was returned + assert token.access_token == "mcp-access-token-12345" + # Verify client credentials were included call_args = mock_client.post.call_args assert call_args[1]["data"]["client_id"] == "test-client-id" assert call_args[1]["data"]["client_secret"] == "test-client-secret" +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_id_only(sample_id_jag, mock_token_storage): + """Test JWT bearer grant with client_id but no client_secret (covers branch 304->307).""" + from pydantic import HttpUrl + + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=["http://localhost:8080/callback"], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info WITHOUT secret (client_secret=None) + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret=None, # No secret + redirect_uris=["http://localhost:8080/callback"], + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=HttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify token was returned correctly + assert token.access_token == "mcp-access-token-12345" + assert token.token_type == "Bearer" + + # Verify client_id was included but NOT client_secret + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert "client_secret" not in call_args[1]["data"] + + @pytest.mark.anyio async def test_exchange_id_jag_error_response(sample_id_jag, mock_token_storage): """Test JWT bearer grant with error response.""" - from mcp.shared.auth import OAuthMetadata from pydantic import HttpUrl + from mcp.shared.auth import OAuthMetadata + token_exchange_params = TokenExchangeParameters.from_id_token( id_token="dummy-token", mcp_server_auth_issuer="https://auth.mcp-server.example/", @@ -770,9 +896,10 @@ async def test_exchange_id_jag_error_response(sample_id_jag, mock_token_storage) @pytest.mark.anyio async def test_exchange_id_jag_non_json_error(sample_id_jag, mock_token_storage): """Test JWT bearer grant with non-JSON error response.""" - from mcp.shared.auth import OAuthMetadata from pydantic import HttpUrl + from mcp.shared.auth import OAuthMetadata + token_exchange_params = TokenExchangeParameters.from_id_token( id_token="dummy-token", mcp_server_auth_issuer="https://auth.mcp-server.example/", @@ -814,9 +941,10 @@ async def test_exchange_id_jag_non_json_error(sample_id_jag, mock_token_storage) @pytest.mark.anyio async def test_exchange_id_jag_http_error(sample_id_jag, mock_token_storage): """Test JWT bearer grant with HTTP error.""" - from mcp.shared.auth import OAuthMetadata from pydantic import HttpUrl + from mcp.shared.auth import OAuthMetadata + token_exchange_params = TokenExchangeParameters.from_id_token( id_token="dummy-token", mcp_server_auth_issuer="https://auth.mcp-server.example/", @@ -872,5 +1000,3 @@ def test_validate_token_exchange_params_missing_resource(): with pytest.raises(ValueError, match="resource is required"): validate_token_exchange_params(params) - - diff --git a/tests/server/auth/test_enterprise_managed_auth_server.py b/tests/server/auth/test_enterprise_managed_auth_server.py index 81a4ba6c93..669089e9ab 100644 --- a/tests/server/auth/test_enterprise_managed_auth_server.py +++ b/tests/server/auth/test_enterprise_managed_auth_server.py @@ -13,7 +13,6 @@ ReplayPreventionStore, ) - # ============================================================================ # Fixtures # ============================================================================ @@ -50,6 +49,7 @@ def valid_id_jag_claims(): @pytest.fixture def create_id_jag(valid_id_jag_claims): """Factory to create ID-JAG tokens.""" + def _create(claims=None, secret="test-secret"): claims_data = valid_id_jag_claims.copy() if claims: @@ -60,6 +60,7 @@ def _create(claims=None, secret="test-secret"): algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"}, ) + return _create @@ -540,10 +541,11 @@ def test_validate_id_jag_with_jwks_client(create_id_jag): mock_signing_key = MagicMock() mock_signing_key.key = "mock-key" - with patch.object(validator.jwks_client, "get_signing_key_from_jwt", return_value=mock_signing_key), \ - patch("jwt.decode") as mock_decode, \ - patch("jwt.get_unverified_header") as mock_header: - + with ( + patch.object(validator.jwks_client, "get_signing_key_from_jwt", return_value=mock_signing_key), + patch("jwt.decode") as mock_decode, + patch("jwt.get_unverified_header") as mock_header, + ): mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "RS256"} mock_decode.return_value = { "jti": "jti-with-jwks", @@ -558,6 +560,10 @@ def test_validate_id_jag_with_jwks_client(create_id_jag): claims = validator.validate_id_jag(id_jag, expected_client_id="client123") + # Verify claims were returned correctly + assert claims.jti == "jti-with-jwks" + assert claims.client_id == "client123" + # Verify JWKS client was called validator.jwks_client.get_signing_key_from_jwt.assert_called_once_with(id_jag) # Verify jwt.decode was called with the key @@ -569,13 +575,9 @@ def test_validate_id_jag_invalid_token_error(jwt_validation_config, create_id_ja validator = IDJAGValidator(jwt_validation_config) id_jag = create_id_jag() - with patch("jwt.get_unverified_header") as mock_header, \ - patch("jwt.decode") as mock_decode: - + with patch("jwt.get_unverified_header") as mock_header, patch("jwt.decode") as mock_decode: mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} mock_decode.side_effect = jwt.InvalidTokenError("Invalid token") with pytest.raises(ValueError, match="Invalid ID-JAG: Invalid token"): validator.validate_id_jag(id_jag, expected_client_id="client123") - - From 1ea72c5d74c29528b99d8b5b5717b7556abd789f Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Wed, 26 Nov 2025 04:34:21 +0530 Subject: [PATCH 04/16] - Fixed pre-commit errors. --- src/mcp/client/auth/extensions/__init__.py | 1 - src/mcp/client/auth/extensions/enterprise_managed_auth.py | 1 - src/mcp/server/auth/extensions/__init__.py | 1 - src/mcp/server/auth/extensions/enterprise_managed_auth.py | 1 - 4 files changed, 4 deletions(-) diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py index 7b3ece607d..56ba368ef8 100644 --- a/src/mcp/client/auth/extensions/__init__.py +++ b/src/mcp/client/auth/extensions/__init__.py @@ -17,4 +17,3 @@ "decode_id_jag", "validate_token_exchange_params", ] - diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py index d5fb9e56ef..fb21f771ec 100644 --- a/src/mcp/client/auth/extensions/enterprise_managed_auth.py +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -412,4 +412,3 @@ def validate_token_exchange_params( "urn:ietf:params:oauth:token-type:saml2", ]: raise ValueError(f"Invalid subject_token_type: {params.subject_token_type}") - diff --git a/src/mcp/server/auth/extensions/__init__.py b/src/mcp/server/auth/extensions/__init__.py index 2d75f43b9f..9230a43a5a 100644 --- a/src/mcp/server/auth/extensions/__init__.py +++ b/src/mcp/server/auth/extensions/__init__.py @@ -13,4 +13,3 @@ "JWTValidationConfig", "ReplayPreventionStore", ] - diff --git a/src/mcp/server/auth/extensions/enterprise_managed_auth.py b/src/mcp/server/auth/extensions/enterprise_managed_auth.py index 084394a127..3fd066d9f1 100644 --- a/src/mcp/server/auth/extensions/enterprise_managed_auth.py +++ b/src/mcp/server/auth/extensions/enterprise_managed_auth.py @@ -274,4 +274,3 @@ async def handle_jwt_bearer_grant( "expires_in": 3600, "scope": claims.scope, } - From c07b7b9b56c96123e2ddcc4b4fdbaa178b1cb422 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Wed, 26 Nov 2025 04:42:58 +0530 Subject: [PATCH 05/16] - Tried to fix the ruff error. --- .../auth/extensions/enterprise_managed_auth.py | 8 ++------ .../auth/extensions/enterprise_managed_auth.py | 14 ++++---------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py index fb21f771ec..adb5f0de1e 100644 --- a/src/mcp/client/auth/extensions/enterprise_managed_auth.py +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -241,9 +241,7 @@ async def exchange_token_for_id_jag( if response.status_code != 200: error_data = ( - response.json() - if response.headers.get("content-type", "").startswith("application/json") - else {} + response.json() if response.headers.get("content-type", "").startswith("application/json") else {} ) error = error_data.get("error", "unknown_error") error_description = error_data.get("error_description", "Token exchange failed") @@ -313,9 +311,7 @@ async def exchange_id_jag_for_access_token( if response.status_code != 200: error_data = ( - response.json() - if response.headers.get("content-type", "").startswith("application/json") - else {} + response.json() if response.headers.get("content-type", "").startswith("application/json") else {} ) error = error_data.get("error", "unknown_error") error_description = error_data.get("error_description", "JWT bearer grant failed") diff --git a/src/mcp/server/auth/extensions/enterprise_managed_auth.py b/src/mcp/server/auth/extensions/enterprise_managed_auth.py index 3fd066d9f1..07565f1fa4 100644 --- a/src/mcp/server/auth/extensions/enterprise_managed_auth.py +++ b/src/mcp/server/auth/extensions/enterprise_managed_auth.py @@ -117,9 +117,7 @@ def is_used(self, jti: str) -> bool: def _cleanup(self) -> None: """Remove expired entries.""" now = time.time() - self._used_jtis = { - jti: timestamp for jti, timestamp in self._used_jtis.items() if now - timestamp < self._ttl - } + self._used_jtis = {jti: timestamp for jti, timestamp in self._used_jtis.items() if now - timestamp < self._ttl} # ============================================================================ @@ -205,22 +203,18 @@ def validate_id_jag( # Step 5: Validate audience if claims.get("aud") != self.config.server_auth_issuer: raise ValueError( - f"Invalid audience: expected '{self.config.server_auth_issuer}', " - f"got '{claims.get('aud')}'" + f"Invalid audience: expected '{self.config.server_auth_issuer}', got '{claims.get('aud')}'" ) # Step 6: Validate resource if claims.get("resource") != self.config.server_resource_id: raise ValueError( - f"Invalid resource: expected '{self.config.server_resource_id}', " - f"got '{claims.get('resource')}'" + f"Invalid resource: expected '{self.config.server_resource_id}', got '{claims.get('resource')}'" ) # Step 7: Validate client_id if claims.get("client_id") != expected_client_id: - raise ValueError( - f"client_id mismatch: expected '{expected_client_id}', " f"got '{claims.get('client_id')}'" - ) + raise ValueError(f"client_id mismatch: expected '{expected_client_id}', got '{claims.get('client_id')}'") # Step 8: Check for replay (if enabled) jti = claims.get("jti") From f431b54218ecc85baac3d29551c289de8484ad61 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Wed, 26 Nov 2025 05:34:54 +0530 Subject: [PATCH 06/16] - Fixed ruff errors. --- tests/client/auth/test_enterprise_managed_auth_client.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py index b3d05542f6..0c02c809dd 100644 --- a/tests/client/auth/test_enterprise_managed_auth_client.py +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -553,6 +553,9 @@ async def test_exchange_token_with_client_authentication(sample_id_token, sample # Perform token exchange id_jag = await provider.exchange_token_for_id_jag(mock_client) + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + # Verify client credentials were included call_args = mock_client.post.call_args assert call_args[1]["data"]["client_id"] == "test-client-id" @@ -607,6 +610,9 @@ async def test_exchange_token_with_client_id_only(sample_id_token, sample_id_jag # Perform token exchange id_jag = await provider.exchange_token_for_id_jag(mock_client) + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + # Verify client_id was included but NOT client_secret call_args = mock_client.post.call_args assert call_args[1]["data"]["client_id"] == "test-client-id" From d4392ae574cdbed080f7841f944cd0c0f0f51c07 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Wed, 26 Nov 2025 13:46:13 +0530 Subject: [PATCH 07/16] - Removed server side changes for enterprise_managed_auth.py --- .../extensions/enterprise_managed_auth.py | 16 +- src/mcp/server/auth/extensions/__init__.py | 15 - .../extensions/enterprise_managed_auth.py | 270 -------- .../test_enterprise_managed_auth_server.py | 583 ------------------ 4 files changed, 8 insertions(+), 876 deletions(-) delete mode 100644 src/mcp/server/auth/extensions/__init__.py delete mode 100644 src/mcp/server/auth/extensions/enterprise_managed_auth.py delete mode 100644 tests/server/auth/test_enterprise_managed_auth_server.py diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py index adb5f0de1e..3abab48684 100644 --- a/src/mcp/client/auth/extensions/enterprise_managed_auth.py +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -229,7 +229,7 @@ async def exchange_token_for_id_jag( # Add client authentication if needed if self.context.client_info: token_data["client_id"] = self.context.client_info.client_id - if self.context.client_info.client_secret: + if self.context.client_info.client_secret is not None: token_data["client_secret"] = self.context.client_info.client_secret try: @@ -240,11 +240,11 @@ async def exchange_token_for_id_jag( ) if response.status_code != 200: - error_data = ( + error_data: dict[str, str] = ( response.json() if response.headers.get("content-type", "").startswith("application/json") else {} ) - error = error_data.get("error", "unknown_error") - error_description = error_data.get("error_description", "Token exchange failed") + error: str = error_data.get("error", "unknown_error") + error_description: str = error_data.get("error_description", "Token exchange failed") raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}") # Parse response @@ -299,7 +299,7 @@ async def exchange_id_jag_for_access_token( # Add client authentication if self.context.client_info: token_data["client_id"] = self.context.client_info.client_id - if self.context.client_info.client_secret: + if self.context.client_info.client_secret is not None: token_data["client_secret"] = self.context.client_info.client_secret try: @@ -310,11 +310,11 @@ async def exchange_id_jag_for_access_token( ) if response.status_code != 200: - error_data = ( + error_data: dict[str, str] = ( response.json() if response.headers.get("content-type", "").startswith("application/json") else {} ) - error = error_data.get("error", "unknown_error") - error_description = error_data.get("error_description", "JWT bearer grant failed") + error: str = error_data.get("error", "unknown_error") + error_description: str = error_data.get("error_description", "JWT bearer grant failed") raise OAuthTokenError(f"JWT bearer grant failed: {error} - {error_description}") # Parse OAuth token response diff --git a/src/mcp/server/auth/extensions/__init__.py b/src/mcp/server/auth/extensions/__init__.py deleted file mode 100644 index 9230a43a5a..0000000000 --- a/src/mcp/server/auth/extensions/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""MCP Server Auth Extensions.""" - -from mcp.server.auth.extensions.enterprise_managed_auth import ( - IDJAGClaims, - IDJAGValidator, - JWTValidationConfig, - ReplayPreventionStore, -) - -__all__ = [ - "IDJAGClaims", - "IDJAGValidator", - "JWTValidationConfig", - "ReplayPreventionStore", -] diff --git a/src/mcp/server/auth/extensions/enterprise_managed_auth.py b/src/mcp/server/auth/extensions/enterprise_managed_auth.py deleted file mode 100644 index 07565f1fa4..0000000000 --- a/src/mcp/server/auth/extensions/enterprise_managed_auth.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -Server-side Enterprise Managed Authorization (SEP-990). - -Implements JWT validation for ID-JAG tokens and JWT bearer grant handling. -""" - -import logging -import time -from typing import Any - -import jwt -from jwt import PyJWKClient -from pydantic import BaseModel, Field - -logger = logging.getLogger(__name__) - - -# ============================================================================ -# Configuration Models -# ============================================================================ - - -class JWTValidationConfig(BaseModel): - """Configuration for JWT validation.""" - - trusted_idp_issuers: list[str] = Field( - ..., - description="List of trusted IdP issuer URLs", - ) - - server_auth_issuer: str = Field( - ..., - description="This server's authorization server issuer URL", - ) - - server_resource_id: str = Field( - ..., - description="This server's resource identifier", - ) - - jwks_uri: str | None = Field( - default=None, - description="JWKS URI for key verification (if single IdP)", - ) - - jwks_cache_ttl: int = Field( - default=3600, - description="JWKS cache TTL in seconds", - ) - - allowed_algorithms: list[str] = Field( - default=["RS256", "ES256"], - description="Allowed JWT signing algorithms", - ) - - replay_prevention_enabled: bool = Field( - default=True, - description="Enable JTI-based replay prevention", - ) - - replay_cache_ttl: int = Field( - default=3600, - description="Replay cache TTL in seconds", - ) - - clock_skew_seconds: int = Field( - default=60, - description="Allowed clock skew for exp/iat validation", - ) - - -class IDJAGClaims(BaseModel): - """Validated ID-JAG claims.""" - - model_config = {"extra": "allow"} - - # JWT header - typ: str - - # Required claims - jti: str - iss: str - sub: str - aud: str - resource: str - client_id: str - exp: int - iat: int - - # Optional claims - scope: str | None = None - email: str | None = None - - -# ============================================================================ -# Replay Prevention -# ============================================================================ - - -class ReplayPreventionStore: - """In-memory store for replay prevention (production should use Redis/similar).""" - - def __init__(self, ttl: int = 3600): - self._used_jtis: dict[str, float] = {} - self._ttl = ttl - - def mark_used(self, jti: str) -> None: - """Mark a JTI as used.""" - self._cleanup() - self._used_jtis[jti] = time.time() - - def is_used(self, jti: str) -> bool: - """Check if a JTI has been used.""" - self._cleanup() - return jti in self._used_jtis - - def _cleanup(self) -> None: - """Remove expired entries.""" - now = time.time() - self._used_jtis = {jti: timestamp for jti, timestamp in self._used_jtis.items() if now - timestamp < self._ttl} - - -# ============================================================================ -# JWT Validator -# ============================================================================ - - -class IDJAGValidator: - """Validator for ID-JAG tokens.""" - - def __init__(self, config: JWTValidationConfig): - self.config = config - self.replay_store = ReplayPreventionStore(ttl=config.replay_cache_ttl) - - # Initialize JWKS client if provided - self.jwks_client: PyJWKClient | None = None - if config.jwks_uri: - self.jwks_client = PyJWKClient( - config.jwks_uri, - cache_keys=True, - max_cached_keys=16, - cache_jwk_set=True, - lifespan=config.jwks_cache_ttl, - ) - - def validate_id_jag( - self, - id_jag: str, - expected_client_id: str, - ) -> IDJAGClaims: - """ - Validate an ID-JAG token. - - Args: - id_jag: The ID-JAG token to validate - expected_client_id: The client_id from client authentication - - Returns: - Validated ID-JAG claims - - Raises: - jwt.InvalidTokenError: If validation fails - ValueError: If claims are invalid - """ - # Step 1: Decode and get header - header = jwt.get_unverified_header(id_jag) - - # Validate typ header - if header.get("typ") != "oauth-id-jag+jwt": - raise ValueError(f"Invalid typ header: expected 'oauth-id-jag+jwt', got '{header.get('typ')}'") - - # Step 2: Get signing key - if self.jwks_client: - signing_key = self.jwks_client.get_signing_key_from_jwt(id_jag) - key = signing_key.key - else: - # For testing/development - decode without verification - logger.warning("No JWKS client configured - skipping signature verification") - key = None - - # Step 3: Decode and verify JWT - try: - claims = jwt.decode( - id_jag, - key, - algorithms=self.config.allowed_algorithms, - options={ - "verify_signature": key is not None, - "verify_exp": True, - "verify_iat": True, - }, - leeway=self.config.clock_skew_seconds, - ) - except jwt.ExpiredSignatureError: - raise ValueError("ID-JAG has expired") - except jwt.InvalidTokenError as e: - raise ValueError(f"Invalid ID-JAG: {e}") - - # Step 4: Validate issuer - if claims.get("iss") not in self.config.trusted_idp_issuers: - raise ValueError(f"Untrusted issuer: {claims.get('iss')}") - - # Step 5: Validate audience - if claims.get("aud") != self.config.server_auth_issuer: - raise ValueError( - f"Invalid audience: expected '{self.config.server_auth_issuer}', got '{claims.get('aud')}'" - ) - - # Step 6: Validate resource - if claims.get("resource") != self.config.server_resource_id: - raise ValueError( - f"Invalid resource: expected '{self.config.server_resource_id}', got '{claims.get('resource')}'" - ) - - # Step 7: Validate client_id - if claims.get("client_id") != expected_client_id: - raise ValueError(f"client_id mismatch: expected '{expected_client_id}', got '{claims.get('client_id')}'") - - # Step 8: Check for replay (if enabled) - jti = claims.get("jti") - if not jti: - raise ValueError("Missing jti claim") - - if self.config.replay_prevention_enabled: - if self.replay_store.is_used(jti): - raise ValueError(f"Token replay detected: jti '{jti}' already used") - self.replay_store.mark_used(jti) - - # Step 9: Create validated claims object - claims["typ"] = header["typ"] - return IDJAGClaims.model_validate(claims) - - async def handle_jwt_bearer_grant( - self, - assertion: str, - client_id: str, - ) -> dict[str, Any]: - """ - Handle JWT bearer grant request. - - Args: - assertion: The ID-JAG assertion - client_id: Authenticated client ID - - Returns: - Token response dict - - Raises: - ValueError: If validation fails - """ - # Validate ID-JAG - claims = self.validate_id_jag(assertion, client_id) - - # TODO: Generate and return access token - # This is where you'd integrate with your token generation logic - logger.info( - "JWT bearer grant validated successfully", - extra={ - "client_id": client_id, - "sub": claims.sub, - "scope": claims.scope, - }, - ) - - return { - "token_type": "Bearer", - "access_token": "generated_access_token_here", - "expires_in": 3600, - "scope": claims.scope, - } diff --git a/tests/server/auth/test_enterprise_managed_auth_server.py b/tests/server/auth/test_enterprise_managed_auth_server.py deleted file mode 100644 index 669089e9ab..0000000000 --- a/tests/server/auth/test_enterprise_managed_auth_server.py +++ /dev/null @@ -1,583 +0,0 @@ -"""Tests for Enterprise Managed Authorization server-side implementation.""" - -import time -from unittest.mock import patch - -import jwt -import pytest - -from src.mcp.server.auth.extensions.enterprise_managed_auth import ( - IDJAGClaims, - IDJAGValidator, - JWTValidationConfig, - ReplayPreventionStore, -) - -# ============================================================================ -# Fixtures -# ============================================================================ - - -@pytest.fixture -def jwt_validation_config(): - """Create a basic JWT validation config.""" - return JWTValidationConfig( - trusted_idp_issuers=["https://idp.example.com"], - server_auth_issuer="https://auth.mcp-server.example/", - server_resource_id="https://mcp-server.example/", - replay_prevention_enabled=True, - ) - - -@pytest.fixture -def valid_id_jag_claims(): - """Create valid ID-JAG claims.""" - return { - "jti": "unique-jwt-id-12345", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "mcp-client-app", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - "scope": "read write", - "email": "user@example.com", - } - - -@pytest.fixture -def create_id_jag(valid_id_jag_claims): - """Factory to create ID-JAG tokens.""" - - def _create(claims=None, secret="test-secret"): - claims_data = valid_id_jag_claims.copy() - if claims: - claims_data.update(claims) - return jwt.encode( - claims_data, - secret, - algorithm="HS256", - headers={"typ": "oauth-id-jag+jwt"}, - ) - - return _create - - -# ============================================================================ -# Tests for ReplayPreventionStore -# ============================================================================ - - -def test_replay_prevention_store_mark_and_check(): - """Test marking JTI as used and checking.""" - store = ReplayPreventionStore(ttl=3600) - - jti = "test-jti-123" - - # Initially not used - assert not store.is_used(jti) - - # Mark as used - store.mark_used(jti) - - # Now should be used - assert store.is_used(jti) - - -def test_replay_prevention_store_cleanup(): - """Test that expired JTIs are cleaned up.""" - store = ReplayPreventionStore(ttl=1) # 1 second TTL - - jti1 = "test-jti-1" - jti2 = "test-jti-2" - - # Mark first JTI - store.mark_used(jti1) - assert store.is_used(jti1) - - # Wait for expiry - time.sleep(1.1) - - # Mark second JTI (triggers cleanup) - store.mark_used(jti2) - - # First JTI should be cleaned up - assert not store.is_used(jti1) - - # Second JTI should still be there - assert store.is_used(jti2) - - -def test_replay_prevention_store_multiple_jtis(): - """Test storing multiple JTIs.""" - store = ReplayPreventionStore(ttl=3600) - - jtis = [f"jti-{i}" for i in range(10)] - - for jti in jtis: - store.mark_used(jti) - - for jti in jtis: - assert store.is_used(jti) - - -# ============================================================================ -# Tests for JWTValidationConfig -# ============================================================================ - - -def test_jwt_validation_config_defaults(): - """Test JWT validation config with default values.""" - config = JWTValidationConfig( - trusted_idp_issuers=["https://idp.example.com"], - server_auth_issuer="https://auth.server.example/", - server_resource_id="https://server.example/", - ) - - assert config.jwks_uri is None - assert config.jwks_cache_ttl == 3600 - assert config.allowed_algorithms == ["RS256", "ES256"] - assert config.replay_prevention_enabled is True - assert config.replay_cache_ttl == 3600 - assert config.clock_skew_seconds == 60 - - -def test_jwt_validation_config_custom_values(): - """Test JWT validation config with custom values.""" - config = JWTValidationConfig( - trusted_idp_issuers=["https://idp1.example.com", "https://idp2.example.com"], - server_auth_issuer="https://auth.server.example/", - server_resource_id="https://server.example/", - jwks_uri="https://idp.example.com/.well-known/jwks.json", - jwks_cache_ttl=7200, - allowed_algorithms=["RS256"], - replay_prevention_enabled=False, - replay_cache_ttl=1800, - clock_skew_seconds=120, - ) - - assert len(config.trusted_idp_issuers) == 2 - assert config.jwks_uri == "https://idp.example.com/.well-known/jwks.json" - assert config.jwks_cache_ttl == 7200 - assert config.allowed_algorithms == ["RS256"] - assert config.replay_prevention_enabled is False - assert config.replay_cache_ttl == 1800 - assert config.clock_skew_seconds == 120 - - -# ============================================================================ -# Tests for IDJAGClaims -# ============================================================================ - - -def test_id_jag_claims_required_fields(valid_id_jag_claims): - """Test IDJAGClaims with all required fields.""" - claims = IDJAGClaims.model_validate({**valid_id_jag_claims, "typ": "oauth-id-jag+jwt"}) - - assert claims.typ == "oauth-id-jag+jwt" - assert claims.jti == "unique-jwt-id-12345" - assert claims.iss == "https://idp.example.com" - assert claims.sub == "user123" - assert claims.aud == "https://auth.mcp-server.example/" - assert claims.resource == "https://mcp-server.example/" - assert claims.client_id == "mcp-client-app" - assert claims.scope == "read write" - assert claims.email == "user@example.com" - - -def test_id_jag_claims_optional_fields(): - """Test IDJAGClaims without optional fields.""" - claims_data = { - "typ": "oauth-id-jag+jwt", - "jti": "jti123", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.server.example/", - "resource": "https://server.example/", - "client_id": "client123", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - } - - claims = IDJAGClaims.model_validate(claims_data) - assert claims.scope is None - assert claims.email is None - - -def test_id_jag_claims_extra_fields(): - """Test that IDJAGClaims allows extra fields.""" - claims_data = { - "typ": "oauth-id-jag+jwt", - "jti": "jti123", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.server.example/", - "resource": "https://server.example/", - "client_id": "client123", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - "custom_field": "custom_value", - "another_field": 123, - } - - claims = IDJAGClaims.model_validate(claims_data) - assert claims.model_extra.get("custom_field") == "custom_value" - assert claims.model_extra.get("another_field") == 123 - - -# ============================================================================ -# Tests for IDJAGValidator -# ============================================================================ - - -def test_id_jag_validator_initialization(jwt_validation_config): - """Test IDJAGValidator initialization.""" - validator = IDJAGValidator(jwt_validation_config) - - assert validator.config == jwt_validation_config - assert isinstance(validator.replay_store, ReplayPreventionStore) - assert validator.jwks_client is None # No JWKS URI provided - - -def test_id_jag_validator_with_jwks(): - """Test IDJAGValidator initialization with JWKS URI.""" - config = JWTValidationConfig( - trusted_idp_issuers=["https://idp.example.com"], - server_auth_issuer="https://auth.server.example/", - server_resource_id="https://server.example/", - jwks_uri="https://idp.example.com/.well-known/jwks.json", - ) - - validator = IDJAGValidator(config) - - assert validator.jwks_client is not None - - -def test_validate_id_jag_success(jwt_validation_config, create_id_jag): - """Test successful ID-JAG validation (without signature verification).""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - # Mock the JWT decode to skip signature verification - with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.return_value = { - "jti": "unique-jwt-id-12345", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "mcp-client-app", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - "scope": "read write", - } - - claims = validator.validate_id_jag(id_jag, expected_client_id="mcp-client-app") - - assert claims.jti == "unique-jwt-id-12345" - assert claims.iss == "https://idp.example.com" - assert claims.sub == "user123" - assert claims.client_id == "mcp-client-app" - - -def test_validate_id_jag_invalid_typ_header(jwt_validation_config, create_id_jag): - """Test validation fails with invalid typ header.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = jwt.encode( - {"iss": "https://idp.example.com"}, - "secret", - algorithm="HS256", - headers={"typ": "JWT"}, # Wrong typ - ) - - with pytest.raises(ValueError, match="Invalid typ header"): - validator.validate_id_jag(id_jag, expected_client_id="client") - - -def test_validate_id_jag_expired(jwt_validation_config, create_id_jag): - """Test validation fails for expired token.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag(claims={"exp": int(time.time()) - 100}) # Expired - - with patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - - # jwt.decode will raise ExpiredSignatureError - with pytest.raises(ValueError, match="ID-JAG has expired"): - validator.validate_id_jag(id_jag, expected_client_id="client") - - -def test_validate_id_jag_untrusted_issuer(jwt_validation_config, create_id_jag): - """Test validation fails for untrusted issuer.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.return_value = { - "jti": "jti123", - "iss": "https://untrusted-idp.example.com", # Untrusted - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "client", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - } - - with pytest.raises(ValueError, match="Untrusted issuer"): - validator.validate_id_jag(id_jag, expected_client_id="client") - - -def test_validate_id_jag_invalid_audience(jwt_validation_config, create_id_jag): - """Test validation fails for invalid audience.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.return_value = { - "jti": "jti123", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://wrong-server.example/", # Wrong audience - "resource": "https://mcp-server.example/", - "client_id": "client", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - } - - with pytest.raises(ValueError, match="Invalid audience"): - validator.validate_id_jag(id_jag, expected_client_id="client") - - -def test_validate_id_jag_invalid_resource(jwt_validation_config, create_id_jag): - """Test validation fails for invalid resource.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.return_value = { - "jti": "jti123", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://wrong-server.example/", # Wrong resource - "client_id": "client", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - } - - with pytest.raises(ValueError, match="Invalid resource"): - validator.validate_id_jag(id_jag, expected_client_id="client") - - -def test_validate_id_jag_client_id_mismatch(jwt_validation_config, create_id_jag): - """Test validation fails for client_id mismatch.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.return_value = { - "jti": "jti123", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "wrong-client", # Doesn't match expected - "exp": int(time.time()) + 300, - "iat": int(time.time()), - } - - with pytest.raises(ValueError, match="client_id mismatch"): - validator.validate_id_jag(id_jag, expected_client_id="expected-client") - - -def test_validate_id_jag_missing_jti(jwt_validation_config, create_id_jag): - """Test validation fails for missing jti.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.return_value = { - # Missing jti - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "client", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - } - - with pytest.raises(ValueError, match="Missing jti claim"): - validator.validate_id_jag(id_jag, expected_client_id="client") - - -def test_validate_id_jag_replay_detection(jwt_validation_config, create_id_jag): - """Test replay attack detection.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.return_value = { - "jti": "replay-jti-123", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "client", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - } - - # First validation should succeed - claims = validator.validate_id_jag(id_jag, expected_client_id="client") - assert claims.jti == "replay-jti-123" - - # Second validation with same jti should fail - with pytest.raises(ValueError, match="Token replay detected"): - validator.validate_id_jag(id_jag, expected_client_id="client") - - -def test_validate_id_jag_replay_disabled(create_id_jag): - """Test that replay detection can be disabled.""" - config = JWTValidationConfig( - trusted_idp_issuers=["https://idp.example.com"], - server_auth_issuer="https://auth.mcp-server.example/", - server_resource_id="https://mcp-server.example/", - replay_prevention_enabled=False, # Disabled - ) - - validator = IDJAGValidator(config) - id_jag = create_id_jag() - - with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.return_value = { - "jti": "jti123", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "client", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - } - - # Should succeed multiple times - validator.validate_id_jag(id_jag, expected_client_id="client") - validator.validate_id_jag(id_jag, expected_client_id="client") - - -@pytest.mark.anyio -async def test_handle_jwt_bearer_grant_success(jwt_validation_config, create_id_jag): - """Test successful JWT bearer grant handling.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - with patch("jwt.decode") as mock_decode, patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.return_value = { - "jti": "jti123", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "client123", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - "scope": "read write", - } - - result = await validator.handle_jwt_bearer_grant( - assertion=id_jag, - client_id="client123", - ) - - assert result["token_type"] == "Bearer" - assert "access_token" in result - assert result["expires_in"] == 3600 - assert result["scope"] == "read write" - - -@pytest.mark.anyio -async def test_handle_jwt_bearer_grant_validation_failure(jwt_validation_config, create_id_jag): - """Test JWT bearer grant with validation failure.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - with patch("jwt.get_unverified_header") as mock_header: - mock_header.return_value = {"typ": "wrong-typ", "alg": "HS256"} - - with pytest.raises(ValueError, match="Invalid typ header"): - await validator.handle_jwt_bearer_grant( - assertion=id_jag, - client_id="client123", - ) - - -def test_validate_id_jag_with_jwks_client(create_id_jag): - """Test ID-JAG validation with JWKS client for signature verification.""" - from unittest.mock import MagicMock - - config = JWTValidationConfig( - trusted_idp_issuers=["https://idp.example.com"], - server_auth_issuer="https://auth.mcp-server.example/", - server_resource_id="https://mcp-server.example/", - jwks_uri="https://idp.example.com/.well-known/jwks.json", - ) - - validator = IDJAGValidator(config) - id_jag = create_id_jag() - - # Mock the JWKS client - mock_signing_key = MagicMock() - mock_signing_key.key = "mock-key" - - with ( - patch.object(validator.jwks_client, "get_signing_key_from_jwt", return_value=mock_signing_key), - patch("jwt.decode") as mock_decode, - patch("jwt.get_unverified_header") as mock_header, - ): - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "RS256"} - mock_decode.return_value = { - "jti": "jti-with-jwks", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "client123", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - } - - claims = validator.validate_id_jag(id_jag, expected_client_id="client123") - - # Verify claims were returned correctly - assert claims.jti == "jti-with-jwks" - assert claims.client_id == "client123" - - # Verify JWKS client was called - validator.jwks_client.get_signing_key_from_jwt.assert_called_once_with(id_jag) - # Verify jwt.decode was called with the key - assert mock_decode.call_args[0][1] == "mock-key" - - -def test_validate_id_jag_invalid_token_error(jwt_validation_config, create_id_jag): - """Test validation handles InvalidTokenError.""" - validator = IDJAGValidator(jwt_validation_config) - id_jag = create_id_jag() - - with patch("jwt.get_unverified_header") as mock_header, patch("jwt.decode") as mock_decode: - mock_header.return_value = {"typ": "oauth-id-jag+jwt", "alg": "HS256"} - mock_decode.side_effect = jwt.InvalidTokenError("Invalid token") - - with pytest.raises(ValueError, match="Invalid ID-JAG: Invalid token"): - validator.validate_id_jag(id_jag, expected_client_id="client123") From db2f02c19c558f263b14dca62cbd37c66d89cfbc Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Wed, 26 Nov 2025 21:22:55 +0530 Subject: [PATCH 08/16] - Added README.md changes for SEP-990 implementation for enterprise managed auth. --- README.md | 135 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/README.md b/README.md index 5e8129c96e..7f9164ff88 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ - [Writing MCP Clients](#writing-mcp-clients) - [Client Display Utilities](#client-display-utilities) - [OAuth Authentication for Clients](#oauth-authentication-for-clients) + - [Enterprise Managed Authorization](#enterprise-managed-authorization) - [Parsing Tool Results](#parsing-tool-results) - [MCP Primitives](#mcp-primitives) - [Server Capabilities](#server-capabilities) @@ -2356,6 +2357,140 @@ _Full example: [examples/snippets/clients/oauth_client.py](https://github.com/mo For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/). +#### Enterprise Managed Authorization + +The SDK includes support for Enterprise Managed Authorization (SEP-990), which enables MCP clients to connect to protected servers using enterprise Single Sign-On (SSO) systems. This implementation supports: + +- **RFC 8693**: OAuth 2.0 Token Exchange (ID Token → ID-JAG) +- **RFC 7523**: JSON Web Token (JWT) Profile for OAuth 2.0 Authorization Grants (ID-JAG → Access Token) +- Integration with enterprise identity providers (Okta, Azure AD, etc.) + +**Key Components:** + +The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provider to implement the enterprise authorization flow: + +```python +from mcp.client.auth.extensions import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, + IDJAGClaims, + decode_id_jag, +) +from mcp.shared.auth import OAuthClientMetadata, OAuthToken +from mcp.client.auth import TokenStorage +``` + +**Token Exchange Flow:** + +1. **Obtain ID Token** from your enterprise IdP (e.g., Okta, Azure AD) +2. **Exchange ID Token for ID-JAG** using RFC 8693 Token Exchange +3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant +4. **Use Access Token** to call protected MCP server tools + +**Example Usage:** + +```python +import asyncio +import httpx +from pydantic import AnyUrl + +from mcp.client.auth.extensions import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, +) +from mcp.shared.auth import OAuthClientMetadata, OAuthToken +from mcp.client.auth import TokenStorage + +# Define token storage implementation +class SimpleTokenStorage(TokenStorage): + def __init__(self): + self._tokens = None + self._client_info = None + + async def get_tokens(self): + return self._tokens + + async def set_tokens(self, tokens): + self._tokens = tokens + + async def get_client_info(self): + return self._client_info + + async def set_client_info(self, client_info): + self._client_info = client_info + +async def main(): + # Step 1: Get ID token from your IdP (example with Okta) + id_token = await get_id_token_from_idp() # Your IdP authentication + + # Step 2: Configure token exchange parameters + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL + mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID + scope="mcp:tools mcp:resources", # Optional scopes + ) + + # Step 3: Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example.com", + client_metadata=OAuthClientMetadata( + client_name="Enterprise MCP Client", + client_id="your-client-id", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", + token_exchange_params=token_exchange_params, + ) + + # Step 4: Perform token exchange and get access token + async with httpx.AsyncClient() as client: + # Exchange ID token for ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + print(f"Obtained ID-JAG: {id_jag[:50]}...") + + # Exchange ID-JAG for access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token( + client, id_jag + ) + print(f"Access token obtained, expires in: {access_token.expires_in}s") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +**Working with SAML Assertions:** + +If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions: + +```python +token_exchange_params = TokenExchangeParameters.from_saml_assertion( + saml_assertion=saml_assertion_string, + mcp_server_auth_issuer="https://your-idp.com", + mcp_server_resource_id="https://mcp-server.example.com", + scope="mcp:tools", +) +``` + +**Decoding and Inspecting ID-JAG Tokens:** + +You can decode ID-JAG tokens to inspect their claims: + +```python +from mcp.client.auth.extensions import decode_id_jag + +# Decode without signature verification (for inspection only) +claims = decode_id_jag(id_jag) +print(f"Subject: {claims.sub}") +print(f"Issuer: {claims.iss}") +print(f"Audience: {claims.aud}") +print(f"Client ID: {claims.client_id}") +print(f"Resource: {claims.resource}") +``` + ### Parsing Tool Results When calling tools through MCP, the `CallToolResult` object contains the tool's response in a structured format. Understanding how to parse this result is essential for properly handling tool outputs. From 5fb2c0f248d00a5ed9ec1a6a3f64566f99ece61d Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Thu, 27 Nov 2025 02:20:50 +0530 Subject: [PATCH 09/16] - Resolved pyright checks error. --- .../extensions/enterprise_managed_auth.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py index 3abab48684..8ed172a360 100644 --- a/src/mcp/client/auth/extensions/enterprise_managed_auth.py +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -6,7 +6,7 @@ """ import logging -from typing import Any +from collections.abc import Awaitable, Callable import httpx from pydantic import BaseModel, Field @@ -166,8 +166,8 @@ def __init__( storage: TokenStorage, idp_token_endpoint: str, token_exchange_params: TokenExchangeParameters, - redirect_handler: Any = None, - callback_handler: Any = None, + redirect_handler: Callable[[str], Awaitable[None]] | None = None, + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, timeout: float = 300.0, ) -> None: """ @@ -228,7 +228,8 @@ async def exchange_token_for_id_jag( # Add client authentication if needed if self.context.client_info: - token_data["client_id"] = self.context.client_info.client_id + if self.context.client_info.client_id is not None: + token_data["client_id"] = self.context.client_info.client_id if self.context.client_info.client_secret is not None: token_data["client_secret"] = self.context.client_info.client_secret @@ -240,11 +241,11 @@ async def exchange_token_for_id_jag( ) if response.status_code != 200: - error_data: dict[str, str] = ( + error_data: dict[str, object] = ( response.json() if response.headers.get("content-type", "").startswith("application/json") else {} ) - error: str = error_data.get("error", "unknown_error") - error_description: str = error_data.get("error_description", "Token exchange failed") + error = str(error_data.get("error", "unknown_error")) + error_description = str(error_data.get("error_description", "Token exchange failed")) raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}") # Parse response @@ -298,7 +299,8 @@ async def exchange_id_jag_for_access_token( # Add client authentication if self.context.client_info: - token_data["client_id"] = self.context.client_info.client_id + if self.context.client_info.client_id is not None: + token_data["client_id"] = self.context.client_info.client_id if self.context.client_info.client_secret is not None: token_data["client_secret"] = self.context.client_info.client_secret @@ -310,11 +312,11 @@ async def exchange_id_jag_for_access_token( ) if response.status_code != 200: - error_data: dict[str, str] = ( + error_data: dict[str, object] = ( response.json() if response.headers.get("content-type", "").startswith("application/json") else {} ) - error: str = error_data.get("error", "unknown_error") - error_description: str = error_data.get("error_description", "JWT bearer grant failed") + error = str(error_data.get("error", "unknown_error")) + error_description = str(error_data.get("error_description", "JWT bearer grant failed")) raise OAuthTokenError(f"JWT bearer grant failed: {error} - {error_description}") # Parse OAuth token response From 005bad4bbde141d9f0a12789ba28ee0996bdb00f Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Thu, 27 Nov 2025 10:43:24 +0530 Subject: [PATCH 10/16] - Resolved README.md file fixes for removing unused imports. --- README.md | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/README.md b/README.md index 7f9164ff88..fab9e05ef6 100644 --- a/README.md +++ b/README.md @@ -2369,17 +2369,6 @@ The SDK includes support for Enterprise Managed Authorization (SEP-990), which e The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provider to implement the enterprise authorization flow: -```python -from mcp.client.auth.extensions import ( - EnterpriseAuthOAuthClientProvider, - TokenExchangeParameters, - IDJAGClaims, - decode_id_jag, -) -from mcp.shared.auth import OAuthClientMetadata, OAuthToken -from mcp.client.auth import TokenStorage -``` - **Token Exchange Flow:** 1. **Obtain ID Token** from your enterprise IdP (e.g., Okta, Azure AD) @@ -2398,7 +2387,7 @@ from mcp.client.auth.extensions import ( EnterpriseAuthOAuthClientProvider, TokenExchangeParameters, ) -from mcp.shared.auth import OAuthClientMetadata, OAuthToken +from mcp.shared.auth import OAuthClientMetadata from mcp.client.auth import TokenStorage # Define token storage implementation From 73b12b76feb1625aa9a216f375d887b880e98268 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Thu, 27 Nov 2025 14:51:05 +0530 Subject: [PATCH 11/16] - Resolved pyright errors. --- .../extensions/enterprise_managed_auth.py | 18 +-- .../test_enterprise_managed_auth_client.py | 138 ++++++++---------- 2 files changed, 73 insertions(+), 83 deletions(-) diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py index 8ed172a360..e55283e968 100644 --- a/src/mcp/client/auth/extensions/enterprise_managed_auth.py +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -6,7 +6,7 @@ """ import logging -from collections.abc import Awaitable, Callable +from typing import Any import httpx from pydantic import BaseModel, Field @@ -166,8 +166,8 @@ def __init__( storage: TokenStorage, idp_token_endpoint: str, token_exchange_params: TokenExchangeParameters, - redirect_handler: Callable[[str], Awaitable[None]] | None = None, - callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, + redirect_handler: Any = None, + callback_handler: Any = None, timeout: float = 300.0, ) -> None: """ @@ -241,11 +241,11 @@ async def exchange_token_for_id_jag( ) if response.status_code != 200: - error_data: dict[str, object] = ( + error_data: dict[str, str] = ( response.json() if response.headers.get("content-type", "").startswith("application/json") else {} ) - error = str(error_data.get("error", "unknown_error")) - error_description = str(error_data.get("error_description", "Token exchange failed")) + error: str = error_data.get("error", "unknown_error") + error_description: str = error_data.get("error_description", "Token exchange failed") raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}") # Parse response @@ -312,11 +312,11 @@ async def exchange_id_jag_for_access_token( ) if response.status_code != 200: - error_data: dict[str, object] = ( + error_data: dict[str, str] = ( response.json() if response.headers.get("content-type", "").startswith("application/json") else {} ) - error = str(error_data.get("error", "unknown_error")) - error_description = str(error_data.get("error_description", "JWT bearer grant failed")) + error: str = error_data.get("error", "unknown_error") + error_description: str = error_data.get("error_description", "JWT bearer grant failed") raise OAuthTokenError(f"JWT bearer grant failed: {error} - {error_description}") # Parse OAuth token response diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py index 0c02c809dd..a0e53e771b 100644 --- a/tests/client/auth/test_enterprise_managed_auth_client.py +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -1,11 +1,13 @@ """Tests for Enterprise Managed Authorization client-side implementation.""" import time +from typing import Any from unittest.mock import AsyncMock, Mock, patch import httpx import jwt import pytest +from pydantic import AnyUrl, AnyHttpUrl from mcp.client.auth import OAuthTokenError from mcp.client.auth.extensions.enterprise_managed_auth import ( @@ -24,7 +26,7 @@ @pytest.fixture -def sample_id_token(): +def sample_id_token() -> str: """Generate a sample ID token for testing.""" payload = { "iss": "https://idp.example.com", @@ -38,7 +40,7 @@ def sample_id_token(): @pytest.fixture -def sample_id_jag(): +def sample_id_jag() -> str: """Generate a sample ID-JAG token for testing.""" payload = { "jti": "unique-jwt-id-12345", @@ -61,7 +63,7 @@ def sample_id_jag(): @pytest.fixture -def mock_token_storage(): +def mock_token_storage() -> Any: """Create a mock token storage.""" storage = Mock() storage.get_tokens = AsyncMock(return_value=None) @@ -188,7 +190,7 @@ def test_token_exchange_response_id_jag_property(): # ============================================================================ -def test_decode_id_jag(sample_id_jag): +def test_decode_id_jag(sample_id_jag: str): """Test decoding ID-JAG token.""" claims = decode_id_jag(sample_id_jag) @@ -220,7 +222,7 @@ def test_id_jag_claims_with_extra_fields(): claims = IDJAGClaims.model_validate(claims_data) assert claims.email == "user@example.com" # Extra field should be preserved - assert claims.model_extra.get("custom_claim") == "custom_value" + assert claims.model_extra is not None and claims.model_extra.get("custom_claim") == "custom_value" # ============================================================================ @@ -229,7 +231,7 @@ def test_id_jag_claims_with_extra_fields(): @pytest.mark.anyio -async def test_exchange_token_for_id_jag_success(sample_id_token, sample_id_jag, mock_token_storage): +async def test_exchange_token_for_id_jag_success(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): """Test successful token exchange for ID-JAG.""" # Create provider token_exchange_params = TokenExchangeParameters.from_id_token( @@ -242,7 +244,7 @@ async def test_exchange_token_for_id_jag_success(sample_id_token, sample_id_jag, provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], client_name="Test Client", ), storage=mock_token_storage, @@ -283,7 +285,7 @@ async def test_exchange_token_for_id_jag_success(sample_id_token, sample_id_jag, @pytest.mark.anyio -async def test_exchange_token_for_id_jag_error(sample_id_token, mock_token_storage): +async def test_exchange_token_for_id_jag_error(sample_id_token: str, mock_token_storage: Any): """Test token exchange failure handling.""" token_exchange_params = TokenExchangeParameters.from_id_token( id_token=sample_id_token, @@ -294,7 +296,7 @@ async def test_exchange_token_for_id_jag_error(sample_id_token, mock_token_stora provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -319,7 +321,7 @@ async def test_exchange_token_for_id_jag_error(sample_id_token, mock_token_stora @pytest.mark.anyio -async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token, mock_token_storage): +async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token: str, mock_token_storage: Any): """Test token exchange with unexpected token type.""" token_exchange_params = TokenExchangeParameters.from_id_token( id_token=sample_id_token, @@ -330,7 +332,7 @@ async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token, provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -356,7 +358,7 @@ async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token, @pytest.mark.anyio -async def test_exchange_id_jag_for_access_token_success(sample_id_jag, mock_token_storage): +async def test_exchange_id_jag_for_access_token_success(sample_id_jag: str, mock_token_storage: Any): """Test successful JWT bearer grant to get access token.""" token_exchange_params = TokenExchangeParameters.from_id_token( id_token="dummy-token", @@ -367,7 +369,7 @@ async def test_exchange_id_jag_for_access_token_success(sample_id_jag, mock_toke provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -375,14 +377,12 @@ async def test_exchange_id_jag_for_access_token_success(sample_id_jag, mock_toke ) # Set up OAuth metadata - from pydantic import HttpUrl - from mcp.shared.auth import OAuthMetadata provider.context.oauth_metadata = OAuthMetadata( - issuer=HttpUrl("https://auth.mcp-server.example/"), - authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), - token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), ) # Mock HTTP response @@ -418,7 +418,7 @@ async def test_exchange_id_jag_for_access_token_success(sample_id_jag, mock_toke @pytest.mark.anyio -async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag, mock_token_storage): +async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag: str, mock_token_storage: Any): """Test JWT bearer grant fails without OAuth metadata.""" token_exchange_params = TokenExchangeParameters.from_id_token( id_token="dummy-token", @@ -429,7 +429,7 @@ async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag, mock_ provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -447,7 +447,7 @@ async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag, mock_ @pytest.mark.anyio -async def test_perform_authorization_not_implemented(mock_token_storage): +async def test_perform_authorization_not_implemented(mock_token_storage: Any): """Test that _perform_authorization raises NotImplementedError.""" token_exchange_params = TokenExchangeParameters.from_id_token( id_token="dummy-token", @@ -458,7 +458,7 @@ async def test_perform_authorization_not_implemented(mock_token_storage): provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -471,7 +471,7 @@ async def test_perform_authorization_not_implemented(mock_token_storage): @pytest.mark.anyio -async def test_perform_authorization_with_valid_tokens(mock_token_storage): +async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any): """Test that _perform_authorization returns dummy request when tokens are valid.""" from mcp.shared.auth import OAuthToken @@ -484,7 +484,7 @@ async def test_perform_authorization_with_valid_tokens(mock_token_storage): provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -497,7 +497,7 @@ async def test_perform_authorization_with_valid_tokens(mock_token_storage): access_token="valid-token", expires_in=3600, ) - provider.context.token_expiry = time.time() + 3600 + provider.context.token_expiry_time = time.time() + 3600 # Should return a dummy request request = await provider._perform_authorization() @@ -506,7 +506,7 @@ async def test_perform_authorization_with_valid_tokens(mock_token_storage): @pytest.mark.anyio -async def test_exchange_token_with_client_authentication(sample_id_token, sample_id_jag, mock_token_storage): +async def test_exchange_token_with_client_authentication(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): """Test token exchange with client authentication.""" from mcp.shared.auth import OAuthClientInformationFull @@ -520,7 +520,7 @@ async def test_exchange_token_with_client_authentication(sample_id_token, sample provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], client_name="Test Client", ), storage=mock_token_storage, @@ -532,7 +532,7 @@ async def test_exchange_token_with_client_authentication(sample_id_token, sample provider.context.client_info = OAuthClientInformationFull( client_id="test-client-id", client_secret="test-client-secret", - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ) # Mock HTTP response @@ -563,7 +563,7 @@ async def test_exchange_token_with_client_authentication(sample_id_token, sample @pytest.mark.anyio -async def test_exchange_token_with_client_id_only(sample_id_token, sample_id_jag, mock_token_storage): +async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): """Test token exchange with client_id but no client_secret (covers branch 232->235).""" from mcp.shared.auth import OAuthClientInformationFull @@ -577,7 +577,7 @@ async def test_exchange_token_with_client_id_only(sample_id_token, sample_id_jag provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], client_name="Test Client", ), storage=mock_token_storage, @@ -589,7 +589,7 @@ async def test_exchange_token_with_client_id_only(sample_id_token, sample_id_jag provider.context.client_info = OAuthClientInformationFull( client_id="test-client-id", client_secret=None, # No secret - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ) # Mock HTTP response @@ -620,7 +620,7 @@ async def test_exchange_token_with_client_id_only(sample_id_token, sample_id_jag @pytest.mark.anyio -async def test_exchange_token_http_error(sample_id_token, mock_token_storage): +async def test_exchange_token_http_error(sample_id_token: str, mock_token_storage: Any): """Test token exchange with HTTP error.""" token_exchange_params = TokenExchangeParameters.from_id_token( id_token=sample_id_token, @@ -631,7 +631,7 @@ async def test_exchange_token_http_error(sample_id_token, mock_token_storage): provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -647,7 +647,7 @@ async def test_exchange_token_http_error(sample_id_token, mock_token_storage): @pytest.mark.anyio -async def test_exchange_token_non_json_error_response(sample_id_token, mock_token_storage): +async def test_exchange_token_non_json_error_response(sample_id_token: str, mock_token_storage: Any): """Test token exchange with non-JSON error response.""" token_exchange_params = TokenExchangeParameters.from_id_token( id_token=sample_id_token, @@ -658,7 +658,7 @@ async def test_exchange_token_non_json_error_response(sample_id_token, mock_toke provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -681,7 +681,7 @@ async def test_exchange_token_non_json_error_response(sample_id_token, mock_toke @pytest.mark.anyio -async def test_exchange_token_warning_for_non_na_token_type(sample_id_token, sample_id_jag, mock_token_storage): +async def test_exchange_token_warning_for_non_na_token_type(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): """Test token exchange logs warning for non-N_A token type.""" token_exchange_params = TokenExchangeParameters.from_id_token( id_token=sample_id_token, @@ -692,7 +692,7 @@ async def test_exchange_token_warning_for_non_na_token_type(sample_id_token, sam provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -726,10 +726,8 @@ async def test_exchange_token_warning_for_non_na_token_type(sample_id_token, sam @pytest.mark.anyio -async def test_exchange_id_jag_with_client_authentication(sample_id_jag, mock_token_storage): +async def test_exchange_id_jag_with_client_authentication(sample_id_jag: str, mock_token_storage: Any): """Test JWT bearer grant with client authentication.""" - from pydantic import HttpUrl - from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata token_exchange_params = TokenExchangeParameters.from_id_token( @@ -741,7 +739,7 @@ async def test_exchange_id_jag_with_client_authentication(sample_id_jag, mock_to provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -752,14 +750,14 @@ async def test_exchange_id_jag_with_client_authentication(sample_id_jag, mock_to provider.context.client_info = OAuthClientInformationFull( client_id="test-client-id", client_secret="test-client-secret", - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ) # Set up OAuth metadata provider.context.oauth_metadata = OAuthMetadata( - issuer=HttpUrl("https://auth.mcp-server.example/"), - authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), - token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), ) # Mock HTTP response @@ -789,10 +787,8 @@ async def test_exchange_id_jag_with_client_authentication(sample_id_jag, mock_to @pytest.mark.anyio -async def test_exchange_id_jag_with_client_id_only(sample_id_jag, mock_token_storage): +async def test_exchange_id_jag_with_client_id_only(sample_id_jag: str, mock_token_storage: Any): """Test JWT bearer grant with client_id but no client_secret (covers branch 304->307).""" - from pydantic import HttpUrl - from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata token_exchange_params = TokenExchangeParameters.from_id_token( @@ -804,7 +800,7 @@ async def test_exchange_id_jag_with_client_id_only(sample_id_jag, mock_token_sto provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -815,14 +811,14 @@ async def test_exchange_id_jag_with_client_id_only(sample_id_jag, mock_token_sto provider.context.client_info = OAuthClientInformationFull( client_id="test-client-id", client_secret=None, # No secret - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ) # Set up OAuth metadata provider.context.oauth_metadata = OAuthMetadata( - issuer=HttpUrl("https://auth.mcp-server.example/"), - authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), - token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), ) # Mock HTTP response @@ -853,10 +849,8 @@ async def test_exchange_id_jag_with_client_id_only(sample_id_jag, mock_token_sto @pytest.mark.anyio -async def test_exchange_id_jag_error_response(sample_id_jag, mock_token_storage): +async def test_exchange_id_jag_error_response(sample_id_jag: str, mock_token_storage: Any): """Test JWT bearer grant with error response.""" - from pydantic import HttpUrl - from mcp.shared.auth import OAuthMetadata token_exchange_params = TokenExchangeParameters.from_id_token( @@ -868,7 +862,7 @@ async def test_exchange_id_jag_error_response(sample_id_jag, mock_token_storage) provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -877,9 +871,9 @@ async def test_exchange_id_jag_error_response(sample_id_jag, mock_token_storage) # Set up OAuth metadata provider.context.oauth_metadata = OAuthMetadata( - issuer=HttpUrl("https://auth.mcp-server.example/"), - authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), - token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), ) # Mock error response @@ -900,10 +894,8 @@ async def test_exchange_id_jag_error_response(sample_id_jag, mock_token_storage) @pytest.mark.anyio -async def test_exchange_id_jag_non_json_error(sample_id_jag, mock_token_storage): +async def test_exchange_id_jag_non_json_error(sample_id_jag: str, mock_token_storage: Any): """Test JWT bearer grant with non-JSON error response.""" - from pydantic import HttpUrl - from mcp.shared.auth import OAuthMetadata token_exchange_params = TokenExchangeParameters.from_id_token( @@ -915,7 +907,7 @@ async def test_exchange_id_jag_non_json_error(sample_id_jag, mock_token_storage) provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -924,9 +916,9 @@ async def test_exchange_id_jag_non_json_error(sample_id_jag, mock_token_storage) # Set up OAuth metadata provider.context.oauth_metadata = OAuthMetadata( - issuer=HttpUrl("https://auth.mcp-server.example/"), - authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), - token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), ) # Mock error response with non-JSON content @@ -945,10 +937,8 @@ async def test_exchange_id_jag_non_json_error(sample_id_jag, mock_token_storage) @pytest.mark.anyio -async def test_exchange_id_jag_http_error(sample_id_jag, mock_token_storage): +async def test_exchange_id_jag_http_error(sample_id_jag: str, mock_token_storage: Any): """Test JWT bearer grant with HTTP error.""" - from pydantic import HttpUrl - from mcp.shared.auth import OAuthMetadata token_exchange_params = TokenExchangeParameters.from_id_token( @@ -960,7 +950,7 @@ async def test_exchange_id_jag_http_error(sample_id_jag, mock_token_storage): provider = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example/", client_metadata=OAuthClientMetadata( - redirect_uris=["http://localhost:8080/callback"], + redirect_uris=[AnyUrl("http://localhost:8080/callback")], ), storage=mock_token_storage, idp_token_endpoint="https://idp.example.com/oauth2/token", @@ -969,9 +959,9 @@ async def test_exchange_id_jag_http_error(sample_id_jag, mock_token_storage): # Set up OAuth metadata provider.context.oauth_metadata = OAuthMetadata( - issuer=HttpUrl("https://auth.mcp-server.example/"), - authorization_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/authorize"), - token_endpoint=HttpUrl("https://auth.mcp-server.example/oauth2/token"), + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), ) mock_client = Mock(spec=httpx.AsyncClient) From 82147782cdca41c94f5689aa853c9ff68c708c7c Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Thu, 27 Nov 2025 15:27:24 +0530 Subject: [PATCH 12/16] - Added new test cases for the missing code lines. --- .../test_enterprise_managed_auth_client.py | 144 +++++++++++++++++- 1 file changed, 139 insertions(+), 5 deletions(-) diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py index a0e53e771b..7b96fbf5a0 100644 --- a/tests/client/auth/test_enterprise_managed_auth_client.py +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -7,7 +7,7 @@ import httpx import jwt import pytest -from pydantic import AnyUrl, AnyHttpUrl +from pydantic import AnyHttpUrl, AnyUrl from mcp.client.auth import OAuthTokenError from mcp.client.auth.extensions.enterprise_managed_auth import ( @@ -506,7 +506,10 @@ async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any): @pytest.mark.anyio -async def test_exchange_token_with_client_authentication(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): +async def test_exchange_token_with_client_authentication( + sample_id_token: str, sample_id_jag: str, + mock_token_storage: Any +): """Test token exchange with client authentication.""" from mcp.shared.auth import OAuthClientInformationFull @@ -563,7 +566,10 @@ async def test_exchange_token_with_client_authentication(sample_id_token: str, s @pytest.mark.anyio -async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): +async def test_exchange_token_with_client_id_only( + sample_id_token: str, sample_id_jag: str, + mock_token_storage: Any +): """Test token exchange with client_id but no client_secret (covers branch 232->235).""" from mcp.shared.auth import OAuthClientInformationFull @@ -681,7 +687,9 @@ async def test_exchange_token_non_json_error_response(sample_id_token: str, mock @pytest.mark.anyio -async def test_exchange_token_warning_for_non_na_token_type(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): +async def test_exchange_token_warning_for_non_na_token_type( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): """Test token exchange logs warning for non-N_A token type.""" token_exchange_params = TokenExchangeParameters.from_id_token( id_token=sample_id_token, @@ -718,7 +726,7 @@ async def test_exchange_token_warning_for_non_na_token_type(sample_id_token: str import logging with patch.object( - logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" ) as mock_warning: id_jag = await provider.exchange_token_for_id_jag(mock_client) assert id_jag == sample_id_jag @@ -972,6 +980,132 @@ async def test_exchange_id_jag_http_error(sample_id_jag: str, mock_token_storage await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) +@pytest.mark.anyio +async def test_exchange_token_with_client_info_but_no_client_id( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test token exchange when client_info exists but client_id is None (covers line 231).""" + from mcp.shared.auth import OAuthClientInformationFull + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with client_id=None + provider.context.client_info = OAuthClientInformationFull( + client_id=None, # This should skip the client_id assignment + client_secret="test-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + + # Verify client_id was not included (None), but client_secret was included + call_args = mock_client.post.call_args + assert "client_id" not in call_args[1]["data"] + assert call_args[1]["data"]["client_secret"] == "test-secret" + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_info_but_no_client_id( + sample_id_jag: str, mock_token_storage: Any +): + """Test ID-JAG exchange when client_info exists but client_id is None (covers line 302).""" + from pydantic import AnyHttpUrl + + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Set client info with client_id=None + provider.context.client_info = OAuthClientInformationFull( + client_id=None, # This should skip the client_id assignment + client_secret="test-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify + assert token.access_token == "mcp-access-token-12345" + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + + # Verify client_id was not included (None), but client_secret was included + call_args = mock_client.post.call_args + assert "client_id" not in call_args[1]["data"] + assert call_args[1]["data"]["client_secret"] == "test-secret" + + def test_validate_token_exchange_params_missing_audience(): """Test validation fails for missing audience.""" params = TokenExchangeParameters( From 04ffe5a56e10c49c134b1ede98d600b09ae26527 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Thu, 27 Nov 2025 15:55:09 +0530 Subject: [PATCH 13/16] - Fixed the failing test cases. --- tests/client/auth/test_enterprise_managed_auth_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py index 7b96fbf5a0..72b977b954 100644 --- a/tests/client/auth/test_enterprise_managed_auth_client.py +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -1035,7 +1035,7 @@ async def test_exchange_token_with_client_info_but_no_client_id( # Verify client_id was not included (None), but client_secret was included call_args = mock_client.post.call_args - assert "client_id" not in call_args[1]["data"] + assert call_args[1]["data"]["client_id"] is None assert call_args[1]["data"]["client_secret"] == "test-secret" @@ -1102,7 +1102,7 @@ async def test_exchange_id_jag_with_client_info_but_no_client_id( # Verify client_id was not included (None), but client_secret was included call_args = mock_client.post.call_args - assert "client_id" not in call_args[1]["data"] + assert call_args[1]["data"]["client_id"] is None assert call_args[1]["data"]["client_secret"] == "test-secret" From 09c05aa6bc5c2d0270d008e6b58ab7c6e167a295 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Thu, 27 Nov 2025 16:13:24 +0530 Subject: [PATCH 14/16] - Fixed the test cases. --- .../test_enterprise_managed_auth_client.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py index 72b977b954..4a10f4f664 100644 --- a/tests/client/auth/test_enterprise_managed_auth_client.py +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -507,8 +507,7 @@ async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any): @pytest.mark.anyio async def test_exchange_token_with_client_authentication( - sample_id_token: str, sample_id_jag: str, - mock_token_storage: Any + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any ): """Test token exchange with client authentication.""" from mcp.shared.auth import OAuthClientInformationFull @@ -566,10 +565,7 @@ async def test_exchange_token_with_client_authentication( @pytest.mark.anyio -async def test_exchange_token_with_client_id_only( - sample_id_token: str, sample_id_jag: str, - mock_token_storage: Any -): +async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): """Test token exchange with client_id but no client_secret (covers branch 232->235).""" from mcp.shared.auth import OAuthClientInformationFull @@ -688,7 +684,7 @@ async def test_exchange_token_non_json_error_response(sample_id_token: str, mock @pytest.mark.anyio async def test_exchange_token_warning_for_non_na_token_type( - sample_id_token: str, sample_id_jag: str, mock_token_storage: Any + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any ): """Test token exchange logs warning for non-N_A token type.""" token_exchange_params = TokenExchangeParameters.from_id_token( @@ -726,7 +722,7 @@ async def test_exchange_token_warning_for_non_na_token_type( import logging with patch.object( - logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" ) as mock_warning: id_jag = await provider.exchange_token_for_id_jag(mock_client) assert id_jag == sample_id_jag @@ -1035,14 +1031,12 @@ async def test_exchange_token_with_client_info_but_no_client_id( # Verify client_id was not included (None), but client_secret was included call_args = mock_client.post.call_args - assert call_args[1]["data"]["client_id"] is None + assert "client_id" not in call_args[1]["data"] assert call_args[1]["data"]["client_secret"] == "test-secret" @pytest.mark.anyio -async def test_exchange_id_jag_with_client_info_but_no_client_id( - sample_id_jag: str, mock_token_storage: Any -): +async def test_exchange_id_jag_with_client_info_but_no_client_id(sample_id_jag: str, mock_token_storage: Any): """Test ID-JAG exchange when client_info exists but client_id is None (covers line 302).""" from pydantic import AnyHttpUrl @@ -1102,7 +1096,7 @@ async def test_exchange_id_jag_with_client_info_but_no_client_id( # Verify client_id was not included (None), but client_secret was included call_args = mock_client.post.call_args - assert call_args[1]["data"]["client_id"] is None + assert "client_id" not in call_args[1]["data"] assert call_args[1]["data"]["client_secret"] == "test-secret" From 28bb315f9022c3b4ad6eaf27a98010a0a4007397 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Fri, 12 Dec 2025 18:00:54 +0530 Subject: [PATCH 15/16] - Added typing for request payload structures TokenExchangeRequestData and JWTBearerGrantRequestData. - Added snippet file for adding code to the README.md file. - Added new section in README.md file to add information regarding: "how to use the access token once you get it" and "How does this work when the client ID is expired?". --- README.md | 209 +++++++++++++++--- .../clients/enterprise_managed_auth_client.py | 204 +++++++++++++++++ .../extensions/enterprise_managed_auth.py | 46 ++-- .../test_enterprise_managed_auth_client.py | 10 - 4 files changed, 417 insertions(+), 52 deletions(-) create mode 100644 examples/snippets/clients/enterprise_managed_auth_client.py diff --git a/README.md b/README.md index fab9e05ef6..ded8056f04 100644 --- a/README.md +++ b/README.md @@ -2376,42 +2376,140 @@ The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provide 3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant 4. **Use Access Token** to call protected MCP server tools +**Using the Access Token with MCP Server:** + +1. Once you have obtained the access token, you can use it to authenticate requests to the MCP server +2. The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions. + +**Handling Token Expiration and Refresh:** + +Access tokens have a limited lifetime and will expire. When tokens expire: + +- **Check Token Expiration**: Use the `expires_in` field to determine when the token expires +- **Refresh Flow**: When expired, repeat the token exchange flow with a fresh ID token from your IdP +- **Automatic Refresh**: Implement automatic token refresh before expiration (recommended for production) +- **Error Handling**: Catch authentication errors and retry with refreshed tokens + +**Important Notes:** + +- **ID Token Expiration**: If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange +- **Token Storage**: Store tokens securely and implement the `TokenStorage` interface to persist tokens between application restarts +- **Scope Changes**: If you need different scopes, you must obtain a new ID token from the IdP with the required scopes +- **Security**: Never log or expose access tokens or ID tokens in production environments + **Example Usage:** + ```python import asyncio +from datetime import datetime, timedelta, timezone +from typing import Any + import httpx from pydantic import AnyUrl +from mcp import ClientSession +from mcp.client.auth import OAuthTokenError, TokenStorage from mcp.client.auth.extensions import ( EnterpriseAuthOAuthClientProvider, TokenExchangeParameters, ) -from mcp.shared.auth import OAuthClientMetadata -from mcp.client.auth import TokenStorage +from mcp.client.sse import sse_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from mcp.types import CallToolResult + + +# Placeholder function for IdP authentication +async def get_id_token_from_idp() -> str: + """ + Placeholder function to get ID token from your IdP. + In production, implement actual IdP authentication flow. + """ + raise NotImplementedError("Implement your IdP authentication flow here") + # Define token storage implementation class SimpleTokenStorage(TokenStorage): - def __init__(self): - self._tokens = None - self._client_info = None - - async def get_tokens(self): + def __init__(self) -> None: + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: return self._tokens - - async def set_tokens(self, tokens): + + async def set_tokens(self, tokens: OAuthToken) -> None: self._tokens = tokens - - async def get_client_info(self): + + async def get_client_info(self) -> OAuthClientInformationFull | None: return self._client_info - - async def set_client_info(self, client_info): + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: self._client_info = client_info -async def main(): + +def is_token_expired(access_token: OAuthToken) -> bool: + """Check if the access token has expired.""" + if access_token.expires_in: + # Calculate expiration time + issued_at = datetime.now(timezone.utc) + expiration_time = issued_at + timedelta(seconds=access_token.expires_in) + return datetime.now(timezone.utc) >= expiration_time + return False + + +async def refresh_access_token( + enterprise_auth: EnterpriseAuthOAuthClientProvider, + client: httpx.AsyncClient, + id_token: str, +) -> OAuthToken: + """Refresh the access token when it expires.""" + try: + # Update token exchange parameters with fresh ID token + enterprise_auth.token_exchange_params.subject_token = id_token + + # Re-exchange for new ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + + # Get new access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + return access_token + except Exception as e: + print(f"Token refresh failed: {e}") + # Re-authenticate with IdP if ID token is also expired + id_token = await get_id_token_from_idp() + return await refresh_access_token(enterprise_auth, client, id_token) + + +async def call_tool_with_retry( + session: ClientSession, + tool_name: str, + arguments: dict[str, Any], + enterprise_auth: EnterpriseAuthOAuthClientProvider, + client: httpx.AsyncClient, + id_token: str, +) -> CallToolResult | None: + """Call a tool with automatic retry on token expiration.""" + max_retries = 1 + + for attempt in range(max_retries + 1): + try: + result = await session.call_tool(tool_name, arguments) + return result + except OAuthTokenError: + if attempt < max_retries: + print("Token expired, refreshing...") + # Refresh token and reconnect + _access_token = await refresh_access_token(enterprise_auth, client, id_token) + # Note: In production, you'd need to reconnect the session here + else: + raise + return None + + +async def main() -> None: # Step 1: Get ID token from your IdP (example with Okta) id_token = await get_id_token_from_idp() # Your IdP authentication - + # Step 2: Configure token exchange parameters token_exchange_params = TokenExchangeParameters.from_id_token( id_token=id_token, @@ -2419,13 +2517,12 @@ async def main(): mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID scope="mcp:tools mcp:resources", # Optional scopes ) - + # Step 3: Create enterprise auth provider enterprise_auth = EnterpriseAuthOAuthClientProvider( server_url="https://mcp-server.example.com", client_metadata=OAuthClientMetadata( client_name="Enterprise MCP Client", - client_id="your-client-id", redirect_uris=[AnyUrl("http://localhost:3000/callback")], grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], response_types=["token"], @@ -2434,23 +2531,85 @@ async def main(): idp_token_endpoint="https://your-idp.com/oauth2/v1/token", token_exchange_params=token_exchange_params, ) - - # Step 4: Perform token exchange and get access token + async with httpx.AsyncClient() as client: - # Exchange ID token for ID-JAG + # Step 4: Exchange ID token for ID-JAG id_jag = await enterprise_auth.exchange_token_for_id_jag(client) print(f"Obtained ID-JAG: {id_jag[:50]}...") - - # Exchange ID-JAG for access token - access_token = await enterprise_auth.exchange_id_jag_for_access_token( - client, id_jag - ) + + # Step 5: Exchange ID-JAG for access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) print(f"Access token obtained, expires in: {access_token.expires_in}s") + # Step 6: Check if token is expired (for demonstration) + if is_token_expired(access_token): + print("Token is expired, refreshing...") + access_token = await refresh_access_token(enterprise_auth, client, id_token) + + # Step 7: Use the access token to connect to MCP server + headers = {"Authorization": f"Bearer {access_token.access_token}"} + + async with sse_client(url="https://mcp-server.example.com", headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # Call tools with automatic retry on token expiration + result = await call_tool_with_retry( + session, "enterprise_tool", {"param": "value"}, enterprise_auth, client, id_token + ) + if result: + print(f"Tool result: {result.content}") + + # List available resources + resources = await session.list_resources() + for resource in resources.resources: + print(f"Resource: {resource.uri}") + + +async def maintain_active_session( + enterprise_auth: EnterpriseAuthOAuthClientProvider, + mcp_server_url: str, +) -> None: + """Maintain an active session with automatic token refresh.""" + id_token_var = await get_id_token_from_idp() + + async with httpx.AsyncClient() as client: + while True: + try: + # Update token exchange params with current ID token + enterprise_auth.token_exchange_params.subject_token = id_token_var + + # Get access token + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + + # Calculate refresh time (refresh before expiration) + refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300 + + # Use the token for MCP operations + headers = {"Authorization": f"Bearer {access_token.access_token}"} + async with sse_client(mcp_server_url, headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # Perform operations... + # Schedule refresh before token expires + await asyncio.sleep(refresh_in) + + except Exception as e: + print(f"Session error: {e}") + # Re-authenticate with IdP + id_token_var = await get_id_token_from_idp() + await asyncio.sleep(5) # Wait before retry + + if __name__ == "__main__": asyncio.run(main()) ``` +_Full example: [examples/snippets/clients/enterprise_managed_auth_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py)_ + + **Working with SAML Assertions:** If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions: diff --git a/examples/snippets/clients/enterprise_managed_auth_client.py b/examples/snippets/clients/enterprise_managed_auth_client.py new file mode 100644 index 0000000000..48cb115dc9 --- /dev/null +++ b/examples/snippets/clients/enterprise_managed_auth_client.py @@ -0,0 +1,204 @@ +import asyncio +from datetime import datetime, timedelta, timezone +from typing import Any + +import httpx +from pydantic import AnyUrl + +from mcp import ClientSession +from mcp.client.auth import OAuthTokenError, TokenStorage +from mcp.client.auth.extensions import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, +) +from mcp.client.sse import sse_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from mcp.types import CallToolResult + + +# Placeholder function for IdP authentication +async def get_id_token_from_idp() -> str: + """ + Placeholder function to get ID token from your IdP. + In production, implement actual IdP authentication flow. + """ + raise NotImplementedError("Implement your IdP authentication flow here") + + +# Define token storage implementation +class SimpleTokenStorage(TokenStorage): + def __init__(self) -> None: + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info + + +def is_token_expired(access_token: OAuthToken) -> bool: + """Check if the access token has expired.""" + if access_token.expires_in: + # Calculate expiration time + issued_at = datetime.now(timezone.utc) + expiration_time = issued_at + timedelta(seconds=access_token.expires_in) + return datetime.now(timezone.utc) >= expiration_time + return False + + +async def refresh_access_token( + enterprise_auth: EnterpriseAuthOAuthClientProvider, + client: httpx.AsyncClient, + id_token: str, +) -> OAuthToken: + """Refresh the access token when it expires.""" + try: + # Update token exchange parameters with fresh ID token + enterprise_auth.token_exchange_params.subject_token = id_token + + # Re-exchange for new ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + + # Get new access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + return access_token + except Exception as e: + print(f"Token refresh failed: {e}") + # Re-authenticate with IdP if ID token is also expired + id_token = await get_id_token_from_idp() + return await refresh_access_token(enterprise_auth, client, id_token) + + +async def call_tool_with_retry( + session: ClientSession, + tool_name: str, + arguments: dict[str, Any], + enterprise_auth: EnterpriseAuthOAuthClientProvider, + client: httpx.AsyncClient, + id_token: str, +) -> CallToolResult | None: + """Call a tool with automatic retry on token expiration.""" + max_retries = 1 + + for attempt in range(max_retries + 1): + try: + result = await session.call_tool(tool_name, arguments) + return result + except OAuthTokenError: + if attempt < max_retries: + print("Token expired, refreshing...") + # Refresh token and reconnect + _access_token = await refresh_access_token(enterprise_auth, client, id_token) + # Note: In production, you'd need to reconnect the session here + else: + raise + return None + + +async def main() -> None: + # Step 1: Get ID token from your IdP (example with Okta) + id_token = await get_id_token_from_idp() # Your IdP authentication + + # Step 2: Configure token exchange parameters + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL + mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID + scope="mcp:tools mcp:resources", # Optional scopes + ) + + # Step 3: Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example.com", + client_metadata=OAuthClientMetadata( + client_name="Enterprise MCP Client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", + token_exchange_params=token_exchange_params, + ) + + async with httpx.AsyncClient() as client: + # Step 4: Exchange ID token for ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + print(f"Obtained ID-JAG: {id_jag[:50]}...") + + # Step 5: Exchange ID-JAG for access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + print(f"Access token obtained, expires in: {access_token.expires_in}s") + + # Step 6: Check if token is expired (for demonstration) + if is_token_expired(access_token): + print("Token is expired, refreshing...") + access_token = await refresh_access_token(enterprise_auth, client, id_token) + + # Step 7: Use the access token to connect to MCP server + headers = {"Authorization": f"Bearer {access_token.access_token}"} + + async with sse_client(url="https://mcp-server.example.com", headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # Call tools with automatic retry on token expiration + result = await call_tool_with_retry( + session, "enterprise_tool", {"param": "value"}, enterprise_auth, client, id_token + ) + if result: + print(f"Tool result: {result.content}") + + # List available resources + resources = await session.list_resources() + for resource in resources.resources: + print(f"Resource: {resource.uri}") + + +async def maintain_active_session( + enterprise_auth: EnterpriseAuthOAuthClientProvider, + mcp_server_url: str, +) -> None: + """Maintain an active session with automatic token refresh.""" + id_token_var = await get_id_token_from_idp() + + async with httpx.AsyncClient() as client: + while True: + try: + # Update token exchange params with current ID token + enterprise_auth.token_exchange_params.subject_token = id_token_var + + # Get access token + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + + # Calculate refresh time (refresh before expiration) + refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300 + + # Use the token for MCP operations + headers = {"Authorization": f"Bearer {access_token.access_token}"} + async with sse_client(mcp_server_url, headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # Perform operations... + # Schedule refresh before token expires + await asyncio.sleep(refresh_in) + + except Exception as e: + print(f"Session error: {e}") + # Re-authenticate with IdP + id_token_var = await get_id_token_from_idp() + await asyncio.sleep(5) # Wait before retry + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py index e55283e968..032b388de4 100644 --- a/src/mcp/client/auth/extensions/enterprise_managed_auth.py +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -6,10 +6,12 @@ """ import logging -from typing import Any +from collections.abc import Awaitable, Callable import httpx +import jwt from pydantic import BaseModel, Field +from typing_extensions import TypedDict from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage from mcp.shared.auth import OAuthClientMetadata, OAuthToken @@ -17,9 +19,27 @@ logger = logging.getLogger(__name__) -# ============================================================================ -# Data Models -# ============================================================================ +class TokenExchangeRequestData(TypedDict, total=False): + """Type definition for RFC 8693 Token Exchange request data.""" + + grant_type: str + requested_token_type: str + audience: str + resource: str + subject_token: str + subject_token_type: str + scope: str + client_id: str + client_secret: str + + +class JWTBearerGrantRequestData(TypedDict, total=False): + """Type definition for RFC 7523 JWT Bearer Grant request data.""" + + grant_type: str + assertion: str + client_id: str + client_secret: str class TokenExchangeParameters(BaseModel): @@ -166,8 +186,8 @@ def __init__( storage: TokenStorage, idp_token_endpoint: str, token_exchange_params: TokenExchangeParameters, - redirect_handler: Any = None, - callback_handler: Any = None, + redirect_handler: Callable[[str], Awaitable[None]] | None = None, + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, timeout: float = 300.0, ) -> None: """ @@ -214,7 +234,7 @@ async def exchange_token_for_id_jag( logger.info("Starting token exchange for ID-JAG") # Build token exchange request - token_data = { + token_data: TokenExchangeRequestData = { "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "requested_token_type": self.token_exchange_params.requested_token_type, "audience": self.token_exchange_params.audience, @@ -292,7 +312,7 @@ async def exchange_id_jag_for_access_token( token_endpoint = str(self.context.oauth_metadata.token_endpoint) # Build JWT bearer grant request - token_data = { + token_data: JWTBearerGrantRequestData = { "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "assertion": id_jag, } @@ -353,18 +373,12 @@ async def _perform_authorization(self) -> httpx.Request: ) -# ============================================================================ -# Helper Functions -# ============================================================================ - - -def decode_id_jag(id_jag: str, verify: bool = False) -> IDJAGClaims: +def decode_id_jag(id_jag: str) -> IDJAGClaims: """ Decode an ID-JAG token without verification. Args: id_jag: The ID-JAG token string - verify: Whether to verify signature (requires key) Returns: Decoded ID-JAG claims @@ -372,8 +386,6 @@ def decode_id_jag(id_jag: str, verify: bool = False) -> IDJAGClaims: Note: For verification, use server-side validation instead. """ - import jwt - # Decode without verification for inspection claims = jwt.decode(id_jag, options={"verify_signature": False}) header = jwt.get_unverified_header(id_jag) diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py index 4a10f4f664..df3c1a22fa 100644 --- a/tests/client/auth/test_enterprise_managed_auth_client.py +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -149,11 +149,6 @@ def test_validate_token_exchange_params_missing_subject_token(): validate_token_exchange_params(params) -# ============================================================================ -# Tests for TokenExchangeResponse -# ============================================================================ - - def test_token_exchange_response_parsing(): """Test parsing token exchange response.""" response_json = """{ @@ -185,11 +180,6 @@ def test_token_exchange_response_id_jag_property(): assert response.id_jag == "the-id-jag-token" -# ============================================================================ -# Tests for IDJAGClaims -# ============================================================================ - - def test_decode_id_jag(sample_id_jag: str): """Test decoding ID-JAG token.""" claims = decode_id_jag(sample_id_jag) From 84162df24ad4efd748addb9fe680650cf71ac8a2 Mon Sep 17 00:00:00 2001 From: BinoyOza-okta Date: Fri, 12 Dec 2025 18:33:12 +0530 Subject: [PATCH 16/16] - Updated test case to include IDJAGClaims type model to verify payload. --- .../test_enterprise_managed_auth_client.py | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py index df3c1a22fa..7291897662 100644 --- a/tests/client/auth/test_enterprise_managed_auth_client.py +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -20,10 +20,6 @@ ) from mcp.shared.auth import OAuthClientMetadata -# ============================================================================ -# Fixtures -# ============================================================================ - @pytest.fixture def sample_id_token() -> str: @@ -42,22 +38,23 @@ def sample_id_token() -> str: @pytest.fixture def sample_id_jag() -> str: """Generate a sample ID-JAG token for testing.""" - payload = { - "jti": "unique-jwt-id-12345", - "iss": "https://idp.example.com", - "sub": "user123", - "aud": "https://auth.mcp-server.example/", - "resource": "https://mcp-server.example/", - "client_id": "mcp-client-app", - "exp": int(time.time()) + 300, - "iat": int(time.time()), - "scope": "read write", - } - token = jwt.encode(payload, "secret", algorithm="HS256") + # Create typed claims using IDJAGClaims model + claims = IDJAGClaims( + typ="oauth-id-jag+jwt", + jti="unique-jwt-id-12345", + iss="https://idp.example.com", + sub="user123", + aud="https://auth.mcp-server.example/", + resource="https://mcp-server.example/", + client_id="mcp-client-app", + exp=int(time.time()) + 300, + iat=int(time.time()), + scope="read write", + email=None, # Optional field + ) - # Manually add typ to header - header = jwt.get_unverified_header(token) - header["typ"] = "oauth-id-jag+jwt" + # Dump to dict for JWT encoding (exclude typ as it goes in header) + payload = claims.model_dump(exclude={"typ"}, exclude_none=True) return jwt.encode(payload, "secret", algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"}) @@ -73,11 +70,6 @@ def mock_token_storage() -> Any: return storage -# ============================================================================ -# Tests for TokenExchangeParameters -# ============================================================================ - - def test_token_exchange_params_from_id_token(): """Test creating TokenExchangeParameters from ID token.""" params = TokenExchangeParameters.from_id_token(