From 8ce851d8c0c823aa948a716e2be3570dc2e02250 Mon Sep 17 00:00:00 2001 From: James LePage <36246732+Jameswlepage@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:01:36 -0500 Subject: [PATCH 1/2] feat: add embedding generation support --- README.md | 26 ++ cli.php | 94 ++++-- src/AiClient.php | 56 +++ src/Builders/PromptBuilder.php | 150 ++++++++- src/Embeddings/DTO/Embedding.php | 160 +++++++++ src/Operations/DTO/EmbeddingOperation.php | 185 ++++++++++ .../OpenAi/OpenAiEmbeddingModel.php | 185 ++++++++++ .../OpenAi/OpenAiModelMetadataDirectory.php | 18 +- .../OpenAi/OpenAiProvider.php | 3 + .../WithEmbeddingOperationsInterface.php | 25 ++ .../Models/DTO/ModelRequirements.php | 2 +- .../EmbeddingGenerationModelInterface.php | 26 ++ ...ddingGenerationOperationModelInterface.php | 26 ++ src/Results/DTO/EmbeddingResult.php | 318 ++++++++++++++++++ tests/mocks/MockProvider.php | 8 +- tests/traits/MockModelCreationTrait.php | 199 +++++++++++ tests/unit/AiClientTest.php | 28 ++ tests/unit/Builders/PromptBuilderTest.php | 57 ++++ tests/unit/Embeddings/DTO/EmbeddingTest.php | 103 ++++++ .../Operations/DTO/EmbeddingOperationTest.php | 107 ++++++ .../OpenAi/OpenAiEmbeddingModelTest.php | 77 +++++ .../Models/DTO/ModelRequirementsTest.php | 23 ++ tests/unit/Providers/ProviderRegistryTest.php | 17 + .../unit/Results/DTO/EmbeddingResultTest.php | 167 +++++++++ 24 files changed, 2032 insertions(+), 28 deletions(-) create mode 100644 src/Embeddings/DTO/Embedding.php create mode 100644 src/Operations/DTO/EmbeddingOperation.php create mode 100644 src/ProviderImplementations/OpenAi/OpenAiEmbeddingModel.php create mode 100644 src/Providers/Models/Contracts/WithEmbeddingOperationsInterface.php create mode 100644 src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php create mode 100644 src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationOperationModelInterface.php create mode 100644 src/Results/DTO/EmbeddingResult.php create mode 100644 tests/unit/Embeddings/DTO/EmbeddingTest.php create mode 100644 tests/unit/Operations/DTO/EmbeddingOperationTest.php create mode 100644 tests/unit/ProviderImplementations/OpenAi/OpenAiEmbeddingModelTest.php create mode 100644 tests/unit/Results/DTO/EmbeddingResultTest.php diff --git a/README.md b/README.md index 8c4bb32a..d3b53cd0 100644 --- a/README.md +++ b/README.md @@ -76,10 +76,36 @@ $imageFile = AiClient::prompt('Generate an illustration of the PHP elephant in t ->generateImage(); ``` +### Embedding generation using any compatible model + +```php +use WordPress\AiClient\AiClient; + +$vectors = AiClient::prompt() + ->withEmbeddingInputs('Summarize this document', 'Summarize that document') + ->generateEmbeddings(); + +// Or work with the detailed result object: +$result = AiClient::prompt(['Embed this input']) + ->generateEmbeddingsResult(); +``` + See the [`PromptBuilder` class](https://github.com/WordPress/php-ai-client/blob/trunk/src/Builders/PromptBuilder.php) and its public methods for all the ways you can configure the prompt. **More documentation is coming soon.** +## CLI usage + +This repository ships with a thin CLI wrapper for quick experiments: + +``` +php cli.php 'Explain WordPress in one sentence' +php cli.php 'Create a postcard photo of the WordPress logo' --outputFormat=image-json +php cli.php '["Embed this document", "And this one"]' --capability=embeddings --outputFormat=embeddings-vectors +``` + +Available embedding output formats are `embeddings-vectors` (default), `embedding-first-vector`, and `embeddings-json`. Use `--capability=embeddings` to explicitly request embeddings while still supporting the existing image/text detection flags. + ## Further reading For more information on the requirements and guiding principles, please review: diff --git a/cli.php b/cli.php index 4716e131..75c86013 100755 --- a/cli.php +++ b/cli.php @@ -9,6 +9,7 @@ * GOOGLE_API_KEY=123456 php cli.php 'Your prompt here' --providerId=google --modelId=gemini-2.5-flash * OPENAI_API_KEY=123456 php cli.php 'Your prompt here' --providerId=openai * GOOGLE_API_KEY=123456 OPENAI_API_KEY=123456 php cli.php 'Your prompt here' + * OPENAI_API_KEY=123456 php cli.php '["Embed this", "and this"]' --capability=embeddings --outputFormat=embeddings-json */ declare(strict_types=1); @@ -97,11 +98,32 @@ function logError(string $message, int $exit_code = 1): void } } -// Provider ID, model ID, and output format. +// Provider ID, model ID, and capability/output format. $providerId = $named_args['providerId'] ?? null; $modelId = $named_args['modelId'] ?? null; $modelPreference = $named_args['modelPreference'] ?? null; -$outputFormat = $named_args['outputFormat'] ?? 'message-text'; +$capabilityInput = $named_args['capability'] ?? 'text'; +$capability = is_string($capabilityInput) ? strtolower($capabilityInput) : 'text'; +$validCapabilities = ['text', 'image', 'embeddings']; +if (!in_array($capability, $validCapabilities, true)) { + logWarning(sprintf('Invalid capability "%s". Defaulting to "text".', (string) $capabilityInput)); + $capability = 'text'; +} +$defaultOutputFormat = 'message-text'; +if ($capability === 'image') { + $defaultOutputFormat = 'image-json'; +} elseif ($capability === 'embeddings') { + $defaultOutputFormat = 'embeddings-vectors'; +} +$outputFormat = $named_args['outputFormat'] ?? $defaultOutputFormat; +$imageOutputFormats = ['image-json', 'image-base64']; +if ($capability === 'embeddings' && in_array($outputFormat, $imageOutputFormats, true)) { + logWarning('Image output formats are not supported for embeddings. Using embeddings-vectors.'); + $outputFormat = 'embeddings-vectors'; +} +if ($capability !== 'embeddings' && in_array($outputFormat, $imageOutputFormats, true)) { + $capability = 'image'; +} // Any model configuration options. $schema = ModelConfig::getJsonSchema()['properties']; @@ -142,7 +164,8 @@ function logError(string $message, int $exit_code = 1): void try { $modelConfig = ModelConfig::fromArray($model_config_data); - $promptBuilder = AiClient::prompt($promptInput); + $initialPrompt = $capability === 'embeddings' ? null : $promptInput; + $promptBuilder = AiClient::prompt($initialPrompt); $promptBuilder = $promptBuilder->usingModelConfig($modelConfig); if ($providerId && $modelId) { $providerClassName = AiClient::defaultRegistry()->getProviderClassName($providerId); @@ -163,6 +186,9 @@ static function ($item) { ); $promptBuilder = $promptBuilder->usingModelPreference(...$modelPreference); } + if ($capability === 'embeddings') { + $promptBuilder = $promptBuilder->withEmbeddingInputs($promptInput); + } } catch (InvalidArgumentException $e) { logError('Invalid arguments while trying to set up prompt builder: ' . $e->getMessage()); } catch (ResponseException $e) { @@ -170,36 +196,60 @@ static function ($item) { } try { - if ($outputFormat === 'image-json' || $outputFormat === 'image-base64') { + $generationAction = 'generate text result'; + if ($capability === 'image') { + $generationAction = 'generate image result'; + } elseif ($capability === 'embeddings') { + $generationAction = 'generate embeddings result'; + } + + if ($capability === 'image') { $result = $promptBuilder->generateImageResult(); + } elseif ($capability === 'embeddings') { + $result = $promptBuilder->generateEmbeddingsResult(); } else { $result = $promptBuilder->generateTextResult(); } } catch (InvalidArgumentException $e) { - logError('Invalid arguments while trying to generate text result: ' . $e->getMessage()); + logError('Invalid arguments while trying to ' . $generationAction . ': ' . $e->getMessage()); } catch (ResponseException $e) { - logError('Request failed while trying to generate text result: ' . $e->getMessage()); + logError('Request failed while trying to ' . $generationAction . ': ' . $e->getMessage()); } logInfo("Using provider ID: \"{$result->getProviderMetadata()->getId()}\""); logInfo("Using model ID: \"{$result->getModelMetadata()->getId()}\""); -switch ($outputFormat) { - case 'result-json': - $output = json_encode($result, JSON_PRETTY_PRINT); - break; - case 'candidates-json': - $output = json_encode($result->getCandidates(), JSON_PRETTY_PRINT); - break; - case 'image-json': - $output = json_encode($result->toFile(), JSON_PRETTY_PRINT); - break; - case 'image-base64': - $output = $result->toFile()->getBase64Data(); - break; - case 'message-text': - default: - $output = $result->toText(); +if ($capability === 'embeddings') { + switch ($outputFormat) { + case 'embeddings-json': + $output = json_encode($result, JSON_PRETTY_PRINT); + break; + case 'embedding-first-vector': + $output = json_encode($result->toVector(), JSON_PRETTY_PRINT); + break; + case 'embeddings-vectors': + default: + $output = json_encode($result->toVectors(), JSON_PRETTY_PRINT); + break; + } +} else { + switch ($outputFormat) { + case 'result-json': + $output = json_encode($result, JSON_PRETTY_PRINT); + break; + case 'candidates-json': + $output = json_encode($result->getCandidates(), JSON_PRETTY_PRINT); + break; + case 'image-json': + $output = json_encode($result->toFile(), JSON_PRETTY_PRINT); + break; + case 'image-base64': + $output = $result->toFile()->getBase64Data(); + break; + case 'message-text': + default: + $output = $result->toText(); + } } printOutput($output); diff --git a/src/AiClient.php b/src/AiClient.php index ab9ceadd..7101645d 100644 --- a/src/AiClient.php +++ b/src/AiClient.php @@ -7,6 +7,7 @@ use WordPress\AiClient\Builders\PromptBuilder; use WordPress\AiClient\Common\Exception\InvalidArgumentException; use WordPress\AiClient\Common\Exception\RuntimeException; +use WordPress\AiClient\Operations\DTO\EmbeddingOperation; use WordPress\AiClient\ProviderImplementations\Anthropic\AnthropicProvider; use WordPress\AiClient\ProviderImplementations\Google\GoogleProvider; use WordPress\AiClient\ProviderImplementations\OpenAi\OpenAiProvider; @@ -16,6 +17,7 @@ use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\ProviderRegistry; +use WordPress\AiClient\Results\DTO\EmbeddingResult; use WordPress\AiClient\Results\DTO\GenerativeAiResult; /** @@ -298,6 +300,60 @@ public static function generateSpeechResult( return self::getConfiguredPromptBuilder($prompt, $modelOrConfig, $registry)->generateSpeechResult(); } + /** + * Generates embeddings for the given input using the traditional API. + * + * @since 0.2.0 + * + * @param Prompt $input The input to embed. Accepts strings, messages, or arrays of those types. + * @param ModelInterface|ModelConfig|null $modelOrConfig Optional model specification. + * @param ProviderRegistry|null $registry Optional custom registry. + * @return EmbeddingResult The embeddings result. + */ + public static function generateEmbeddingsResult( + $input, + $modelOrConfig = null, + ?ProviderRegistry $registry = null + ): EmbeddingResult { + self::validateModelOrConfigParameter($modelOrConfig); + $builder = self::prompt(null, $registry); + if ($modelOrConfig instanceof ModelInterface) { + $builder->usingModel($modelOrConfig); + } elseif ($modelOrConfig instanceof ModelConfig) { + $builder->usingModelConfig($modelOrConfig); + } + + $builder->withEmbeddingInputs($input); + return $builder->generateEmbeddingsResult(); + } + + /** + * Generates an embeddings operation for asynchronous processing. + * + * @since 0.2.0 + * + * @param Prompt $input The input to embed. + * @param ModelInterface|ModelConfig|null $modelOrConfig Optional model specification. + * @param ProviderRegistry|null $registry Optional custom registry. + * @return EmbeddingOperation The embeddings operation. + */ + public static function generateEmbeddingsOperation( + $input, + $modelOrConfig = null, + ?ProviderRegistry $registry = null + ): EmbeddingOperation { + self::validateModelOrConfigParameter($modelOrConfig); + $builder = self::prompt(null, $registry); + if ($modelOrConfig instanceof ModelInterface) { + $builder->usingModel($modelOrConfig); + } elseif ($modelOrConfig instanceof ModelConfig) { + $builder->usingModelConfig($modelOrConfig); + } + + $builder->withEmbeddingInputs($input); + return $builder->generateEmbeddingsOperation(); + } + /** * Creates a new message builder for fluent API usage. * diff --git a/src/Builders/PromptBuilder.php b/src/Builders/PromptBuilder.php index 047fb4ec..aad7858f 100644 --- a/src/Builders/PromptBuilder.php +++ b/src/Builders/PromptBuilder.php @@ -13,17 +13,21 @@ use WordPress\AiClient\Messages\DTO\UserMessage; use WordPress\AiClient\Messages\Enums\MessageRoleEnum; use WordPress\AiClient\Messages\Enums\ModalityEnum; +use WordPress\AiClient\Operations\DTO\EmbeddingOperation; use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\DTO\ModelRequirements; use WordPress\AiClient\Providers\Models\DTO\RequiredOption; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; +use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationModelInterface; +use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationOperationModelInterface; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\SpeechGeneration\Contracts\SpeechGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextToSpeechConversion\Contracts\TextToSpeechConversionModelInterface; use WordPress\AiClient\Providers\ProviderRegistry; +use WordPress\AiClient\Results\DTO\EmbeddingResult; use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Tools\DTO\FunctionDeclaration; use WordPress\AiClient\Tools\DTO\FunctionResponse; @@ -55,6 +59,11 @@ class PromptBuilder */ protected array $messages = []; + /** + * @var list Embedding-specific input messages. + */ + protected array $embeddingInputs = []; + /** * @var ModelInterface|null The model to use for generation. */ @@ -196,6 +205,25 @@ public function withHistory(Message ...$messages): self return $this; } + /** + * Adds embedding inputs. + * + * Accepts the same input shapes as the prompt constructor, or arrays of those shapes. + * + * @since 0.2.0 + * + * @param mixed ...$inputs The embedding inputs to add. + * @return self + */ + public function withEmbeddingInputs(...$inputs): self + { + foreach ($inputs as $input) { + $this->appendEmbeddingInput($input); + } + + return $this; + } + /** * Sets the model to use for generation. * @@ -667,7 +695,15 @@ private function isSupported(?CapabilityEnum $intendedCapability = null): bool } // Build requirements with the specified capability - $requirements = ModelRequirements::fromPromptData($intendedCapability, $this->messages, $this->modelConfig); + $messagesForRequirements = $this->messages; + if ($intendedCapability->isEmbeddingGeneration() && !empty($this->embeddingInputs)) { + $messagesForRequirements = $this->embeddingInputs; + } + $requirements = ModelRequirements::fromPromptData( + $intendedCapability, + $messagesForRequirements, + $this->modelConfig + ); // If the model has been set, check if it meets the requirements if ($this->model !== null) { @@ -859,12 +895,75 @@ public function generateResult(?CapabilityEnum $capability = null): GenerativeAi throw new RuntimeException('Output modality "video" is not yet supported.'); } + if ($capability->isEmbeddingGeneration()) { + throw new RuntimeException( + 'Embedding generation results must be retrieved using generateEmbeddingsResult().' + ); + } + // TODO: Add support for other capabilities when interfaces are available throw new RuntimeException( sprintf('Capability "%s" is not yet supported for generation.', $capability->value) ); } + /** + * Generates embeddings for the configured inputs. + * + * @since 0.2.0 + * + * @return EmbeddingResult The embedding result. + */ + public function generateEmbeddingsResult(): EmbeddingResult + { + $inputs = $this->resolveEmbeddingInputs(); + $model = $this->getConfiguredModel(CapabilityEnum::embeddingGeneration(), $inputs); + + if (!$model instanceof EmbeddingGenerationModelInterface) { + throw new RuntimeException( + sprintf('Model "%s" does not support embedding generation.', $model->metadata()->getId()) + ); + } + + return $model->generateEmbeddingsResult($inputs); + } + + /** + * Generates an embeddings operation for asynchronous processing. + * + * @since 0.2.0 + * + * @return EmbeddingOperation The embedding operation. + */ + public function generateEmbeddingsOperation(): EmbeddingOperation + { + $inputs = $this->resolveEmbeddingInputs(); + $model = $this->getConfiguredModel(CapabilityEnum::embeddingGeneration(), $inputs); + + if (!$model instanceof EmbeddingGenerationOperationModelInterface) { + throw new RuntimeException( + sprintf( + 'Model "%s" does not support embedding generation operations.', + $model->metadata()->getId() + ) + ); + } + + return $model->generateEmbeddingsOperation($inputs); + } + + /** + * Generates embeddings and returns their vector representations. + * + * @since 0.2.0 + * + * @return list> The embedding vectors. + */ + public function generateEmbeddings(): array + { + return $this->generateEmbeddingsResult()->toVectors(); + } + /** * Generates a text result from the prompt. * @@ -1094,6 +1193,32 @@ protected function appendPartToMessages(MessagePart $part): void $this->messages[] = new UserMessage([$part]); } + /** + * Normalizes embedding inputs into messages. + * + * @since 0.2.0 + * + * @param mixed $input The embedding input to append. + * @return void + */ + private function appendEmbeddingInput($input): void + { + if ( + is_array($input) + && !Message::isArrayShape($input) + && !MessagePart::isArrayShape($input) + && array_is_list($input) + ) { + foreach ($input as $nestedInput) { + $this->appendEmbeddingInput($nestedInput); + } + return; + } + + $message = $this->parseMessage($input, MessageRoleEnum::user()); + $this->embeddingInputs[] = $message; + } + /** * Gets the model to use for generation. * @@ -1103,12 +1228,14 @@ protected function appendPartToMessages(MessagePart $part): void * @since 0.1.0 * * @param CapabilityEnum $capability The capability the model will be using. + * @param list|null $messageContext Optional custom message context for requirements inference. * @return ModelInterface The model to use. * @throws InvalidArgumentException If no suitable model is found or set model doesn't meet requirements. */ - private function getConfiguredModel(CapabilityEnum $capability): ModelInterface + private function getConfiguredModel(CapabilityEnum $capability, ?array $messageContext = null): ModelInterface { - $requirements = ModelRequirements::fromPromptData($capability, $this->messages, $this->modelConfig); + $messages = $messageContext ?? $this->messages; + $requirements = ModelRequirements::fromPromptData($capability, $messages, $this->modelConfig); if ($this->model !== null) { // Explicit model was provided via usingModel(); just update config and bind dependencies. @@ -1398,6 +1525,23 @@ private function validateMessages(): void } } + /** + * Resolves embedding inputs, falling back to the current prompt messages if none were specified. + * + * @since 0.2.0 + * + * @return list The embedding input messages. + */ + private function resolveEmbeddingInputs(): array + { + if (!empty($this->embeddingInputs)) { + return $this->embeddingInputs; + } + + $this->validateMessages(); + return $this->messages; + } + /** * Checks if the value is a list of Message objects. * diff --git a/src/Embeddings/DTO/Embedding.php b/src/Embeddings/DTO/Embedding.php new file mode 100644 index 00000000..5b07748f --- /dev/null +++ b/src/Embeddings/DTO/Embedding.php @@ -0,0 +1,160 @@ +, + * dimension: int + * } + * + * @extends AbstractDataTransferObject + */ +class Embedding extends AbstractDataTransferObject +{ + public const KEY_VECTOR = 'vector'; + public const KEY_DIMENSION = 'dimension'; + + /** + * @var list The embedding vector values. + */ + private array $vector; + + /** + * @var int The dimensionality of the embedding vector. + */ + private int $dimension; + + /** + * Constructor. + * + * @since 0.2.0 + * + * @param list $vector The embedding vector values. + * @param int $dimension The dimensionality of the vector. + * + * @throws InvalidArgumentException If vector validation fails. + */ + public function __construct(array $vector, int $dimension) + { + if (!array_is_list($vector)) { + throw new InvalidArgumentException('Embedding vector must be a list array.'); + } + + $normalizedVector = []; + foreach ($vector as $value) { + if (!is_float($value) && !is_int($value)) { + throw new InvalidArgumentException('Embedding vector values must be numeric.'); + } + $normalizedVector[] = (float) $value; + } + + if ($dimension <= 0) { + throw new InvalidArgumentException('Embedding dimension must be greater than zero.'); + } + + if (count($normalizedVector) !== $dimension) { + throw new InvalidArgumentException( + sprintf( + 'Embedding dimension mismatch: expected %d values, got %d.', + $dimension, + count($normalizedVector) + ) + ); + } + + $this->vector = $normalizedVector; + $this->dimension = $dimension; + } + + /** + * Gets the embedding vector values. + * + * @since 0.2.0 + * + * @return list The embedding vector. + */ + public function getVector(): array + { + return $this->vector; + } + + /** + * Gets the dimensionality of the embedding vector. + * + * @since 0.2.0 + * + * @return int The embedding dimension. + */ + public function getDimension(): int + { + return $this->dimension; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public static function getJsonSchema(): array + { + return [ + 'type' => 'object', + 'properties' => [ + self::KEY_VECTOR => [ + 'type' => 'array', + 'items' => [ + 'type' => 'number', + ], + 'description' => 'The embedding vector values.', + ], + self::KEY_DIMENSION => [ + 'type' => 'integer', + 'minimum' => 1, + 'description' => 'The dimensionality of the embedding vector.', + ], + ], + 'required' => [self::KEY_VECTOR, self::KEY_DIMENSION], + 'additionalProperties' => false, + ]; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + * + * @return EmbeddingArrayShape + */ + public function toArray(): array + { + return [ + self::KEY_VECTOR => $this->vector, + self::KEY_DIMENSION => $this->dimension, + ]; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public static function fromArray(array $array): self + { + static::validateFromArrayData($array, [self::KEY_VECTOR, self::KEY_DIMENSION]); + + return new self( + $array[self::KEY_VECTOR], + $array[self::KEY_DIMENSION] + ); + } +} diff --git a/src/Operations/DTO/EmbeddingOperation.php b/src/Operations/DTO/EmbeddingOperation.php new file mode 100644 index 00000000..9911312d --- /dev/null +++ b/src/Operations/DTO/EmbeddingOperation.php @@ -0,0 +1,185 @@ + + */ +class EmbeddingOperation extends AbstractDataTransferObject implements OperationInterface +{ + public const KEY_ID = 'id'; + public const KEY_STATE = 'state'; + public const KEY_RESULT = 'result'; + + /** + * @var string Unique identifier for this operation. + */ + private string $id; + + /** + * @var OperationStateEnum The current state of the operation. + */ + private OperationStateEnum $state; + + /** + * @var EmbeddingResult|null The result once the operation completes. + */ + private ?EmbeddingResult $result; + + /** + * Constructor. + * + * @since 0.2.0 + * + * @param string $id Unique identifier for this operation. + * @param OperationStateEnum $state The current state of the operation. + * @param EmbeddingResult|null $result The result once the operation completes. + */ + public function __construct(string $id, OperationStateEnum $state, ?EmbeddingResult $result = null) + { + $this->id = $id; + $this->state = $state; + $this->result = $result; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public function getId(): string + { + return $this->id; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public function getState(): OperationStateEnum + { + return $this->state; + } + + /** + * Gets the operation result. + * + * @since 0.2.0 + * + * @return EmbeddingResult|null The embedding result or null if not yet complete. + */ + public function getResult(): ?EmbeddingResult + { + return $this->result; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public static function getJsonSchema(): array + { + return [ + 'oneOf' => [ + [ + 'type' => 'object', + 'properties' => [ + self::KEY_ID => [ + 'type' => 'string', + 'description' => 'Unique identifier for this operation.', + ], + self::KEY_STATE => [ + 'type' => 'string', + 'const' => OperationStateEnum::succeeded()->value, + ], + self::KEY_RESULT => EmbeddingResult::getJsonSchema(), + ], + 'required' => [self::KEY_ID, self::KEY_STATE, self::KEY_RESULT], + 'additionalProperties' => false, + ], + [ + 'type' => 'object', + 'properties' => [ + self::KEY_ID => [ + 'type' => 'string', + 'description' => 'Unique identifier for this operation.', + ], + self::KEY_STATE => [ + 'type' => 'string', + 'enum' => [ + OperationStateEnum::starting()->value, + OperationStateEnum::processing()->value, + OperationStateEnum::failed()->value, + OperationStateEnum::canceled()->value, + ], + 'description' => 'The current state of the operation.', + ], + ], + 'required' => [self::KEY_ID, self::KEY_STATE], + 'additionalProperties' => false, + ], + ], + ]; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + * + * @return EmbeddingOperationArrayShape + */ + public function toArray(): array + { + $data = [ + self::KEY_ID => $this->id, + self::KEY_STATE => $this->state->value, + ]; + + if ($this->result !== null) { + $data[self::KEY_RESULT] = $this->result->toArray(); + } + + return $data; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public static function fromArray(array $array): self + { + static::validateFromArrayData($array, [self::KEY_ID, self::KEY_STATE]); + + $state = OperationStateEnum::from($array[self::KEY_STATE]); + + if ($state->isSucceeded()) { + static::validateFromArrayData($array, [self::KEY_RESULT]); + } + + $result = null; + if (isset($array[self::KEY_RESULT])) { + $result = EmbeddingResult::fromArray($array[self::KEY_RESULT]); + } + + return new self($array[self::KEY_ID], $state, $result); + } +} diff --git a/src/ProviderImplementations/OpenAi/OpenAiEmbeddingModel.php b/src/ProviderImplementations/OpenAi/OpenAiEmbeddingModel.php new file mode 100644 index 00000000..6c1ca5c9 --- /dev/null +++ b/src/ProviderImplementations/OpenAi/OpenAiEmbeddingModel.php @@ -0,0 +1,185 @@ +, index?: int} + * @phpstan-type UsageData array{prompt_tokens?: int, total_tokens?: int} + * @phpstan-type ResponseData array{id?: string, data?: list, usage?: UsageData, model?: string} + */ +class OpenAiEmbeddingModel extends AbstractApiBasedModel implements EmbeddingGenerationModelInterface +{ + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public function generateEmbeddingsResult(array $input): EmbeddingResult + { + $httpTransporter = $this->getHttpTransporter(); + + $params = $this->prepareGenerateEmbeddingsParams($input); + + $request = new Request( + HttpMethodEnum::POST(), + OpenAiProvider::url('embeddings'), + ['Content-Type' => 'application/json'], + $params + ); + + $request = $this->getRequestAuthentication()->authenticateRequest($request); + + $response = $httpTransporter->send($request); + ResponseUtil::throwIfNotSuccessful($response); + + return $this->parseResponseToEmbeddingResult($response); + } + + /** + * Prepares the request payload for the embeddings endpoint. + * + * @param list $input The embedding inputs. + * @return array + */ + private function prepareGenerateEmbeddingsParams(array $input): array + { + if (!array_is_list($input)) { + throw new InvalidArgumentException('Embedding input must be provided as a list of messages.'); + } + + $preparedInput = array_map( + fn(Message $message): string => $this->messageToText($message), + $input + ); + + $params = [ + 'model' => $this->metadata()->getId(), + 'input' => $preparedInput, + ]; + + $customOptions = $this->getConfig()->getCustomOptions(); + foreach ($customOptions as $key => $value) { + if (isset($params[$key])) { + throw new InvalidArgumentException( + sprintf('The custom option "%s" conflicts with an existing parameter.', $key) + ); + } + $params[$key] = $value; + } + + return $params; + } + + /** + * Converts a message to a text payload accepted by the embeddings API. + * + * @param Message $message The message to convert. + * @return string + */ + private function messageToText(Message $message): string + { + $parts = []; + /** @var MessagePart $part */ + foreach ($message->getParts() as $part) { + $text = $part->getText(); + if ($text !== null) { + $parts[] = $text; + } + } + + if (empty($parts)) { + throw new InvalidArgumentException( + 'Embedding input messages must contain at least one text part.' + ); + } + + return implode("\n\n", $parts); + } + + /** + * Parses the embeddings response into an EmbeddingResult. + * + * @param Response $response The API response. + * @return EmbeddingResult + */ + private function parseResponseToEmbeddingResult(Response $response): EmbeddingResult + { + /** @var ResponseData $responseData */ + $responseData = $response->getData(); + if ($responseData === null) { + throw ResponseException::fromInvalidData( + $this->providerMetadata()->getName(), + 'body', + 'Response body must contain JSON data.' + ); + } + + if (!isset($responseData['data']) || !is_array($responseData['data']) || empty($responseData['data'])) { + throw ResponseException::fromMissingData($this->providerMetadata()->getName(), 'data'); + } + + $embeddings = []; + foreach ($responseData['data'] as $index => $embeddingData) { + if ( + !is_array($embeddingData) || + !isset($embeddingData['embedding']) || + !is_array($embeddingData['embedding']) + ) { + throw ResponseException::fromInvalidData( + $this->providerMetadata()->getName(), + sprintf('data[%d].embedding', $index), + 'The value must be an array of floats.' + ); + } + + $embeddings[] = new Embedding( + $embeddingData['embedding'], + count($embeddingData['embedding']) + ); + } + + $usageData = $responseData['usage'] ?? []; + $promptTokens = isset($usageData['prompt_tokens']) ? (int) $usageData['prompt_tokens'] : 0; + $totalTokens = isset($usageData['total_tokens']) ? (int) $usageData['total_tokens'] : $promptTokens; + + $tokenUsage = new TokenUsage($promptTokens, 0, $totalTokens); + + $resultId = isset($responseData['id']) && is_string($responseData['id']) + ? $responseData['id'] + : sprintf('%s-embeddings', $this->metadata()->getId()); + + $additionalData = []; + if (isset($responseData['model']) && is_string($responseData['model'])) { + $additionalData['model'] = $responseData['model']; + } + + return new EmbeddingResult( + $resultId, + $embeddings, + $tokenUsage, + $this->providerMetadata(), + $this->metadata(), + $additionalData + ); + } +} diff --git a/src/ProviderImplementations/OpenAi/OpenAiModelMetadataDirectory.php b/src/ProviderImplementations/OpenAi/OpenAiModelMetadataDirectory.php index fc522260..02686eb4 100644 --- a/src/ProviderImplementations/OpenAi/OpenAiModelMetadataDirectory.php +++ b/src/ProviderImplementations/OpenAi/OpenAiModelMetadataDirectory.php @@ -164,6 +164,13 @@ protected function parseResponseToModelMetadataList(Response $response): array new SupportedOption(OptionEnum::outputSpeechVoice()), new SupportedOption(OptionEnum::customOptions()), ]; + $embeddingCapabilities = [ + CapabilityEnum::embeddingGeneration(), + ]; + $embeddingOptions = [ + new SupportedOption(OptionEnum::inputModalities(), [[ModalityEnum::text()]]), + new SupportedOption(OptionEnum::customOptions()), + ]; $modelsData = (array) $responseData['data']; @@ -179,9 +186,15 @@ static function (array $modelData) use ( $dalleImageOptions, $gptImageOptions, $ttsCapabilities, - $ttsOptions + $ttsOptions, + $embeddingCapabilities, + $embeddingOptions ): ModelMetadata { $modelId = $modelData['id']; + /** @var list $modelCaps */ + $modelCaps = []; + /** @var list $modelOptions */ + $modelOptions = []; if ( str_starts_with($modelId, 'dall-e-') || str_starts_with($modelId, 'gpt-image-') @@ -198,6 +211,9 @@ static function (array $modelData) use ( ) { $modelCaps = $ttsCapabilities; $modelOptions = $ttsOptions; + } elseif (str_contains($modelId, 'embedding')) { + $modelCaps = $embeddingCapabilities; + $modelOptions = $embeddingOptions; } elseif ( (str_starts_with($modelId, 'gpt-') || str_starts_with($modelId, 'o1-')) && !str_contains($modelId, '-instruct') diff --git a/src/ProviderImplementations/OpenAi/OpenAiProvider.php b/src/ProviderImplementations/OpenAi/OpenAiProvider.php index 6857feea..9514e0cb 100644 --- a/src/ProviderImplementations/OpenAi/OpenAiProvider.php +++ b/src/ProviderImplementations/OpenAi/OpenAiProvider.php @@ -54,6 +54,9 @@ protected static function createModel( 'OpenAI text to speech conversion model class is not yet implemented.' ); } + if ($capability->isEmbeddingGeneration()) { + return new OpenAiEmbeddingModel($modelMetadata, $providerMetadata); + } } throw new RuntimeException( diff --git a/src/Providers/Models/Contracts/WithEmbeddingOperationsInterface.php b/src/Providers/Models/Contracts/WithEmbeddingOperationsInterface.php new file mode 100644 index 00000000..6572515e --- /dev/null +++ b/src/Providers/Models/Contracts/WithEmbeddingOperationsInterface.php @@ -0,0 +1,25 @@ + 1) { + if (!$capability->isEmbeddingGeneration() && count($messages) > 1) { $capabilities[] = CapabilityEnum::chatHistory(); } diff --git a/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php b/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php new file mode 100644 index 00000000..18acd0f9 --- /dev/null +++ b/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationModelInterface.php @@ -0,0 +1,26 @@ + $input The input documents/messages to embed. + * @return EmbeddingResult The generated embeddings result. + */ + public function generateEmbeddingsResult(array $input): EmbeddingResult; +} diff --git a/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationOperationModelInterface.php b/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationOperationModelInterface.php new file mode 100644 index 00000000..a8ae3834 --- /dev/null +++ b/src/Providers/Models/EmbeddingGeneration/Contracts/EmbeddingGenerationOperationModelInterface.php @@ -0,0 +1,26 @@ + $input The input documents/messages to embed. + * @return EmbeddingOperation The created operation. + */ + public function generateEmbeddingsOperation(array $input): EmbeddingOperation; +} diff --git a/src/Results/DTO/EmbeddingResult.php b/src/Results/DTO/EmbeddingResult.php new file mode 100644 index 00000000..3ae3961c --- /dev/null +++ b/src/Results/DTO/EmbeddingResult.php @@ -0,0 +1,318 @@ +, + * tokenUsage: TokenUsageArrayShape, + * providerMetadata: ProviderMetadataArrayShape, + * modelMetadata: ModelMetadataArrayShape, + * additionalData?: array + * } + * + * @extends AbstractDataTransferObject + */ +class EmbeddingResult extends AbstractDataTransferObject implements ResultInterface +{ + public const KEY_ID = 'id'; + public const KEY_EMBEDDINGS = 'embeddings'; + public const KEY_TOKEN_USAGE = 'tokenUsage'; + public const KEY_PROVIDER_METADATA = 'providerMetadata'; + public const KEY_MODEL_METADATA = 'modelMetadata'; + public const KEY_ADDITIONAL_DATA = 'additionalData'; + + /** + * @var string Unique identifier for this result. + */ + private string $id; + + /** + * @var list Embeddings returned by the provider. + */ + private array $embeddings; + + /** + * @var TokenUsage Token usage statistics. + */ + private TokenUsage $tokenUsage; + + /** + * @var ProviderMetadata Provider metadata. + */ + private ProviderMetadata $providerMetadata; + + /** + * @var ModelMetadata Model metadata. + */ + private ModelMetadata $modelMetadata; + + /** + * @var array Provider-specific metadata. + */ + private array $additionalData; + + /** + * Constructor. + * + * @since 0.2.0 + * + * @param string $id Unique identifier for this result. + * @param list $embeddings Embeddings returned by the provider. + * @param TokenUsage $tokenUsage Token usage statistics. + * @param ProviderMetadata $providerMetadata Provider metadata. + * @param ModelMetadata $modelMetadata Model metadata. + * @param array $additionalData Provider-specific metadata. + */ + public function __construct( + string $id, + array $embeddings, + TokenUsage $tokenUsage, + ProviderMetadata $providerMetadata, + ModelMetadata $modelMetadata, + array $additionalData = [] + ) { + if (empty($embeddings)) { + throw new InvalidArgumentException('At least one embedding must be provided.'); + } + + if (!array_is_list($embeddings)) { + throw new InvalidArgumentException('Embeddings must be provided as a list array.'); + } + + $this->id = $id; + $this->embeddings = $embeddings; + $this->tokenUsage = $tokenUsage; + $this->providerMetadata = $providerMetadata; + $this->modelMetadata = $modelMetadata; + $this->additionalData = $additionalData; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public function getId(): string + { + return $this->id; + } + + /** + * Gets the embeddings. + * + * @since 0.2.0 + * + * @return list The embeddings. + */ + public function getEmbeddings(): array + { + return $this->embeddings; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public function getTokenUsage(): TokenUsage + { + return $this->tokenUsage; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public function getProviderMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public function getModelMetadata(): ModelMetadata + { + return $this->modelMetadata; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public function getAdditionalData(): array + { + return $this->additionalData; + } + + /** + * Gets the number of embeddings. + * + * @since 0.2.0 + * + * @return int The number of embeddings. + */ + public function getEmbeddingCount(): int + { + return count($this->embeddings); + } + + /** + * Checks if multiple embeddings were returned. + * + * @since 0.2.0 + * + * @return bool True if more than one embedding is available. + */ + public function hasMultipleEmbeddings(): bool + { + return $this->getEmbeddingCount() > 1; + } + + /** + * Returns the first embedding vector. + * + * @since 0.2.0 + * + * @return list The first embedding vector. + */ + public function toVector(): array + { + return $this->embeddings[0]->getVector(); + } + + /** + * Returns all embedding vectors. + * + * @since 0.2.0 + * + * @return list> All embedding vectors. + */ + public function toVectors(): array + { + return array_map( + static fn(Embedding $embedding): array => $embedding->getVector(), + $this->embeddings + ); + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public static function getJsonSchema(): array + { + return [ + 'type' => 'object', + 'properties' => [ + self::KEY_ID => [ + 'type' => 'string', + 'description' => 'Unique identifier for this result.', + ], + self::KEY_EMBEDDINGS => [ + 'type' => 'array', + 'items' => Embedding::getJsonSchema(), + 'minItems' => 1, + 'description' => 'Embeddings returned by the provider.', + ], + self::KEY_TOKEN_USAGE => TokenUsage::getJsonSchema(), + self::KEY_PROVIDER_METADATA => ProviderMetadata::getJsonSchema(), + self::KEY_MODEL_METADATA => ModelMetadata::getJsonSchema(), + self::KEY_ADDITIONAL_DATA => [ + 'type' => 'object', + 'additionalProperties' => true, + 'description' => 'Provider-specific metadata.', + ], + ], + 'required' => [ + self::KEY_ID, + self::KEY_EMBEDDINGS, + self::KEY_TOKEN_USAGE, + self::KEY_PROVIDER_METADATA, + self::KEY_MODEL_METADATA, + ], + 'additionalProperties' => false, + ]; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + * + * @return EmbeddingResultArrayShape + */ + public function toArray(): array + { + return [ + self::KEY_ID => $this->id, + self::KEY_EMBEDDINGS => array_map( + static fn(Embedding $embedding): array => $embedding->toArray(), + $this->embeddings + ), + self::KEY_TOKEN_USAGE => $this->tokenUsage->toArray(), + self::KEY_PROVIDER_METADATA => $this->providerMetadata->toArray(), + self::KEY_MODEL_METADATA => $this->modelMetadata->toArray(), + self::KEY_ADDITIONAL_DATA => $this->additionalData, + ]; + } + + /** + * {@inheritDoc} + * + * @since 0.2.0 + */ + public static function fromArray(array $array): self + { + static::validateFromArrayData($array, [ + self::KEY_ID, + self::KEY_EMBEDDINGS, + self::KEY_TOKEN_USAGE, + self::KEY_PROVIDER_METADATA, + self::KEY_MODEL_METADATA, + ]); + + $embeddings = array_values( + array_map( + static fn(array $embedding): Embedding => Embedding::fromArray($embedding), + $array[self::KEY_EMBEDDINGS] + ) + ); + + $additionalData = $array[self::KEY_ADDITIONAL_DATA] ?? []; + + return new self( + $array[self::KEY_ID], + $embeddings, + TokenUsage::fromArray($array[self::KEY_TOKEN_USAGE]), + ProviderMetadata::fromArray($array[self::KEY_PROVIDER_METADATA]), + ModelMetadata::fromArray($array[self::KEY_MODEL_METADATA]), + $additionalData + ); + } +} diff --git a/tests/mocks/MockProvider.php b/tests/mocks/MockProvider.php index 8414046d..e0235da2 100644 --- a/tests/mocks/MockProvider.php +++ b/tests/mocks/MockProvider.php @@ -78,7 +78,13 @@ public static function modelMetadataDirectory(): ModelMetadataDirectoryInterface 'Mock Text Model', [CapabilityEnum::textGeneration()], [] - ) + ), + 'mock-embedding-model' => new ModelMetadata( + 'mock-embedding-model', + 'Mock Embedding Model', + [CapabilityEnum::embeddingGeneration()], + [] + ), ]; static::$modelMetadataDirectory = new MockModelMetadataDirectory($mockModels); diff --git a/tests/traits/MockModelCreationTrait.php b/tests/traits/MockModelCreationTrait.php index ecf033c8..171d8a3e 100644 --- a/tests/traits/MockModelCreationTrait.php +++ b/tests/traits/MockModelCreationTrait.php @@ -5,18 +5,24 @@ namespace WordPress\AiClient\Tests\traits; use Generator; +use WordPress\AiClient\Embeddings\DTO\Embedding; use WordPress\AiClient\Messages\DTO\MessagePart; use WordPress\AiClient\Messages\DTO\ModelMessage; +use WordPress\AiClient\Operations\DTO\EmbeddingOperation; +use WordPress\AiClient\Operations\Enums\OperationStateEnum; use WordPress\AiClient\Providers\DTO\ProviderMetadata; use WordPress\AiClient\Providers\Enums\ProviderTypeEnum; use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; +use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationModelInterface; +use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationOperationModelInterface; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; use WordPress\AiClient\Providers\ProviderRegistry; use WordPress\AiClient\Results\DTO\Candidate; +use WordPress\AiClient\Results\DTO\EmbeddingResult; use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Results\DTO\TokenUsage; use WordPress\AiClient\Results\Enums\FinishReasonEnum; @@ -114,6 +120,65 @@ protected function createTestImageModelMetadata( ); } + /** + * Creates a test model metadata instance for embedding generation. + * + * @param string $id Optional model ID. + * @param string $name Optional model name. + * @return ModelMetadata + */ + protected function createTestEmbeddingModelMetadata( + string $id = 'test-embedding-model', + string $name = 'Test Embedding Model' + ): ModelMetadata { + return new ModelMetadata( + $id, + $name, + [CapabilityEnum::embeddingGeneration()], + [] + ); + } + + /** + * Creates a test embedding result. + * + * @param list> $vectors Optional embedding vectors. + * @return EmbeddingResult + */ + protected function createTestEmbeddingResult(array $vectors = [[0.1, 0.2], [0.3, 0.4]]): EmbeddingResult + { + $embeddings = array_map( + static fn(array $vector): Embedding => new Embedding($vector, count($vector)), + $vectors + ); + + $tokenUsage = new TokenUsage(10, 0, 10); + $providerMetadata = new ProviderMetadata('mock', 'Mock Provider', ProviderTypeEnum::cloud()); + $modelMetadata = $this->createTestEmbeddingModelMetadata(); + + return new EmbeddingResult( + 'test-embedding-result', + $embeddings, + $tokenUsage, + $providerMetadata, + $modelMetadata + ); + } + + /** + * Creates a test embedding operation. + * + * @return EmbeddingOperation + */ + protected function createTestEmbeddingOperation(): EmbeddingOperation + { + return new EmbeddingOperation( + 'test-embedding-operation', + OperationStateEnum::succeeded(), + $this->createTestEmbeddingResult() + ); + } + /** * Creates a mock text generation model using anonymous class. * @@ -253,6 +318,140 @@ public function generateImageResult(array $prompt): GenerativeAiResult }; } + /** + * Creates a mock embedding generation model using an anonymous class. + * + * @param EmbeddingResult $result The result to return from generation. + * @param ModelMetadata|null $metadata Optional metadata (uses default if not provided). + * @return ModelInterface&EmbeddingGenerationModelInterface The mock model. + */ + protected function createMockEmbeddingGenerationModel( + EmbeddingResult $result, + ?ModelMetadata $metadata = null + ): ModelInterface { + $metadata = $metadata ?? $this->createTestEmbeddingModelMetadata(); + + $providerMetadata = new ProviderMetadata( + 'mock', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + + return new class ( + $metadata, + $providerMetadata, + $result + ) implements ModelInterface, EmbeddingGenerationModelInterface { + private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; + private EmbeddingResult $result; + private ModelConfig $config; + + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata, + EmbeddingResult $result + ) { + $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->result = $result; + $this->config = new ModelConfig(); + } + + public function metadata(): ModelMetadata + { + return $this->metadata; + } + + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + public function getConfig(): ModelConfig + { + return $this->config; + } + + public function generateEmbeddingsResult(array $input): EmbeddingResult + { + return $this->result; + } + }; + } + + /** + * Creates a mock embedding generation operation model using an anonymous class. + * + * @param EmbeddingOperation $operation The operation to return. + * @param ModelMetadata|null $metadata Optional metadata (uses default if not provided). + * @return ModelInterface&EmbeddingGenerationOperationModelInterface The mock model. + */ + protected function createMockEmbeddingOperationModel( + EmbeddingOperation $operation, + ?ModelMetadata $metadata = null + ): ModelInterface { + $metadata = $metadata ?? $this->createTestEmbeddingModelMetadata(); + + $providerMetadata = new ProviderMetadata( + 'mock', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + + return new class ( + $metadata, + $providerMetadata, + $operation + ) implements ModelInterface, EmbeddingGenerationOperationModelInterface { + private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; + private EmbeddingOperation $operation; + private ModelConfig $config; + + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata, + EmbeddingOperation $operation + ) { + $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->operation = $operation; + $this->config = new ModelConfig(); + } + + public function metadata(): ModelMetadata + { + return $this->metadata; + } + + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + public function getConfig(): ModelConfig + { + return $this->config; + } + + public function generateEmbeddingsOperation(array $input): EmbeddingOperation + { + return $this->operation; + } + }; + } + /** * Creates a mock model that doesn't implement any generation interfaces. * diff --git a/tests/unit/AiClientTest.php b/tests/unit/AiClientTest.php index 8aef3f0b..f8451335 100644 --- a/tests/unit/AiClientTest.php +++ b/tests/unit/AiClientTest.php @@ -742,4 +742,32 @@ public function testGetConfiguredPromptBuilderHelperIntegration(): void $this->expectExceptionMessageMatches('/No models found that support/'); AiClient::generateResult($prompt, null, $this->createMockEmptyRegistry()); } + + /** + * Tests generateEmbeddingsResult with string input and provided model. + */ + public function testGenerateEmbeddingsResultWithString(): void + { + $embeddingResult = $this->createTestEmbeddingResult(); + $mockModel = $this->createMockEmbeddingGenerationModel($embeddingResult); + $registry = $this->createRegistryWithMockProvider(); + + $result = AiClient::generateEmbeddingsResult('Embed this text', $mockModel, $registry); + + $this->assertSame($embeddingResult, $result); + } + + /** + * Tests generateEmbeddingsOperation with provided model. + */ + public function testGenerateEmbeddingsOperation(): void + { + $operation = $this->createTestEmbeddingOperation(); + $mockModel = $this->createMockEmbeddingOperationModel($operation); + $registry = $this->createRegistryWithMockProvider(); + + $result = AiClient::generateEmbeddingsOperation('Operation doc', $mockModel, $registry); + + $this->assertSame($operation, $result); + } } diff --git a/tests/unit/Builders/PromptBuilderTest.php b/tests/unit/Builders/PromptBuilderTest.php index de23054a..2567a710 100644 --- a/tests/unit/Builders/PromptBuilderTest.php +++ b/tests/unit/Builders/PromptBuilderTest.php @@ -3390,4 +3390,61 @@ public function testMethodChainingWithNewMethods(): void $this->assertTrue($config->getLogprobs()); $this->assertEquals(3, $config->getTopLogprobs()); } + + /** + * Tests generating embeddings result with explicit embedding inputs. + */ + public function testGenerateEmbeddingsResultWithExplicitInputs(): void + { + $embeddingResult = $this->createTestEmbeddingResult(); + $mockModel = $this->createMockEmbeddingGenerationModel($embeddingResult); + + $builder = new PromptBuilder($this->registry); + $builder->usingModel($mockModel)->withEmbeddingInputs('Doc 1', 'Doc 2'); + + $this->assertSame($embeddingResult, $builder->generateEmbeddingsResult()); + } + + /** + * Tests embedding generation falls back to existing prompt messages. + */ + public function testGenerateEmbeddingsResultFallsBackToPromptMessages(): void + { + $embeddingResult = $this->createTestEmbeddingResult([[0.5, 0.6]]); + $mockModel = $this->createMockEmbeddingGenerationModel($embeddingResult); + + $builder = new PromptBuilder($this->registry, 'Embed this text'); + $builder->usingModel($mockModel); + + $this->assertSame($embeddingResult, $builder->generateEmbeddingsResult()); + } + + /** + * Tests generating embedding operations via the builder. + */ + public function testGenerateEmbeddingsOperation(): void + { + $operation = $this->createTestEmbeddingOperation(); + $mockModel = $this->createMockEmbeddingOperationModel($operation); + + $builder = new PromptBuilder($this->registry); + $builder->usingModel($mockModel)->withEmbeddingInputs('Doc for operation'); + + $this->assertSame($operation, $builder->generateEmbeddingsOperation()); + } + + /** + * Tests generateEmbeddings convenience helper. + */ + public function testGenerateEmbeddingsHelperReturnsVectors(): void + { + $vectors = [[0.9, 0.1]]; + $embeddingResult = $this->createTestEmbeddingResult($vectors); + $mockModel = $this->createMockEmbeddingGenerationModel($embeddingResult); + + $builder = new PromptBuilder($this->registry); + $builder->usingModel($mockModel)->withEmbeddingInputs('Only doc'); + + $this->assertSame($vectors, $builder->generateEmbeddings()); + } } diff --git a/tests/unit/Embeddings/DTO/EmbeddingTest.php b/tests/unit/Embeddings/DTO/EmbeddingTest.php new file mode 100644 index 00000000..c3dfa1c5 --- /dev/null +++ b/tests/unit/Embeddings/DTO/EmbeddingTest.php @@ -0,0 +1,103 @@ +assertSame($vector, $embedding->getVector()); + $this->assertSame(3, $embedding->getDimension()); + } + + /** + * Tests constructor normalizes integer values to floats. + */ + public function testConstructorNormalizesNumericValues(): void + { + $embedding = new Embedding([1, 2, 3], 3); + + $this->assertSame([1.0, 2.0, 3.0], $embedding->getVector()); + } + + /** + * Tests constructor validates dimension length. + */ + public function testConstructorValidatesDimension(): void + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Embedding dimension mismatch'); + + new Embedding([0.1, 0.2], 3); + } + + /** + * Tests constructor rejects non-numeric values. + */ + public function testConstructorRejectsNonNumericValues(): void + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Embedding vector values must be numeric.'); + + new Embedding([0.1, 'invalid'], 2); + } + + /** + * Tests array transformation produces the expected payload. + */ + public function testArrayTransformation(): void + { + $embedding = new Embedding([0.5, 0.25], 2); + + $this->assertSame( + [ + Embedding::KEY_VECTOR => [0.5, 0.25], + Embedding::KEY_DIMENSION => 2, + ], + $embedding->toArray() + ); + } + + /** + * Tests fromArray creates a matching embedding instance. + */ + public function testFromArrayCreatesEmbedding(): void + { + $data = [ + Embedding::KEY_VECTOR => [0.1, 0.2, 0.3], + Embedding::KEY_DIMENSION => 3, + ]; + + $embedding = Embedding::fromArray($data); + + $this->assertSame($data[Embedding::KEY_VECTOR], $embedding->getVector()); + $this->assertSame($data[Embedding::KEY_DIMENSION], $embedding->getDimension()); + } + + /** + * Tests JSON schema definition. + */ + public function testJsonSchema(): void + { + $schema = Embedding::getJsonSchema(); + + $this->assertSame('object', $schema['type']); + $this->assertArrayHasKey(Embedding::KEY_VECTOR, $schema['properties']); + $this->assertArrayHasKey(Embedding::KEY_DIMENSION, $schema['properties']); + } +} diff --git a/tests/unit/Operations/DTO/EmbeddingOperationTest.php b/tests/unit/Operations/DTO/EmbeddingOperationTest.php new file mode 100644 index 00000000..318efc37 --- /dev/null +++ b/tests/unit/Operations/DTO/EmbeddingOperationTest.php @@ -0,0 +1,107 @@ +assertSame('op-1', $operation->getId()); + $this->assertTrue($operation->getState()->isProcessing()); + $this->assertNull($operation->getResult()); + } + + /** + * Tests creating an operation with a completed result. + */ + public function testCreateWithResult(): void + { + $result = $this->createEmbeddingResult(); + $operation = new EmbeddingOperation('op-2', OperationStateEnum::succeeded(), $result); + + $this->assertSame('op-2', $operation->getId()); + $this->assertTrue($operation->getState()->isSucceeded()); + $this->assertSame($result, $operation->getResult()); + } + + /** + * Tests operation implements the interface. + */ + public function testImplementsInterface(): void + { + $operation = new EmbeddingOperation('op-3', OperationStateEnum::starting()); + + $this->assertInstanceOf(OperationInterface::class, $operation); + } + + /** + * Tests array transformation round-trip. + */ + public function testArrayTransformation(): void + { + $operation = new EmbeddingOperation( + 'op-4', + OperationStateEnum::succeeded(), + $this->createEmbeddingResult() + ); + + $array = $operation->toArray(); + $rehydrated = EmbeddingOperation::fromArray($array); + + $this->assertSame($operation->getId(), $rehydrated->getId()); + $this->assertTrue($rehydrated->getState()->isSucceeded()); + $this->assertSame( + $operation->getResult()->toVectors(), + $rehydrated->getResult()->toVectors() + ); + } + + /** + * Tests JSON schema definition. + */ + public function testJsonSchema(): void + { + $schema = EmbeddingOperation::getJsonSchema(); + + $this->assertArrayHasKey('oneOf', $schema); + $this->assertCount(2, $schema['oneOf']); + } +} diff --git a/tests/unit/ProviderImplementations/OpenAi/OpenAiEmbeddingModelTest.php b/tests/unit/ProviderImplementations/OpenAi/OpenAiEmbeddingModelTest.php new file mode 100644 index 00000000..4dfab8ba --- /dev/null +++ b/tests/unit/ProviderImplementations/OpenAi/OpenAiEmbeddingModelTest.php @@ -0,0 +1,77 @@ +setHttpTransporter($transporter); + $model->setRequestAuthentication(new MockRequestAuthentication('test-token')); + + $responseBody = json_encode([ + 'id' => 'embed-123', + 'model' => 'text-embedding-3-small', + 'data' => [ + ['embedding' => [0.1, 0.2], 'index' => 0], + ['embedding' => [0.3, 0.4], 'index' => 1], + ], + 'usage' => [ + 'prompt_tokens' => 20, + 'total_tokens' => 20, + ], + ]); + $transporter->setResponseToReturn(new Response(200, [], $responseBody)); + + $messages = [ + new UserMessage([new MessagePart('First document')]), + new UserMessage([new MessagePart('Second document')]), + ]; + + $result = $model->generateEmbeddingsResult($messages); + + $this->assertSame([[0.1, 0.2], [0.3, 0.4]], $result->toVectors()); + $this->assertEquals('embed-123', $result->getId()); + $this->assertEquals(20, $result->getTokenUsage()->getPromptTokens()); + + $request = $transporter->getLastRequest(); + $this->assertNotNull($request); + $this->assertStringEndsWith('/embeddings', $request->getUri()); + + $payload = $request->getData(); + $this->assertIsArray($payload); + $this->assertSame('text-embedding-3-small', $payload['model']); + $this->assertSame(['First document', 'Second document'], $payload['input']); + } +} diff --git a/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php b/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php index 1cb7897e..221a1a33 100644 --- a/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php +++ b/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php @@ -632,4 +632,27 @@ public function testFromPromptDataWithModelConfigOptions(): void $this->assertTrue($hasMaxTokens, 'Max tokens option should be present'); $this->assertTrue($hasTopP, 'Top P option should be present'); } + + /** + * Tests fromPromptData does not mark embeddings as requiring chat history. + * + * @return void + */ + public function testFromPromptDataForEmbeddings(): void + { + $messages = [ + new UserMessage([new MessagePart('First doc')]), + new UserMessage([new MessagePart('Second doc')]), + ]; + + $requirements = ModelRequirements::fromPromptData( + CapabilityEnum::embeddingGeneration(), + $messages, + new ModelConfig() + ); + + $capabilities = $requirements->getRequiredCapabilities(); + $this->assertCount(1, $capabilities); + $this->assertTrue($capabilities[0]->isEmbeddingGeneration()); + } } diff --git a/tests/unit/Providers/ProviderRegistryTest.php b/tests/unit/Providers/ProviderRegistryTest.php index 6c08fbae..5a038ba5 100644 --- a/tests/unit/Providers/ProviderRegistryTest.php +++ b/tests/unit/Providers/ProviderRegistryTest.php @@ -196,6 +196,23 @@ public function testFindProviderModelsMetadataForSupportWithRegisteredProvider() $this->assertCount(1, $results); } + /** + * Tests discovery for embedding-capable models. + * + * @return void + */ + public function testFindModelsMetadataForEmbeddingSupport(): void + { + $this->registry->registerProvider(MockProvider::class); + + $requirements = new ModelRequirements([CapabilityEnum::embeddingGeneration()], []); + $results = $this->registry->findModelsMetadataForSupport($requirements); + + $this->assertNotEmpty($results); + $embeddingModels = $results[0]->getModels(); + $this->assertSame('mock-embedding-model', $embeddingModels[0]->getId()); + } + /** * Tests findProviderModelsMetadataForSupport with unregistered provider. * diff --git a/tests/unit/Results/DTO/EmbeddingResultTest.php b/tests/unit/Results/DTO/EmbeddingResultTest.php new file mode 100644 index 00000000..8311b66d --- /dev/null +++ b/tests/unit/Results/DTO/EmbeddingResultTest.php @@ -0,0 +1,167 @@ + + */ + private function createEmbeddings(): array + { + return [ + new Embedding([0.1, 0.2], 2), + new Embedding([0.3, 0.4], 2), + ]; + } + + /** + * Tests creating an embedding result with valid data. + */ + public function testCreateEmbeddingResult(): void + { + $tokenUsage = new TokenUsage(10, 0, 10); + $result = new EmbeddingResult( + 'embedding-result', + $this->createEmbeddings(), + $tokenUsage, + $this->createProviderMetadata(), + $this->createModelMetadata() + ); + + $this->assertSame('embedding-result', $result->getId()); + $this->assertCount(2, $result->getEmbeddings()); + $this->assertSame($tokenUsage, $result->getTokenUsage()); + $this->assertInstanceOf(ResultInterface::class, $result); + } + + /** + * Tests constructor enforces at least one embedding. + */ + public function testConstructorRejectsEmptyEmbeddings(): void + { + $tokenUsage = new TokenUsage(5, 0, 5); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('At least one embedding must be provided.'); + + new EmbeddingResult( + 'empty', + [], + $tokenUsage, + $this->createProviderMetadata(), + $this->createModelMetadata() + ); + } + + /** + * Tests helper methods for embedding vectors. + */ + public function testVectorHelpers(): void + { + $result = new EmbeddingResult( + 'vectors', + $this->createEmbeddings(), + new TokenUsage(3, 0, 3), + $this->createProviderMetadata(), + $this->createModelMetadata() + ); + + $this->assertSame([0.1, 0.2], $result->toVector()); + $this->assertSame([[0.1, 0.2], [0.3, 0.4]], $result->toVectors()); + $this->assertTrue($result->hasMultipleEmbeddings()); + $this->assertSame(2, $result->getEmbeddingCount()); + } + + /** + * Tests additional data handling. + */ + public function testAdditionalDataExposure(): void + { + $metadata = ['dimension' => 1536]; + $result = new EmbeddingResult( + 'with-metadata', + $this->createEmbeddings(), + new TokenUsage(1536, 0, 1536), + $this->createProviderMetadata(), + $this->createModelMetadata(), + $metadata + ); + + $this->assertSame($metadata, $result->getAdditionalData()); + } + + /** + * Tests JSON schema definition. + */ + public function testJsonSchema(): void + { + $schema = EmbeddingResult::getJsonSchema(); + + $this->assertSame('object', $schema['type']); + $this->assertArrayHasKey(EmbeddingResult::KEY_EMBEDDINGS, $schema['properties']); + $this->assertArrayHasKey(EmbeddingResult::KEY_TOKEN_USAGE, $schema['properties']); + $this->assertContains( + EmbeddingResult::KEY_ID, + $schema['required'] + ); + } + + /** + * Tests array conversion round-trips correctly. + */ + public function testArrayTransformation(): void + { + $result = new EmbeddingResult( + 'to-array', + $this->createEmbeddings(), + new TokenUsage(42, 0, 42), + $this->createProviderMetadata(), + $this->createModelMetadata() + ); + + $array = $result->toArray(); + $hydrated = EmbeddingResult::fromArray($array); + + $this->assertSame($result->getId(), $hydrated->getId()); + $this->assertSame($result->toVectors(), $hydrated->toVectors()); + $this->assertSame( + $result->getTokenUsage()->getTotalTokens(), + $hydrated->getTokenUsage()->getTotalTokens() + ); + } +} From d0f83c0fa596137e570220c4161d646b50bee7fc Mon Sep 17 00:00:00 2001 From: James LePage <36246732+Jameswlepage@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:12:27 -0500 Subject: [PATCH 2/2] chore: sort new embedding use statements --- src/Builders/PromptBuilder.php | 2 +- tests/traits/MockModelCreationTrait.php | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Builders/PromptBuilder.php b/src/Builders/PromptBuilder.php index aad7858f..a6fde23d 100644 --- a/src/Builders/PromptBuilder.php +++ b/src/Builders/PromptBuilder.php @@ -19,9 +19,9 @@ use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\DTO\ModelRequirements; use WordPress\AiClient\Providers\Models\DTO\RequiredOption; -use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationModelInterface; use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationOperationModelInterface; +use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\SpeechGeneration\Contracts\SpeechGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; diff --git a/tests/traits/MockModelCreationTrait.php b/tests/traits/MockModelCreationTrait.php index 171d8a3e..c85a6d02 100644 --- a/tests/traits/MockModelCreationTrait.php +++ b/tests/traits/MockModelCreationTrait.php @@ -15,9 +15,9 @@ use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; -use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationModelInterface; use WordPress\AiClient\Providers\Models\EmbeddingGeneration\Contracts\EmbeddingGenerationOperationModelInterface; +use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; use WordPress\AiClient\Providers\ProviderRegistry;