Skip to content

Commit 9631781

Browse files
Motta Kinmochow13
authored andcommitted
Cast source to specific type; update tests to use _map_messages
1 parent 6e6f667 commit 9631781

File tree

2 files changed

+78
-24
lines changed

2 files changed

+78
-24
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import anyio.to_thread
1313
from botocore.exceptions import ClientError
14+
from mypy_boto3_bedrock_runtime.type_defs import DocumentSourceTypeDef
1415
from typing_extensions import ParamSpec, assert_never
1516

1617
from pydantic_ai import (
@@ -743,15 +744,15 @@ async def _map_user_prompt( # noqa: C901
743744
if item.kind == 'image-url':
744745
format = item.media_type.split('/')[1]
745746
assert format in ('jpeg', 'png', 'gif', 'webp'), f'Unsupported image format: {format}'
746-
image: ImageBlockTypeDef = {'format': format, 'source': cast(Any, source)}
747+
image: ImageBlockTypeDef = {'format': format, 'source': cast(DocumentSourceTypeDef, source)}
747748
content.append({'image': image})
748749

749750
elif item.kind == 'document-url':
750751
name = f'Document {next(document_count)}'
751752
document: DocumentBlockTypeDef = {
752753
'name': name,
753754
'format': item.format,
754-
'source': cast(Any, source),
755+
'source': cast(DocumentSourceTypeDef, source),
755756
}
756757
content.append({'document': document})
757758

@@ -768,7 +769,7 @@ async def _map_user_prompt( # noqa: C901
768769
'wmv',
769770
'three_gp',
770771
), f'Unsupported video format: {format}'
771-
video: VideoBlockTypeDef = {'format': format, 'source': cast(Any, source)}
772+
video: VideoBlockTypeDef = {'format': format, 'source': cast(DocumentSourceTypeDef, source)}
772773
content.append({'video': video})
773774
elif isinstance(item, AudioUrl): # pragma: no cover
774775
raise NotImplementedError('Audio is not supported yet.')

tests/models/test_bedrock.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -739,42 +739,95 @@ async def test_text_document_url_input(allow_model_requests: None, bedrock_provi
739739
)
740740

741741

742-
@pytest.mark.vcr()
743-
async def test_s3_image_url_input(allow_model_requests: None, bedrock_provider: BedrockProvider):
742+
async def test_s3_image_url_input(bedrock_provider: BedrockProvider):
744743
"""Test that s3:// image URLs are passed directly to Bedrock API without downloading."""
745-
m = BedrockConverseModel('us.amazon.nova-pro-v1:0', provider=bedrock_provider)
746-
agent = Agent(m, system_prompt='You are a helpful chatbot.')
744+
model = BedrockConverseModel('us.amazon.nova-pro-v1:0', provider=bedrock_provider)
747745
image_url = ImageUrl(url='s3://my-bucket/images/test-image.jpg', media_type='image/jpeg')
748746

749-
result = await agent.run(['What is in this image?', image_url])
750-
assert result.output == snapshot(
751-
'The image shows a scenic landscape with mountains in the background and a clear blue sky above.'
747+
req = [
748+
ModelRequest(parts=[UserPromptPart(content=['What is in this image?', image_url])]),
749+
]
750+
751+
_, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage]
752+
753+
assert bedrock_messages == snapshot(
754+
[
755+
{
756+
'role': 'user',
757+
'content': [
758+
{'text': 'What is in this image?'},
759+
{
760+
'image': {
761+
'format': 'jpeg',
762+
'source': {'s3Location': {'uri': 's3://my-bucket/images/test-image.jpg'}},
763+
}
764+
},
765+
],
766+
}
767+
]
752768
)
753769

754770

755-
@pytest.mark.vcr()
756-
async def test_s3_video_url_input(allow_model_requests: None, bedrock_provider: BedrockProvider):
771+
async def test_s3_video_url_input(bedrock_provider: BedrockProvider):
757772
"""Test that s3:// video URLs are passed directly to Bedrock API."""
758-
m = BedrockConverseModel('us.amazon.nova-pro-v1:0', provider=bedrock_provider)
759-
agent = Agent(m, system_prompt='You are a helpful chatbot.')
773+
model = BedrockConverseModel('us.amazon.nova-pro-v1:0', provider=bedrock_provider)
760774
video_url = VideoUrl(url='s3://my-bucket/videos/test-video.mp4', media_type='video/mp4')
761775

762-
result = await agent.run(['Describe this video', video_url])
763-
assert result.output == snapshot(
764-
'The video shows a time-lapse of a sunset over the ocean with waves gently rolling onto the shore.'
776+
# Create a ModelRequest with the S3 video URL
777+
req = [
778+
ModelRequest(parts=[UserPromptPart(content=['Describe this video', video_url])]),
779+
]
780+
781+
# Call the mapping function directly
782+
_, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage]
783+
784+
assert bedrock_messages == snapshot(
785+
[
786+
{
787+
'role': 'user',
788+
'content': [
789+
{'text': 'Describe this video'},
790+
{
791+
'video': {
792+
'format': 'mp4',
793+
'source': {'s3Location': {'uri': 's3://my-bucket/videos/test-video.mp4'}},
794+
}
795+
},
796+
],
797+
}
798+
]
765799
)
766800

767801

768-
@pytest.mark.vcr()
769-
async def test_s3_document_url_input(allow_model_requests: None, bedrock_provider: BedrockProvider):
802+
async def test_s3_document_url_input(bedrock_provider: BedrockProvider):
770803
"""Test that s3:// document URLs are passed directly to Bedrock API."""
771-
m = BedrockConverseModel('anthropic.claude-v2', provider=bedrock_provider)
772-
agent = Agent(m, system_prompt='You are a helpful chatbot.')
804+
model = BedrockConverseModel('anthropic.claude-v2', provider=bedrock_provider)
773805
document_url = DocumentUrl(url='s3://my-bucket/documents/test-doc.pdf', media_type='application/pdf')
774806

775-
result = await agent.run(['What is the main content on this document?', document_url])
776-
assert result.output == snapshot(
777-
'Based on the provided document, the main content discusses best practices for cloud storage and data management.'
807+
# Create a ModelRequest with the S3 document URL
808+
req = [
809+
ModelRequest(parts=[UserPromptPart(content=['What is the main content on this document?', document_url])]),
810+
]
811+
812+
# Call the mapping function directly
813+
_, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage]
814+
815+
assert bedrock_messages == snapshot(
816+
[
817+
{
818+
'role': 'user',
819+
'content': [
820+
{'text': 'What is the main content on this document?'},
821+
{
822+
'document': {
823+
'format': 'pdf',
824+
'name': 'Document 1',
825+
'source': {'s3Location': {'uri': 's3://my-bucket/documents/test-doc.pdf'}},
826+
}
827+
},
828+
],
829+
}
830+
]
778831
)
779832

780833

0 commit comments

Comments
 (0)