Skip to content

Commit 0d679c8

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-5841 Support function calling when stream=True
1 parent cc72ddd commit 0d679c8

File tree

1 file changed

+46
-8
lines changed

1 file changed

+46
-8
lines changed

openlayer/llm_monitors.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,10 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
156156
}
157157
else:
158158
function_call = {
159-
"name": output_tool_calls[0].name,
160-
"arguments": json.loads(output_function_call.arguments),
159+
"name": output_tool_calls[0].function.name,
160+
"arguments": json.loads(
161+
output_tool_calls[0].function.arguments
162+
),
161163
}
162164
output_data = function_call
163165
else:
@@ -193,17 +195,47 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
193195
chunks = self.create_chat_completion(*args, **kwargs)
194196

195197
def stream_chunks():
196-
collected_messages = []
198+
collected_output_data = []
199+
collected_function_call = {
200+
"name": "",
201+
"arguments": "",
202+
}
203+
raw_outputs = []
197204
start_time = time.time()
205+
end_time = None
198206
first_token_time = None
199207
num_of_completion_tokens = None
200208
latency = None
201209
try:
202210
i = 0
203211
for i, chunk in enumerate(chunks):
212+
raw_outputs.append(chunk.model_dump())
204213
if i == 0:
205214
first_token_time = time.time()
206-
collected_messages.append(chunk.choices[0].delta.content)
215+
216+
delta = chunk.choices[0].delta
217+
218+
if delta.content:
219+
collected_output_data.append(delta.content)
220+
elif delta.function_call:
221+
if delta.function_call.name:
222+
collected_function_call[
223+
"name"
224+
] += delta.function_call.name
225+
if delta.function_call.arguments:
226+
collected_function_call[
227+
"arguments"
228+
] += delta.function_call.arguments
229+
elif delta.tool_calls:
230+
if delta.tool_calls[0].function.name:
231+
collected_function_call["name"] += delta.tool_calls[
232+
0
233+
].function.name
234+
if delta.tool_calls[0].function.arguments:
235+
collected_function_call[
236+
"arguments"
237+
] += delta.tool_calls[0].function.arguments
238+
207239
yield chunk
208240
if i > 0:
209241
num_of_completion_tokens = i + 1
@@ -215,12 +247,18 @@ def stream_chunks():
215247
finally:
216248
# Try to add step to the trace
217249
try:
218-
collected_messages = [
250+
collected_output_data = [
219251
message
220-
for message in collected_messages
252+
for message in collected_output_data
221253
if message is not None
222254
]
223-
output_data = "".join(collected_messages)
255+
if collected_output_data:
256+
output_data = "".join(collected_output_data)
257+
else:
258+
collected_function_call["arguments"] = json.loads(
259+
collected_function_call["arguments"]
260+
)
261+
output_data = collected_function_call
224262
completion_cost = self.get_cost_estimate(
225263
model=kwargs.get("model"),
226264
num_input_tokens=0,
@@ -244,7 +282,7 @@ def stream_chunks():
244282
completion_tokens=num_of_completion_tokens,
245283
model=kwargs.get("model"),
246284
model_parameters=kwargs.get("model_parameters"),
247-
raw_output=None,
285+
raw_output=raw_outputs,
248286
metadata={
249287
"timeToFirstToken": (
250288
(first_token_time - start_time) * 1000

0 commit comments

Comments
 (0)