diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 5551eddf5b..280dc539cc 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -606,6 +606,17 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa f'Model token limit ({max_tokens or "provider default"}) exceeded before any response was generated. Increase the `max_tokens` model setting, or simplify the prompt to result in a shorter response that will fit within the limit.' ) + # Check for content filter on empty response + if self.model_response.finish_reason == 'content_filter': + details = self.model_response.provider_details or {} + reason = details.get('finish_reason', 'content_filter') + + body = _messages.ModelMessagesTypeAdapter.dump_json([self.model_response]).decode() + + raise exceptions.ContentFilterError( + f"Content filter triggered. Finish reason: '{reason}'", body=body + ) + # we got an empty response. # this sometimes happens with anthropic (and perhaps other models) # when the model has already returned text along side tool calls diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 2cdb850921..241bfe9ef4 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -26,6 +26,7 @@ 'UsageLimitExceeded', 'ModelAPIError', 'ModelHTTPError', + 'ContentFilterError', 'IncompleteToolCall', 'FallbackExceptionGroup', ) @@ -154,6 +155,10 @@ def __str__(self) -> str: return self.message +class ContentFilterError(UnexpectedModelBehavior): + """Raised when content filtering is triggered by the model provider resulting in an empty response.""" + + class ModelAPIError(AgentRunError): """Raised when a model provider API request fails.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 3a6ff9e665..eec5a342fc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -569,14 +569,13 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: raw_finish_reason = candidate.finish_reason if raw_finish_reason: # pragma: no branch vendor_details = {'finish_reason': raw_finish_reason.value} + # Add safety ratings to provider details + if candidate.safety_ratings: + vendor_details['safety_ratings'] = [r.model_dump(by_alias=True) for r in candidate.safety_ratings] finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) if candidate.content is None or candidate.content.parts is None: - if finish_reason == 'content_filter' and raw_finish_reason: - raise UnexpectedModelBehavior( - f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json() - ) - parts = [] # pragma: no cover + parts = [] else: parts = candidate.content.parts or [] @@ -752,6 +751,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: raw_finish_reason = candidate.finish_reason if raw_finish_reason: self.provider_details = {'finish_reason': raw_finish_reason.value} + + if candidate.safety_ratings: + self.provider_details['safety_ratings'] = [ + r.model_dump(by_alias=True) for r in candidate.safety_ratings + ] + self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) # Google streams the grounding metadata (including the web search queries and results) @@ -777,12 +782,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: yield self._parts_manager.handle_part(vendor_part_id=uuid4(), part=web_fetch_return) if candidate.content is None or candidate.content.parts is None: - if self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover - raise UnexpectedModelBehavior( - f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json() - ) - else: # pragma: no cover - continue + continue parts = candidate.content.parts if not parts: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 1a6137cc7b..163ecb6f3a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -11,7 +11,7 @@ from functools import cached_property from typing import Any, Literal, cast, overload -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from pydantic_core import to_json from typing_extensions import assert_never, deprecated @@ -164,6 +164,36 @@ _OPENAI_IMAGE_SIZES: tuple[_OPENAI_IMAGE_SIZE, ...] = _utils.get_args(_OPENAI_IMAGE_SIZE) +class _AzureContentFilterResultDetail(BaseModel): + filtered: bool + severity: str | None = None + detected: bool | None = None + + +class _AzureContentFilterResult(BaseModel): + hate: _AzureContentFilterResultDetail | None = None + self_harm: _AzureContentFilterResultDetail | None = None + sexual: _AzureContentFilterResultDetail | None = None + violence: _AzureContentFilterResultDetail | None = None + jailbreak: _AzureContentFilterResultDetail | None = None + profanity: _AzureContentFilterResultDetail | None = None + + +class _AzureInnerError(BaseModel): + code: str + content_filter_result: _AzureContentFilterResult + + +class _AzureError(BaseModel): + code: str + message: str + innererror: _AzureInnerError | None = None + + +class _AzureErrorResponse(BaseModel): + error: _AzureError + + def _resolve_openai_image_generation_size( tool: ImageGenerationTool, ) -> _OPENAI_IMAGE_SIZE: @@ -194,6 +224,36 @@ def _resolve_openai_image_generation_size( return mapped_size +def _check_azure_content_filter(e: APIStatusError, system: str, model_name: str) -> ModelResponse | None: + """Check if the error is an Azure content filter error.""" + # Assign to Any to avoid 'dict[Unknown, Unknown]' inference in strict mode + body_any: Any = e.body + + if system == 'azure' and e.status_code == 400 and isinstance(body_any, dict): + try: + error_data = _AzureErrorResponse.model_validate(body_any) + + if error_data.error.code == 'content_filter': + provider_details: dict[str, Any] = {'finish_reason': 'content_filter'} + + if error_data.error.innererror: + provider_details['content_filter_result'] = ( + error_data.error.innererror.content_filter_result.model_dump(exclude_none=True) + ) + + return ModelResponse( + parts=[], # Empty parts to trigger content filter error in agent graph + model_name=model_name, + timestamp=_utils.now_utc(), + provider_name=system, + finish_reason='content_filter', + provider_details=provider_details, + ) + except ValidationError: + pass + return None + + class OpenAIChatModelSettings(ModelSettings, total=False): """Settings used for an OpenAI model request.""" @@ -532,6 +592,11 @@ async def request( response = await self._completions_create( messages, False, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) + + # Handle ModelResponse returned directly (for content filters) + if isinstance(response, ModelResponse): + return response + model_response = self._process_response(response) return model_response @@ -570,7 +635,7 @@ async def _completions_create( stream: Literal[False], model_settings: OpenAIChatModelSettings, model_request_parameters: ModelRequestParameters, - ) -> chat.ChatCompletion: ... + ) -> chat.ChatCompletion | ModelResponse: ... async def _completions_create( self, @@ -578,7 +643,7 @@ async def _completions_create( stream: bool, model_settings: OpenAIChatModelSettings, model_request_parameters: ModelRequestParameters, - ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: + ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk] | ModelResponse: tools = self._get_tools(model_request_parameters) web_search_options = self._get_web_search_options(model_request_parameters) @@ -642,6 +707,8 @@ async def _completions_create( extra_body=model_settings.get('extra_body'), ) except APIStatusError as e: + if model_response := _check_azure_content_filter(e, self.system, self.model_name): + return model_response if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover @@ -689,6 +756,7 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons raise UnexpectedModelBehavior(f'Invalid response from {self.system} chat completions endpoint: {e}') from e choice = response.choices[0] + items: list[ModelResponsePart] = [] if thinking_parts := self._process_thinking(choice.message): @@ -1217,6 +1285,11 @@ async def request( response = await self._responses_create( messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) + + # Handle ModelResponse + if isinstance(response, ModelResponse): + return response + return self._process_response(response, model_request_parameters) @asynccontextmanager @@ -1397,7 +1470,7 @@ async def _responses_create( # noqa: C901 stream: bool, model_settings: OpenAIResponsesModelSettings, model_request_parameters: ModelRequestParameters, - ) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]: + ) -> responses.Response | AsyncStream[responses.ResponseStreamEvent] | ModelResponse: tools = ( self._get_builtin_tools(model_request_parameters) + list(model_settings.get('openai_builtin_tools', [])) @@ -1499,6 +1572,9 @@ async def _responses_create( # noqa: C901 extra_body=model_settings.get('extra_body'), ) except APIStatusError as e: + if model_response := _check_azure_content_filter(e, self.system, self.model_name): + return model_response + if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover @@ -2190,6 +2266,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: raw_finish_reason = ( details.reason if (details := chunk.response.incomplete_details) else chunk.response.status ) + if raw_finish_reason: # pragma: no branch self.provider_details = {'finish_reason': raw_finish_reason} self.finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 0560dc883f..fdb7e163d5 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -3,6 +3,7 @@ import asyncio import base64 import datetime +import json import os import re import tempfile @@ -58,7 +59,13 @@ WebFetchTool, WebSearchTool, ) -from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError +from pydantic_ai.exceptions import ( + ContentFilterError, + ModelAPIError, + ModelHTTPError, + ModelRetry, + UserError, +) from pydantic_ai.messages import ( BuiltinToolCallEvent, # pyright: ignore[reportDeprecated] BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] @@ -998,9 +1005,26 @@ async def test_google_model_safety_settings(allow_model_requests: None, google_p ) agent = Agent(m, instructions='You hate the world!', model_settings=settings) - with pytest.raises(UnexpectedModelBehavior, match="Content filter 'SAFETY' triggered"): + with pytest.raises( + ContentFilterError, + match="Content filter triggered. Finish reason: 'SAFETY'", + ) as exc_info: await agent.run('Tell me a joke about a Brazilians.') + # Verify that we captured the safety settings in the exception body + assert exc_info.value.body is not None + body_json = json.loads(exc_info.value.body) + assert len(body_json) == 1 + response_msg = body_json[0] + + assert response_msg['finish_reason'] == 'content_filter' + details = response_msg['provider_details'] + assert details['finish_reason'] == 'SAFETY' + assert len(details['safety_ratings']) > 0 + # The first rating should reflect the blocking + assert details['safety_ratings'][0]['category'] == 'HARM_CATEGORY_HATE_SPEECH' + assert details['safety_ratings'][0]['blocked'] is True + async def test_google_model_web_search_tool(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.5-pro', provider=google_provider) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 993eebcc88..7228435db0 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -40,6 +40,7 @@ ) from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.builtin_tools import ImageGenerationTool, WebSearchTool +from pydantic_ai.exceptions import ContentFilterError from pydantic_ai.models import ModelRequestParameters from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile @@ -59,7 +60,7 @@ ) with try_import() as imports_successful: - from openai import APIConnectionError, APIStatusError, AsyncOpenAI + from openai import APIConnectionError, APIStatusError, AsyncAzureOpenAI, AsyncOpenAI from openai.types import chat from openai.types.chat.chat_completion import ChoiceLogprobs from openai.types.chat.chat_completion_chunk import ( @@ -84,6 +85,7 @@ _resolve_openai_image_generation_size, # pyright: ignore[reportPrivateUsage] ) from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer + from pydantic_ai.providers.azure import AzureProvider from pydantic_ai.providers.cerebras import CerebrasProvider from pydantic_ai.providers.google import GoogleProvider from pydantic_ai.providers.ollama import OllamaProvider @@ -3560,6 +3562,183 @@ async def test_openai_reasoning_in_thinking_tags(allow_model_requests: None): ) +def test_azure_prompt_filter_error(allow_model_requests: None) -> None: + body = { + 'error': { + 'code': 'content_filter', + 'message': 'The content was filtered.', + 'innererror': { + 'code': 'ResponsibleAIPolicyViolation', + 'content_filter_result': { + 'hate': {'filtered': True, 'severity': 'high'}, + 'self_harm': {'filtered': False, 'severity': 'safe'}, + 'sexual': {'filtered': False, 'severity': 'safe'}, + 'violence': {'filtered': False, 'severity': 'medium'}, + 'jailbreak': {'filtered': False, 'detected': False}, + 'profanity': {'filtered': False, 'detected': True}, # Added profanity + }, + }, + } + } + + mock_client = MockOpenAI.create_mock( + APIStatusError( + 'content filter', + response=httpx.Response(status_code=400, request=httpx.Request('POST', 'https://example.com/v1')), + body=body, + ) + ) + + m = OpenAIChatModel('gpt-5-mini', provider=AzureProvider(openai_client=cast(AsyncAzureOpenAI, mock_client))) + agent = Agent(m) + + with pytest.raises( + ContentFilterError, match=r"Content filter triggered. Finish reason: 'content_filter'" + ) as exc_info: + agent.run_sync('bad prompt') + + assert exc_info.value.body is not None + + assert json.loads(exc_info.value.body) == snapshot( + [ + { + 'parts': [], + 'usage': { + 'input_tokens': 0, + 'cache_write_tokens': 0, + 'cache_read_tokens': 0, + 'output_tokens': 0, + 'input_audio_tokens': 0, + 'cache_audio_read_tokens': 0, + 'output_audio_tokens': 0, + 'details': {}, + }, + 'model_name': 'gpt-5-mini', + 'timestamp': IsStr(), + 'kind': 'response', + 'provider_name': 'azure', + 'provider_url': None, + 'provider_details': { + 'finish_reason': 'content_filter', + 'content_filter_result': { + 'hate': {'filtered': True, 'severity': 'high'}, + 'self_harm': {'filtered': False, 'severity': 'safe'}, + 'sexual': {'filtered': False, 'severity': 'safe'}, + 'violence': {'filtered': False, 'severity': 'medium'}, + 'jailbreak': {'filtered': False, 'detected': False}, + 'profanity': {'filtered': False, 'detected': True}, + }, + }, + 'provider_response_id': None, + 'finish_reason': 'content_filter', + 'run_id': IsStr(), + 'metadata': None, + } + ] + ) + + +def test_responses_azure_prompt_filter_error(allow_model_requests: None) -> None: + mock_client = MockOpenAIResponses.create_mock( + APIStatusError( + 'content filter', + response=httpx.Response(status_code=400, request=httpx.Request('POST', 'https://example.com/v1')), + body={'error': {'code': 'content_filter', 'message': 'The content was filtered.'}}, + ) + ) + + m = OpenAIResponsesModel('gpt-5-mini', provider=AzureProvider(openai_client=cast(AsyncAzureOpenAI, mock_client))) + agent = Agent(m) + + with pytest.raises(ContentFilterError, match=r"Content filter triggered. Finish reason: 'content_filter'"): + agent.run_sync('bad prompt') + + +async def test_openai_response_filter_error(allow_model_requests: None): + c = completion_message( + ChatCompletionMessage(content=None, role='assistant'), + ) + c.choices[0].finish_reason = 'content_filter' + c.model = 'gpt-5-mini' + + mock_client = MockOpenAI.create_mock(c) + m = OpenAIChatModel('gpt-5-mini', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m) + + with pytest.raises(ContentFilterError, match=r"Content filter triggered. Finish reason: 'content_filter'"): + await agent.run('hello') + + +async def test_openai_response_filter_with_partial_content(allow_model_requests: None): + """Test that NO exception is raised if content is returned, even if finish_reason is content_filter.""" + c = completion_message( + ChatCompletionMessage(content='Partial', role='assistant'), + ) + c.choices[0].finish_reason = 'content_filter' + + mock_client = MockOpenAI.create_mock(c) + m = OpenAIChatModel('gpt-5-mini', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m) + + result = await agent.run('hello') + assert result.output == 'Partial' + + +def test_azure_400_non_content_filter(allow_model_requests: None) -> None: + """Test a 400 error from Azure that is NOT a content filter (different code).""" + mock_client = MockOpenAI.create_mock( + APIStatusError( + 'Bad Request', + response=httpx.Response(status_code=400, request=httpx.Request('POST', 'https://example.com/v1')), + body={'error': {'code': 'invalid_parameter', 'message': 'Invalid param.'}}, + ) + ) + m = OpenAIChatModel('gpt-5-mini', provider=AzureProvider(openai_client=cast(AsyncAzureOpenAI, mock_client))) + agent = Agent(m) + + with pytest.raises(ModelHTTPError) as exc_info: + agent.run_sync('hello') + + assert exc_info.value.status_code == 400 + assert not isinstance(exc_info.value, ContentFilterError) + + +def test_azure_400_non_dict_body(allow_model_requests: None) -> None: + """Test a 400 error from Azure where the body is not a dictionary.""" + mock_client = MockOpenAI.create_mock( + APIStatusError( + 'Bad Request', + response=httpx.Response(status_code=400, request=httpx.Request('POST', 'https://example.com/v1')), + body='Raw string body', + ) + ) + m = OpenAIChatModel('gpt-5-mini', provider=AzureProvider(openai_client=cast(AsyncAzureOpenAI, mock_client))) + agent = Agent(m) + + with pytest.raises(ModelHTTPError) as exc_info: + agent.run_sync('hello') + + assert exc_info.value.status_code == 400 + + +def test_azure_400_malformed_error(allow_model_requests: None) -> None: + """Test a 400 error from Azure where body matches dict but error structure is wrong.""" + mock_client = MockOpenAI.create_mock( + APIStatusError( + 'Bad Request', + response=httpx.Response(status_code=400, request=httpx.Request('POST', 'https://example.com/v1')), + body={'something_else': 'foo'}, # No 'error' key + ) + ) + m = OpenAIChatModel('gpt-5-mini', provider=AzureProvider(openai_client=cast(AsyncAzureOpenAI, mock_client))) + agent = Agent(m) + + with pytest.raises(ModelHTTPError) as exc_info: + agent.run_sync('hello') + + assert exc_info.value.status_code == 400 + + async def test_openai_chat_instructions_after_system_prompts(allow_model_requests: None): """Test that instructions are inserted after all system prompts in mapped messages.""" mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) diff --git a/tests/test_agent.py b/tests/test_agent.py index 6ce2d91c54..c3bdae376f 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -65,6 +65,7 @@ WebSearchTool, WebSearchUserLocation, ) +from pydantic_ai.exceptions import ContentFilterError from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import OutputObjectDefinition, StructuredDict, ToolOutput @@ -7165,3 +7166,43 @@ async def test_dynamic_tool_in_run_call(): assert isinstance(tool, WebSearchTool) assert tool.user_location is not None assert tool.user_location.get('city') == 'Berlin' + + +async def test_central_content_filter_handling(): + """ + Test that the agent graph correctly raises ContentFilterError + when a model returns finish_reason='content_filter' AND empty content. + """ + + async def filtered_response(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + return ModelResponse( + parts=[], + model_name='test-model', + finish_reason='content_filter', + provider_details={'finish_reason': 'content_filter'}, + ) + + model = FunctionModel(function=filtered_response, model_name='test-model') + agent = Agent(model) + + with pytest.raises(ContentFilterError, match="Content filter triggered. Finish reason: 'content_filter'"): + await agent.run('Trigger filter') + + +async def test_central_content_filter_with_partial_content(): + """ + Test that the agent graph returns partial content (does not raise exception) + even if finish_reason='content_filter', provided parts are not empty. + """ + + async def filtered_response(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + return ModelResponse( + parts=[TextPart('Partially generated content...')], model_name='test-model', finish_reason='content_filter' + ) + + model = FunctionModel(function=filtered_response, model_name='test-model') + agent = Agent(model) + + # Should NOT raise ContentFilterError + result = await agent.run('Trigger filter') + assert result.output == 'Partially generated content...'