diff --git a/src/chat/src/MessageNormalizer.php b/src/chat/src/MessageNormalizer.php index e55eda9da..dce2a6a6c 100644 --- a/src/chat/src/MessageNormalizer.php +++ b/src/chat/src/MessageNormalizer.php @@ -28,6 +28,9 @@ use Symfony\Component\Serializer\Exception\InvalidArgumentException; use Symfony\Component\Serializer\Normalizer\DenormalizerInterface; use Symfony\Component\Serializer\Normalizer\NormalizerInterface; +use Symfony\Component\Uid\AbstractUid; +use Symfony\Component\Uid\TimeBasedUidInterface; +use Symfony\Component\Uid\Uuid; /** * @author Guillaume Loulier @@ -71,12 +74,17 @@ public function denormalize(mixed $data, string $type, ?string $format = null, a default => throw new LogicException(\sprintf('Unknown message type "%s".', $type)), }; - $message->getMetadata()->set([ + /** @var AbstractUid&TimeBasedUidInterface&Uuid $existingUuid */ + $existingUuid = Uuid::fromString($data['id']); + + $messageWithExistingUuid = $message->withId($existingUuid); + + $messageWithExistingUuid->getMetadata()->set([ ...$data['metadata'], 'addedAt' => $data['addedAt'], ]); - return $message; + return $messageWithExistingUuid; } public function supportsDenormalization(mixed $data, string $type, ?string $format = null, array $context = []): bool diff --git a/src/chat/tests/MessageNormalizerTest.php b/src/chat/tests/MessageNormalizerTest.php index a96a0192d..a7c15a03b 100644 --- a/src/chat/tests/MessageNormalizerTest.php +++ b/src/chat/tests/MessageNormalizerTest.php @@ -54,10 +54,11 @@ public function testItCanNormalize() public function testItCanDenormalize() { + $uuid = Uuid::v7()->toRfc4122(); $normalizer = new MessageNormalizer(); $message = $normalizer->denormalize([ - 'id' => Uuid::v7()->toRfc4122(), + 'id' => $uuid, 'type' => UserMessage::class, 'content' => '', 'contentAsBase64' => [ @@ -71,6 +72,7 @@ public function testItCanDenormalize() 'addedAt' => (new \DateTimeImmutable())->getTimestamp(), ], MessageInterface::class); + $this->assertSame($uuid, $message->getId()->toRfc4122()); $this->assertSame(Role::User, $message->getRole()); $this->assertArrayHasKey('addedAt', $message->getMetadata()->all()); } diff --git a/src/platform/src/Message/AssistantMessage.php b/src/platform/src/Message/AssistantMessage.php index 3bb7370b2..f6e6b0e20 100644 --- a/src/platform/src/Message/AssistantMessage.php +++ b/src/platform/src/Message/AssistantMessage.php @@ -13,8 +13,6 @@ use Symfony\AI\Platform\Metadata\MetadataAwareTrait; use Symfony\AI\Platform\Result\ToolCall; -use Symfony\Component\Uid\AbstractUid; -use Symfony\Component\Uid\TimeBasedUidInterface; use Symfony\Component\Uid\Uuid; /** @@ -22,10 +20,9 @@ */ final class AssistantMessage implements MessageInterface { + use IdentifierAwareTrait; use MetadataAwareTrait; - private readonly AbstractUid&TimeBasedUidInterface $id; - /** * @param ?ToolCall[] $toolCalls */ @@ -41,11 +38,6 @@ public function getRole(): Role return Role::Assistant; } - public function getId(): AbstractUid&TimeBasedUidInterface - { - return $this->id; - } - public function hasToolCalls(): bool { return null !== $this->toolCalls && [] !== $this->toolCalls; diff --git a/src/platform/src/Message/IdentifierAwareTrait.php b/src/platform/src/Message/IdentifierAwareTrait.php new file mode 100644 index 000000000..0885be3fc --- /dev/null +++ b/src/platform/src/Message/IdentifierAwareTrait.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +use Symfony\Component\Uid\AbstractUid; +use Symfony\Component\Uid\TimeBasedUidInterface; + +/** + * @author Guillaume Loulier + */ +trait IdentifierAwareTrait +{ + private AbstractUid&TimeBasedUidInterface $id; + + public function withId(AbstractUid&TimeBasedUidInterface $id): self + { + $clone = clone $this; + $clone->id = $id; + + return $clone; + } + + public function getId(): AbstractUid&TimeBasedUidInterface + { + return $this->id; + } +} diff --git a/src/platform/src/Message/MessageInterface.php b/src/platform/src/Message/MessageInterface.php index 7ca9229c6..88c891a50 100644 --- a/src/platform/src/Message/MessageInterface.php +++ b/src/platform/src/Message/MessageInterface.php @@ -25,6 +25,8 @@ public function getRole(): Role; public function getId(): AbstractUid&TimeBasedUidInterface; + public function withId(AbstractUid&TimeBasedUidInterface $id): self; + /** * @return string|ContentInterface[]|null */ diff --git a/src/platform/src/Message/SystemMessage.php b/src/platform/src/Message/SystemMessage.php index d0650773a..a5b0b0a89 100644 --- a/src/platform/src/Message/SystemMessage.php +++ b/src/platform/src/Message/SystemMessage.php @@ -12,8 +12,6 @@ namespace Symfony\AI\Platform\Message; use Symfony\AI\Platform\Metadata\MetadataAwareTrait; -use Symfony\Component\Uid\AbstractUid; -use Symfony\Component\Uid\TimeBasedUidInterface; use Symfony\Component\Uid\Uuid; /** @@ -21,10 +19,9 @@ */ final class SystemMessage implements MessageInterface { + use IdentifierAwareTrait; use MetadataAwareTrait; - private readonly AbstractUid&TimeBasedUidInterface $id; - public function __construct( private readonly string $content, ) { @@ -36,11 +33,6 @@ public function getRole(): Role return Role::System; } - public function getId(): AbstractUid&TimeBasedUidInterface - { - return $this->id; - } - public function getContent(): string { return $this->content; diff --git a/src/platform/src/Message/ToolCallMessage.php b/src/platform/src/Message/ToolCallMessage.php index c664c6b7b..348c8cc76 100644 --- a/src/platform/src/Message/ToolCallMessage.php +++ b/src/platform/src/Message/ToolCallMessage.php @@ -13,8 +13,6 @@ use Symfony\AI\Platform\Metadata\MetadataAwareTrait; use Symfony\AI\Platform\Result\ToolCall; -use Symfony\Component\Uid\AbstractUid; -use Symfony\Component\Uid\TimeBasedUidInterface; use Symfony\Component\Uid\Uuid; /** @@ -22,10 +20,9 @@ */ final class ToolCallMessage implements MessageInterface { + use IdentifierAwareTrait; use MetadataAwareTrait; - private readonly AbstractUid&TimeBasedUidInterface $id; - public function __construct( private readonly ToolCall $toolCall, private readonly string $content, @@ -38,11 +35,6 @@ public function getRole(): Role return Role::ToolCall; } - public function getId(): AbstractUid&TimeBasedUidInterface - { - return $this->id; - } - public function getToolCall(): ToolCall { return $this->toolCall; diff --git a/src/platform/src/Message/UserMessage.php b/src/platform/src/Message/UserMessage.php index 445af64c3..a4691f6e0 100644 --- a/src/platform/src/Message/UserMessage.php +++ b/src/platform/src/Message/UserMessage.php @@ -17,8 +17,6 @@ use Symfony\AI\Platform\Message\Content\ImageUrl; use Symfony\AI\Platform\Message\Content\Text; use Symfony\AI\Platform\Metadata\MetadataAwareTrait; -use Symfony\Component\Uid\AbstractUid; -use Symfony\Component\Uid\TimeBasedUidInterface; use Symfony\Component\Uid\Uuid; /** @@ -26,6 +24,7 @@ */ final class UserMessage implements MessageInterface { + use IdentifierAwareTrait; use MetadataAwareTrait; /** @@ -33,8 +32,6 @@ final class UserMessage implements MessageInterface */ private readonly array $content; - private readonly AbstractUid&TimeBasedUidInterface $id; - public function __construct( ContentInterface ...$content, ) { @@ -47,11 +44,6 @@ public function getRole(): Role return Role::User; } - public function getId(): AbstractUid&TimeBasedUidInterface - { - return $this->id; - } - /** * @return ContentInterface[] */ diff --git a/src/platform/tests/ContractTest.php b/src/platform/tests/ContractTest.php index 64bf3ee5d..3f6f23120 100644 --- a/src/platform/tests/ContractTest.php +++ b/src/platform/tests/ContractTest.php @@ -23,15 +23,13 @@ use Symfony\AI\Platform\Message\Content\Image; use Symfony\AI\Platform\Message\Content\ImageUrl; use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\IdentifierAwareTrait; use Symfony\AI\Platform\Message\Message; use Symfony\AI\Platform\Message\MessageBag; use Symfony\AI\Platform\Message\MessageInterface; use Symfony\AI\Platform\Message\Role; use Symfony\AI\Platform\Metadata\MetadataAwareTrait; use Symfony\AI\Platform\Model; -use Symfony\Component\Uid\AbstractUid; -use Symfony\Component\Uid\TimeBasedUidInterface; -use Symfony\Component\Uid\Uuid; final class ContractTest extends TestCase { @@ -174,6 +172,7 @@ public static function providePayloadTestCases(): iterable ]; $customSerializableMessage = new class implements MessageInterface, \JsonSerializable { + use IdentifierAwareTrait; use MetadataAwareTrait; public function getRole(): Role @@ -181,11 +180,6 @@ public function getRole(): Role return Role::User; } - public function getId(): AbstractUid&TimeBasedUidInterface - { - return Uuid::v7(); - } - public function getContent(): array { return [new Text('This is a custom serializable message.')];