Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
secrets.toml
secrets.toml
.env
.streamlit/secrets.toml

6 changes: 6 additions & 0 deletions .streamlit/example.secrets.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
openai_api_key = ""
pinecone_api_key = ""
pinecone_index = ""
pinecone_env = ""


3 changes: 3 additions & 0 deletions example.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
PINECONE_API_KEY=
PINECONE_INDEX_NAME=
OPENAI_API_KEY=
60 changes: 37 additions & 23 deletions rag_engine.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import os, tempfile
import pinecone
from pathlib import Path

from dotenv import load_dotenv
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain import OpenAI
from langchain.llms.openai import OpenAIChat
from langchain.document_loaders import DirectoryLoader
from langchain_community.llms.openai import OpenAIChat
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma, Pinecone
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_pinecone import Pinecone
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain

import streamlit as st

load_dotenv()

TMP_DIR = Path(__file__).resolve().parent.joinpath('data', 'tmp')
LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath('data', 'vector_store')
Expand All @@ -42,20 +41,36 @@ def embeddings_on_local_vectordb(texts):
return retriever

def embeddings_on_pinecone(texts):
pinecone.init(api_key=st.session_state.pinecone_api_key, environment=st.session_state.pinecone_env)
embeddings = OpenAIEmbeddings(openai_api_key=st.session_state.openai_api_key)
vectordb = Pinecone.from_documents(texts, embeddings, index_name=st.session_state.pinecone_index)
retriever = vectordb.as_retriever()
return retriever


index_name = st.session_state.pinecone_index
docsearch = Pinecone.from_documents(texts , embeddings, index_name=index_name)
return docsearch.as_retriever()

def query_llm(retriever, query):
qa_chain = ConversationalRetrievalChain.from_llm(
llm=OpenAIChat(openai_api_key=st.session_state.openai_api_key),
retriever=retriever,
return_source_documents=True,


# TODO: export to sessions state
with_source = False

llm = ChatOpenAI(
openai_api_key=st.session_state.openai_api_key,
model_name='gpt-3.5-turbo',
temperature=0.0
)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever
)
result = qa_chain({'question': query, 'chat_history': st.session_state.messages})
result = result['answer']

qa_with_sources = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever
)
result = qa_with_sources(query) if with_source else qa(query)
st.session_state.messages.append((query, result))
return result

Expand All @@ -70,7 +85,7 @@ def input_fields():
#
if "pinecone_api_key" in st.secrets:
st.session_state.pinecone_api_key = st.secrets.pinecone_api_key
else:
else:
st.session_state.pinecone_api_key = st.text_input("Pinecone API key", type="password")
#
if "pinecone_env" in st.secrets:
Expand Down Expand Up @@ -125,7 +140,7 @@ def boot():
#
for message in st.session_state.messages:
st.chat_message('human').write(message[0])
st.chat_message('ai').write(message[1])
st.chat_message('ai').write(message[1])
#
if query := st.chat_input():
st.chat_message("human").write(query)
Expand All @@ -135,4 +150,3 @@ def boot():
if __name__ == '__main__':
#
boot()

14 changes: 10 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
langchain==0.0.279
pinecone_client==2.2.2
pinecone_client==3.0.3
streamlit==1.26.0
unstructured
unstructured[pdf]
openai
openai==1.12.0
chromadb
tiktoken
tiktoken
langchain==0.1.8
langchain-community==0.0.21
langchain-core==0.1.25
langchain-openai==0.0.6
langchain-pinecone==0.0.2
langdetect==1.0.9
langsmith==0.1.5