Skip to content

Commit 5ba691d

Browse files
g-eojDouweM
andauthored
Add Agent.output_json_schema() method (#3454)
Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent 8ac6436 commit 5ba691d

File tree

5 files changed

+623
-0
lines changed

5 files changed

+623
-0
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
1010

1111
from pydantic import Json, TypeAdapter, ValidationError
12+
from pydantic._internal._typing_extra import get_function_type_hints
1213
from pydantic_core import SchemaValidator, to_json
1314
from typing_extensions import Self, TypedDict, TypeVar
1415

@@ -1012,3 +1013,34 @@ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem
10121013
else:
10131014
outputs_flat.append(cast(_OutputSpecItem[T], output))
10141015
return outputs_flat
1016+
1017+
1018+
def types_from_output_spec(output_spec: OutputSpec[T]) -> Sequence[T | type[str]]:
1019+
outputs: Sequence[OutputSpec[T]]
1020+
if isinstance(output_spec, Sequence):
1021+
outputs = output_spec
1022+
else:
1023+
outputs = (output_spec,)
1024+
1025+
outputs_flat: list[T | type[str]] = []
1026+
for output in outputs:
1027+
if isinstance(output, NativeOutput):
1028+
outputs_flat.extend(types_from_output_spec(output.outputs))
1029+
elif isinstance(output, PromptedOutput):
1030+
outputs_flat.extend(types_from_output_spec(output.outputs))
1031+
elif isinstance(output, TextOutput):
1032+
outputs_flat.append(str)
1033+
elif isinstance(output, ToolOutput):
1034+
outputs_flat.extend(types_from_output_spec(output.output))
1035+
elif union_types := _utils.get_union_args(output):
1036+
outputs_flat.extend(union_types)
1037+
elif inspect.isfunction(output) or inspect.ismethod(output):
1038+
type_hints = get_function_type_hints(output)
1039+
if return_annotation := type_hints.get('return', None):
1040+
outputs_flat.extend(types_from_output_spec(return_annotation))
1041+
else:
1042+
outputs_flat.append(str)
1043+
else:
1044+
outputs_flat.append(cast(T, output))
1045+
1046+
return outputs_flat

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, cast, overload
1010

1111
import anyio
12+
from pydantic import TypeAdapter
1213
from typing_extensions import Self, TypeIs, TypeVar
1314

1415
from pydantic_graph import End
@@ -23,6 +24,8 @@
2324
result,
2425
usage as _usage,
2526
)
27+
from .._json_schema import JsonSchema
28+
from .._output import types_from_output_spec
2629
from .._tool_manager import ToolManager
2730
from ..builtin_tools import AbstractBuiltinTool
2831
from ..output import OutputDataT, OutputSpec
@@ -123,6 +126,28 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
123126
"""
124127
raise NotImplementedError
125128

129+
def output_json_schema(self, output_type: OutputSpec[OutputDataT | RunOutputDataT] | None = None) -> JsonSchema:
130+
"""The output return JSON schema."""
131+
if output_type is None:
132+
output_type = self.output_type
133+
134+
return_types = types_from_output_spec(output_spec=output_type)
135+
136+
json_schemas: list[JsonSchema] = []
137+
for return_type in return_types:
138+
json_schema = TypeAdapter(return_type).json_schema(mode='serialization')
139+
if json_schema not in json_schemas:
140+
json_schemas.append(json_schema)
141+
142+
if len(json_schemas) == 1:
143+
return json_schemas[0]
144+
else:
145+
json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas)
146+
json_schema: JsonSchema = {'anyOf': json_schemas}
147+
if all_defs:
148+
json_schema['$defs'] = all_defs
149+
return json_schema
150+
126151
@overload
127152
async def run(
128153
self,

pydantic_ai_slim/pydantic_ai/agent/wrapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
models,
1111
usage as _usage,
1212
)
13+
from .._json_schema import JsonSchema
1314
from ..builtin_tools import AbstractBuiltinTool
1415
from ..output import OutputDataT, OutputSpec
1516
from ..run import AgentRun
@@ -68,6 +69,9 @@ async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
6869
async def __aexit__(self, *args: Any) -> bool | None:
6970
return await self.wrapped.__aexit__(*args)
7071

72+
def output_json_schema(self, output_type: OutputSpec[OutputDataT | RunOutputDataT] | None = None) -> JsonSchema:
73+
return self.wrapped.output_json_schema(output_type=output_type)
74+
7175
@overload
7276
def iter(
7377
self,

tests/test_agent.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5256,6 +5256,15 @@ def foo() -> str:
52565256
assert wrapper_agent.name == 'wrapped'
52575257
assert wrapper_agent.output_type == agent.output_type
52585258
assert wrapper_agent.event_stream_handler == agent.event_stream_handler
5259+
assert wrapper_agent.output_json_schema() == snapshot(
5260+
{
5261+
'type': 'object',
5262+
'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'string'}},
5263+
'title': 'Foo',
5264+
'required': ['a', 'b'],
5265+
}
5266+
)
5267+
assert wrapper_agent.output_json_schema(output_type=str) == snapshot({'type': 'string'})
52595268

52605269
bar_toolset = FunctionToolset()
52615270

0 commit comments

Comments
 (0)