From 486e4023522d2c899db67ddcd6792dc3a35c73ac Mon Sep 17 00:00:00 2001 From: Ranjan Shrestha Date: Thu, 21 Nov 2024 10:08:37 +0545 Subject: [PATCH] Load the models in the startup; --- .env.sample | 6 +++++- app.py | 51 ++++++++++++++++++++++++--------------------------- reranker.py | 7 ++++--- utils.py | 4 +++- 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/.env.sample b/.env.sample index 9847a1d..6578ff8 100644 --- a/.env.sample +++ b/.env.sample @@ -1 +1,5 @@ -OPENAI_API_KEY= \ No newline at end of file +OPENAI_API_KEY= +EMBEDDING_MODEL_NAME= +EMBEDDING_MODEL_TYPE= +EMBEDDING_MODEL_VECTOR_SIZE= +OLLAMA_BASE_URL= diff --git a/app.py b/app.py index d853dd9..fe4b642 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,7 @@ import json +import os from enum import Enum -from typing import List, Optional, Union +from typing import List, Union from dotenv import load_dotenv from fastapi import FastAPI, Response, status @@ -31,13 +32,25 @@ class EmbeddingModelType(Enum): OPENAI = 3 +MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "sentence-transformers/gtr-t5-large") +MODEL_TYPE = EmbeddingModelType(int(os.getenv("EMBEDDING_MODEL_TYPE", "1"))) +OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", None) + +embedding_model = None + +if MODEL_TYPE == EmbeddingModelType.SENTENCE_TRANSFORMERS: + embedding_model = SentenceTransformerEmbeddingModel(model=MODEL_NAME) +elif MODEL_TYPE == EmbeddingModelType.OLLAMA: + embedding_model = OllamaEmbeddingModel(model=MODEL_NAME, base_url=OLLAMA_BASE_URL) +elif MODEL_TYPE == EmbeddingModelType.OPENAI: + embedding_model = OpenAIEmbeddingModel(model=MODEL_NAME) + + class RequestSchemaForEmbeddings(BaseModel): """Request Schema""" - type_model: EmbeddingModelType - name_model: str texts: Union[str, List[str]] - base_url: Optional[str] = None class RequestSchemaForTextSplitter(BaseModel): @@ -68,29 +81,13 @@ async def generate_embeddings(item: RequestSchemaForEmbeddings): Generates the embedding vectors for the text/documents based on different models """ - type_model = item.type_model - name_model = item.name_model - base_url = item.base_url - texts = item.texts - - def generate(em_model, texts): - if isinstance(texts, str): - return em_model.embed_query(text=texts) - elif isinstance(texts, list): - return em_model.embed_documents(texts=texts) - return None - - if type_model == EmbeddingModelType.SENTENCE_TRANSFORMERS: - embedding_model = SentenceTransformerEmbeddingModel(model=name_model) - return generate(em_model=embedding_model, texts=texts) - - elif type_model == EmbeddingModelType.OLLAMA: - embedding_model = OllamaEmbeddingModel(model=name_model, base_url=base_url) - return generate(em_model=embedding_model, texts=texts) - - elif type_model == EmbeddingModelType.OPENAI: - embedding_model = OpenAIEmbeddingModel(model=name_model) - return generate(em_model=embedding_model, texts=texts) + + if embedding_model: + if isinstance(item.texts, str): + return embedding_model.embed_query(text=item.texts) + elif isinstance(item.texts, list): + return embedding_model.embed_documents(texts=item.texts) + return [] @app.post("/split_docs_based_on_tokens") diff --git a/reranker.py b/reranker.py index 5020414..2497f98 100644 --- a/reranker.py +++ b/reranker.py @@ -4,10 +4,11 @@ from sentence_transformers import CrossEncoder from torch import Tensor +cross_encoder_model = CrossEncoder(model_name="cross-encoder/ms-marco-MiniLM-L-2-v2", max_length=512) -def get_scores(query: str, documents: List[str], model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2"): + +def get_scores(query: str, documents: List[str]): """Get the scores""" - model = CrossEncoder(model_name=model_name, max_length=512) doc_tuple = [(query, doc) for doc in documents] - scores = model.predict(doc_tuple) + scores = cross_encoder_model.predict(doc_tuple) return F.softmax(Tensor(scores), dim=0).tolist() diff --git a/utils.py b/utils.py index 5d32be6..58ee57e 100644 --- a/utils.py +++ b/utils.py @@ -6,7 +6,7 @@ from huggingface_hub import snapshot_download logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) def download_model(embedding_model: str, models_path: str): @@ -21,6 +21,8 @@ def check_models(sent_embedding_model: str): models_path = Path("/opt/models") models_info_path = models_path / "model_info.json" + logging.info("Checking models status.") + if not os.path.exists(models_path): os.makedirs(models_path)