diff --git a/ragna/source_storages/_qdrant.py b/ragna/source_storages/_qdrant.py index fe494fbb..68a179f2 100644 --- a/ragna/source_storages/_qdrant.py +++ b/ragna/source_storages/_qdrant.py @@ -54,19 +54,16 @@ def __init__(self) -> None: from qdrant_client import QdrantClient - url = os.getenv("QDRANT_URL") - api_key = os.getenv("QDRANT_API_KEY") - path = ragna.local_root() / "qdrant" - - # Cannot pass both url and path - self._client = ( - QdrantClient(url=url, api_key=api_key) if url else QdrantClient(path=path) - ) + if (url := os.environ.get("QDRANT_URL")) is not None: + kwargs = dict(url=url, api_key=os.environ.get("QDRANT_API_KEY")) + else: + kwargs = dict(path=str(ragna.local_root() / "qdrant")) + self._client = QdrantClient(**kwargs) # type: ignore[arg-type] def list_corpuses(self) -> list[str]: return [c.name for c in self._client.get_collections().collections] - def _ensure_table(self, corpus_name: str, *, create: bool = False): + def _ensure_table(self, corpus_name: str, *, create: bool = False) -> None: table_names = self.list_corpuses() no_corpuses = not table_names non_existing_corpus = corpus_name not in table_names @@ -91,6 +88,7 @@ def list_metadata( if corpus_name is None: corpus_names = self.list_corpuses() else: + self._ensure_table(corpus_name) corpus_names = [corpus_name] metadata = {} @@ -101,7 +99,7 @@ def list_metadata( corpus_metadata = defaultdict(set) for point in points: - for key, value in point.payload.items(): + for key, value in cast(dict[str, Any], point.payload).items(): if any( [ (key.startswith("__") and key.endswith("__")), @@ -142,7 +140,10 @@ def store( points.append( models.PointStruct( id=str(uuid.uuid4()), - vector=self._embedding_function([chunk.text])[0], + vector=cast( + list[float], + self._embedding_function([chunk.text])[0].tolist(), + ), payload={ "document_id": str(document.id), "document_name": document.name, @@ -158,7 +159,9 @@ def store( self._client.upsert(collection_name=corpus_name, points=points) - def _build_condition(self, operator, key, value): + def _build_condition( + self, operator: MetadataOperator, key: str, value: Any + ) -> models.FieldCondition: from qdrant_client import models # See https://qdrant.tech/documentation/concepts/filtering/#range @@ -184,7 +187,7 @@ def _build_condition(self, operator, key, value): def _translate_metadata_filter( self, metadata_filter: MetadataFilter - ) -> models.Filter: + ) -> models.Filter | models.FieldCondition: from qdrant_client import models if metadata_filter.operator is MetadataOperator.RAW: @@ -247,12 +250,14 @@ def retrieve( return self._take_sources_up_to_max_tokens( ( Source( - id=point.id, - document_id=point.payload["document_id"], - document_name=point.payload["document_name"], - location=point.payload["__page_numbers__"], - content=point.payload[self.DOC_CONTENT_KEY], - num_tokens=point.payload["__num_tokens__"], + id=cast(str, point.id), + document_id=(payload := cast(dict[str, Any], point.payload))[ + "document_id" + ], + document_name=payload["document_name"], + location=payload["__page_numbers__"], + content=payload[self.DOC_CONTENT_KEY], + num_tokens=payload["__num_tokens__"], ) for point in points ),