-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDhruva-AI.py
108 lines (89 loc) · 4.55 KB
/
Dhruva-AI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import os
from dotenv import load_dotenv
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from groq import Groq
from langchain_groq import ChatGroq
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.vectorstores import FAISS
import pickle
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
st.set_page_config(layout="wide")
load_dotenv()
# Component #1 - Document Upload
with st.sidebar:
DOCS_DIR = os.path.abspath("./uploaded_docs")
if not os.path.exists(DOCS_DIR):
os.makedirs(DOCS_DIR)
st.subheader("Add to the Knowledge Base")
with st.form("my-form", clear_on_submit=True):
uploaded_files = st.file_uploader("Upload a file to the Knowledge Base:", accept_multiple_files=True)
submitted = st.form_submit_button("Upload!")
if uploaded_files and submitted:
for uploaded_file in uploaded_files:
st.success(f"File {uploaded_file.name} uploaded successfully!")
with open(os.path.join(DOCS_DIR, uploaded_file.name), "wb") as f:
f.write(uploaded_file.read())
# Component #2 - Embedding Model and LLM
llm = ChatGroq(model="llama3-8b-8192", api_key=os.getenv('GROQ_API'))
document_embedder = NVIDIAEmbeddings(model="nvidia/nv-embedqa-e5-v5", api_key=os.getenv('NVIDIA_API'),model_type="passage")
# Component #3 - Vector Database Store
with st.sidebar:
use_existing_vector_store = st.radio("Use existing vector store if available", ["Yes", "No"], horizontal=True)
vector_store_path = "vectorstore.pkl"
raw_documents = DirectoryLoader(DOCS_DIR).load()
vector_store_exists = os.path.exists(vector_store_path)
vectorstore = None
if use_existing_vector_store == "Yes" and vector_store_exists:
with open(vector_store_path, "rb") as f:
vectorstore = pickle.load(f)
with st.sidebar:
st.success("Existing vector store loaded successfully.")
else:
with st.sidebar:
if raw_documents and use_existing_vector_store == "Yes":
with st.spinner("Splitting documents into chunks..."):
text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=200)
documents = text_splitter.split_documents(raw_documents)
with st.spinner("Adding document chunks to vector database..."):
vectorstore = FAISS.from_documents(documents, document_embedder)
with st.spinner("Saving vector store"):
with open(vector_store_path, "wb") as f:
pickle.dump(vectorstore, f)
st.success("Vector store created and saved.")
else:
st.warning("No documents available to process!", icon="⚠️")
# Component #4 - LLM Response Generation and Chat
st.subheader("Chat with your AI Assistant, Dhruva AI!")
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
prompt_template = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant named Envie. If provided with context, use it to inform your responses. If no context is available, use your general knowledge to provide a helpful response."),
("human", "{input}")
])
chain = prompt_template | llm | StrOutputParser()
user_input = st.chat_input("Ask me anything related to uploaded docs?")
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
if vectorstore is not None and use_existing_vector_store == "Yes":
retriever = vectorstore.as_retriever()
docs = retriever.invoke(user_input)
context = "\n\n".join([doc.page_content for doc in docs])
augmented_user_input = f"Context: {context}\n\nQuestion: {user_input}\n"
else:
augmented_user_input = f"Question: {user_input}\n"
for response in chain.stream({"input": augmented_user_input}):
full_response += response
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})