diff --git a/examples/rag/postgres-hybrid.php b/examples/rag/postgres-hybrid.php new file mode 100644 index 000000000..ae4ec52f9 --- /dev/null +++ b/examples/rag/postgres-hybrid.php @@ -0,0 +1,168 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Doctrine\DBAL\DriverManager; +use Doctrine\DBAL\Tools\DsnParser; +use Symfony\AI\Fixtures\Movies; +use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory; +use Symfony\AI\Store\Bridge\Postgres\HybridStore; +use Symfony\AI\Store\Bridge\Postgres\ReciprocalRankFusion; +use Symfony\AI\Store\Bridge\Postgres\TextSearch\Bm25TextSearchStrategy; +use Symfony\AI\Store\Bridge\Postgres\TextSearch\PostgresTextSearchStrategy; +use Symfony\AI\Store\Document\Loader\InMemoryLoader; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\TextDocument; +use Symfony\AI\Store\Document\Vectorizer; +use Symfony\AI\Store\Exception\RuntimeException; +use Symfony\AI\Store\Indexer; +use Symfony\Component\Uid\Uuid; + +require_once dirname(__DIR__).'/bootstrap.php'; + +echo "=== PostgreSQL Hybrid Search Demo ===\n\n"; +echo "Demonstrates HybridStore with configurable search strategies:\n"; +echo "- Native PostgreSQL FTS vs BM25\n"; +echo "- Semantic ratio adjustment\n"; +echo "- Custom RRF scoring\n\n"; + +$connection = DriverManager::getConnection((new DsnParser())->parse(env('POSTGRES_URI'))); +$pdo = $connection->getNativeConnection(); + +if (!$pdo instanceof PDO) { + throw new RuntimeException('Unable to get native PDO connection from Doctrine DBAL.'); +} + +echo "=== Using BM25 Text Search Strategy ===\n\n"; + +$store = new HybridStore( + connection: $pdo, + tableName: 'hybrid_movies', + textSearchStrategy: new Bm25TextSearchStrategy('en'), + rrf: new ReciprocalRankFusion(k: 60, normalizeScores: true), + semanticRatio: 0.5, +); + +// Create embeddings and documents +$documents = []; +foreach (Movies::all() as $i => $movie) { + $documents[] = new TextDocument( + id: Uuid::v4(), + content: 'Title: '.$movie['title'].\PHP_EOL.'Director: '.$movie['director'].\PHP_EOL.'Description: '.$movie['description'], + metadata: new Metadata(array_merge($movie, ['content' => 'Title: '.$movie['title'].\PHP_EOL.'Director: '.$movie['director'].\PHP_EOL.'Description: '.$movie['description']])), + ); +} + +// Initialize the table +$store->setup(); + +// Create embeddings for documents +$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client()); +$vectorizer = new Vectorizer($platform, 'text-embedding-3-small', logger()); +$indexer = new Indexer(new InMemoryLoader($documents), $vectorizer, $store, logger: logger()); +$indexer->index($documents); + +// Create a query embedding +$queryText = 'futuristic technology and artificial intelligence'; +echo "Query: \"$queryText\"\n\n"; +$queryEmbedding = $vectorizer->vectorize($queryText); + +// Test different semantic ratios to compare results +$ratios = [ + ['ratio' => 0.0, 'description' => '100% Full-text search (keyword matching)'], + ['ratio' => 0.5, 'description' => 'Balanced hybrid (RRF: 50% semantic + 50% FTS)'], + ['ratio' => 1.0, 'description' => '100% Semantic search (vector similarity)'], +]; + +foreach ($ratios as $config) { + echo "--- {$config['description']} ---\n"; + + // Override the semantic ratio for this specific query + $results = $store->query($queryEmbedding, [ + 'semanticRatio' => $config['ratio'], + 'q' => 'technology', // Full-text search keyword + 'limit' => 3, + ]); + + echo "Top 3 results:\n"; + foreach ($results as $i => $result) { + $metadata = $result->metadata->getArrayCopy(); + echo sprintf( + " %d. %s (Score: %.4f)\n", + $i + 1, + $metadata['title'] ?? 'Unknown', + $result->score ?? 0.0 + ); + } + echo "\n"; +} + +echo "--- Custom query with pure semantic search ---\n"; +echo "Query: Movies about space exploration\n"; +$spaceEmbedding = $vectorizer->vectorize('space exploration and cosmic adventures'); +$results = $store->query($spaceEmbedding, [ + 'semanticRatio' => 1.0, // Pure semantic search + 'limit' => 3, +]); + +echo "Top 3 results:\n"; +foreach ($results as $i => $result) { + $metadata = $result->metadata->getArrayCopy(); + echo sprintf( + " %d. %s (Score: %.4f)\n", + $i + 1, + $metadata['title'] ?? 'Unknown', + $result->score ?? 0.0 + ); +} +echo "\n"; + +// Cleanup +$store->drop(); + +echo "=== Comparing with Native PostgreSQL FTS ===\n\n"; + +$storeFts = new HybridStore( + connection: $pdo, + tableName: 'hybrid_movies_fts', + textSearchStrategy: new PostgresTextSearchStrategy(), + semanticRatio: 0.5, +); + +$storeFts->setup(); +$indexer = new Indexer(new InMemoryLoader($documents), $vectorizer, $storeFts, logger: logger()); +$indexer->index($documents); + +$resultsFts = $storeFts->query($queryEmbedding, [ + 'semanticRatio' => 0.5, + 'q' => 'technology', + 'limit' => 3, +]); + +echo "Top 3 results (Native FTS):\n"; +foreach ($resultsFts as $i => $result) { + $metadata = $result->metadata->getArrayCopy(); + echo sprintf( + " %d. %s (Score: %.4f)\n", + $i + 1, + $metadata['title'] ?? 'Unknown', + $result->score ?? 0.0 + ); +} + +$storeFts->drop(); + +echo "\n=== Summary ===\n"; +echo "- semanticRatio = 0.0: Pure keyword matching\n"; +echo "- semanticRatio = 0.5: Balanced hybrid (RRF)\n"; +echo "- semanticRatio = 1.0: Pure semantic search\n"; +echo "\nText Search Strategies:\n"; +echo "- PostgresTextSearchStrategy: Native FTS (ts_rank_cd)\n"; +echo "- Bm25TextSearchStrategy: BM25 ranking (requires pg_bm25 extension)\n"; diff --git a/src/ai-bundle/config/options.php b/src/ai-bundle/config/options.php index bafabd267..da777db4a 100644 --- a/src/ai-bundle/config/options.php +++ b/src/ai-bundle/config/options.php @@ -738,6 +738,108 @@ ->end() ->end() ->end() + ->arrayNode('postgres_hybrid') + ->info('PostgreSQL Hybrid Search combining pgvector (semantic) and Full-Text Search (lexical) using RRF') + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->stringNode('connection')->cannotBeEmpty()->end() + ->stringNode('dsn')->cannotBeEmpty()->end() + ->stringNode('username')->end() + ->stringNode('password')->end() + ->stringNode('table_name')->isRequired()->end() + ->stringNode('vector_field')->defaultValue('embedding')->end() + ->stringNode('content_field')->defaultValue('content')->end() + ->floatNode('semantic_ratio') + ->info('Ratio between semantic (vector) and keyword (FTS) search. 0.0 = pure FTS, 0.5 = balanced, 1.0 = pure semantic') + ->defaultValue(1.0) + ->min(0.0) + ->max(1.0) + ->end() + ->enumNode('distance') + ->info('Distance metric to use for vector similarity search') + ->enumFqcn(PostgresDistance::class) + ->defaultValue(PostgresDistance::L2) + ->end() + ->stringNode('language') + ->info('PostgreSQL text search configuration (e.g., "simple", "english", "french"). Default: "simple" (multilingual)') + ->defaultValue('simple') + ->end() + ->stringNode('bm25_language') + ->info('BM25 language code for stemming (e.g., "en", "fr", "es", "de", "it", "pt", "nl", "ru", "ar", "zh"). Default: "en"') + ->defaultValue('en') + ->end() + ->integerNode('rrf_k') + ->info('RRF (Reciprocal Rank Fusion) constant. Higher = more equal weighting. Default: 60 (Supabase)') + ->defaultValue(60) + ->min(1) + ->end() + ->floatNode('default_max_score') + ->info('Default maximum distance threshold for filtering results (optional)') + ->defaultNull() + ->end() + ->floatNode('default_min_score') + ->info('Default minimum RRF score threshold for filtering results (optional)') + ->defaultNull() + ->end() + ->booleanNode('normalize_scores') + ->info('Normalize scores to 0-100 range for better readability') + ->defaultTrue() + ->end() + ->floatNode('fuzzy_primary_threshold') + ->info('Primary threshold for fuzzy matching (pg_trgm word_similarity). Higher = stricter. Default: 0.25') + ->defaultValue(0.25) + ->min(0.0) + ->max(1.0) + ->end() + ->floatNode('fuzzy_secondary_threshold') + ->info('Secondary threshold for fuzzy matching with double validation. Catches more typos. Default: 0.2') + ->defaultValue(0.2) + ->min(0.0) + ->max(1.0) + ->end() + ->floatNode('fuzzy_strict_threshold') + ->info('Strict similarity threshold for double validation to eliminate false positives. Default: 0.15') + ->defaultValue(0.15) + ->min(0.0) + ->max(1.0) + ->end() + ->floatNode('fuzzy_weight') + ->info('Weight of fuzzy matching vs FTS in hybrid search. 0.0 = disabled, 0.5 = equal (recommended), 1.0 = fuzzy only') + ->defaultValue(0.5) + ->min(0.0) + ->max(1.0) + ->end() + ->arrayNode('searchable_attributes') + ->info('Searchable attributes with field-specific boosting (similar to Meilisearch). Each attribute creates a separate tsvector column.') + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->floatNode('boost') + ->info('Boost multiplier for this field (e.g., 2.0 = twice as important). Default: 1.0') + ->defaultValue(1.0) + ->min(0.0) + ->end() + ->scalarNode('metadata_key') + ->info('JSON path to extract value from metadata (e.g., "title", "description")') + ->isRequired() + ->cannotBeEmpty() + ->end() + ->end() + ->end() + ->end() + ->stringNode('dbal_connection')->cannotBeEmpty()->end() + ->end() + ->validate() + ->ifTrue(static fn ($v) => !isset($v['dsn']) && !isset($v['dbal_connection']) && !isset($v['connection'])) + ->thenInvalid('Either "dsn", "dbal_connection", or "connection" must be configured.') + ->end() + ->validate() + ->ifTrue(static fn ($v) => (int) isset($v['dsn']) + (int) isset($v['dbal_connection']) + (int) isset($v['connection']) > 1) + ->thenInvalid('Only one of "dsn", "dbal_connection", or "connection" can be configured.') + ->end() + ->end() + ->end() ->end() ->end() ->arrayNode('message_store') diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index a602b339a..64af332fb 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -79,6 +79,7 @@ use Symfony\AI\Store\Bridge\MongoDb\Store as MongoDbStore; use Symfony\AI\Store\Bridge\Neo4j\Store as Neo4jStore; use Symfony\AI\Store\Bridge\Pinecone\Store as PineconeStore; +use Symfony\AI\Store\Bridge\Postgres\HybridStore; use Symfony\AI\Store\Bridge\Postgres\Store as PostgresStore; use Symfony\AI\Store\Bridge\Qdrant\Store as QdrantStore; use Symfony\AI\Store\Bridge\Redis\Store as RedisStore; @@ -1366,6 +1367,113 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde } } + if ('postgres_hybrid' === $type) { + foreach ($stores as $name => $store) { + $definition = new Definition(HybridStore::class); + + // Handle connection (PDO service reference, DBAL connection, or DSN) + if (\array_key_exists('connection', $store)) { + // Direct PDO service reference + $serviceId = ltrim($store['connection'], '@'); + $connection = new Reference($serviceId); + $arguments = [ + $connection, + $store['table_name'], + ]; + } elseif (\array_key_exists('dbal_connection', $store)) { + // DBAL connection - extract native PDO + $connection = (new Definition(\PDO::class)) + ->setFactory([new Reference($store['dbal_connection']), 'getNativeConnection']); + $arguments = [ + $connection, + $store['table_name'], + ]; + } else { + // Create new PDO instance from DSN + $pdo = new Definition(\PDO::class); + $pdo->setArguments([ + $store['dsn'], + $store['username'] ?? null, + $store['password'] ?? null], + ); + + $arguments = [ + $pdo, + $store['table_name'], + ]; + } + + // Add optional parameters + if (\array_key_exists('vector_field', $store)) { + $arguments[2] = $store['vector_field']; + } + + if (\array_key_exists('content_field', $store)) { + $arguments[3] = $store['content_field']; + } + + if (\array_key_exists('semantic_ratio', $store)) { + $arguments[4] = $store['semantic_ratio']; + } + + if (\array_key_exists('distance', $store)) { + $arguments[5] = $store['distance']; + } + + if (\array_key_exists('language', $store)) { + $arguments[6] = $store['language']; + } + + if (\array_key_exists('rrf_k', $store)) { + $arguments[7] = $store['rrf_k']; + } + + if (\array_key_exists('default_max_score', $store)) { + $arguments[8] = $store['default_max_score']; + } + + if (\array_key_exists('default_min_score', $store)) { + $arguments[9] = $store['default_min_score']; + } + + if (\array_key_exists('normalize_scores', $store)) { + $arguments[10] = $store['normalize_scores']; + } + + if (\array_key_exists('fuzzy_primary_threshold', $store)) { + $arguments[11] = $store['fuzzy_primary_threshold']; + } + + if (\array_key_exists('fuzzy_secondary_threshold', $store)) { + $arguments[12] = $store['fuzzy_secondary_threshold']; + } + + if (\array_key_exists('fuzzy_strict_threshold', $store)) { + $arguments[13] = $store['fuzzy_strict_threshold']; + } + + if (\array_key_exists('fuzzy_weight', $store)) { + $arguments[14] = $store['fuzzy_weight']; + } + + if (\array_key_exists('searchable_attributes', $store)) { + $arguments[15] = $store['searchable_attributes']; + } + + if (\array_key_exists('bm25_language', $store)) { + $arguments[16] = $store['bm25_language']; + } + + $definition + ->addTag('ai.store') + ->setArguments($arguments); + + $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); + $container->registerAliasForArgument('ai.store.'.$type.'.'.$name, StoreInterface::class, $name); + $container->registerAliasForArgument('ai.store.'.$type.'.'.$name, StoreInterface::class, $type.'_'.$name); + } + } + if ('supabase' === $type) { foreach ($stores as $name => $store) { $arguments = [ diff --git a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php index 37e515b92..b6cfa100f 100644 --- a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php +++ b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php @@ -540,6 +540,75 @@ public function testPostgresStoreWithDifferentConnectionCanBeConfigured() $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); } + public function testPostgresHybridStoreWithDsnCanBeConfigured() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'postgres_hybrid' => [ + 'hybrid_db' => [ + 'dsn' => 'pgsql:host=localhost;port=5432;dbname=testdb', + 'username' => 'app', + 'password' => 'mypass', + 'table_name' => 'hybrid_vectors', + 'semantic_ratio' => 0.7, + 'language' => 'english', + ], + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.store.postgres_hybrid.hybrid_db')); + $definition = $container->getDefinition('ai.store.postgres_hybrid.hybrid_db'); + $this->assertInstanceOf(Definition::class, $definition->getArgument(0)); + $this->assertSame('hybrid_vectors', $definition->getArgument(1)); + } + + public function testPostgresHybridStoreWithDbalConnectionCanBeConfigured() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'postgres_hybrid' => [ + 'hybrid_db' => [ + 'dbal_connection' => 'my_connection', + 'table_name' => 'hybrid_vectors', + 'rrf_k' => 100, + ], + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.store.postgres_hybrid.hybrid_db')); + $definition = $container->getDefinition('ai.store.postgres_hybrid.hybrid_db'); + $this->assertInstanceOf(Definition::class, $definition->getArgument(0)); + $this->assertSame('hybrid_vectors', $definition->getArgument(1)); + $this->assertSame(100, $definition->getArgument(7)); + } + + public function testPostgresHybridStoreWithConnectionReferenceCanBeConfigured() + { + $container = $this->buildContainer([ + 'ai' => [ + 'store' => [ + 'postgres_hybrid' => [ + 'hybrid_db' => [ + 'connection' => '@my_pdo_service', + 'table_name' => 'hybrid_vectors', + ], + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.store.postgres_hybrid.hybrid_db')); + $definition = $container->getDefinition('ai.store.postgres_hybrid.hybrid_db'); + $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); + $this->assertSame('my_pdo_service', (string) $definition->getArgument(0)); + } + public function testConfigurationWithUseAttributeAsKeyWorksWithoutNormalizeKeys() { // Test that configurations using useAttributeAsKey work correctly diff --git a/src/store/src/Bridge/Postgres/HybridStore.php b/src/store/src/Bridge/Postgres/HybridStore.php new file mode 100644 index 000000000..d1dbc9aca --- /dev/null +++ b/src/store/src/Bridge/Postgres/HybridStore.php @@ -0,0 +1,719 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Postgres; + +use Symfony\AI\Platform\Vector\NullVector; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Platform\Vector\VectorInterface; +use Symfony\AI\Store\Bridge\Postgres\TextSearch\PostgresTextSearchStrategy; +use Symfony\AI\Store\Bridge\Postgres\TextSearch\TextSearchStrategyInterface; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\InvalidArgumentException; +use Symfony\AI\Store\ManagedStoreInterface; +use Symfony\AI\Store\StoreInterface; +use Symfony\Component\Uid\Uuid; + +/** + * Hybrid Search Store for PostgreSQL combining vector similarity and full-text search. + * + * Uses Reciprocal Rank Fusion (RRF) to combine multiple search signals: + * - Vector similarity (pgvector) + * - Full-text search (configurable: native PostgreSQL or BM25) + * - Fuzzy matching (pg_trgm) for typo tolerance + * + * @see https://supabase.com/docs/guides/ai/hybrid-search + * + * @author Ahmed EBEN HASSINE + */ +final class HybridStore implements ManagedStoreInterface, StoreInterface +{ + private readonly ReciprocalRankFusion $rrf; + private readonly TextSearchStrategyInterface $textSearchStrategy; + + /** + * @param string $vectorFieldName Name of the vector field + * @param string $contentFieldName Name of the text field for FTS + * @param float $semanticRatio Ratio between semantic and keyword search (0.0 to 1.0) + * @param Distance $distance Distance metric for vector similarity + * @param string $language PostgreSQL text search configuration + * @param TextSearchStrategyInterface|null $textSearchStrategy Text search strategy (defaults to native PostgreSQL) + * @param ReciprocalRankFusion|null $rrf RRF calculator (defaults to k=60, normalized) + * @param float|null $defaultMaxScore Default max distance for vector search + * @param float|null $defaultMinScore Default min RRF score threshold + * @param float $fuzzyPrimaryThreshold Primary threshold for fuzzy matching + * @param float $fuzzySecondaryThreshold Secondary threshold for fuzzy matching + * @param float $fuzzyStrictThreshold Strict threshold for double validation + * @param float $fuzzyWeight Weight of fuzzy matching (0.0 to 1.0) + * @param array $searchableAttributes Searchable attributes with boosting config + */ + public function __construct( + private readonly \PDO $connection, + private readonly string $tableName, + private readonly string $vectorFieldName = 'embedding', + private readonly string $contentFieldName = 'content', + private readonly float $semanticRatio = 1.0, + private readonly Distance $distance = Distance::L2, + private readonly string $language = 'simple', + ?TextSearchStrategyInterface $textSearchStrategy = null, + ?ReciprocalRankFusion $rrf = null, + private readonly ?float $defaultMaxScore = null, + private readonly ?float $defaultMinScore = null, + private readonly float $fuzzyPrimaryThreshold = 0.25, + private readonly float $fuzzySecondaryThreshold = 0.2, + private readonly float $fuzzyStrictThreshold = 0.15, + private readonly float $fuzzyWeight = 0.5, + private readonly array $searchableAttributes = [], + ) { + if ($semanticRatio < 0.0 || $semanticRatio > 1.0) { + throw new InvalidArgumentException(\sprintf('The semantic ratio must be between 0.0 and 1.0, "%s" given.', $semanticRatio)); + } + + if ($fuzzyWeight < 0.0 || $fuzzyWeight > 1.0) { + throw new InvalidArgumentException(\sprintf('The fuzzy weight must be between 0.0 and 1.0, "%s" given.', $fuzzyWeight)); + } + + $this->textSearchStrategy = $textSearchStrategy ?? new PostgresTextSearchStrategy(); + $this->rrf = $rrf ?? new ReciprocalRankFusion(); + } + + /** + * @param array{vector_type?: string, vector_size?: positive-int, index_method?: string, index_opclass?: string} $options + */ + public function setup(array $options = []): void + { + // Enable pgvector extension + $this->connection->exec('CREATE EXTENSION IF NOT EXISTS vector'); + + // Enable pg_trgm extension for fuzzy matching + $this->connection->exec('CREATE EXTENSION IF NOT EXISTS pg_trgm'); + + // Build tsvector columns + $tsvectorColumns = $this->buildTsvectorColumns(); + + // Create main table + $this->connection->exec( + \sprintf( + 'CREATE TABLE IF NOT EXISTS %s ( + id UUID PRIMARY KEY, + metadata JSONB, + %s TEXT NOT NULL, + %s %s(%d) NOT NULL%s + )', + $this->tableName, + $this->contentFieldName, + $this->vectorFieldName, + $options['vector_type'] ?? 'vector', + $options['vector_size'] ?? 1536, + $tsvectorColumns, + ), + ); + + // Add search_text field for fuzzy matching + $this->connection->exec( + \sprintf( + 'ALTER TABLE %s ADD COLUMN IF NOT EXISTS search_text TEXT', + $this->tableName, + ), + ); + + // Create trigger for search_text auto-update + $this->createSearchTextTrigger(); + + // Create vector index + $this->connection->exec( + \sprintf( + 'CREATE INDEX IF NOT EXISTS %s_%s_idx ON %s USING %s (%s %s)', + $this->tableName, + $this->vectorFieldName, + $this->tableName, + $options['index_method'] ?? 'ivfflat', + $this->vectorFieldName, + $options['index_opclass'] ?? 'vector_cosine_ops', + ), + ); + + // Execute text search strategy setup (only if not using searchableAttributes) + if ([] === $this->searchableAttributes) { + foreach ($this->textSearchStrategy->getSetupSql($this->tableName, $this->contentFieldName, $this->language) as $sql) { + $this->connection->exec($sql); + } + } else { + // Create GIN indexes for tsvector columns when using searchableAttributes + $this->createTsvectorIndexes(); + } + + // Create trigram index for fuzzy matching + $this->connection->exec( + \sprintf( + 'CREATE INDEX IF NOT EXISTS %s_search_text_trgm_idx ON %s USING gin(search_text gin_trgm_ops)', + $this->tableName, + $this->tableName, + ), + ); + } + + public function drop(): void + { + $this->connection->exec(\sprintf('DROP TABLE IF EXISTS %s', $this->tableName)); + } + + public function add(VectorDocument ...$documents): void + { + $statement = $this->connection->prepare( + \sprintf( + 'INSERT INTO %1$s (id, metadata, %2$s, %3$s) + VALUES (:id, :metadata, :content, :vector) + ON CONFLICT (id) DO UPDATE SET + metadata = EXCLUDED.metadata, + %2$s = EXCLUDED.%2$s, + %3$s = EXCLUDED.%3$s', + $this->tableName, + $this->contentFieldName, + $this->vectorFieldName, + ), + ); + + foreach ($documents as $document) { + $statement->execute([ + 'id' => $document->id->toRfc4122(), + 'metadata' => json_encode($document->metadata->getArrayCopy(), \JSON_THROW_ON_ERROR), + 'content' => $document->metadata->getText() ?? '', + 'vector' => $this->toPgvector($document->vector), + ]); + } + } + + /** + * Hybrid search combining vector similarity and full-text search. + * + * Note: When results come from FTS-only or fuzzy-only matches (no vector similarity), + * the VectorDocument will contain a NullVector. Check with `$doc->vector instanceof NullVector` + * before calling getData() or getDimensions() on the vector. + * + * @param array{ + * q?: string, + * semanticRatio?: float, + * limit?: int, + * where?: string, + * params?: array, + * maxScore?: float, + * minScore?: float, + * includeScoreBreakdown?: bool, + * boostFields?: array + * } $options + * + * @return VectorDocument[] + */ + public function query(Vector $vector, array $options = []): array + { + $semanticRatio = $this->validateSemanticRatio($options['semanticRatio'] ?? $this->semanticRatio); + $queryText = $options['q'] ?? ''; + $limit = $options['limit'] ?? 5; + + // Build WHERE clause and params + [$whereClause, $params] = $this->buildWhereClause($vector, $options, $semanticRatio); + + // Choose query strategy + $sql = $this->buildQuery($semanticRatio, $queryText, $whereClause, $limit); + + if ('' !== $queryText && $semanticRatio < 1.0) { + $params['query'] = $queryText; + } + + // Execute query + $statement = $this->connection->prepare($sql); + $statement->execute([...$params, ...($options['params'] ?? [])]); + + // Process results + $documents = $this->processResults( + $statement->fetchAll(\PDO::FETCH_ASSOC), + $options['includeScoreBreakdown'] ?? false, + ); + + // Apply boosting + if (isset($options['boostFields']) && [] !== $options['boostFields']) { + $documents = $this->applyBoostFields($documents, $options['boostFields']); + } + + // Apply minimum score filter + $minScore = $options['minScore'] ?? $this->defaultMinScore; + if (null !== $minScore) { + $documents = array_values(array_filter( + $documents, + fn (VectorDocument $doc) => $doc->score >= $minScore + )); + } + + return $documents; + } + + /** + * Get the text search strategy being used. + */ + public function getTextSearchStrategy(): TextSearchStrategyInterface + { + return $this->textSearchStrategy; + } + + /** + * Get the RRF calculator being used. + */ + public function getRrf(): ReciprocalRankFusion + { + return $this->rrf; + } + + private function buildTsvectorColumns(): string + { + if ([] !== $this->searchableAttributes) { + $columns = ''; + foreach ($this->searchableAttributes as $fieldName => $config) { + $metadataKey = $config['metadata_key']; + $columns .= \sprintf( + ",\n %s_tsv tsvector GENERATED ALWAYS AS (to_tsvector('%s', COALESCE(metadata->>'%s', ''))) STORED", + $fieldName, + $this->language, + $metadataKey + ); + } + + return $columns; + } + + // When not using searchableAttributes, let the TextSearchStrategy handle tsvector columns + return ''; + } + + private function createSearchTextTrigger(): void + { + $this->connection->exec( + "CREATE OR REPLACE FUNCTION update_search_text() + RETURNS TRIGGER AS \$\$ + BEGIN + NEW.search_text := COALESCE(NEW.metadata->>'title', ''); + RETURN NEW; + END; + \$\$ LANGUAGE plpgsql;" + ); + + $this->connection->exec( + \sprintf( + 'DROP TRIGGER IF EXISTS trigger_update_search_text ON %s; + CREATE TRIGGER trigger_update_search_text + BEFORE INSERT OR UPDATE ON %s + FOR EACH ROW + EXECUTE FUNCTION update_search_text();', + $this->tableName, + $this->tableName, + ), + ); + } + + private function createTsvectorIndexes(): void + { + if ([] !== $this->searchableAttributes) { + foreach ($this->searchableAttributes as $fieldName => $config) { + $this->connection->exec( + \sprintf( + 'CREATE INDEX IF NOT EXISTS %s_%s_tsv_idx ON %s USING gin(%s_tsv)', + $this->tableName, + $fieldName, + $this->tableName, + $fieldName, + ), + ); + } + } else { + $this->connection->exec( + \sprintf( + 'CREATE INDEX IF NOT EXISTS %s_content_tsv_idx ON %s USING gin(content_tsv)', + $this->tableName, + $this->tableName, + ), + ); + } + } + + private function validateSemanticRatio(float $ratio): float + { + if ($ratio < 0.0 || $ratio > 1.0) { + throw new InvalidArgumentException(\sprintf('The semantic ratio must be between 0.0 and 1.0, "%s" given.', $ratio)); + } + + return $ratio; + } + + /** + * @param array $options + * + * @return array{string, array} + */ + private function buildWhereClause(Vector $vector, array $options, float $semanticRatio): array + { + $where = []; + $params = []; + + $maxScore = $options['maxScore'] ?? $this->defaultMaxScore; + + if ($semanticRatio > 0.0 || null !== $maxScore) { + $params['embedding'] = $this->toPgvector($vector); + } + + if (null !== $maxScore) { + $where[] = \sprintf( + '(%s %s :embedding) <= :maxScore', + $this->vectorFieldName, + $this->distance->getComparisonSign() + ); + $params['maxScore'] = $maxScore; + } + + if (isset($options['where']) && '' !== $options['where']) { + $where[] = '('.$options['where'].')'; + } + + $whereClause = $where ? 'WHERE '.implode(' AND ', $where) : ''; + + return [$whereClause, $params]; + } + + private function buildQuery(float $semanticRatio, string $queryText, string $whereClause, int $limit): string + { + if (1.0 === $semanticRatio || '' === $queryText) { + return $this->buildVectorOnlyQuery($whereClause, $limit); + } + + if (0.0 === $semanticRatio) { + return $this->buildFtsOnlyQuery($whereClause, $limit); + } + + return $this->buildHybridQuery($whereClause, $limit, $semanticRatio); + } + + private function buildVectorOnlyQuery(string $whereClause, int $limit): string + { + return \sprintf( + 'SELECT id, %s AS embedding, metadata, (%s %s :embedding) AS score + FROM %s + %s + ORDER BY score ASC + LIMIT %d', + $this->vectorFieldName, + $this->vectorFieldName, + $this->distance->getComparisonSign(), + $this->tableName, + $whereClause, + $limit, + ); + } + + private function buildFtsOnlyQuery(string $whereClause, int $limit): string + { + $ftsCte = $this->textSearchStrategy->buildSearchCte( + $this->tableName, + $this->contentFieldName, + $this->language, + ); + $cteAlias = $this->textSearchStrategy->getCteAlias(); + $scoreColumn = $this->textSearchStrategy->getScoreColumn(); + + return \sprintf( + 'WITH %s + SELECT id, NULL AS embedding, metadata, %s AS score + FROM %s + %s + ORDER BY %s DESC + LIMIT %d', + $ftsCte, + $scoreColumn, + $cteAlias, + $whereClause ? 'WHERE id IN (SELECT id FROM '.$this->tableName.' '.$whereClause.')' : '', + $scoreColumn, + $limit, + ); + } + + private function buildHybridQuery(string $whereClause, int $limit, float $semanticRatio): string + { + $ftsCte = $this->textSearchStrategy->buildSearchCte( + $this->tableName, + $this->contentFieldName, + $this->language, + ); + $ftsAlias = $this->textSearchStrategy->getCteAlias(); + $ftsRankColumn = $this->textSearchStrategy->getRankColumn(); + $ftsScoreColumn = $this->textSearchStrategy->getScoreColumn(); + $ftsNormalizedScore = $this->textSearchStrategy->getNormalizedScoreExpression($ftsScoreColumn); + + // Calculate weights + $ftsWeight = (1.0 - $semanticRatio) * (1.0 - $this->fuzzyWeight); + $fuzzyWeightCalculated = (1.0 - $semanticRatio) * $this->fuzzyWeight; + + // Build fuzzy filter + $fuzzyFilter = $this->buildFuzzyFilter(); + $fuzzyWhereClause = $this->addFilterToWhereClause($whereClause, $fuzzyFilter); + + // Build RRF expressions using the RRF class + $vectorContribution = $this->rrf->buildSqlExpression( + 'v.rank_ix', + '(1.0 - LEAST(v.distance / 2.0, 1.0))', + $semanticRatio, + ); + $ftsContribution = $this->rrf->buildSqlExpression( + "b.{$ftsRankColumn}", + $ftsNormalizedScore, + $ftsWeight, + ); + $fuzzyContribution = $this->rrf->buildSqlExpression( + 'fz.rank_ix', + 'fz.fuzzy_similarity', + $fuzzyWeightCalculated, + ); + + return \sprintf( + 'WITH vector_scores AS ( + SELECT id, %s AS embedding, metadata, + (%s %s :embedding) AS distance, + ROW_NUMBER() OVER (ORDER BY %s %s :embedding) AS rank_ix + FROM %s + %s + ), + %s, + fuzzy_scores AS ( + SELECT id, metadata, + word_similarity(:query, search_text) AS fuzzy_similarity, + ROW_NUMBER() OVER (ORDER BY word_similarity(:query, search_text) DESC) AS rank_ix + FROM %s + %s + ), + combined_results AS ( + SELECT + COALESCE(v.id, b.id, fz.id) as id, + v.embedding, + COALESCE(v.metadata, b.metadata, fz.metadata) as metadata, + (%s + %s + %s) AS score, + v.rank_ix AS vector_rank, + b.%s AS fts_rank, + v.distance AS vector_distance, + b.%s AS fts_score, + fz.rank_ix AS fuzzy_rank, + fz.fuzzy_similarity AS fuzzy_score, + %s AS vector_contribution, + %s AS fts_contribution, + %s AS fuzzy_contribution + FROM vector_scores v + FULL OUTER JOIN %s b ON v.id = b.id + FULL OUTER JOIN fuzzy_scores fz ON COALESCE(v.id, b.id) = fz.id + WHERE v.id IS NOT NULL OR b.id IS NOT NULL OR fz.id IS NOT NULL + ) + SELECT * FROM ( + SELECT DISTINCT ON (metadata->>\'title\') * + FROM combined_results + ORDER BY metadata->>\'title\', score DESC + ) unique_results + ORDER BY score DESC + LIMIT %d', + $this->vectorFieldName, + $this->vectorFieldName, + $this->distance->getComparisonSign(), + $this->vectorFieldName, + $this->distance->getComparisonSign(), + $this->tableName, + $whereClause, + $ftsCte, + $this->tableName, + $fuzzyWhereClause, + $vectorContribution, + $ftsContribution, + $fuzzyContribution, + $ftsRankColumn, + $ftsScoreColumn, + $vectorContribution, + $ftsContribution, + $fuzzyContribution, + $ftsAlias, + $limit, + ); + } + + private function buildFuzzyFilter(): string + { + return \sprintf( + '( + word_similarity(:query, search_text) > %f + OR ( + word_similarity(:query, search_text) > %f + AND similarity(:query, search_text) > %f + ) + )', + $this->fuzzyPrimaryThreshold, + $this->fuzzySecondaryThreshold, + $this->fuzzyStrictThreshold + ); + } + + private function addFilterToWhereClause(string $whereClause, string $filter): string + { + if ('' === $whereClause) { + return "WHERE $filter"; + } + + $whereClause = rtrim($whereClause); + + if (str_starts_with($whereClause, 'WHERE ')) { + return "$whereClause AND $filter"; + } + + return "WHERE $filter AND ".ltrim($whereClause); + } + + /** + * @param array> $results + * + * @return VectorDocument[] + */ + private function processResults(array $results, bool $includeBreakdown): array + { + $documents = []; + + foreach ($results as $result) { + $metadata = new Metadata(json_decode($result['metadata'] ?? '{}', true, 512, \JSON_THROW_ON_ERROR)); + + if ($includeBreakdown && isset($result['vector_rank'])) { + $metadata['_score_breakdown'] = $this->buildScoreBreakdown($result); + } + + // Use NullVector for results without embedding (FTS-only or fuzzy-only matches) + $vector = null !== $result['embedding'] + ? new Vector($this->fromPgvector($result['embedding'])) + : new NullVector(); + + $score = $result['score']; + if ($this->rrf->isNormalized()) { + $score = $this->rrf->normalize($score); + } + + $documents[] = new VectorDocument( + id: Uuid::fromString($result['id']), + vector: $vector, + metadata: $metadata, + score: $score, + ); + } + + return $documents; + } + + /** + * @param array $result + * + * @return array + */ + private function buildScoreBreakdown(array $result): array + { + $breakdown = [ + 'vector_rank' => $result['vector_rank'], + 'fts_rank' => $result['fts_rank'], + 'vector_distance' => $result['vector_distance'], + 'fts_score' => $result['fts_score'], + 'vector_contribution' => $result['vector_contribution'], + 'fts_contribution' => $result['fts_contribution'], + ]; + + if (isset($result['fuzzy_rank'])) { + $breakdown['fuzzy_rank'] = $result['fuzzy_rank']; + $breakdown['fuzzy_score'] = $result['fuzzy_score']; + $breakdown['fuzzy_contribution'] = $result['fuzzy_contribution']; + } + + if ($this->rrf->isNormalized()) { + $breakdown['vector_contribution'] = $this->rrf->normalize($breakdown['vector_contribution']); + $breakdown['fts_contribution'] = $this->rrf->normalize($breakdown['fts_contribution']); + + if (isset($breakdown['fuzzy_contribution'])) { + $breakdown['fuzzy_contribution'] = $this->rrf->normalize($breakdown['fuzzy_contribution']); + } + } + + return $breakdown; + } + + /** + * @param VectorDocument[] $documents + * @param array $boostFields + * + * @return VectorDocument[] + */ + private function applyBoostFields(array $documents, array $boostFields): array + { + $documents = array_map(function (VectorDocument $doc) use ($boostFields) { + $metadata = $doc->metadata; + $score = $doc->score; + $appliedBoosts = []; + + foreach ($boostFields as $field => $boostConfig) { + if (!isset($metadata[$field])) { + continue; + } + + $value = $metadata[$field]; + $boost = $boostConfig['boost'] ?? 0.0; + + $shouldBoost = true; + if (isset($boostConfig['min']) && $value < $boostConfig['min']) { + $shouldBoost = false; + } + if (isset($boostConfig['max']) && $value > $boostConfig['max']) { + $shouldBoost = false; + } + + if ($shouldBoost && 0.0 !== $boost) { + $score *= (1.0 + $boost); + $appliedBoosts[$field] = [ + 'value' => $value, + 'boost' => $boost, + 'multiplier' => (1.0 + $boost), + ]; + } + } + + if ([] !== $appliedBoosts) { + $metadata['_applied_boosts'] = $appliedBoosts; + } + + return new VectorDocument( + id: $doc->id, + vector: $doc->vector, + metadata: $metadata, + score: $score + ); + }, $documents); + + usort($documents, fn (VectorDocument $a, VectorDocument $b) => $b->score <=> $a->score); + + return $documents; + } + + private function toPgvector(VectorInterface $vector): string + { + return '['.implode(',', $vector->getData()).']'; + } + + /** + * @return float[] + */ + private function fromPgvector(string $vector): array + { + return json_decode($vector, true, 512, \JSON_THROW_ON_ERROR); + } +} diff --git a/src/store/src/Bridge/Postgres/PostgresTextSearchStrategy.php b/src/store/src/Bridge/Postgres/PostgresTextSearchStrategy.php new file mode 100644 index 000000000..9c81f64e1 --- /dev/null +++ b/src/store/src/Bridge/Postgres/PostgresTextSearchStrategy.php @@ -0,0 +1,112 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Postgres\TextSearch; + +/** + * PostgreSQL native full-text search strategy using ts_rank_cd. + * + * This is the default strategy that works with any PostgreSQL installation + * without requiring additional extensions. + * + * @author Ahmed EBEN HASSINE + */ +final class PostgresTextSearchStrategy implements TextSearchStrategyInterface +{ + private const CTE_ALIAS = 'fts_search'; + private const RANK_COLUMN = 'fts_rank'; + private const SCORE_COLUMN = 'fts_score'; + + public function getSetupSql(string $tableName, string $contentFieldName, string $language): array + { + return [ + // Add tsvector column if not exists + \sprintf( + "ALTER TABLE %s ADD COLUMN IF NOT EXISTS content_tsv tsvector + GENERATED ALWAYS AS (to_tsvector('%s', %s)) STORED", + $tableName, + $language, + $contentFieldName, + ), + // Create GIN index for full-text search + \sprintf( + 'CREATE INDEX IF NOT EXISTS %s_content_tsv_idx ON %s USING gin(content_tsv)', + $tableName, + $tableName, + ), + ]; + } + + public function buildSearchCte( + string $tableName, + string $contentFieldName, + string $language, + string $queryParam = ':query', + ): string { + return \sprintf( + "%s AS ( + SELECT + id, + metadata, + %s, + ts_rank_cd(content_tsv, plainto_tsquery('%s', %s)) AS %s, + ROW_NUMBER() OVER ( + ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('%s', %s)) DESC + ) AS %s + FROM %s + WHERE content_tsv @@ plainto_tsquery('%s', %s) + )", + self::CTE_ALIAS, + $contentFieldName, + $language, + $queryParam, + self::SCORE_COLUMN, + $language, + $queryParam, + self::RANK_COLUMN, + $tableName, + $language, + $queryParam, + ); + } + + public function getCteAlias(): string + { + return self::CTE_ALIAS; + } + + public function getRankColumn(): string + { + return self::RANK_COLUMN; + } + + public function getScoreColumn(): string + { + return self::SCORE_COLUMN; + } + + public function getNormalizedScoreExpression(string $scoreColumn): string + { + // ts_rank_cd returns values typically between 0 and 1, but can exceed 1 + // We cap it at 1.0 for normalization + return \sprintf('LEAST(%s, 1.0)', $scoreColumn); + } + + public function getRequiredExtensions(): array + { + return []; // No additional extensions required + } + + public function isAvailable(\PDO $connection): bool + { + return true; // Always available in PostgreSQL + } +} diff --git a/src/store/src/Bridge/Postgres/ReciprocalRankFusion.php b/src/store/src/Bridge/Postgres/ReciprocalRankFusion.php new file mode 100644 index 000000000..26855b970 --- /dev/null +++ b/src/store/src/Bridge/Postgres/ReciprocalRankFusion.php @@ -0,0 +1,157 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Postgres; + +/** + * Reciprocal Rank Fusion (RRF) calculator for combining multiple search rankings. + * + * RRF is a method to combine results from multiple search algorithms by their ranks. + * The formula is: score = Σ (weight_i / (k + rank_i)) + * + * @see https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf + * + * @author Ahmed EBEN HASSINE + */ +final class ReciprocalRankFusion +{ + /** + * @param int $k RRF constant (default: 60). Higher values give more equal weighting between results. + * @param bool $normalizeScores Whether to normalize scores to 0-100 range (default: true) + */ + public function __construct( + private readonly int $k = 60, + private readonly bool $normalizeScores = true, + ) { + } + + /** + * Calculate RRF score for a single result with multiple rankings. + * + * @param array $rankings + * Each entry contains: rank (1-based or null), score (normalized 0-1), weight (0-1) + * + * @return float The combined RRF score + */ + public function calculate(array $rankings): float + { + $score = 0.0; + + foreach ($rankings as $ranking) { + if (null === $ranking['rank']) { + continue; + } + + $contribution = (1.0 / ($this->k + $ranking['rank'])) * $ranking['score'] * $ranking['weight']; + $score += $contribution; + } + + if ($this->normalizeScores) { + $score = $this->normalize($score); + } + + return $score; + } + + /** + * Calculate individual contribution for a ranking. + * + * @param int $rank The rank (1-based position) + * @param float $score The normalized score (0-1) + * @param float $weight The weight for this ranking source (0-1) + */ + public function calculateContribution(int $rank, float $score, float $weight): float + { + $contribution = (1.0 / ($this->k + $rank)) * $score * $weight; + + if ($this->normalizeScores) { + $contribution = $this->normalize($contribution); + } + + return $contribution; + } + + /** + * Normalize a score to 0-100 range. + * + * The theoretical maximum RRF score is 1/(k+1), so we normalize against that. + */ + public function normalize(float $score): float + { + $maxScore = 1.0 / ($this->k + 1); + + return ($score / $maxScore) * 100; + } + + /** + * Denormalize a score from 0-100 range back to raw RRF score. + */ + public function denormalize(float $normalizedScore): float + { + $maxScore = 1.0 / ($this->k + 1); + + return ($normalizedScore / 100) * $maxScore; + } + + /** + * Build SQL expression for RRF calculation. + * + * @param string $rankColumn The column containing the rank + * @param string $scoreExpr SQL expression for the normalized score (0-1) + * @param float $weight The weight for this ranking source + * @param string $nullDefault Default value when rank is NULL (default: '0.0') + */ + public function buildSqlExpression( + string $rankColumn, + string $scoreExpr, + float $weight, + string $nullDefault = '0.0', + ): string { + return \sprintf( + 'COALESCE(1.0 / (%d + %s) * %s * %f, %s)', + $this->k, + $rankColumn, + $scoreExpr, + $weight, + $nullDefault, + ); + } + + /** + * Build SQL expression for combining multiple RRF contributions. + * + * @param array $sources + */ + public function buildCombinedSqlExpression(array $sources): string + { + $expressions = []; + + foreach ($sources as $source) { + $expressions[] = $this->buildSqlExpression( + $source['rank_column'], + $source['score_expr'], + $source['weight'], + ); + } + + return '('.implode(' + ', $expressions).')'; + } + + public function getK(): int + { + return $this->k; + } + + public function isNormalized(): bool + { + return $this->normalizeScores; + } +} diff --git a/src/store/src/Bridge/Postgres/TextSearch/Bm25TextSearchStrategy.php b/src/store/src/Bridge/Postgres/TextSearch/Bm25TextSearchStrategy.php new file mode 100644 index 000000000..136338fb3 --- /dev/null +++ b/src/store/src/Bridge/Postgres/TextSearch/Bm25TextSearchStrategy.php @@ -0,0 +1,140 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Postgres\TextSearch; + +/** + * BM25 full-text search strategy using plpgsql_bm25 extension. + * + * BM25 (Best Matching 25) is a ranking function used by search engines + * to estimate the relevance of documents to a given search query. + * It's generally more accurate than PostgreSQL's native ts_rank_cd. + * + * Requirements: + * - plpgsql_bm25 extension must be installed + * + * @see https://github.com/pgsql-bm25/plpgsql_bm25 + * + * @author Ahmed EBEN HASSINE + */ +final class Bm25TextSearchStrategy implements TextSearchStrategyInterface +{ + private const CTE_ALIAS = 'bm25_with_metadata'; + private const RANK_COLUMN = 'bm25_rank'; + private const SCORE_COLUMN = 'bm25_score'; + + /** + * @param string $bm25Language BM25 language code ('en', 'fr', 'es', etc.) + * @param int $topK Number of results to retrieve from BM25 (default: 100) + */ + public function __construct( + private readonly string $bm25Language = 'en', + private readonly int $topK = 100, + ) { + } + + public function getSetupSql(string $tableName, string $contentFieldName, string $language): array + { + // BM25 doesn't require additional table setup, it uses the content field directly + // The index is managed internally by the bm25topk function + return []; + } + + public function buildSearchCte( + string $tableName, + string $contentFieldName, + string $language, + string $queryParam = ':query', + ): string { + // BM25 search with deduplication fix for duplicate titles + return \sprintf( + "bm25_search AS ( + SELECT + SUBSTRING(bm25.doc FROM 'title: ([^\n]+)') as extracted_title, + bm25.doc, + bm25.score as %s, + ROW_NUMBER() OVER (ORDER BY bm25.score DESC) as %s + FROM bm25topk( + '%s', + '%s', + %s, + %d, + '', + '%s' + ) AS bm25 + ), + %s AS ( + SELECT DISTINCT ON (b.%s) + m.id, + m.metadata, + m.%s, + b.%s, + b.%s + FROM bm25_search b + INNER JOIN %s m ON (m.metadata->>'title') = b.extracted_title + ORDER BY b.%s, m.id + )", + self::SCORE_COLUMN, + self::RANK_COLUMN, + $tableName, + $contentFieldName, + $queryParam, + $this->topK, + $this->bm25Language, + self::CTE_ALIAS, + self::RANK_COLUMN, + $contentFieldName, + self::SCORE_COLUMN, + self::RANK_COLUMN, + $tableName, + self::RANK_COLUMN, + ); + } + + public function getCteAlias(): string + { + return self::CTE_ALIAS; + } + + public function getRankColumn(): string + { + return self::RANK_COLUMN; + } + + public function getScoreColumn(): string + { + return self::SCORE_COLUMN; + } + + public function getNormalizedScoreExpression(string $scoreColumn): string + { + // BM25 scores are typically in 0-10+ range, normalize to 0-1 + return \sprintf('LEAST(%s / 10.0, 1.0)', $scoreColumn); + } + + public function getRequiredExtensions(): array + { + return ['plpgsql_bm25']; + } + + public function isAvailable(\PDO $connection): bool + { + try { + $stmt = $connection->query( + "SELECT 1 FROM pg_proc WHERE proname = 'bm25topk' LIMIT 1" + ); + + return false !== $stmt->fetchColumn(); + } catch (\PDOException) { + return false; + } + } +} diff --git a/src/store/src/Bridge/Postgres/TextSearch/PostgresTextSearchStrategy.php b/src/store/src/Bridge/Postgres/TextSearch/PostgresTextSearchStrategy.php new file mode 100644 index 000000000..9c81f64e1 --- /dev/null +++ b/src/store/src/Bridge/Postgres/TextSearch/PostgresTextSearchStrategy.php @@ -0,0 +1,112 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Postgres\TextSearch; + +/** + * PostgreSQL native full-text search strategy using ts_rank_cd. + * + * This is the default strategy that works with any PostgreSQL installation + * without requiring additional extensions. + * + * @author Ahmed EBEN HASSINE + */ +final class PostgresTextSearchStrategy implements TextSearchStrategyInterface +{ + private const CTE_ALIAS = 'fts_search'; + private const RANK_COLUMN = 'fts_rank'; + private const SCORE_COLUMN = 'fts_score'; + + public function getSetupSql(string $tableName, string $contentFieldName, string $language): array + { + return [ + // Add tsvector column if not exists + \sprintf( + "ALTER TABLE %s ADD COLUMN IF NOT EXISTS content_tsv tsvector + GENERATED ALWAYS AS (to_tsvector('%s', %s)) STORED", + $tableName, + $language, + $contentFieldName, + ), + // Create GIN index for full-text search + \sprintf( + 'CREATE INDEX IF NOT EXISTS %s_content_tsv_idx ON %s USING gin(content_tsv)', + $tableName, + $tableName, + ), + ]; + } + + public function buildSearchCte( + string $tableName, + string $contentFieldName, + string $language, + string $queryParam = ':query', + ): string { + return \sprintf( + "%s AS ( + SELECT + id, + metadata, + %s, + ts_rank_cd(content_tsv, plainto_tsquery('%s', %s)) AS %s, + ROW_NUMBER() OVER ( + ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('%s', %s)) DESC + ) AS %s + FROM %s + WHERE content_tsv @@ plainto_tsquery('%s', %s) + )", + self::CTE_ALIAS, + $contentFieldName, + $language, + $queryParam, + self::SCORE_COLUMN, + $language, + $queryParam, + self::RANK_COLUMN, + $tableName, + $language, + $queryParam, + ); + } + + public function getCteAlias(): string + { + return self::CTE_ALIAS; + } + + public function getRankColumn(): string + { + return self::RANK_COLUMN; + } + + public function getScoreColumn(): string + { + return self::SCORE_COLUMN; + } + + public function getNormalizedScoreExpression(string $scoreColumn): string + { + // ts_rank_cd returns values typically between 0 and 1, but can exceed 1 + // We cap it at 1.0 for normalization + return \sprintf('LEAST(%s, 1.0)', $scoreColumn); + } + + public function getRequiredExtensions(): array + { + return []; // No additional extensions required + } + + public function isAvailable(\PDO $connection): bool + { + return true; // Always available in PostgreSQL + } +} diff --git a/src/store/src/Bridge/Postgres/TextSearch/TextSearchStrategyInterface.php b/src/store/src/Bridge/Postgres/TextSearch/TextSearchStrategyInterface.php new file mode 100644 index 000000000..0081a801d --- /dev/null +++ b/src/store/src/Bridge/Postgres/TextSearch/TextSearchStrategyInterface.php @@ -0,0 +1,84 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Postgres\TextSearch; + +/** + * Strategy interface for full-text search implementations. + * + * Allows pluggable FTS backends (PostgreSQL native, BM25, etc.) + * + * @author Ahmed EBEN HASSINE + */ +interface TextSearchStrategyInterface +{ + /** + * Get the SQL statements needed to set up the text search. + * + * @param string $tableName The table name + * @param string $contentFieldName The content field name + * @param string $language The language configuration + * + * @return string[] Array of SQL statements to execute + */ + public function getSetupSql(string $tableName, string $contentFieldName, string $language): array; + + /** + * Build the CTE (Common Table Expression) for text search ranking. + * + * @param string $tableName The table name + * @param string $contentFieldName The content field name + * @param string $language The language configuration + * @param string $queryParam The parameter name for the query (e.g., ':query') + * + * @return string SQL CTE expression + */ + public function buildSearchCte( + string $tableName, + string $contentFieldName, + string $language, + string $queryParam = ':query', + ): string; + + /** + * Get the name of the CTE that will be used in joins. + */ + public function getCteAlias(): string; + + /** + * Get the rank column name from the CTE. + */ + public function getRankColumn(): string; + + /** + * Get the score column name from the CTE. + */ + public function getScoreColumn(): string; + + /** + * Get the SQL expression to normalize the score to 0-1 range. + * + * @param string $scoreColumn The score column name + */ + public function getNormalizedScoreExpression(string $scoreColumn): string; + + /** + * Check if this strategy requires external extensions. + * + * @return string[] List of required extensions + */ + public function getRequiredExtensions(): array; + + /** + * Check if the strategy is available (extensions installed, etc.). + */ + public function isAvailable(\PDO $connection): bool; +} diff --git a/src/store/tests/Bridge/Postgres/HybridStoreTest.php b/src/store/tests/Bridge/Postgres/HybridStoreTest.php new file mode 100644 index 000000000..d3c569e58 --- /dev/null +++ b/src/store/tests/Bridge/Postgres/HybridStoreTest.php @@ -0,0 +1,983 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests\Bridge\Postgres; + +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Vector\NullVector; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Bridge\Postgres\HybridStore; +use Symfony\AI\Store\Bridge\Postgres\ReciprocalRankFusion; +use Symfony\AI\Store\Bridge\Postgres\TextSearch\Bm25TextSearchStrategy; +use Symfony\AI\Store\Bridge\Postgres\TextSearch\PostgresTextSearchStrategy; +use Symfony\AI\Store\Bridge\Postgres\TextSearch\TextSearchStrategyInterface; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\InvalidArgumentException; +use Symfony\Component\Uid\Uuid; + +final class HybridStoreTest extends TestCase +{ + public function testConstructorValidatesSemanticRatio() + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('The semantic ratio must be between 0.0 and 1.0'); + + $pdo = $this->createMock(\PDO::class); + new HybridStore($pdo, 'test_table', semanticRatio: 1.5); + } + + public function testConstructorValidatesNegativeSemanticRatio() + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('The semantic ratio must be between 0.0 and 1.0'); + + $pdo = $this->createMock(\PDO::class); + new HybridStore($pdo, 'test_table', semanticRatio: -0.5); + } + + public function testConstructorValidatesFuzzyWeight() + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('The fuzzy weight must be between 0.0 and 1.0'); + + $pdo = $this->createMock(\PDO::class); + new HybridStore($pdo, 'test_table', fuzzyWeight: 1.5); + } + + public function testConstructorUsesDefaultTextSearchStrategy() + { + $pdo = $this->createMock(\PDO::class); + $store = new HybridStore($pdo, 'test_table'); + + $this->assertInstanceOf(PostgresTextSearchStrategy::class, $store->getTextSearchStrategy()); + } + + public function testConstructorUsesCustomTextSearchStrategy() + { + $pdo = $this->createMock(\PDO::class); + $customStrategy = new Bm25TextSearchStrategy(); + $store = new HybridStore($pdo, 'test_table', textSearchStrategy: $customStrategy); + + $this->assertSame($customStrategy, $store->getTextSearchStrategy()); + } + + public function testConstructorUsesDefaultRrf() + { + $pdo = $this->createMock(\PDO::class); + $store = new HybridStore($pdo, 'test_table'); + + $this->assertInstanceOf(ReciprocalRankFusion::class, $store->getRrf()); + $this->assertSame(60, $store->getRrf()->getK()); + } + + public function testConstructorUsesCustomRrf() + { + $pdo = $this->createMock(\PDO::class); + $customRrf = new ReciprocalRankFusion(k: 100, normalizeScores: false); + $store = new HybridStore($pdo, 'test_table', rrf: $customRrf); + + $this->assertSame($customRrf, $store->getRrf()); + $this->assertSame(100, $store->getRrf()->getK()); + } + + public function testSetupCreatesTableWithFullTextSearchSupport() + { + $pdo = $this->createMock(\PDO::class); + $store = new HybridStore($pdo, 'hybrid_table'); + + $pdo->expects($this->exactly(10)) + ->method('exec') + ->willReturnCallback(function (string $sql): int { + static $callCount = 0; + ++$callCount; + + if (1 === $callCount) { + $this->assertSame('CREATE EXTENSION IF NOT EXISTS vector', $sql); + } elseif (2 === $callCount) { + $this->assertSame('CREATE EXTENSION IF NOT EXISTS pg_trgm', $sql); + } elseif (3 === $callCount) { + $this->assertStringContainsString('CREATE TABLE IF NOT EXISTS hybrid_table', $sql); + $this->assertStringContainsString('content TEXT NOT NULL', $sql); + $this->assertStringContainsString('embedding vector(1536) NOT NULL', $sql); + } elseif (4 === $callCount) { + $this->assertStringContainsString('ALTER TABLE hybrid_table ADD COLUMN IF NOT EXISTS search_text TEXT', $sql); + } elseif (5 === $callCount) { + $this->assertStringContainsString('CREATE OR REPLACE FUNCTION update_search_text()', $sql); + } elseif (6 === $callCount) { + $this->assertStringContainsString('CREATE TRIGGER trigger_update_search_text', $sql); + } elseif (7 === $callCount) { + $this->assertStringContainsString('CREATE INDEX IF NOT EXISTS hybrid_table_embedding_idx', $sql); + } elseif (8 === $callCount) { + // TextSearchStrategy adds content_tsv column via ALTER TABLE + $this->assertStringContainsString('ALTER TABLE hybrid_table ADD COLUMN IF NOT EXISTS content_tsv', $sql); + $this->assertStringContainsString('GENERATED ALWAYS AS (to_tsvector(\'simple\', content)) STORED', $sql); + } elseif (9 === $callCount) { + // TextSearchStrategy creates GIN index for content_tsv + $this->assertStringContainsString('CREATE INDEX IF NOT EXISTS hybrid_table_content_tsv_idx', $sql); + $this->assertStringContainsString('USING gin(content_tsv)', $sql); + } else { + $this->assertStringContainsString('CREATE INDEX IF NOT EXISTS hybrid_table_search_text_trgm_idx', $sql); + $this->assertStringContainsString('USING gin(search_text gin_trgm_ops)', $sql); + } + + return 0; + }); + + $store->setup(); + } + + public function testSetupExecutesTextSearchStrategySetupSql() + { + $pdo = $this->createMock(\PDO::class); + + $mockStrategy = $this->createMock(TextSearchStrategyInterface::class); + $mockStrategy->expects($this->once()) + ->method('getSetupSql') + ->with('hybrid_table', 'content', 'simple') + ->willReturn([ + 'CREATE INDEX custom_idx ON hybrid_table USING gin(content)', + ]); + + $store = new HybridStore($pdo, 'hybrid_table', textSearchStrategy: $mockStrategy); + + $execCalls = []; + $pdo->expects($this->any()) + ->method('exec') + ->willReturnCallback(function (string $sql) use (&$execCalls): int { + $execCalls[] = $sql; + + return 0; + }); + + $store->setup(); + + $this->assertContains('CREATE INDEX custom_idx ON hybrid_table USING gin(content)', $execCalls); + $this->assertNotEmpty($execCalls, 'Expected at least one exec() call'); + } + + public function testAddDocument() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new HybridStore($pdo, 'hybrid_table'); + + $expectedSql = 'INSERT INTO hybrid_table (id, metadata, content, embedding) + VALUES (:id, :metadata, :content, :vector) + ON CONFLICT (id) DO UPDATE SET + metadata = EXCLUDED.metadata, + content = EXCLUDED.content, + embedding = EXCLUDED.embedding'; + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) use ($expectedSql) { + return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + })) + ->willReturn($statement); + + $uuid = Uuid::v4(); + + $statement->expects($this->once()) + ->method('execute') + ->with([ + 'id' => $uuid->toRfc4122(), + 'metadata' => json_encode(['_text' => 'Test content', 'category' => 'test']), + 'content' => 'Test content', + 'vector' => '[0.1,0.2,0.3]', + ]); + + $metadata = new Metadata(['_text' => 'Test content', 'category' => 'test']); + $document = new VectorDocument($uuid, new Vector([0.1, 0.2, 0.3]), $metadata); + $store->add($document); + } + + public function testAddMultipleDocuments() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new HybridStore($pdo, 'hybrid_table'); + + $pdo->expects($this->once()) + ->method('prepare') + ->willReturn($statement); + + $uuid1 = Uuid::v4(); + $uuid2 = Uuid::v4(); + + $statement->expects($this->exactly(2)) + ->method('execute') + ->willReturnCallback(function (array $params) use ($uuid1, $uuid2): bool { + static $callCount = 0; + ++$callCount; + + if (1 === $callCount) { + $this->assertSame($uuid1->toRfc4122(), $params['id']); + $this->assertSame('First document', $params['content']); + } else { + $this->assertSame($uuid2->toRfc4122(), $params['id']); + $this->assertSame('Second document', $params['content']); + } + + return true; + }); + + $metadata1 = new Metadata(['_text' => 'First document']); + $metadata2 = new Metadata(['_text' => 'Second document']); + + $document1 = new VectorDocument($uuid1, new Vector([0.1, 0.2, 0.3]), $metadata1); + $document2 = new VectorDocument($uuid2, new Vector([0.4, 0.5, 0.6]), $metadata2); + + $store->add($document1, $document2); + } + + public function testPureVectorSearch() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + // Disable score normalization for this test + $rrf = new ReciprocalRankFusion(normalizeScores: false); + $store = new HybridStore($pdo, 'hybrid_table', semanticRatio: 1.0, rrf: $rrf); + + $expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <-> :embedding) AS score + FROM hybrid_table + + ORDER BY score ASC + LIMIT 5'; + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) use ($expectedSql) { + return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql); + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute') + ->with(['embedding' => '[0.1,0.2,0.3]']); + + $uuid = Uuid::v4(); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid->toRfc4122(), + 'embedding' => '[0.1,0.2,0.3]', + 'metadata' => json_encode(['text' => 'Test Document']), + 'score' => 0.05, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3])); + + $this->assertCount(1, $results); + $this->assertInstanceOf(VectorDocument::class, $results[0]); + $this->assertSame(0.05, $results[0]->score); + } + + public function testPureKeywordSearchWithPostgresStrategy() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $rrf = new ReciprocalRankFusion(normalizeScores: false); + $store = new HybridStore( + $pdo, + 'hybrid_table', + semanticRatio: 0.0, + textSearchStrategy: new PostgresTextSearchStrategy(), + rrf: $rrf + ); + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) { + // Verify PostgreSQL native FTS structure + $this->assertStringContainsString('WITH', $sql); + $this->assertStringContainsString('fts_search AS', $sql); + $this->assertStringContainsString('ts_rank_cd', $sql); + $this->assertStringContainsString('plainto_tsquery', $sql); + $this->assertStringContainsString('content_tsv @@', $sql); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute') + ->with($this->callback(function ($params) { + return isset($params['query']) && 'PostgreSQL' === $params['query']; + })); + + $uuid = Uuid::v4(); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid->toRfc4122(), + 'embedding' => null, + 'metadata' => json_encode(['text' => 'PostgreSQL is awesome']), + 'score' => 0.5, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), ['q' => 'PostgreSQL']); + + $this->assertCount(1, $results); + $this->assertSame(0.5, $results[0]->score); + // FTS-only results should have NullVector + $this->assertInstanceOf(NullVector::class, $results[0]->vector); + } + + public function testPureKeywordSearchWithBm25Strategy() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $rrf = new ReciprocalRankFusion(normalizeScores: false); + $store = new HybridStore( + $pdo, + 'hybrid_table', + semanticRatio: 0.0, + textSearchStrategy: new Bm25TextSearchStrategy(bm25Language: 'en'), + rrf: $rrf + ); + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) { + // Verify BM25 structure + $this->assertStringContainsString('WITH', $sql); + $this->assertStringContainsString('bm25_search AS', $sql); + $this->assertStringContainsString('bm25topk(', $sql); + $this->assertStringContainsString('bm25_with_metadata AS', $sql); + $this->assertStringContainsString('DISTINCT ON', $sql); + + // Should NOT contain native FTS functions + $this->assertStringNotContainsString('ts_rank_cd', $sql); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute') + ->with($this->callback(function ($params) { + return isset($params['query']) && 'PostgreSQL' === $params['query']; + })); + + $uuid = Uuid::v4(); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid->toRfc4122(), + 'embedding' => null, + 'metadata' => json_encode(['text' => 'PostgreSQL is awesome']), + 'score' => 0.5, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), ['q' => 'PostgreSQL']); + + $this->assertCount(1, $results); + $this->assertSame(0.5, $results[0]->score); + } + + public function testHybridSearchWithRRF() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $rrf = new ReciprocalRankFusion(k: 60, normalizeScores: false); + $store = new HybridStore($pdo, 'hybrid_table', semanticRatio: 0.5, rrf: $rrf); + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) { + // Check for RRF CTE structure + $this->assertStringContainsString('WITH vector_scores AS', $sql); + $this->assertStringContainsString('fuzzy_scores AS', $sql); + $this->assertStringContainsString('combined_results AS', $sql); + $this->assertStringContainsString('ROW_NUMBER() OVER', $sql); + $this->assertStringContainsString('FULL OUTER JOIN', $sql); + $this->assertStringContainsString('ORDER BY score DESC', $sql); + + // Should contain fuzzy matching + $this->assertStringContainsString('word_similarity', $sql); + + // Should contain RRF formula with k=60 + $this->assertStringContainsString('60 +', $sql); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute') + ->with($this->callback(function ($params) { + return isset($params['embedding']) && isset($params['query']); + })); + + $uuid = Uuid::v4(); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid->toRfc4122(), + 'embedding' => '[0.1,0.2,0.3]', + 'metadata' => json_encode(['text' => 'PostgreSQL database']), + 'score' => 0.025, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), ['q' => 'PostgreSQL', 'semanticRatio' => 0.5]); + + $this->assertCount(1, $results); + $this->assertSame(0.025, $results[0]->score); + } + + public function testQueryWithDefaultMaxScore() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new HybridStore( + $pdo, + 'hybrid_table', + semanticRatio: 1.0, + defaultMaxScore: 0.8 + ); + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) { + $this->assertStringContainsString('WHERE (embedding <-> :embedding) <= :maxScore', $sql); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute') + ->with($this->callback(function ($params) { + return isset($params['embedding']) + && isset($params['maxScore']) + && 0.8 === $params['maxScore']; + })); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3])); + + $this->assertCount(0, $results); + } + + public function testQueryWithMaxScoreOverride() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new HybridStore( + $pdo, + 'hybrid_table', + semanticRatio: 1.0, + defaultMaxScore: 0.8 + ); + + $pdo->expects($this->once()) + ->method('prepare') + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute') + ->with($this->callback(function ($params) { + // Should use override value 0.5, not default 0.8 + return isset($params['maxScore']) && 0.5 === $params['maxScore']; + })); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), ['maxScore' => 0.5]); + + $this->assertCount(0, $results); + } + + public function testQueryWithMinScoreFilter() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $rrf = new ReciprocalRankFusion(normalizeScores: false); + $store = new HybridStore( + $pdo, + 'hybrid_table', + semanticRatio: 1.0, + rrf: $rrf, + defaultMinScore: 0.5 + ); + + $pdo->expects($this->once()) + ->method('prepare') + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute'); + + $uuid1 = Uuid::v4(); + $uuid2 = Uuid::v4(); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid1->toRfc4122(), + 'embedding' => '[0.1,0.2,0.3]', + 'metadata' => json_encode(['text' => 'High score']), + 'score' => 0.8, + ], + [ + 'id' => $uuid2->toRfc4122(), + 'embedding' => '[0.4,0.5,0.6]', + 'metadata' => json_encode(['text' => 'Low score']), + 'score' => 0.3, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3])); + + // Only high score result should be returned + $this->assertCount(1, $results); + $this->assertSame(0.8, $results[0]->score); + } + + public function testQueryWithCustomRRFK() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $rrf = new ReciprocalRankFusion(k: 100); + $store = new HybridStore($pdo, 'hybrid_table', semanticRatio: 0.5, rrf: $rrf); + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) { + // Check for RRF constant 100 in the formula + $this->assertStringContainsString('100 +', $sql); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute'); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([]); + + $store->query(new Vector([0.1, 0.2, 0.3]), ['q' => 'test']); + } + + public function testQueryInvalidSemanticRatioInOptions() + { + $pdo = $this->createMock(\PDO::class); + $store = new HybridStore($pdo, 'hybrid_table'); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('The semantic ratio must be between 0.0 and 1.0'); + + $store->query(new Vector([0.1, 0.2, 0.3]), ['semanticRatio' => 1.5]); + } + + public function testDrop() + { + $pdo = $this->createMock(\PDO::class); + $store = new HybridStore($pdo, 'hybrid_table'); + + $pdo->expects($this->once()) + ->method('exec') + ->with('DROP TABLE IF EXISTS hybrid_table'); + + $store->drop(); + } + + public function testQueryWithCustomLimit() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new HybridStore($pdo, 'hybrid_table', semanticRatio: 1.0); + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) { + $this->assertStringContainsString('LIMIT 10', $sql); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute'); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([]); + + $store->query(new Vector([0.1, 0.2, 0.3]), ['limit' => 10]); + } + + public function testPureKeywordSearchReturnsEmptyWhenNoMatch() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new HybridStore($pdo, 'hybrid_table', semanticRatio: 0.0); + + $pdo->expects($this->once()) + ->method('prepare') + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute'); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), ['q' => 'zzzzzzzzzzzzz']); + + $this->assertCount(0, $results); + } + + public function testFuzzyMatchingWithWordSimilarity() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new HybridStore( + $pdo, + 'hybrid_table', + semanticRatio: 0.5, + fuzzyWeight: 0.3, + fuzzyPrimaryThreshold: 0.3, + fuzzySecondaryThreshold: 0.25, + fuzzyStrictThreshold: 0.2 + ); + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) { + // Verify fuzzy_scores CTE exists + $this->assertStringContainsString('fuzzy_scores AS', $sql); + + // Verify word_similarity function is used + $this->assertStringContainsString('word_similarity(:query, search_text)', $sql); + + // Verify custom thresholds are applied + $this->assertStringContainsString('0.300000', $sql); + $this->assertStringContainsString('0.250000', $sql); + $this->assertStringContainsString('0.200000', $sql); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once())->method('execute'); + $statement->expects($this->once())->method('fetchAll')->willReturn([]); + + $store->query(new Vector([0.1, 0.2, 0.3]), ['q' => 'test']); + } + + public function testSearchableAttributesWithBoost() + { + $pdo = $this->createMock(\PDO::class); + + $searchableAttributes = [ + 'title' => ['boost' => 2.0, 'metadata_key' => 'title'], + 'overview' => ['boost' => 1.0, 'metadata_key' => 'overview'], + ]; + + $store = new HybridStore( + $pdo, + 'hybrid_table', + searchableAttributes: $searchableAttributes + ); + + $pdo->expects($this->exactly(10)) + ->method('exec') + ->willReturnCallback(function (string $sql): int { + static $callCount = 0; + ++$callCount; + + if (3 === $callCount) { + // Verify separate tsvector columns for each attribute + $this->assertStringContainsString('title_tsv tsvector GENERATED ALWAYS AS', $sql); + $this->assertStringContainsString('overview_tsv tsvector GENERATED ALWAYS AS', $sql); + + // Should NOT contain generic content_tsv + $this->assertStringNotContainsString('content_tsv tsvector GENERATED ALWAYS AS (to_tsvector(\'simple\', content)) STORED', $sql); + } elseif ($callCount >= 8 && $callCount <= 9) { + // Verify separate GIN indexes + $this->assertStringContainsString('_tsv_idx', $sql); + $this->assertStringContainsString('USING gin(', $sql); + } + + return 0; + }); + + $store->setup(); + } + + public function testFuzzyWeightParameter() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new HybridStore( + $pdo, + 'hybrid_table', + semanticRatio: 0.4, + fuzzyWeight: 0.5 + ); + + $pdo->expects($this->once()) + ->method('prepare') + ->with($this->callback(function ($sql) { + $this->assertStringContainsString('fuzzy_scores AS', $sql); + $this->assertStringContainsString('combined_results AS', $sql); + $this->assertStringContainsString('COALESCE(1.0 / (', $sql); + + return true; + })) + ->willReturn($statement); + + $statement->expects($this->once())->method('execute'); + $statement->expects($this->once())->method('fetchAll')->willReturn([]); + + $store->query(new Vector([0.1, 0.2, 0.3]), ['q' => 'test']); + } + + public function testBoostFieldsApplied() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $rrf = new ReciprocalRankFusion(normalizeScores: false); + $store = new HybridStore($pdo, 'hybrid_table', semanticRatio: 1.0, rrf: $rrf); + + $pdo->expects($this->once()) + ->method('prepare') + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute'); + + $uuid1 = Uuid::v4(); + $uuid2 = Uuid::v4(); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid1->toRfc4122(), + 'embedding' => '[0.1,0.2,0.3]', + 'metadata' => json_encode(['text' => 'Popular', 'popularity' => 100]), + 'score' => 0.5, + ], + [ + 'id' => $uuid2->toRfc4122(), + 'embedding' => '[0.4,0.5,0.6]', + 'metadata' => json_encode(['text' => 'Unpopular', 'popularity' => 10]), + 'score' => 0.6, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), [ + 'boostFields' => [ + 'popularity' => ['min' => 50, 'boost' => 0.5], + ], + ]); + + $this->assertCount(2, $results); + + // First result should be boosted (popularity >= 50) + // Original score 0.5 * 1.5 = 0.75 + $this->assertSame(0.75, $results[0]->score); + $this->assertArrayHasKey('_applied_boosts', $results[0]->metadata->getArrayCopy()); + + // Second result should not be boosted (popularity < 50) + $this->assertSame(0.6, $results[1]->score); + } + + public function testScoreBreakdownIncluded() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $rrf = new ReciprocalRankFusion(normalizeScores: false); + $store = new HybridStore($pdo, 'hybrid_table', semanticRatio: 0.5, rrf: $rrf); + + $pdo->expects($this->once()) + ->method('prepare') + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute'); + + $uuid = Uuid::v4(); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid->toRfc4122(), + 'embedding' => '[0.1,0.2,0.3]', + 'metadata' => json_encode(['text' => 'Test']), + 'score' => 0.025, + 'vector_rank' => 1, + 'fts_rank' => 2, + 'vector_distance' => 0.1, + 'fts_score' => 0.8, + 'vector_contribution' => 0.015, + 'fts_contribution' => 0.01, + 'fuzzy_rank' => 3, + 'fuzzy_score' => 0.7, + 'fuzzy_contribution' => 0.005, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), [ + 'q' => 'test', + 'includeScoreBreakdown' => true, + ]); + + $this->assertCount(1, $results); + + $metadata = $results[0]->metadata->getArrayCopy(); + $this->assertArrayHasKey('_score_breakdown', $metadata); + + $breakdown = $metadata['_score_breakdown']; + $this->assertSame(1, $breakdown['vector_rank']); + $this->assertSame(2, $breakdown['fts_rank']); + $this->assertSame(3, $breakdown['fuzzy_rank']); + $this->assertSame(0.1, $breakdown['vector_distance']); + $this->assertSame(0.8, $breakdown['fts_score']); + $this->assertSame(0.7, $breakdown['fuzzy_score']); + } + + public function testNullVectorForFtsOnlyResults() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + $store = new HybridStore($pdo, 'hybrid_table', semanticRatio: 0.0); + + $pdo->expects($this->once()) + ->method('prepare') + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute'); + + $uuid = Uuid::v4(); + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid->toRfc4122(), + 'embedding' => null, + 'metadata' => json_encode(['text' => 'FTS only result']), + 'score' => 0.5, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3]), ['q' => 'FTS']); + + $this->assertCount(1, $results); + $this->assertInstanceOf(NullVector::class, $results[0]->vector); + } + + public function testScoreNormalization() + { + $pdo = $this->createMock(\PDO::class); + $statement = $this->createMock(\PDOStatement::class); + + // Enable normalization (default) + $rrf = new ReciprocalRankFusion(k: 60, normalizeScores: true); + $store = new HybridStore($pdo, 'hybrid_table', semanticRatio: 1.0, rrf: $rrf); + + $pdo->expects($this->once()) + ->method('prepare') + ->willReturn($statement); + + $statement->expects($this->once()) + ->method('execute'); + + $uuid = Uuid::v4(); + + // Raw RRF score + $rawScore = 0.01639; // Approximately 1/(60+1) = theoretical max + + $statement->expects($this->once()) + ->method('fetchAll') + ->with(\PDO::FETCH_ASSOC) + ->willReturn([ + [ + 'id' => $uuid->toRfc4122(), + 'embedding' => '[0.1,0.2,0.3]', + 'metadata' => json_encode(['text' => 'Test']), + 'score' => $rawScore, + ], + ]); + + $results = $store->query(new Vector([0.1, 0.2, 0.3])); + + $this->assertCount(1, $results); + + // Score should be normalized to approximately 100 + $expectedNormalized = $rrf->normalize($rawScore); + $this->assertEqualsWithDelta($expectedNormalized, $results[0]->score, 0.01); + } + + private function normalizeQuery(string $query): string + { + $normalized = preg_replace('/\s+/', ' ', $query); + + return trim($normalized); + } +} diff --git a/src/store/tests/Bridge/Postgres/ReciprocalRankFusionTest.php b/src/store/tests/Bridge/Postgres/ReciprocalRankFusionTest.php new file mode 100644 index 000000000..e6b588684 --- /dev/null +++ b/src/store/tests/Bridge/Postgres/ReciprocalRankFusionTest.php @@ -0,0 +1,221 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests\Bridge\Postgres; + +use PHPUnit\Framework\TestCase; +use Symfony\AI\Store\Bridge\Postgres\ReciprocalRankFusion; + +final class ReciprocalRankFusionTest extends TestCase +{ + public function testDefaultConstruction() + { + $rrf = new ReciprocalRankFusion(); + + $this->assertSame(60, $rrf->getK()); + $this->assertTrue($rrf->isNormalized()); + } + + public function testCustomConstruction() + { + $rrf = new ReciprocalRankFusion(k: 100, normalizeScores: false); + + $this->assertSame(100, $rrf->getK()); + $this->assertFalse($rrf->isNormalized()); + } + + public function testCalculateSingleRanking() + { + $rrf = new ReciprocalRankFusion(k: 60, normalizeScores: false); + + $score = $rrf->calculate([ + 'vector' => ['rank' => 1, 'score' => 1.0, 'weight' => 1.0], + ]); + + // 1/(60+1) * 1.0 * 1.0 = 0.01639... + $this->assertEqualsWithDelta(1 / 61, $score, 0.0001); + } + + public function testCalculateMultipleRankings() + { + $rrf = new ReciprocalRankFusion(k: 60, normalizeScores: false); + + $score = $rrf->calculate([ + 'vector' => ['rank' => 1, 'score' => 1.0, 'weight' => 0.5], + 'fts' => ['rank' => 2, 'score' => 0.8, 'weight' => 0.5], + ]); + + // (1/(60+1) * 1.0 * 0.5) + (1/(60+2) * 0.8 * 0.5) + $expected = (1 / 61 * 1.0 * 0.5) + (1 / 62 * 0.8 * 0.5); + $this->assertEqualsWithDelta($expected, $score, 0.0001); + } + + public function testCalculateSkipsNullRank() + { + $rrf = new ReciprocalRankFusion(k: 60, normalizeScores: false); + + $score = $rrf->calculate([ + 'vector' => ['rank' => 1, 'score' => 1.0, 'weight' => 0.5], + 'fts' => ['rank' => null, 'score' => 0.8, 'weight' => 0.5], + ]); + + // Only vector contribution + $expected = 1 / 61 * 1.0 * 0.5; + $this->assertEqualsWithDelta($expected, $score, 0.0001); + } + + public function testCalculateWithNormalization() + { + $rrf = new ReciprocalRankFusion(k: 60, normalizeScores: true); + + $score = $rrf->calculate([ + 'vector' => ['rank' => 1, 'score' => 1.0, 'weight' => 1.0], + ]); + + // Should be normalized to ~100 (since rank=1 with full score/weight gives max RRF) + $this->assertEqualsWithDelta(100.0, $score, 0.01); + } + + public function testCalculateContribution() + { + $rrf = new ReciprocalRankFusion(k: 60, normalizeScores: false); + + $contribution = $rrf->calculateContribution(rank: 1, score: 1.0, weight: 0.5); + + $expected = (1 / 61) * 1.0 * 0.5; + $this->assertEqualsWithDelta($expected, $contribution, 0.0001); + } + + public function testNormalize() + { + $rrf = new ReciprocalRankFusion(k: 60); + + $maxRawScore = 1 / 61; // Theoretical maximum + $normalized = $rrf->normalize($maxRawScore); + + $this->assertEqualsWithDelta(100.0, $normalized, 0.01); + } + + public function testDenormalize() + { + $rrf = new ReciprocalRankFusion(k: 60); + + $denormalized = $rrf->denormalize(100.0); + + $this->assertEqualsWithDelta(1 / 61, $denormalized, 0.0001); + } + + public function testNormalizeAndDenormalizeAreInverse() + { + $rrf = new ReciprocalRankFusion(k: 60); + + $original = 0.008; + $normalized = $rrf->normalize($original); + $denormalized = $rrf->denormalize($normalized); + + $this->assertEqualsWithDelta($original, $denormalized, 0.0001); + } + + public function testBuildSqlExpression() + { + $rrf = new ReciprocalRankFusion(k: 60); + + $sql = $rrf->buildSqlExpression( + rankColumn: 'v.rank_ix', + scoreExpr: '(1.0 - v.distance)', + weight: 0.7 + ); + + $this->assertStringContainsString('COALESCE(1.0 / (60 + v.rank_ix)', $sql); + $this->assertStringContainsString('(1.0 - v.distance)', $sql); + $this->assertStringContainsString('0.700000', $sql); + $this->assertStringContainsString(', 0.0)', $sql); + } + + public function testBuildSqlExpressionWithCustomNullDefault() + { + $rrf = new ReciprocalRankFusion(k: 60); + + $sql = $rrf->buildSqlExpression( + rankColumn: 'rank', + scoreExpr: 'score', + weight: 1.0, + nullDefault: '-1.0' + ); + + $this->assertStringContainsString(', -1.0)', $sql); + } + + public function testBuildCombinedSqlExpression() + { + $rrf = new ReciprocalRankFusion(k: 60); + + $sql = $rrf->buildCombinedSqlExpression([ + ['rank_column' => 'v.rank', 'score_expr' => 'v.score', 'weight' => 0.5], + ['rank_column' => 'f.rank', 'score_expr' => 'f.score', 'weight' => 0.5], + ]); + + $this->assertStringContainsString('(', $sql); + $this->assertStringContainsString(' + ', $sql); + $this->assertStringContainsString('60 + v.rank', $sql); + $this->assertStringContainsString('60 + f.rank', $sql); + } + + public function testDifferentKValues() + { + $rrf60 = new ReciprocalRankFusion(k: 60, normalizeScores: false); + $rrf100 = new ReciprocalRankFusion(k: 100, normalizeScores: false); + + $rankings = [ + 'vector' => ['rank' => 1, 'score' => 1.0, 'weight' => 1.0], + ]; + + $score60 = $rrf60->calculate($rankings); + $score100 = $rrf100->calculate($rankings); + + // Higher k means lower individual contributions + $this->assertGreaterThan($score100, $score60); + + // Verify exact values + $this->assertEqualsWithDelta(1 / 61, $score60, 0.0001); + $this->assertEqualsWithDelta(1 / 101, $score100, 0.0001); + } + + public function testWeightAffectsScore() + { + $rrf = new ReciprocalRankFusion(k: 60, normalizeScores: false); + + $scoreFullWeight = $rrf->calculate([ + 'vector' => ['rank' => 1, 'score' => 1.0, 'weight' => 1.0], + ]); + + $scoreHalfWeight = $rrf->calculate([ + 'vector' => ['rank' => 1, 'score' => 1.0, 'weight' => 0.5], + ]); + + $this->assertEqualsWithDelta($scoreFullWeight / 2, $scoreHalfWeight, 0.0001); + } + + public function testLowerRankGivesLowerScore() + { + $rrf = new ReciprocalRankFusion(k: 60, normalizeScores: false); + + $scoreRank1 = $rrf->calculate([ + 'vector' => ['rank' => 1, 'score' => 1.0, 'weight' => 1.0], + ]); + + $scoreRank10 = $rrf->calculate([ + 'vector' => ['rank' => 10, 'score' => 1.0, 'weight' => 1.0], + ]); + + $this->assertGreaterThan($scoreRank10, $scoreRank1); + } +}