Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ uvicorn==0.27.0
uvloop==0.19.0
watchfiles==0.21.0
websockets==12.0
langchain-astradb==0.3.3
astrapy==1.5.2
155 changes: 98 additions & 57 deletions wikidatachat/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@
setup_document_stream_from_list
)

from langchain_astradb import AstraDBVectorStore
from astrapy.info import CollectionVectorServiceOptions
import json

# Retrieve the SERAPI API key from environment variables.
SERAPI_API_KEY = os.environ.get("SERAPI_API_KEY")
EMBEDDING_MODEL = os.environ.get(
'EMBEDDING_MODEL',
'svalabs/german-gpl-adapted-covid'
)

# Retrieve the DataStax API keys from environment variables.
COLLECTION_NAME = os.environ.get('COLLECTION_NAME')
ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN')
ASTRA_DB_API_ENDPOINT = os.environ.get('ASTRA_DB_API_ENDPOINT')
ASTRA_DB_KEYSPACE = os.environ.get('ASTRA_DB_KEYSPACE')

class RetreivalAugmentedGenerationPipeline:
def __init__(self, embedding_model=EMBEDDING_MODEL, device='cpu'):
Expand All @@ -51,6 +60,20 @@ def __init__(self, embedding_model=EMBEDDING_MODEL, device='cpu'):
device=self.device
)

collection_vector_service_options = CollectionVectorServiceOptions(
provider="nvidia",
model_name="NV-Embed-QA"
)

# Initialize the graph store
self.graph_store = AstraDBVectorStore(
collection_name="wikidata",
collection_vector_service_options=collection_vector_service_options,
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
namespace=ASTRA_DB_KEYSPACE,
)

def process_query(
self, query: str, top_k: int = 10, lang: str = 'de',
content_key: str = None, meta_keys: list = [],
Expand All @@ -71,63 +94,81 @@ def process_query(
Returns:
The first answer from the generated answers.
"""
if wikidata_kwargs is None:
# Default Wikidata query parameters.
wikidata_kwargs = {
'timeout': 10,
'n_cores': cpu_count(),
'verbose': False,
'api_url': 'https://www.wikidata.org/w',
'wikidata_base': '"wikidata.org"',
'return_list': True
}

# Create a Document object from the query.
query_document = Document(content=query)

# Embed the query document.
query_embedded = self.embedder.run([query_document])

# Extract the embedding of the query document.
query_embedding = query_embedded['documents'][0].embedding

# Retrieve Wikidata statements related to the query.
wikidata_statements = get_wikidata_statements_from_query(
query,
lang=lang,
serapi_api_key=SERAPI_API_KEY,
**wikidata_kwargs
)

# Log the retrieved Wikidata statements for debugging.
self.logger.debug(f'{wikidata_statements=}')
for wds_ in wikidata_statements:
# Log each Wikidata statement for debugging.
self.logger.debug(f'{wds_=}')

# Setup the document stream from the list of Wikidata statements.
_, retriever = setup_document_stream_from_list(
dict_list=wikidata_statements,
content_key=content_key,
meta_keys=meta_keys,
embedder=self.embedder,
embedding_similarity_function=embedding_similarity_function,
device=self.device
)

# Run the retriever to find relevant documents
# based on the query embedding.
retriever_results = retriever.run(
query_embedding=list(query_embedding),
filters=None,
top_k=top_k,
scale_score=None,
return_embedding=None
)
retriever_results = []
try:
results = self.graph_store.similarity_search_with_relevance_scores(query, k=top_k)

retriever_results = [
Document(
content=r[0].page_content,
score=r[1],
meta={'qid': r[0].metadata['QID']}
)
for r in results]
except Exception as e:
print(e)

# If DataStax fails, use SERAPI instead
if len(retriever_results) == 0:
if wikidata_kwargs is None:
# Default Wikidata query parameters.
wikidata_kwargs = {
'timeout': 10,
'n_cores': cpu_count(),
'verbose': False,
'api_url': 'https://www.wikidata.org/w',
'wikidata_base': '"wikidata.org"',
'return_list': True
}

# Create a Document object from the query.
query_document = Document(content=query)

# Embed the query document.
query_embedded = self.embedder.run([query_document])

# Extract the embedding of the query document.
query_embedding = query_embedded['documents'][0].embedding

# Retrieve Wikidata statements related to the query.
wikidata_statements = get_wikidata_statements_from_query(
query,
lang=lang,
serapi_api_key=SERAPI_API_KEY,
**wikidata_kwargs
)

# Log the retrieved Wikidata statements for debugging.
self.logger.debug(f'{wikidata_statements=}')
for wds_ in wikidata_statements:
# Log each Wikidata statement for debugging.
self.logger.debug(f'{wds_=}')

# Setup the document stream from the list of Wikidata statements.
_, retriever = setup_document_stream_from_list(
dict_list=wikidata_statements,
content_key=content_key,
meta_keys=meta_keys,
embedder=self.embedder,
embedding_similarity_function=embedding_similarity_function,
device=self.device
)

# Run the retriever to find relevant documents
# based on the query embedding.
retriever_results = retriever.run(
query_embedding=list(query_embedding),
filters=None,
top_k=top_k,
scale_score=None,
return_embedding=None
)
retriever_results = retriever_results['documents']

# Log the start of retriever results for debugging.
self.logger.debug('retriever results:')
for retriever_result_ in retriever_results['documents']:
for retriever_result_ in retriever_results:
# Log each retriever result for debugging.
self.logger.debug(retriever_result_)

Expand All @@ -140,8 +181,8 @@ def process_query(
# Build the user prompt based on the retrieved documents
# and the original query.
user_prompt_build = user_prompt_builder.run(
question=query_document.content,
documents=retriever_results['documents']
question=query,
documents=retriever_results
)

# Extract the constructed prompt.
Expand All @@ -167,10 +208,10 @@ def process_query(
# Build the answer based on the language model's response
# and the retrieved documents.
answer_build = answer_builder.run(
query=query_document.content,
query=query,
replies=response['replies'],
meta=[r.meta for r in response['replies']],
documents=retriever_results['documents']
documents=retriever_results
)

# Log the constructed answer for debugging.
Expand Down