Skip to content

Commit

Permalink
Fix Qdrant source storage (#534)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Jan 10, 2025
1 parent 4385620 commit 37d0998
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions ragna/source_storages/_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,16 @@ def __init__(self) -> None:

from qdrant_client import QdrantClient

url = os.getenv("QDRANT_URL")
api_key = os.getenv("QDRANT_API_KEY")
path = ragna.local_root() / "qdrant"

# Cannot pass both url and path
self._client = (
QdrantClient(url=url, api_key=api_key) if url else QdrantClient(path=path)
)
if (url := os.environ.get("QDRANT_URL")) is not None:
kwargs = dict(url=url, api_key=os.environ.get("QDRANT_API_KEY"))
else:
kwargs = dict(path=str(ragna.local_root() / "qdrant"))
self._client = QdrantClient(**kwargs) # type: ignore[arg-type]

def list_corpuses(self) -> list[str]:
return [c.name for c in self._client.get_collections().collections]

def _ensure_table(self, corpus_name: str, *, create: bool = False):
def _ensure_table(self, corpus_name: str, *, create: bool = False) -> None:
table_names = self.list_corpuses()
no_corpuses = not table_names
non_existing_corpus = corpus_name not in table_names
Expand All @@ -91,6 +88,7 @@ def list_metadata(
if corpus_name is None:
corpus_names = self.list_corpuses()
else:
self._ensure_table(corpus_name)
corpus_names = [corpus_name]

metadata = {}
Expand All @@ -101,7 +99,7 @@ def list_metadata(

corpus_metadata = defaultdict(set)
for point in points:
for key, value in point.payload.items():
for key, value in cast(dict[str, Any], point.payload).items():
if any(
[
(key.startswith("__") and key.endswith("__")),
Expand Down Expand Up @@ -142,7 +140,10 @@ def store(
points.append(
models.PointStruct(
id=str(uuid.uuid4()),
vector=self._embedding_function([chunk.text])[0],
vector=cast(
list[float],
self._embedding_function([chunk.text])[0].tolist(),
),
payload={
"document_id": str(document.id),
"document_name": document.name,
Expand All @@ -158,7 +159,9 @@ def store(

self._client.upsert(collection_name=corpus_name, points=points)

def _build_condition(self, operator, key, value):
def _build_condition(
self, operator: MetadataOperator, key: str, value: Any
) -> models.FieldCondition:
from qdrant_client import models

# See https://qdrant.tech/documentation/concepts/filtering/#range
Expand All @@ -184,7 +187,7 @@ def _build_condition(self, operator, key, value):

def _translate_metadata_filter(
self, metadata_filter: MetadataFilter
) -> models.Filter:
) -> models.Filter | models.FieldCondition:
from qdrant_client import models

if metadata_filter.operator is MetadataOperator.RAW:
Expand Down Expand Up @@ -247,12 +250,14 @@ def retrieve(
return self._take_sources_up_to_max_tokens(
(
Source(
id=point.id,
document_id=point.payload["document_id"],
document_name=point.payload["document_name"],
location=point.payload["__page_numbers__"],
content=point.payload[self.DOC_CONTENT_KEY],
num_tokens=point.payload["__num_tokens__"],
id=cast(str, point.id),
document_id=(payload := cast(dict[str, Any], point.payload))[
"document_id"
],
document_name=payload["document_name"],
location=payload["__page_numbers__"],
content=payload[self.DOC_CONTENT_KEY],
num_tokens=payload["__num_tokens__"],
)
for point in points
),
Expand Down

0 comments on commit 37d0998

Please sign in to comment.