Skip to content

Commit

Permalink
refac: rag pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
tjbck committed Apr 27, 2024
1 parent 8f1563a commit ce9a5d1
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 152 deletions.
100 changes: 55 additions & 45 deletions backend/apps/rag/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@

from apps.rag.utils import (
get_model_path,
query_embeddings_doc,
get_embeddings_function,
query_embeddings_collection,
get_embedding_function,
query_doc,
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
)

from utils.misc import (
Expand Down Expand Up @@ -147,6 +149,15 @@ def update_reranking_model(
RAG_RERANKING_MODEL_AUTO_UPDATE,
)


app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)

origins = ["*"]


Expand Down Expand Up @@ -227,6 +238,14 @@ async def update_embedding_config(

update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)

app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)

return {
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
Expand Down Expand Up @@ -367,27 +386,22 @@ def query_doc_handler(
user=Depends(get_current_user),
):
try:
embeddings_function = get_embeddings_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)

return query_embeddings_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
hybrid_search=(
form_data.hybrid
if form_data.hybrid
else app.state.ENABLE_RAG_HYBRID_SEARCH
),
)
if app.state.ENABLE_RAG_HYBRID_SEARCH:
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION,
reranking_function=app.state.sentence_transformer_rf,
k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
)
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e:
log.exception(e)
raise HTTPException(
Expand All @@ -410,27 +424,23 @@ def query_collection_handler(
user=Depends(get_current_user),
):
try:
embeddings_function = get_embeddings_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
if app.state.ENABLE_RAG_HYBRID_SEARCH:
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION,
reranking_function=app.state.sentence_transformer_rf,
k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
)

return query_embeddings_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
hybrid_search=(
form_data.hybrid
if form_data.hybrid
else app.state.ENABLE_RAG_HYBRID_SEARCH
),
)
except Exception as e:
log.exception(e)
raise HTTPException(
Expand Down Expand Up @@ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b

collection = CHROMA_CLIENT.create_collection(name=collection_name)

embedding_func = get_embeddings_function(
embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
Expand Down
Loading

0 comments on commit ce9a5d1

Please sign in to comment.