Skip to content

Commit 67bc07d

Browse files
committed
Refactor StreamResult to use PlatformStreamResult and improve metadata handling
1 parent 8a7acca commit 67bc07d

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

src/agent/src/Toolbox/StreamResult.php

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
use Symfony\AI\Platform\Message\Message;
1515
use Symfony\AI\Platform\Result\BaseResult;
16+
use Symfony\AI\Platform\Result\StreamResult as PlatformStreamResult;
1617
use Symfony\AI\Platform\Result\ToolCallResult;
1718

1819
/**
@@ -21,15 +22,15 @@
2122
final class StreamResult extends BaseResult
2223
{
2324
public function __construct(
24-
private readonly \Generator $generator,
25+
private readonly PlatformStreamResult $sourceStreamResult,
2526
private readonly \Closure $handleToolCallsCallback,
2627
) {
2728
}
2829

2930
public function getContent(): \Generator
3031
{
3132
$streamedResult = '';
32-
foreach ($this->generator as $value) {
33+
foreach ($this->sourceStreamResult->getContent() as $value) {
3334
if ($value instanceof ToolCallResult) {
3435
$innerResult = ($this->handleToolCallsCallback)($value, Message::ofAssistant($streamedResult));
3536

@@ -48,17 +49,31 @@ public function getContent(): \Generator
4849
yield from $content;
4950
}
5051

51-
break;
52-
}
52+
if ($innerResult->getMetadata()->has('calls')) {
53+
$innerCalls = $innerResult->getMetadata()->get('calls');
54+
$previousCalls = $this->getMetadata()->get('calls', []);
55+
$calls = array_merge($previousCalls, $innerCalls);
56+
} else {
57+
$calls[] = $innerResult->getMetadata()->all();
58+
}
5359

54-
if (!\is_string($value)) {
55-
yield $value;
56-
break;
60+
if ($calls !== ['calls' => []]) {
61+
$this->getMetadata()->add('calls', $calls);
62+
}
63+
64+
continue;
5765
}
5866

5967
$streamedResult .= $value;
6068

6169
yield $value;
70+
6271
}
72+
73+
// Attach the metadata from the platform stream to the agent after the stream has been fully processed
74+
// and the post-result metadata, such as usage, has been received.
75+
$calls = $this->getMetadata()->get('calls', []);
76+
$calls[] = $this->sourceStreamResult->getMetadata()->all();
77+
$this->getMetadata()->add('calls', $calls);
6378
}
6479
}

0 commit comments

Comments
 (0)