Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
OPENAI_API_KEY=
OPENAI_API_KEY=
EMBEDDING_MODEL_NAME=
EMBEDDING_MODEL_TYPE=
EMBEDDING_MODEL_VECTOR_SIZE=
OLLAMA_BASE_URL=
51 changes: 24 additions & 27 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from enum import Enum
from typing import List, Optional, Union
from typing import List, Union

from dotenv import load_dotenv
from fastapi import FastAPI, Response, status
Expand Down Expand Up @@ -31,13 +32,25 @@ class EmbeddingModelType(Enum):
OPENAI = 3


MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "sentence-transformers/gtr-t5-large")
MODEL_TYPE = EmbeddingModelType(int(os.getenv("EMBEDDING_MODEL_TYPE", "1")))
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", None)

embedding_model = None

if MODEL_TYPE == EmbeddingModelType.SENTENCE_TRANSFORMERS:
embedding_model = SentenceTransformerEmbeddingModel(model=MODEL_NAME)
elif MODEL_TYPE == EmbeddingModelType.OLLAMA:
embedding_model = OllamaEmbeddingModel(model=MODEL_NAME, base_url=OLLAMA_BASE_URL)
elif MODEL_TYPE == EmbeddingModelType.OPENAI:
embedding_model = OpenAIEmbeddingModel(model=MODEL_NAME)


class RequestSchemaForEmbeddings(BaseModel):
"""Request Schema"""

type_model: EmbeddingModelType
name_model: str
texts: Union[str, List[str]]
base_url: Optional[str] = None


class RequestSchemaForTextSplitter(BaseModel):
Expand Down Expand Up @@ -68,29 +81,13 @@ async def generate_embeddings(item: RequestSchemaForEmbeddings):
Generates the embedding vectors for the text/documents
based on different models
"""
type_model = item.type_model
name_model = item.name_model
base_url = item.base_url
texts = item.texts

def generate(em_model, texts):
if isinstance(texts, str):
return em_model.embed_query(text=texts)
elif isinstance(texts, list):
return em_model.embed_documents(texts=texts)
return None

if type_model == EmbeddingModelType.SENTENCE_TRANSFORMERS:
embedding_model = SentenceTransformerEmbeddingModel(model=name_model)
return generate(em_model=embedding_model, texts=texts)

elif type_model == EmbeddingModelType.OLLAMA:
embedding_model = OllamaEmbeddingModel(model=name_model, base_url=base_url)
return generate(em_model=embedding_model, texts=texts)

elif type_model == EmbeddingModelType.OPENAI:
embedding_model = OpenAIEmbeddingModel(model=name_model)
return generate(em_model=embedding_model, texts=texts)

if embedding_model:
if isinstance(item.texts, str):
return embedding_model.embed_query(text=item.texts)
elif isinstance(item.texts, list):
return embedding_model.embed_documents(texts=item.texts)
return []


@app.post("/split_docs_based_on_tokens")
Expand Down
7 changes: 4 additions & 3 deletions reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from sentence_transformers import CrossEncoder
from torch import Tensor

cross_encoder_model = CrossEncoder(model_name="cross-encoder/ms-marco-MiniLM-L-2-v2", max_length=512)

def get_scores(query: str, documents: List[str], model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2"):

def get_scores(query: str, documents: List[str]):
"""Get the scores"""
model = CrossEncoder(model_name=model_name, max_length=512)
doc_tuple = [(query, doc) for doc in documents]
scores = model.predict(doc_tuple)
scores = cross_encoder_model.predict(doc_tuple)
return F.softmax(Tensor(scores), dim=0).tolist()
4 changes: 3 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from huggingface_hub import snapshot_download

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)


def download_model(embedding_model: str, models_path: str):
Expand All @@ -21,6 +21,8 @@ def check_models(sent_embedding_model: str):
models_path = Path("/opt/models")
models_info_path = models_path / "model_info.json"

logging.info("Checking models status.")

if not os.path.exists(models_path):
os.makedirs(models_path)

Expand Down
Loading