diff --git a/infrastructure/README.md b/infrastructure/README.md index 887e7ea5..a6e5ba0f 100644 --- a/infrastructure/README.md +++ b/infrastructure/README.md @@ -291,7 +291,8 @@ backend: retriever: RETRIEVER_THRESHOLD: 0.3 RETRIEVER_K_DOCUMENTS: 10 - RETRIEVER_TOTAL_K: 7 + # Canonical global cap (previously RETRIEVER_TOTAL_K / RETRIEVER_OVERALL_K_DOCUMENTS) + RETRIEVER_TOTAL_K_DOCUMENTS: 7 RETRIEVER_SUMMARY_THRESHOLD: 0.3 RETRIEVER_SUMMARY_K_DOCUMENTS: 10 RETRIEVER_TABLE_THRESHOLD: 0.3 @@ -489,7 +490,7 @@ Afterwards, the services are accessible from [http://rag.localhost](http://rag.l Note: The command above has only been tested on *Ubuntu 22.04 LTS*. -On *Windows* you can adjust the hosts file as described [here](https://docs.digitalocean.com/products/paperspace/machines/how-to/edit-windows-hosts-file/). +On *Windows* you can adjust the hosts file as described in the DigitalOcean guide on [editing the Windows hosts file](https://docs.digitalocean.com/products/paperspace/machines/how-to/edit-windows-hosts-file/). ### 2.2 Production Setup Instructions @@ -499,7 +500,7 @@ For deployment of the *NGINX Ingress Controller* and a cert-manager, the followi [base-setup](server-setup/base-setup/Chart.yaml) -The email [here](server-setup/base-setup/templates/cert-issuer.yaml) should be changed from `` to a real email address. +The email in the [cert-issuer template](server-setup/base-setup/templates/cert-issuer.yaml) should be changed from `` to a real email address. ## 3. Contributing diff --git a/infrastructure/rag/values.yaml b/infrastructure/rag/values.yaml index c8d95a29..78226195 100644 --- a/infrastructure/rag/values.yaml +++ b/infrastructure/rag/values.yaml @@ -167,7 +167,8 @@ backend: retriever: RETRIEVER_THRESHOLD: 0.3 RETRIEVER_K_DOCUMENTS: 10 - RETRIEVER_TOTAL_K: 7 + # Canonical global cap across all retrievers. Replaces legacy RETRIEVER_TOTAL_K / RETRIEVER_OVERALL_K_DOCUMENTS + RETRIEVER_TOTAL_K_DOCUMENTS: 7 RETRIEVER_SUMMARY_THRESHOLD: 0.3 RETRIEVER_SUMMARY_K_DOCUMENTS: 10 RETRIEVER_TABLE_THRESHOLD: 0.3 @@ -220,6 +221,7 @@ backend: reranker: RERANKER_K_DOCUMENTS: 5 RERANKER_MIN_RELEVANCE_SCORE: 0.001 + RERANKER_ENABLED: true chatHistory: CHAT_HISTORY_LIMIT: 4 CHAT_HISTORY_REVERSE: true diff --git a/libs/rag-core-api/src/rag_core_api/dependency_container.py b/libs/rag-core-api/src/rag_core_api/dependency_container.py index 181868ef..3f6f2ef2 100644 --- a/libs/rag-core-api/src/rag_core_api/dependency_container.py +++ b/libs/rag-core-api/src/rag_core_api/dependency_container.py @@ -130,7 +130,9 @@ class DependencyContainer(DeclarativeContainer): vectorstore=vectorstore, ) - flashrank_reranker = Singleton(FlashrankRerank, top_n=reranker_settings.k_documents) + flashrank_reranker = Singleton( + FlashrankRerank, top_n=reranker_settings.k_documents, score_threshold=reranker_settings.min_relevance_score + ) reranker = Singleton(FlashrankReranker, flashrank_reranker) information_pieces_uploader = Singleton(DefaultInformationPiecesUploader, vector_database) @@ -170,6 +172,9 @@ class DependencyContainer(DeclarativeContainer): CompositeRetriever, List(image_retriever, table_retriever, text_retriever, summary_retriever), reranker, + reranker_settings.enabled, + retriever_settings.total_k_documents, + reranker_settings.k_documents, ) information_piece_mapper = Singleton(InformationPieceMapper) diff --git a/libs/rag-core-api/src/rag_core_api/impl/graph/chat_graph.py b/libs/rag-core-api/src/rag_core_api/impl/graph/chat_graph.py index 7d65c978..fa087375 100644 --- a/libs/rag-core-api/src/rag_core_api/impl/graph/chat_graph.py +++ b/libs/rag-core-api/src/rag_core_api/impl/graph/chat_graph.py @@ -269,6 +269,13 @@ async def _retrieve_node(self, state: dict) -> dict: if document.metadata.get("type", ContentType.SUMMARY.value) != ContentType.SUMMARY.value ] + # If only summaries were retrieved (no concrete underlying documents), treat as "no documents" + if not information_pieces: + return { + self.ERROR_MESSAGES_KEY: [self._error_messages.no_documents_message], + self.FINISH_REASONS: ["No documents found"], + } + response["information_pieces"] = information_pieces response["langchain_documents"] = retrieved_documents diff --git a/libs/rag-core-api/src/rag_core_api/impl/retriever/composite_retriever.py b/libs/rag-core-api/src/rag_core_api/impl/retriever/composite_retriever.py index ed7501e7..9ab1bcc7 100644 --- a/libs/rag-core-api/src/rag_core_api/impl/retriever/composite_retriever.py +++ b/libs/rag-core-api/src/rag_core_api/impl/retriever/composite_retriever.py @@ -1,8 +1,19 @@ -"""Module for the CompositeRetriever class.""" +"""Module for the CompositeRetriever class. + +Performance notes / improvements (2025-10): + - Retriever invocations are now executed concurrently via ``asyncio.gather`` instead of + sequential awaits inside a for-loop. This reduces end-to-end latency roughly to the + slowest individual retriever call instead of the sum of all. + - Duplicate filtering now uses an O(1) set membership check instead of rebuilding a list + comprehension for every candidate (previously O(n^2)). + - Early pruning hook (``_total_k``) is prepared for future enhancement; if provided it + allows trimming the merged candidate list before an optional reranker is invoked. +""" import logging +import asyncio from copy import deepcopy -from typing import Any, Optional +from typing import Any, Optional, Iterable from langchain_core.documents import Document from langchain_core.runnables import RunnableConfig @@ -22,6 +33,9 @@ def __init__( self, retrievers: list[RetrieverQuark], reranker: Optional[Reranker], + reranker_enabled: bool, + total_retrieved_k_documents: int | None = None, + reranker_k_documents: int | None = None, **kwargs, ): """ @@ -33,12 +47,22 @@ def __init__( A list of retriever quarks to be used by the composite retriever. reranker : Optional[Reranker] An optional reranker to rerank the retrieved results. + reranker_enabled : bool + A flag indicating whether the reranker is enabled. + total_retrieved_k_documents : int | None + The total number of documents to retrieve (default None, meaning no limit). + reranker_k_documents : int | None + The number of documents to retrieve for the reranker (default None, meaning no limit). **kwargs : dict Additional keyword arguments to be passed to the superclass initializer. """ super().__init__(**kwargs) self._reranker = reranker self._retrievers = retrievers + # Optional global cap (before reranking) on merged candidates. If None, no cap applied. + self._total_retrieved_k_documents = total_retrieved_k_documents + self._reranker_k_documents = reranker_k_documents + self._reranker_enabled = reranker_enabled def verify_readiness(self) -> None: """ @@ -67,7 +91,7 @@ async def ainvoke( retriever_input : str The input string to be processed by the retrievers. config : Optional[RunnableConfig] - Configuration for the retrievers (default None). + Configuration for the retrievers and reranker (default None). **kwargs : Any Additional keyword arguments. @@ -83,24 +107,148 @@ async def ainvoke( - Duplicate entries are removed based on their metadata ID. - If a reranker is available, the results are further processed by the reranker. """ - results = [] if config is None: config = RunnableConfig(metadata={"filter_kwargs": {}}) - for retriever in self._retrievers: - tmp_config = deepcopy(config) - results += await retriever.ainvoke(retriever_input, config=tmp_config) - # remove summaries - results = [x for x in results if x.metadata["type"] != ContentType.SUMMARY.value] + # Run all retrievers concurrently instead of sequentially. + tasks = [r.ainvoke(retriever_input, config=deepcopy(config)) for r in self._retrievers] + retriever_outputs = await asyncio.gather(*tasks, return_exceptions=False) + # Flatten + results: list[Document] = [doc for group in retriever_outputs for doc in group] + + summary_docs: list[Document] = [d for d in results if d.metadata.get("type") == ContentType.SUMMARY.value] + + results = self._use_summaries(summary_docs, results) + + return_val = self._remove_duplicates(results) + + return_val = self._early_pruning(return_val) + + return await self._arerank_pruning(return_val, retriever_input, config) + + def _use_summaries(self, summary_docs: list[Document], results: list[Document]) -> list[Document]: + """Utilize summary documents to enhance retrieval results. + + Parameters + ---------- + summary_docs : list[Document] + A list of summary documents to use. + results : list[Document] + A list of retrieval results to enhance. + + Returns + ------- + list[Document] + The enhanced list of documents. + """ + try: + # Collect existing ids for fast membership tests + existing_ids: set[str] = {d.metadata.get("id") for d in results} + + # Gather related ids not yet present + + missing_related_ids: set[str] = set() + for sdoc in summary_docs: + related_list: Iterable[str] = sdoc.metadata.get("related", []) + [missing_related_ids.add(rid) for rid in related_list if rid and rid not in existing_ids] + + if missing_related_ids: + # Heuristic: use the first retriever's underlying vector database for lookup. + # All quarks share the same vector database instance in current design. + vector_db = None + if self._retrievers: + # Access protected member as an implementation detail – acceptable within package. + vector_db = getattr(self._retrievers[0], "_vector_database", None) + if vector_db and hasattr(vector_db, "get_documents_by_ids"): + try: + expanded_docs: list[Document] = vector_db.get_documents_by_ids(list(missing_related_ids)) + # Merge while preserving original order precedence (append new ones) + results.extend(expanded_docs) + existing_ids.update(d.metadata.get("id") for d in expanded_docs) + logger.debug( + "Summary expansion added %d underlying documents (from %d summaries).", + len(expanded_docs), + len(summary_docs), + ) + except Exception: + logger.exception("Failed to expand summary related documents.") + else: + logger.debug("Vector database does not expose get_documents_by_ids; skipping summary expansion.") + finally: + # Remove summaries after expansion step + results = [x for x in results if x.metadata.get("type") != ContentType.SUMMARY.value] + return results + + def _remove_duplicates(self, documents: list[Document]) -> list[Document]: + """Remove duplicate documents from a list based on their IDs. + + Parameters + ---------- + documents : list[Document] + The list of documents to filter. + + Returns + ------- + list[Document] + The filtered list of documents with duplicates removed. + """ + seen_ids = set() + unique_docs = [] + for doc in documents: + doc_id = doc.metadata.get("id") + if doc_id not in seen_ids: + seen_ids.add(doc_id) + unique_docs.append(doc) + return unique_docs + + def _early_pruning(self, documents: list[Document]) -> list[Document]: + """Prune documents early based on certain criteria. - # remove duplicated entries - return_val = [] - for result in results: - if result.metadata["id"] in [x.metadata["id"] for x in return_val]: - continue - return_val.append(result) + Parameters + ---------- + documents : list[Document] + The list of documents to prune. - if self._reranker and results: - return_val = await self._reranker.ainvoke((return_val, retriever_input), config=config) + Returns + ------- + list[Document] + The pruned list of documents. + """ + # Optional early global pruning (only if configured and more than total_k) + if self._total_retrieved_k_documents is not None and len(documents) > self._total_retrieved_k_documents: + # If score metadata exists, use it to prune; otherwise keep ordering as-is. + if all("score" in d.metadata for d in documents): + documents.sort(key=lambda d: d.metadata["score"], reverse=True) + return documents[: self._total_retrieved_k_documents] + return documents + + async def _arerank_pruning( + self, documents: list[Document], retriever_input: dict, config: Optional[RunnableConfig] = None + ) -> list[Document]: + """Prune documents by reranker. + + Parameters + ---------- + documents : list[Document] + The list of documents to prune. + retriever_input : dict + The input to the retriever. + config : Optional[RunnableConfig] + Configuration for the retrievers and reranker (default None). - return return_val + Returns + ------- + list[Document] + The pruned list of documents. + """ + if ( + self._reranker_k_documents is not None + and len(documents) > self._reranker_k_documents + and self._reranker_enabled + ): + # Only invoke reranker if there are more docs than it will output OR if score missing. + try: + documents = await self._reranker.ainvoke((documents, retriever_input), config=config) + except Exception: # pragma: no cover - fail soft; return unreranked if reranker errors + logger.exception("Reranker failed; returning unreranked results.") + return documents diff --git a/libs/rag-core-api/src/rag_core_api/impl/settings/reranker_settings.py b/libs/rag-core-api/src/rag_core_api/impl/settings/reranker_settings.py index 2af6fd86..3b0df436 100644 --- a/libs/rag-core-api/src/rag_core_api/impl/settings/reranker_settings.py +++ b/libs/rag-core-api/src/rag_core_api/impl/settings/reranker_settings.py @@ -12,6 +12,10 @@ class RerankerSettings(BaseSettings): ---------- k_documents : int The number of documents to return after reranking (default 5). + min_relevance_score : float + Minimum relevance threshold to return (default 0.001). + enabled : bool + A flag indicating whether the reranker is enabled (default True). """ class Config: @@ -21,3 +25,5 @@ class Config: case_sensitive = False k_documents: int = Field(default=5) + min_relevance_score: float = Field(default=0.001) + enabled: bool = Field(default=True) diff --git a/libs/rag-core-api/src/rag_core_api/impl/settings/retriever_settings.py b/libs/rag-core-api/src/rag_core_api/impl/settings/retriever_settings.py index fa2a3250..29e5ac41 100644 --- a/libs/rag-core-api/src/rag_core_api/impl/settings/retriever_settings.py +++ b/libs/rag-core-api/src/rag_core_api/impl/settings/retriever_settings.py @@ -1,6 +1,14 @@ -"""Module that contains settings regarding the retriever.""" +"""Module that contains settings regarding the retriever. -from pydantic import Field +Notes +----- +`total_k_documents` is the canonical global cap across all retrievers. It replaces the +previous environment variable names `RETRIEVER_TOTAL_K` and `RETRIEVER_OVERALL_K_DOCUMENTS`. +For backward compatibility, those legacy names are still accepted if the canonical +`RETRIEVER_TOTAL_K_DOCUMENTS` is not set. +""" + +from pydantic import Field, AliasChoices from pydantic_settings import BaseSettings @@ -11,7 +19,7 @@ class RetrieverSettings(BaseSettings): The threshold value for the retriever (default 0.5). k_documents : int The number of documents to retrieve (default 10). - total_k : int + total_k_documents : int The total number of documents (default 10). table_threshold : float The threshold value for table retrieval (default 0.37). @@ -35,10 +43,19 @@ class Config: threshold: float = Field(default=0.5) k_documents: int = Field(default=10) - total_k: int = Field(default=10) table_threshold: float = Field(default=0.37) table_k_documents: int = Field(default=10) summary_threshold: float = Field(default=0.5) summary_k_documents: int = Field(default=10) image_threshold: float = Field(default=0.5) image_k_documents: int = Field(default=10) + # Canonical global cap (previously RETRIEVER_TOTAL_K / RETRIEVER_OVERALL_K_DOCUMENTS). + # Accept legacy env var names as fallbacks via validation alias choices. + total_k_documents: int = Field( + default=10, + validation_alias=AliasChoices( + "TOTAL_K_DOCUMENTS", # canonical -> RETRIEVER_TOTAL_K_DOCUMENTS + "TOTAL_K", # legacy -> RETRIEVER_TOTAL_K + "OVERALL_K_DOCUMENTS", # legacy -> RETRIEVER_OVERALL_K_DOCUMENTS + ), + ) diff --git a/libs/rag-core-api/src/rag_core_api/impl/vector_databases/qdrant_database.py b/libs/rag-core-api/src/rag_core_api/impl/vector_databases/qdrant_database.py index 5a123d32..f97df6b8 100644 --- a/libs/rag-core-api/src/rag_core_api/impl/vector_databases/qdrant_database.py +++ b/libs/rag-core-api/src/rag_core_api/impl/vector_databases/qdrant_database.py @@ -155,6 +155,30 @@ def get_specific_document(self, document_id: str) -> list[Document]: for search_result in requested[0] ] + def get_documents_by_ids(self, document_ids: list[str]) -> list[Document]: + """Batch fetch multiple documents by their IDs. + + Parameters + ---------- + document_ids : list[str] + A list of document IDs to retrieve. + + Returns + ------- + list[Document] + A list of found documents. Missing IDs are ignored. + """ + if not document_ids: + return [] + # Scroll with OR semantics: build multiple FieldConditions + # Qdrant Python client doesn't support direct OR via 'should' in Filter shortcuts, + # but we can perform multiple scrolls as a fallback if needed. + # For efficiency, attempt a single scroll per id chunk (keep it simple for now). + results: list[Document] = [] + for doc_id in document_ids: + results.extend(self.get_specific_document(doc_id)) + return results + def upload(self, documents: list[Document]) -> None: """ Save the given documents to the Qdrant database. diff --git a/libs/rag-core-api/tests/composite_retriever_test.py b/libs/rag-core-api/tests/composite_retriever_test.py new file mode 100644 index 00000000..1fd1048a --- /dev/null +++ b/libs/rag-core-api/tests/composite_retriever_test.py @@ -0,0 +1,164 @@ +"""Test internal helper methods of ``CompositeRetriever``. + +The goal of these tests is to verify the transformation semantics of: + - _use_summaries + - _remove_duplicates + - _early_pruning + - _arerank_pruning + +They operate with light‑weight mock objects (no real vector DB / reranker logic). +""" + +from __future__ import annotations + +import asyncio +from typing import Iterable + +import pytest +from langchain_core.documents import Document + +from rag_core_api.impl.retriever.composite_retriever import CompositeRetriever +from rag_core_lib.impl.data_types.content_type import ContentType +from mocks.mock_vector_db import MockVectorDB +from mocks.mock_retriever_quark import MockRetrieverQuark +from mocks.mock_reranker import MockReranker + + +def _mk_doc( + doc_id: str, + score: float | None = None, + doc_type: ContentType = ContentType.TEXT, + related: Iterable[str] | None = None, +): + meta = {"id": doc_id, "type": doc_type.value} + if score is not None: + meta["score"] = score + if related is not None: + meta["related"] = list(related) + return Document(page_content=f"content-{doc_id}", metadata=meta) + + +@pytest.mark.asyncio +async def test_use_summaries_expands_and_removes_summary(): + """Expand a summary into its related documents. + + Verify that summary documents are removed and replaced by their related underlying documents. + """ + # Summary references an underlying doc not in initial results. + underlying = _mk_doc("doc1", score=0.9) + summary = _mk_doc("sum1", doc_type=ContentType.SUMMARY, related=["doc1"]) # type: ignore[arg-type] + vector_db = MockVectorDB({"doc1": underlying}) + retriever = MockRetrieverQuark([summary, underlying], vector_database=vector_db) + + cr = CompositeRetriever(retrievers=[retriever], reranker=None, reranker_enabled=False) + # Directly call _use_summaries for deterministic control + results = cr._use_summaries([summary], [summary]) + + # Underlying doc added (via expansion) & summary removed. + assert len(results) == 1 + assert results[0].metadata["id"] == "doc1" + assert all(d.metadata.get("type") != ContentType.SUMMARY.value for d in results) + + +def test_use_summaries_only_summary_no_related(): + """Drop a summary document that has no related documents. + + Verify that the returned result is empty when no related ids are present. + """ + summary = _mk_doc("sum1", doc_type=ContentType.SUMMARY, related=[]) # type: ignore[arg-type] + retriever = MockRetrieverQuark([summary]) + cr = CompositeRetriever(retrievers=[retriever], reranker=None, reranker_enabled=False) + results = cr._use_summaries([summary], [summary]) + # Expect empty list after removal because there are no related expansions. + assert results == [] + + +def test_remove_duplicates_preserves_first_occurrence(): + """Preserve the first occurrence when duplicate ids are present. + + Verify that duplicate documents are removed while maintaining the original order. + """ + d1a = _mk_doc("a") + d1b = _mk_doc("a") # duplicate id + d2 = _mk_doc("b") + retriever = MockRetrieverQuark([d1a, d1b, d2]) + cr = CompositeRetriever(retrievers=[retriever], reranker=None, reranker_enabled=False) + unique = cr._remove_duplicates([d1a, d1b, d2]) + assert [d.metadata["id"] for d in unique] == ["a", "b"] + + +def test_early_pruning_sorts_by_score_when_all_have_score(): + """Sort by score and keep only the top-k documents. + + Verify that documents are sorted descending by score when all documents include scores. + """ + docs = [_mk_doc("a", score=0.7), _mk_doc("b", score=0.9), _mk_doc("c", score=0.8)] + retriever = MockRetrieverQuark(docs) + cr = CompositeRetriever( + retrievers=[retriever], reranker=None, reranker_enabled=False, total_retrieved_k_documents=2 + ) + pruned = cr._early_pruning(docs.copy()) + # Expect top two by score descending: b (0.9), c (0.8) + assert [d.metadata["id"] for d in pruned] == ["b", "c"] + + +def test_early_pruning_preserves_order_without_scores(): + """Preserve input order when pruning without score metadata. + + Verify that pruning keeps the original order when scores are absent. + """ + docs = [_mk_doc("a"), _mk_doc("b"), _mk_doc("c")] # no scores + retriever = MockRetrieverQuark(docs) + cr = CompositeRetriever( + retrievers=[retriever], reranker=None, reranker_enabled=False, total_retrieved_k_documents=2 + ) + pruned = cr._early_pruning(docs.copy()) + assert [d.metadata["id"] for d in pruned] == ["a", "b"] + + +@pytest.mark.asyncio +async def test_arerank_pruning_invokes_reranker_when_needed(): + """Invoke the reranker when more than k documents are retrieved. + + Verify that the reranker is called and that the returned list is trimmed to ``reranker_k_documents``. + """ + docs = [_mk_doc("a", score=0.5), _mk_doc("b", score=0.7), _mk_doc("c", score=0.9)] + retriever = MockRetrieverQuark(docs) + reranker = MockReranker() + cr = CompositeRetriever( + retrievers=[retriever], + reranker=reranker, + reranker_enabled=True, + reranker_k_documents=2, + ) + pruned = await cr._arerank_pruning(docs.copy(), retriever_input="question") + # Reranker should be invoked and return top-2 by score (ids c, b) + assert reranker.invoked is True + assert [d.metadata["id"] for d in pruned] == ["c", "b"] + assert len(pruned) == 2 + + +@pytest.mark.asyncio +async def test_arerank_pruning_skips_when_not_needed(): + """Skip reranking when the retrieved docs are already within k. + + Verify that the reranker is not invoked when no pruning is required. + """ + docs = [_mk_doc("a", score=0.5), _mk_doc("b", score=0.7)] # already <= k + retriever = MockRetrieverQuark(docs) + reranker = MockReranker() + cr = CompositeRetriever( + retrievers=[retriever], + reranker=reranker, + reranker_enabled=True, + reranker_k_documents=3, + ) + pruned = await cr._arerank_pruning(docs.copy(), retriever_input="question") + # Not invoked because len(docs) <= reranker_k_documents + assert reranker.invoked is False + assert pruned == docs + + +# Convenience: allow running this test module directly for quick local dev. +if __name__ == "__main__": # pragma: no cover + asyncio.run(pytest.main([__file__])) diff --git a/libs/rag-core-api/tests/mocks/mock_environment_variables.py b/libs/rag-core-api/tests/mocks/mock_environment_variables.py index 72a7bbbf..161bac4d 100644 --- a/libs/rag-core-api/tests/mocks/mock_environment_variables.py +++ b/libs/rag-core-api/tests/mocks/mock_environment_variables.py @@ -34,7 +34,9 @@ def mock_environment_variables() -> None: os.environ["RETRIEVER_THRESHOLD"] = "0.0" os.environ["RETRIEVER_K_DOCUMENTS"] = "10" - os.environ["RETRIEVER_TOTAL_K"] = "10" + # Canonical global cap env. Legacy aliases RETRIEVER_TOTAL_K and RETRIEVER_OVERALL_K_DOCUMENTS + # are intentionally not set here to exercise primary path. + os.environ["RETRIEVER_TOTAL_K_DOCUMENTS"] = "10" os.environ["RETRIEVER_TABLE_THRESHOLD"] = "0.0" os.environ["RETRIEVER_TABLE_K_DOCUMENTS"] = "10" os.environ["RETRIEVER_SUMMARY_THRESHOLD"] = "0.0" diff --git a/libs/rag-core-api/tests/mocks/mock_reranker.py b/libs/rag-core-api/tests/mocks/mock_reranker.py new file mode 100644 index 00000000..45c00c92 --- /dev/null +++ b/libs/rag-core-api/tests/mocks/mock_reranker.py @@ -0,0 +1,37 @@ +"""Provide a mock reranker for CompositeRetriever unit tests.""" + +__all__ = ["MockReranker"] + + +class MockReranker: + """Provide a simple reranker test double. + + The mock records whether it was invoked and returns a deterministic top-2 subset. + """ + + def __init__(self): + self.invoked = False + + async def ainvoke(self, payload, config=None): + """Return a reranked subset of the provided documents. + + Parameters + ---------- + payload : tuple + A ``(documents, query)`` tuple. + config : Any, optional + Optional runtime config passed through by the caller. + + Returns + ------- + list + The top two documents sorted by score when available. + """ + self.invoked = True + documents, _query = payload + # Emulate reranker selecting top 2 with highest 'score' if present; else first 2 reversed + if all("score" in d.metadata for d in documents): + docs_sorted = sorted(documents, key=lambda d: d.metadata["score"], reverse=True) + else: # pragma: no cover - fallback path + docs_sorted = list(reversed(documents)) + return docs_sorted[:2] diff --git a/libs/rag-core-api/tests/mocks/mock_retriever_quark.py b/libs/rag-core-api/tests/mocks/mock_retriever_quark.py new file mode 100644 index 00000000..af3f8198 --- /dev/null +++ b/libs/rag-core-api/tests/mocks/mock_retriever_quark.py @@ -0,0 +1,38 @@ +"""Provide a mock retriever quark for CompositeRetriever unit tests.""" + +from langchain_core.documents import Document + +from .mock_vector_db import MockVectorDB + +__all__ = ["MockRetrieverQuark"] + + +class MockRetrieverQuark: + """Provide a minimal stand-in for a RetrieverQuark. + + Exposes an ``ainvoke`` returning pre-seeded documents and a ``_vector_database`` attribute + referenced by summary expansion logic. + """ + + def __init__(self, documents: list[Document], vector_database: MockVectorDB | None = None): + self._documents = documents + self._vector_database = vector_database or MockVectorDB() + + def verify_readiness(self): # pragma: no cover - trivial + """Verify that the retriever is ready. + + Returns + ------- + None + Always returns ``None``. + """ + + async def ainvoke(self, *_args, **_kwargs): + """Return the pre-seeded documents. + + Returns + ------- + list[Document] + The documents passed to the constructor. + """ + return self._documents diff --git a/libs/rag-core-api/tests/mocks/mock_vector_db.py b/libs/rag-core-api/tests/mocks/mock_vector_db.py new file mode 100644 index 00000000..b1c96f2f --- /dev/null +++ b/libs/rag-core-api/tests/mocks/mock_vector_db.py @@ -0,0 +1,43 @@ +"""Provide a minimal vector database interface for tests. + +Provides only the methods required by the CompositeRetriever unit tests: +- get_documents_by_ids: Used during summary expansion +- asearch: (async) provided as a defensive stub +""" + +from langchain_core.documents import Document + +__all__ = ["MockVectorDB"] + + +class MockVectorDB: + """Provide a minimal in-memory vector database test double.""" + + def __init__(self, docs_by_id: dict[str, Document] | None = None): + self.collection_available = True + self._docs_by_id = docs_by_id or {} + + def get_documents_by_ids(self, ids: list[str]) -> list[Document]: # pragma: no cover - simple mapping + """Return documents for the provided ids. + + Parameters + ---------- + ids : list[str] + Document ids to look up. + + Returns + ------- + list[Document] + Documents that exist in the in-memory mapping. + """ + return [self._docs_by_id[i] for i in ids if i in self._docs_by_id] + + async def asearch(self, *_, **__): # pragma: no cover - defensive stub + """Return an empty result for async search. + + Returns + ------- + list + Always returns an empty list. + """ + return []