Skip to content

Commit 060ddbd

Browse files
Bedrock Converse Streaming Support (#1565)
* Add more formatting to custom event validatators * Add streamed responses to converse mock server * Add streaming fixtures for testing for converse * Rename other bedrock test files * Add tests for converse streaming * Instrument converse streaming * Move GeneratorProxy adjacent functions to mixin * Fix checking of supported models * Reorganize converse error tests * Port new converse botocore tests to aiobotocore * Instrument response streaming in aiobotocore converse * Fix suggestions from code review * Port in converse changes from strands PR * Delete commented code --------- Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent f181585 commit 060ddbd

14 files changed

+837
-641
lines changed

newrelic/hooks/external_aiobotocore.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ async def wrap_client__make_api_call(wrapped, instance, args, kwargs):
149149
bedrock_attrs = extract_bedrock_converse_attrs(
150150
args[1], response, response_headers, model, span_id, trace_id
151151
)
152+
153+
if response_streaming:
154+
# Wrap EventStream object here to intercept __iter__ method instead of instrumenting class.
155+
# This class is used in numerous other services in botocore, and would cause conflicts.
156+
response["stream"] = stream = AsyncEventStreamWrapper(response["stream"])
157+
stream._nr_ft = ft or None
158+
stream._nr_bedrock_attrs = bedrock_attrs or {}
159+
stream._nr_model_extractor = stream_extractor or None
160+
stream._nr_is_converse = True
161+
return response
162+
152163
else:
153164
bedrock_attrs = {
154165
"request_id": response_headers.get("x-amzn-requestid"),

newrelic/hooks/external_botocore.py

Lines changed: 122 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -576,9 +576,9 @@ def handle_bedrock_exception(
576576
}
577577

578578
if is_embedding:
579-
notice_error_attributes.update({"embedding_id": str(uuid.uuid4())})
579+
notice_error_attributes["embedding_id"] = str(uuid.uuid4())
580580
else:
581-
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
581+
notice_error_attributes["completion_id"] = str(uuid.uuid4())
582582

583583
if ft:
584584
ft.notice_error(attributes=notice_error_attributes)
@@ -766,7 +766,7 @@ def _wrap_bedrock_runtime_converse(wrapped, instance, args, kwargs):
766766
if not transaction:
767767
return wrapped(*args, **kwargs)
768768

769-
settings = transaction.settings or global_settings
769+
settings = transaction.settings or global_settings()
770770
if not settings.ai_monitoring.enabled:
771771
return wrapped(*args, **kwargs)
772772

@@ -826,6 +826,16 @@ def _wrap_bedrock_runtime_converse(wrapped, instance, args, kwargs):
826826
bedrock_attrs = extract_bedrock_converse_attrs(kwargs, response, response_headers, model, span_id, trace_id)
827827

828828
try:
829+
if response_streaming:
830+
# Wrap EventStream object here to intercept __iter__ method instead of instrumenting class.
831+
# This class is used in numerous other services in botocore, and would cause conflicts.
832+
response["stream"] = stream = EventStreamWrapper(response["stream"])
833+
stream._nr_ft = ft
834+
stream._nr_bedrock_attrs = bedrock_attrs
835+
stream._nr_model_extractor = stream_extractor
836+
stream._nr_is_converse = True
837+
return response
838+
829839
ft.__exit__(None, None, None)
830840
bedrock_attrs["duration"] = ft.duration * 1000
831841
run_bedrock_response_extractor(response_extractor, {}, bedrock_attrs, False, transaction)
@@ -846,39 +856,132 @@ def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, sp
846856

847857
# kwargs["messages"] can hold multiple requests and responses to maintain conversation history
848858
# We grab the last message (the newest request) in the list each time, so we don't duplicate recorded data
859+
_input_messages = kwargs.get("messages", [])
860+
_input_messages = _input_messages and (_input_messages[-1] or {})
861+
_input_messages = _input_messages.get("content", [])
849862
input_message_list.extend(
850-
[{"role": "user", "content": result["text"]} for result in kwargs["messages"][-1].get("content", [])]
863+
[{"role": "user", "content": result["text"]} for result in _input_messages if "text" in result]
851864
)
852865

853-
output_message_list = [
854-
{"role": "assistant", "content": result["text"]}
855-
for result in response.get("output").get("message").get("content", [])
856-
]
866+
output_message_list = None
867+
if "output" in response:
868+
output_message_list = [
869+
{"role": "assistant", "content": result["text"]}
870+
for result in response.get("output").get("message").get("content", [])
871+
]
857872

858873
bedrock_attrs = {
859874
"request_id": response_headers.get("x-amzn-requestid"),
860875
"model": model,
861876
"span_id": span_id,
862877
"trace_id": trace_id,
863878
"response.choices.finish_reason": response.get("stopReason"),
864-
"output_message_list": output_message_list,
865879
"request.max_tokens": kwargs.get("inferenceConfig", {}).get("maxTokens", None),
866880
"request.temperature": kwargs.get("inferenceConfig", {}).get("temperature", None),
867881
"input_message_list": input_message_list,
868882
}
883+
884+
if output_message_list is not None:
885+
bedrock_attrs["output_message_list"] = output_message_list
886+
869887
return bedrock_attrs
870888

871889

890+
class BedrockRecordEventMixin:
891+
def record_events_on_stop_iteration(self, transaction):
892+
if hasattr(self, "_nr_ft"):
893+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
894+
self._nr_ft.__exit__(None, None, None)
895+
896+
# If there are no bedrock attrs exit early as there's no data to record.
897+
if not bedrock_attrs:
898+
return
899+
900+
try:
901+
bedrock_attrs["duration"] = self._nr_ft.duration * 1000
902+
handle_chat_completion_event(transaction, bedrock_attrs)
903+
except Exception:
904+
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE, exc_info=True)
905+
906+
# Clear cached data as this can be very large.
907+
self._nr_bedrock_attrs.clear()
908+
909+
def record_error(self, transaction, exc):
910+
if hasattr(self, "_nr_ft"):
911+
try:
912+
ft = self._nr_ft
913+
error_attributes = getattr(self, "_nr_bedrock_attrs", {})
914+
915+
# If there are no bedrock attrs exit early as there's no data to record.
916+
if not error_attributes:
917+
return
918+
919+
error_attributes = bedrock_error_attributes(exc, error_attributes)
920+
notice_error_attributes = {
921+
"http.statusCode": error_attributes.get("http.statusCode"),
922+
"error.message": error_attributes.get("error.message"),
923+
"error.code": error_attributes.get("error.code"),
924+
}
925+
notice_error_attributes["completion_id"] = str(uuid.uuid4())
926+
927+
ft.notice_error(attributes=notice_error_attributes)
928+
929+
ft.__exit__(*sys.exc_info())
930+
error_attributes["duration"] = ft.duration * 1000
931+
932+
handle_chat_completion_event(transaction, error_attributes)
933+
934+
# Clear cached data as this can be very large.
935+
error_attributes.clear()
936+
except Exception:
937+
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE, exc_info=True)
938+
939+
def record_stream_chunk(self, event, transaction):
940+
if event:
941+
try:
942+
if getattr(self, "_nr_is_converse", False):
943+
return self.converse_record_stream_chunk(event, transaction)
944+
else:
945+
return self.invoke_record_stream_chunk(event, transaction)
946+
except Exception:
947+
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE, exc_info=True)
948+
949+
def invoke_record_stream_chunk(self, event, transaction):
950+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
951+
chunk = json.loads(event["chunk"]["bytes"].decode("utf-8"))
952+
self._nr_model_extractor(chunk, bedrock_attrs)
953+
# In Langchain, the bedrock iterator exits early if type is "content_block_stop".
954+
# So we need to call the record events here since stop iteration will not be raised.
955+
_type = chunk.get("type")
956+
if _type == "content_block_stop":
957+
self.record_events_on_stop_iteration(transaction)
958+
959+
def converse_record_stream_chunk(self, event, transaction):
960+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
961+
if "contentBlockDelta" in event:
962+
if not bedrock_attrs:
963+
return
964+
965+
content = ((event.get("contentBlockDelta") or {}).get("delta") or {}).get("text", "")
966+
if "output_message_list" not in bedrock_attrs:
967+
bedrock_attrs["output_message_list"] = [{"role": "assistant", "content": ""}]
968+
bedrock_attrs["output_message_list"][0]["content"] += content
969+
970+
if "messageStop" in event:
971+
bedrock_attrs["response.choices.finish_reason"] = (event.get("messageStop") or {}).get("stopReason", "")
972+
973+
872974
class EventStreamWrapper(ObjectProxy):
873975
def __iter__(self):
874976
g = GeneratorProxy(self.__wrapped__.__iter__())
875977
g._nr_ft = getattr(self, "_nr_ft", None)
876978
g._nr_bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
877979
g._nr_model_extractor = getattr(self, "_nr_model_extractor", NULL_EXTRACTOR)
980+
g._nr_is_converse = getattr(self, "_nr_is_converse", False)
878981
return g
879982

880983

881-
class GeneratorProxy(ObjectProxy):
984+
class GeneratorProxy(BedrockRecordEventMixin, ObjectProxy):
882985
def __init__(self, wrapped):
883986
super().__init__(wrapped)
884987

@@ -893,12 +996,12 @@ def __next__(self):
893996
return_val = None
894997
try:
895998
return_val = self.__wrapped__.__next__()
896-
record_stream_chunk(self, return_val, transaction)
999+
self.record_stream_chunk(return_val, transaction)
8971000
except StopIteration:
898-
record_events_on_stop_iteration(self, transaction)
1001+
self.record_events_on_stop_iteration(transaction)
8991002
raise
9001003
except Exception as exc:
901-
record_error(self, transaction, exc)
1004+
self.record_error(transaction, exc)
9021005
raise
9031006
return return_val
9041007

@@ -912,13 +1015,11 @@ def __aiter__(self):
9121015
g._nr_ft = getattr(self, "_nr_ft", None)
9131016
g._nr_bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
9141017
g._nr_model_extractor = getattr(self, "_nr_model_extractor", NULL_EXTRACTOR)
1018+
g._nr_is_converse = getattr(self, "_nr_is_converse", False)
9151019
return g
9161020

9171021

918-
class AsyncGeneratorProxy(ObjectProxy):
919-
def __init__(self, wrapped):
920-
super().__init__(wrapped)
921-
1022+
class AsyncGeneratorProxy(BedrockRecordEventMixin, ObjectProxy):
9221023
def __aiter__(self):
9231024
return self
9241025

@@ -929,83 +1030,19 @@ async def __anext__(self):
9291030
return_val = None
9301031
try:
9311032
return_val = await self.__wrapped__.__anext__()
932-
record_stream_chunk(self, return_val, transaction)
1033+
self.record_stream_chunk(return_val, transaction)
9331034
except StopAsyncIteration:
934-
record_events_on_stop_iteration(self, transaction)
1035+
self.record_events_on_stop_iteration(transaction)
9351036
raise
9361037
except Exception as exc:
937-
record_error(self, transaction, exc)
1038+
self.record_error(transaction, exc)
9381039
raise
9391040
return return_val
9401041

9411042
async def aclose(self):
9421043
return await super().aclose()
9431044

9441045

945-
def record_stream_chunk(self, return_val, transaction):
946-
if return_val:
947-
try:
948-
chunk = json.loads(return_val["chunk"]["bytes"].decode("utf-8"))
949-
self._nr_model_extractor(chunk, self._nr_bedrock_attrs)
950-
# In Langchain, the bedrock iterator exits early if type is "content_block_stop".
951-
# So we need to call the record events here since stop iteration will not be raised.
952-
_type = chunk.get("type")
953-
if _type == "content_block_stop":
954-
record_events_on_stop_iteration(self, transaction)
955-
except Exception:
956-
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE, exc_info=True)
957-
958-
959-
def record_events_on_stop_iteration(self, transaction):
960-
if hasattr(self, "_nr_ft"):
961-
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
962-
self._nr_ft.__exit__(None, None, None)
963-
964-
# If there are no bedrock attrs exit early as there's no data to record.
965-
if not bedrock_attrs:
966-
return
967-
968-
try:
969-
bedrock_attrs["duration"] = self._nr_ft.duration * 1000
970-
handle_chat_completion_event(transaction, bedrock_attrs)
971-
except Exception:
972-
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE, exc_info=True)
973-
974-
# Clear cached data as this can be very large.
975-
self._nr_bedrock_attrs.clear()
976-
977-
978-
def record_error(self, transaction, exc):
979-
if hasattr(self, "_nr_ft"):
980-
try:
981-
ft = self._nr_ft
982-
error_attributes = getattr(self, "_nr_bedrock_attrs", {})
983-
984-
# If there are no bedrock attrs exit early as there's no data to record.
985-
if not error_attributes:
986-
return
987-
988-
error_attributes = bedrock_error_attributes(exc, error_attributes)
989-
notice_error_attributes = {
990-
"http.statusCode": error_attributes.get("http.statusCode"),
991-
"error.message": error_attributes.get("error.message"),
992-
"error.code": error_attributes.get("error.code"),
993-
}
994-
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
995-
996-
ft.notice_error(attributes=notice_error_attributes)
997-
998-
ft.__exit__(*sys.exc_info())
999-
error_attributes["duration"] = ft.duration * 1000
1000-
1001-
handle_chat_completion_event(transaction, error_attributes)
1002-
1003-
# Clear cached data as this can be very large.
1004-
error_attributes.clear()
1005-
except Exception:
1006-
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE, exc_info=True)
1007-
1008-
10091046
def handle_embedding_event(transaction, bedrock_attrs):
10101047
embedding_id = str(uuid.uuid4())
10111048

@@ -1551,6 +1588,7 @@ def wrap_serialize_to_request(wrapped, instance, args, kwargs):
15511588
response_streaming=True
15521589
),
15531590
("bedrock-runtime", "converse"): wrap_bedrock_runtime_converse(response_streaming=False),
1591+
("bedrock-runtime", "converse_stream"): wrap_bedrock_runtime_converse(response_streaming=True),
15541592
}
15551593

15561594

0 commit comments

Comments
 (0)