@@ -251,145 +251,110 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
251251 stream := g .client .Chat .Completions .NewStreaming (ctx , * g .request )
252252 defer stream .Close ()
253253
254- var fullResponse ai.ModelResponse
255- fullResponse .Message = & ai.Message {
256- Role : ai .RoleModel ,
257- Content : make ([]* ai.Part , 0 ),
258- }
259-
260- // Initialize request and usage
261- fullResponse .Request = & ai.ModelRequest {}
262- fullResponse .Usage = & ai.GenerationUsage {
263- InputTokens : 0 ,
264- OutputTokens : 0 ,
265- TotalTokens : 0 ,
266- }
267-
268- var currentToolCall * ai.ToolRequest
269- var currentArguments string
270- var toolCallCollects []struct {
271- toolCall * ai.ToolRequest
272- args string
273- }
254+ // Use openai-go's accumulator to collect the complete response
255+ acc := & openai.ChatCompletionAccumulator {}
274256
275257 for stream .Next () {
276258 chunk := stream .Current ()
277- if len (chunk .Choices ) > 0 {
278- choice := chunk .Choices [0 ]
279- modelChunk := & ai.ModelResponseChunk {}
280-
281- switch choice .FinishReason {
282- case "tool_calls" , "stop" :
283- fullResponse .FinishReason = ai .FinishReasonStop
284- case "length" :
285- fullResponse .FinishReason = ai .FinishReasonLength
286- case "content_filter" :
287- fullResponse .FinishReason = ai .FinishReasonBlocked
288- case "function_call" :
289- fullResponse .FinishReason = ai .FinishReasonOther
290- default :
291- fullResponse .FinishReason = ai .FinishReasonUnknown
292- }
259+ acc .AddChunk (chunk )
293260
294- // handle tool calls
295- for _ , toolCall := range choice .Delta .ToolCalls {
296- // first tool call (= current tool call is nil) contains the tool call name
297- if currentToolCall != nil && toolCall .ID != "" && currentToolCall .Ref != toolCall .ID {
298- toolCallCollects = append (toolCallCollects , struct {
299- toolCall * ai.ToolRequest
300- args string
301- }{
302- toolCall : currentToolCall ,
303- args : currentArguments ,
304- })
305- currentToolCall = nil
306- currentArguments = ""
307- }
261+ if len (chunk .Choices ) == 0 {
262+ continue
263+ }
308264
309- if currentToolCall == nil {
310- currentToolCall = & ai.ToolRequest {
311- Name : toolCall .Function .Name ,
312- Ref : toolCall .ID ,
313- }
314- }
265+ // Create chunk for callback
266+ modelChunk := & ai.ModelResponseChunk {}
315267
316- if toolCall .Function .Arguments != "" {
317- currentArguments += toolCall .Function .Arguments
318- }
268+ // Handle content delta
269+ if chunk .Choices [0 ].Delta .Content != "" {
270+ modelChunk .Content = append (modelChunk .Content , ai .NewTextPart (chunk .Choices [0 ].Delta .Content ))
271+ }
319272
273+ // Handle tool call deltas
274+ for _ , toolCall := range chunk .Choices [0 ].Delta .ToolCalls {
275+ // Send the incremental tool call part in the chunk
276+ if toolCall .Function .Name != "" || toolCall .Function .Arguments != "" {
320277 modelChunk .Content = append (modelChunk .Content , ai .NewToolRequestPart (& ai.ToolRequest {
321- Name : currentToolCall .Name ,
278+ Name : toolCall . Function .Name ,
322279 Input : toolCall .Function .Arguments ,
323- Ref : currentToolCall . Ref ,
280+ Ref : toolCall . ID ,
324281 }))
325282 }
283+ }
326284
327- // when tool call is complete
328- if choice .FinishReason == "tool_calls" && currentToolCall != nil {
329- // parse accumulated arguments string
330- for _ , toolcall := range toolCallCollects {
331- args , err := jsonStringToMap (toolcall .args )
332- if err != nil {
333- return nil , fmt .Errorf ("could not parse tool args: %w" , err )
334- }
335- toolcall .toolCall .Input = args
336- fullResponse .Message .Content = append (fullResponse .Message .Content , ai .NewToolRequestPart (toolcall .toolCall ))
337- }
338- if currentArguments != "" {
339- args , err := jsonStringToMap (currentArguments )
340- if err != nil {
341- return nil , fmt .Errorf ("could not parse tool args: %w" , err )
342- }
343- currentToolCall .Input = args
344- }
345- fullResponse .Message .Content = append (fullResponse .Message .Content , ai .NewToolRequestPart (currentToolCall ))
346- }
347-
348- content := chunk .Choices [0 ].Delta .Content
349- // when starting a tool call, the content is empty
350- if content != "" {
351- modelChunk .Content = append (modelChunk .Content , ai .NewTextPart (content ))
352- fullResponse .Message .Content = append (fullResponse .Message .Content , modelChunk .Content ... )
353- }
354-
285+ // Call the chunk handler with incremental data
286+ if len (modelChunk .Content ) > 0 {
355287 if err := handleChunk (ctx , modelChunk ); err != nil {
356288 return nil , fmt .Errorf ("callback error: %w" , err )
357289 }
358-
359- fullResponse .Usage .InputTokens += int (chunk .Usage .PromptTokens )
360- fullResponse .Usage .OutputTokens += int (chunk .Usage .CompletionTokens )
361- fullResponse .Usage .TotalTokens += int (chunk .Usage .TotalTokens )
362290 }
363291 }
364292
365293 if err := stream .Err (); err != nil {
366294 return nil , fmt .Errorf ("stream error: %w" , err )
367295 }
368296
369- return & fullResponse , nil
297+ // Convert accumulated ChatCompletion to ai.ModelResponse
298+ return convertChatCompletionToModelResponse (& acc .ChatCompletion )
370299}
371300
372- // generateComplete generates a complete model response
373- func (g * ModelGenerator ) generateComplete (ctx context.Context , req * ai.ModelRequest ) (* ai.ModelResponse , error ) {
374- completion , err := g .client .Chat .Completions .New (ctx , * g .request )
375- if err != nil {
376- return nil , fmt .Errorf ("failed to create completion: %w" , err )
301+ // convertChatCompletionToModelResponse converts openai.ChatCompletion to ai.ModelResponse
302+ func convertChatCompletionToModelResponse (completion * openai.ChatCompletion ) (* ai.ModelResponse , error ) {
303+ if len (completion .Choices ) == 0 {
304+ return nil , fmt .Errorf ("no choices in completion" )
305+ }
306+
307+ choice := completion .Choices [0 ]
308+
309+ // Build usage information with detailed token breakdown
310+ usage := & ai.GenerationUsage {
311+ InputTokens : int (completion .Usage .PromptTokens ),
312+ OutputTokens : int (completion .Usage .CompletionTokens ),
313+ TotalTokens : int (completion .Usage .TotalTokens ),
314+ }
315+
316+ // Add reasoning tokens (thoughts tokens) if available
317+ if completion .Usage .CompletionTokensDetails .ReasoningTokens > 0 {
318+ usage .ThoughtsTokens = int (completion .Usage .CompletionTokensDetails .ReasoningTokens )
319+ }
320+
321+ // Add cached tokens if available
322+ if completion .Usage .PromptTokensDetails .CachedTokens > 0 {
323+ usage .CachedContentTokens = int (completion .Usage .PromptTokensDetails .CachedTokens )
324+ }
325+
326+ // Add audio tokens to custom field if available
327+ if completion .Usage .CompletionTokensDetails .AudioTokens > 0 {
328+ if usage .Custom == nil {
329+ usage .Custom = make (map [string ]float64 )
330+ }
331+ usage .Custom ["audioTokens" ] = float64 (completion .Usage .CompletionTokensDetails .AudioTokens )
332+ }
333+
334+ // Add prediction tokens to custom field if available
335+ if completion .Usage .CompletionTokensDetails .AcceptedPredictionTokens > 0 {
336+ if usage .Custom == nil {
337+ usage .Custom = make (map [string ]float64 )
338+ }
339+ usage .Custom ["acceptedPredictionTokens" ] = float64 (completion .Usage .CompletionTokensDetails .AcceptedPredictionTokens )
340+ }
341+ if completion .Usage .CompletionTokensDetails .RejectedPredictionTokens > 0 {
342+ if usage .Custom == nil {
343+ usage .Custom = make (map [string ]float64 )
344+ }
345+ usage .Custom ["rejectedPredictionTokens" ] = float64 (completion .Usage .CompletionTokensDetails .RejectedPredictionTokens )
377346 }
378347
379348 resp := & ai.ModelResponse {
380- Request : req ,
381- Usage : & ai.GenerationUsage {
382- InputTokens : int (completion .Usage .PromptTokens ),
383- OutputTokens : int (completion .Usage .CompletionTokens ),
384- TotalTokens : int (completion .Usage .TotalTokens ),
385- },
349+ Request : & ai.ModelRequest {},
350+ Usage : usage ,
386351 Message : & ai.Message {
387- Role : ai .RoleModel ,
352+ Role : ai .RoleModel ,
353+ Content : make ([]* ai.Part , 0 ),
388354 },
389355 }
390356
391- choice := completion .Choices [0 ]
392-
357+ // Map finish reason
393358 switch choice .FinishReason {
394359 case "stop" , "tool_calls" :
395360 resp .FinishReason = ai .FinishReasonStop
@@ -403,30 +368,57 @@ func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequ
403368 resp .FinishReason = ai .FinishReasonUnknown
404369 }
405370
406- // handle tool calls
407- var toolRequestParts []* ai.Part
371+ // Set finish message if there's a refusal
372+ if choice .Message .Refusal != "" {
373+ resp .FinishMessage = choice .Message .Refusal
374+ resp .FinishReason = ai .FinishReasonBlocked
375+ }
376+
377+ // Add text content
378+ if choice .Message .Content != "" {
379+ resp .Message .Content = append (resp .Message .Content , ai .NewTextPart (choice .Message .Content ))
380+ }
381+
382+ // Add tool calls
408383 for _ , toolCall := range choice .Message .ToolCalls {
409384 args , err := jsonStringToMap (toolCall .Function .Arguments )
410385 if err != nil {
411- return nil , err
386+ return nil , fmt . Errorf ( "could not parse tool args: %w" , err )
412387 }
413- toolRequestParts = append (toolRequestParts , ai .NewToolRequestPart (& ai.ToolRequest {
388+ resp . Message . Content = append (resp . Message . Content , ai .NewToolRequestPart (& ai.ToolRequest {
414389 Ref : toolCall .ID ,
415390 Name : toolCall .Function .Name ,
416391 Input : args ,
417392 }))
418393 }
419394
420- // content and tool call may exist simultaneously
421- if completion .Choices [0 ].Message .Content != "" {
422- resp .Message .Content = append (resp .Message .Content , ai .NewTextPart (completion .Choices [0 ].Message .Content ))
395+ // Store additional metadata in custom field if needed
396+ if completion .SystemFingerprint != "" {
397+ resp .Custom = map [string ]any {
398+ "systemFingerprint" : completion .SystemFingerprint ,
399+ "model" : completion .Model ,
400+ "id" : completion .ID ,
401+ }
402+ }
403+
404+ return resp , nil
405+ }
406+
407+ // generateComplete generates a complete model response
408+ func (g * ModelGenerator ) generateComplete (ctx context.Context , req * ai.ModelRequest ) (* ai.ModelResponse , error ) {
409+ completion , err := g .client .Chat .Completions .New (ctx , * g .request )
410+ if err != nil {
411+ return nil , fmt .Errorf ("failed to create completion: %w" , err )
423412 }
424413
425- if len ( toolRequestParts ) > 0 {
426- resp . Message . Content = append ( resp . Message . Content , toolRequestParts ... )
427- return resp , nil
414+ resp , err := convertChatCompletionToModelResponse ( completion )
415+ if err != nil {
416+ return nil , err
428417 }
429418
419+ // Set the original request
420+ resp .Request = req
421+
430422 return resp , nil
431423}
432424
0 commit comments