Skip to content

Commit

Permalink
feat: Embeddings pipeline improvements (#33)
Browse files Browse the repository at this point in the history
* feat: Embeddings pipeline improvements

* fix: Reranking

* fix: Pinecone delete

* feat: Added strategy option

* chore: Merging

* chore: Merging

* walkthrough

* chore: Merging

---------

Co-authored-by: Ismail Pelaseyed <[email protected]>
  • Loading branch information
simjak and homanp authored Feb 12, 2024
1 parent b1817cc commit 5a0b995
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 26 deletions.
1 change: 1 addition & 0 deletions api/delete.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import APIRouter

from models.delete import RequestPayload, ResponsePayload
from service.embedding import get_encoder
from service.vector_database import VectorService, get_vector_service
from service.embedding import get_encoder

Expand Down
24 changes: 22 additions & 2 deletions dev/embedding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,27 @@
"metadata": {},
"outputs": [],
"source": [
"docs = await embedding_service.generate_chunks()"
"elements = await embedding_service._download_and_extract_elements(file, strategy=\"auto\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for element in elements:\n",
" print(type(element))\n",
" # print(f\"Text: {element.text}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"docs = await embedding_service.generate_chunks(strategy=\"auto\")"
]
},
{
Expand Down Expand Up @@ -68,7 +88,7 @@
" print(colored_text)\n",
" concatenated_document += chunk + \" \"\n",
"\n",
"print(\"\\nConcatenated Document:\\n\", concatenated_document)"
"# print(\"\\nConcatenated Document:\\n\", concatenated_document)"
]
},
{
Expand Down
10 changes: 10 additions & 0 deletions dev/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = query_response.json().get('data', [])\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
33 changes: 9 additions & 24 deletions service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import requests
from tqdm import tqdm
from unstructured.chunking.title import chunk_by_title
from unstructured.documents.elements import Element
from unstructured.partition.auto import partition

import encoders
Expand Down Expand Up @@ -41,8 +42,8 @@ def _get_datasource_suffix(self, type: str) -> str:
raise ValueError("Unsupported datasource type")

async def _download_and_extract_elements(
self, file, strategy="hi_res"
) -> List[Any]:
self, file, strategy: Optional[str] = "hi_res"
) -> List[Element]:
"""
Downloads the file and extracts elements using the partition function.
Returns a list of unstructured elements.
Expand Down Expand Up @@ -84,33 +85,17 @@ async def generate_document(
except Exception as e:
logger.error(f"Error loading document {file.url}: {e}")

async def generate_summary_document(
self, documents: List[BaseDocument]
) -> List[BaseDocument]:
pbar = tqdm(total=len(documents), desc="Summarizing documents")
pages = {}
for document in documents:
page_number = document.metadata.get("page_number")
if page_number not in pages:
doc = copy.deepcopy(document)
doc.text = await completion(document=doc)
pages[page_number] = doc
else:
pages[page_number].text += document.text
pbar.update()
pbar.close()
summary_documents = list(pages.values())
return summary_documents

async def generate_chunks(self) -> List[BaseDocumentChunk]:
async def generate_chunks(self, strategy: Optional[str]) -> List[BaseDocumentChunk]:
doc_chunks = []
for file in tqdm(self.files, desc="Generating chunks"):
try:
elements = await self._download_and_extract_elements(file)
elements = await self._download_and_extract_elements(file, strategy)
document = await self.generate_document(file, elements)
if not document:
continue
chunks = chunk_by_title(elements)
chunks = chunk_by_title(
elements, max_characters=500, combine_text_under_n_chars=0
)
for chunk in chunks:
# Ensure all metadata values are of a type acceptable to Pinecone
sanitized_metadata = {
Expand Down Expand Up @@ -225,4 +210,4 @@ def get_encoder(*, encoder_type: EncoderEnum) -> encoders.BaseEncoder:
encoder_class = encoder_mapping.get(encoder_type)
if encoder_class is None:
raise ValueError(f"Unsupported encoder: {encoder_type}")
return encoder_class()
return encoder_class()
14 changes: 14 additions & 0 deletions service/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ async def query(self, input: str, top_k: int = 25) -> List[BaseDocumentChunk]:
# async def convert_to_rerank_format():
# pass

# TODO: make it default method instead of abstract
# async def convert_to_rerank_format(self, chunks: List[BaseDocumentChunk]):
# docs = [
# {
# "content": chunk.text,
# "page_label": (
# chunk.metadata.get("page_number", "") if chunk.metadata else ""
# ),
# "file_url": chunk.doc_url,
# }
# for chunk in chunks
# ]
# return docs

@abstractmethod
async def delete(self, file_url: str):
pass
Expand Down

0 comments on commit 5a0b995

Please sign in to comment.