diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 587be2f..6ad29c5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,6 +20,12 @@ repos: pep8-naming==0.13.3, flake8-bugbear==23.5.9 ] + - repo: https://github.com/PyCQA/docformatter + rev: v1.7.1 + hooks: + - id: docformatter + additional_dependencies: [ tomli ] + args: [ --in-place, --black ] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: diff --git a/tests/__init__.py b/tests/__init__.py index 8415036..4f5709e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,7 +4,7 @@ class StubRpcError(grpc.RpcError): - def __init__(self, code: str, details: Optional[str]): + def __init__(self, code: grpc.StatusCode, details: Optional[str]): self._code = code self._details = details @@ -13,3 +13,13 @@ def code(self): def details(self): return self._details + + +class UnavailableRpcError(StubRpcError): + def __init__(self, details: Optional[str]): + super().__init__(grpc.StatusCode.UNAVAILABLE, details) + + +class NotFoundRpcError(StubRpcError): + def __init__(self, details: Optional[str]): + super().__init__(grpc.StatusCode.NOT_FOUND, details) diff --git a/tests/test_auth.py b/tests/test_auth.py index 4e6bf6d..0f826b6 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -52,7 +52,7 @@ def test_get_access_token_with_rpc_failure(self, channel_ready_future, grpc_auth channel_ready_future.return_value = self.done_future mock_grpc_auth = grpc_auth() mock_grpc_auth.GetAccessToken.side_effect = StubRpcError( - code="Unavailable", details="" + code=grpc.StatusCode.UNAVAILABLE, details="" ) auth_gateway = AuthGateway(self.client_config) diff --git a/tests/test_search_index.py b/tests/test_search_index.py index 9f43872..f330260 100644 --- a/tests/test_search_index.py +++ b/tests/test_search_index.py @@ -17,7 +17,7 @@ SearchIndexResponse, UpdateDocumentResponse, ) -from tests import StubRpcError +from tests import UnavailableRpcError from tigrisdb.errors import TigrisServerError from tigrisdb.search_index import SearchIndex from tigrisdb.types import ClientConfig, Document @@ -54,9 +54,7 @@ def test_search(self, grpc_search): def test_search_with_error(self, grpc_search): search_index = SearchIndex(self.index_name, grpc_search(), self.client_config) mock_grpc = grpc_search() - mock_grpc.Search.side_effect = StubRpcError( - code="Unavailable", details="operational failure" - ) + mock_grpc.Search.side_effect = UnavailableRpcError("operational failure") with self.assertRaisesRegex(TigrisServerError, "operational failure") as e: search_index.search(SearchQuery()) self.assertIsNotNone(e) @@ -88,9 +86,7 @@ def test_create_many_with_error(self, grpc_search): docs = [{"id": 1, "name": "shoe"}, {"id": 2, "name": "jacket"}] search_index = SearchIndex(self.index_name, grpc_search(), self.client_config) mock_grpc = grpc_search() - mock_grpc.Create.side_effect = StubRpcError( - code="Unavailable", details="operational failure" - ) + mock_grpc.Create.side_effect = UnavailableRpcError("operational failure") with self.assertRaisesRegex(TigrisServerError, "operational failure") as e: search_index.create_many(docs) @@ -132,9 +128,7 @@ def test_delete_many(self, grpc_search): def test_delete_many_with_error(self, grpc_search): search_index = SearchIndex(self.index_name, grpc_search(), self.client_config) mock_grpc = grpc_search() - mock_grpc.Delete.side_effect = StubRpcError( - code="Unavailable", details="operational failure" - ) + mock_grpc.Delete.side_effect = UnavailableRpcError("operational failure") with self.assertRaisesRegex(TigrisServerError, "operational failure") as e: search_index.delete_many(["id"]) @@ -178,8 +172,8 @@ def test_create_or_replace_many_with_error(self, grpc_search): docs = [{"id": 1, "name": "shoe"}, {"id": 2, "name": "jacket"}] search_index = SearchIndex(self.index_name, grpc_search(), self.client_config) mock_grpc = grpc_search() - mock_grpc.CreateOrReplace.side_effect = StubRpcError( - code="Unavailable", details="operational failure" + mock_grpc.CreateOrReplace.side_effect = UnavailableRpcError( + "operational failure" ) with self.assertRaisesRegex(TigrisServerError, "operational failure") as e: @@ -242,9 +236,7 @@ def test_get_many(self, grpc_search): def test_get_many_with_error(self, grpc_search): search_index = SearchIndex(self.index_name, grpc_search(), self.client_config) mock_grpc = grpc_search() - mock_grpc.Get.side_effect = StubRpcError( - code="Unavailable", details="operational failure" - ) + mock_grpc.Get.side_effect = UnavailableRpcError("operational failure") with self.assertRaisesRegex(TigrisServerError, "operational failure") as e: search_index.get_many(["id"]) @@ -288,9 +280,7 @@ def test_update_many_with_error(self, grpc_search): docs = [{"id": 1, "name": "shoe"}, {"id": 2, "name": "jacket"}] search_index = SearchIndex(self.index_name, grpc_search(), self.client_config) mock_grpc = grpc_search() - mock_grpc.Update.side_effect = StubRpcError( - code="Unavailable", details="operational failure" - ) + mock_grpc.Update.side_effect = UnavailableRpcError("operational failure") with self.assertRaisesRegex(TigrisServerError, "operational failure") as e: search_index.update_many(docs) diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py new file mode 100644 index 0000000..89e1ea7 --- /dev/null +++ b/tests/test_vector_store.py @@ -0,0 +1,111 @@ +from unittest import TestCase +from unittest.mock import MagicMock, Mock, call, patch + +from tests import NotFoundRpcError +from tigrisdb.errors import TigrisServerError +from tigrisdb.search import Search +from tigrisdb.search_index import SearchIndex +from tigrisdb.types.search import ( + DocMeta, + DocStatus, + IndexedDoc, + Query, + Result, + TextMatchInfo, + VectorField, +) +from tigrisdb.types.vector import Document +from tigrisdb.vector_store import VectorStore + +doc: Document = { + "text": "Hello world vector embed", + "embeddings": [1.2, 2.3, 4.5], + "metadata": {"category": "shoes"}, +} + + +class VectorStoreTest(TestCase): + def setUp(self) -> None: + self.mock_index = Mock(spec=SearchIndex) + self.mock_client = Mock(spec=Search) + with patch("tigrisdb.client.TigrisClient.__new__") as mock_tigris: + instance = MagicMock() + mock_tigris.return_value = instance + instance.get_search.return_value = self.mock_client + self.mock_client.get_index.return_value = self.mock_index + self.store = VectorStore("my_vectors") + + def test_add_documents_when_index_not_found(self): + # throw error on first call and succeed on second + self.mock_index.create_many.side_effect = [ + TigrisServerError("", NotFoundRpcError("search index not found")), + [DocStatus(id="1")], + ] + + resp = self.store.add_documents([doc]) + self.assertEqual([DocStatus(id="1")], resp) + self.assertEqual(self.mock_index.create_many.call_count, 2) + self.mock_index.create_many.assert_has_calls([call([doc]), call([doc])]) + + # create_or_update_index gets called once + expected_schema = { + "title": self.store.name, + "additionalProperties": False, + "type": "object", + "properties": { + "id": {"type": "string"}, + "text": {"type": "string"}, + "metadata": {"type": "object"}, + "embeddings": {"type": "array", "format": "vector", "dimensions": 3}, + }, + } + + self.mock_client.create_or_update_index.assert_called_once_with( + name=self.store.name, schema=expected_schema + ) + + def test_add_documents_when_index_exists(self): + self.mock_index.create_many.return_value = [DocStatus(id="1")] + resp = self.store.add_documents([doc]) + self.assertEqual([DocStatus(id="1")], resp) + + # no calls to create_or_update_index + self.mock_client.assert_not_called() + + def test_add_documents_when_project_not_found(self): + self.mock_index.create_many.side_effect = [ + TigrisServerError("", NotFoundRpcError("project not found")), + [DocStatus(id="1")], + ] + with self.assertRaisesRegex(TigrisServerError, "project not found"): + self.store.add_documents([doc]) + self.mock_index.create_many.assert_called_once_with([doc]) + + def test_delete_documents(self): + self.store.delete_documents(["id"]) + self.mock_index.delete_many.assert_called_once_with(["id"]) + + def test_get_documents(self): + self.store.get_documents(["id"]) + self.mock_index.get_many.assert_called_once_with(["id"]) + + def test_similarity_search(self): + self.mock_index.search.return_value = Result( + hits=[ + IndexedDoc( + doc=doc, + meta=DocMeta(text_match=TextMatchInfo(vector_distance=0.1234)), + ) + ] + ) + resp = self.store.similarity_search([1, 1, 1], 12) + self.assertEqual(1, len(resp)) + self.assertEqual(doc, resp[0].doc) + self.assertEqual(0.1234, resp[0].score) + + self.mock_index.search.assert_called_once_with( + query=Query( + vector_query=VectorField(field="embeddings", vector=[1, 1, 1]), + hits_per_page=12, + ) + ) diff --git a/tigrisdb/client.py b/tigrisdb/client.py index 8189f9b..0ad2c2a 100644 --- a/tigrisdb/client.py +++ b/tigrisdb/client.py @@ -26,7 +26,7 @@ def __init__(self, config: Optional[ClientConfig]): config = ClientConfig() self.__config = config if not config.server_url: - config.server_url = TigrisClient.__LOCAL_SERVER + config.server_url = os.getenv("TIGRIS_URI", TigrisClient.__LOCAL_SERVER) if config.server_url.startswith("https://"): config.server_url = config.server_url.replace("https://", "") if config.server_url.startswith("http://"): @@ -34,6 +34,16 @@ def __init__(self, config: Optional[ClientConfig]): if ":" not in config.server_url: config.server_url = f"{config.server_url}:443" + # initialize rest of config + if not config.project_name: + config.project_name = os.getenv("TIGRIS_PROJECT") + if not config.client_id: + config.client_id = os.getenv("TIGRIS_CLIENT_ID") + if not config.client_secret: + config.client_secret = os.getenv("TIGRIS_CLIENT_SECRET") + if not config.branch: + config.branch = os.getenv("TIGRIS_DB_BRANCH", "") + is_local_dev = any( map( lambda k: k in config.server_url, diff --git a/tigrisdb/errors.py b/tigrisdb/errors.py index 1ff1c81..31beca1 100644 --- a/tigrisdb/errors.py +++ b/tigrisdb/errors.py @@ -1,10 +1,10 @@ +from typing import cast + import grpc class TigrisException(Exception): - """ - Base class for all TigrisExceptions - """ + """Base class for all TigrisExceptions.""" msg: str @@ -17,4 +17,13 @@ def __init__(self, msg: str, **kwargs): # TODO: make this typesafe class TigrisServerError(TigrisException): def __init__(self, msg: str, e: grpc.RpcError): - super(TigrisServerError, self).__init__(msg, code=e.code(), details=e.details()) + if isinstance(e.code(), grpc.StatusCode): + self.code = cast(grpc.StatusCode, e.code()) + else: + self.code = grpc.StatusCode.UNKNOWN + + self.details = e.details() + super(TigrisServerError, self).__init__( + msg, code=self.code.name, details=self.details + ) + self.__suppress_context__ = True diff --git a/tigrisdb/types/__init__.py b/tigrisdb/types/__init__.py index 3393ec8..61adb8a 100644 --- a/tigrisdb/types/__init__.py +++ b/tigrisdb/types/__init__.py @@ -7,7 +7,7 @@ @dataclass class ClientConfig: - project_name: str + project_name: str = "" client_id: Optional[str] = None client_secret: Optional[str] = None branch: str = "" diff --git a/tigrisdb/types/search.py b/tigrisdb/types/search.py index 644573c..5298948 100644 --- a/tigrisdb/types/search.py +++ b/tigrisdb/types/search.py @@ -50,7 +50,6 @@ def query(self): return {self.field: self.vector} -# TODO: add filter, collation @dataclass class Query: q: str = "" diff --git a/tigrisdb/types/vector.py b/tigrisdb/types/vector.py new file mode 100644 index 0000000..bd4ae7c --- /dev/null +++ b/tigrisdb/types/vector.py @@ -0,0 +1,24 @@ +from dataclasses import InitVar, dataclass +from typing import Dict, List, TypedDict + +from tigrisdb.types.search import IndexedDoc, dataclass_default_proto_field + + +class Document(TypedDict, total=False): + id: str + text: str + embeddings: List[float] + metadata: Dict + + +@dataclass +class DocWithScore: + doc: Document = None + score: float = 0.0 + _h: InitVar[IndexedDoc] = dataclass_default_proto_field + + def __post_init__(self, _h: IndexedDoc): + if _h and _h.doc: + self.doc = _h.doc + if _h and _h.meta: + self.score = _h.meta.text_match.vector_distance diff --git a/tigrisdb/vector_store.py b/tigrisdb/vector_store.py new file mode 100644 index 0000000..30a6586 --- /dev/null +++ b/tigrisdb/vector_store.py @@ -0,0 +1,125 @@ +from typing import List, Optional + +import grpc + +from tigrisdb.client import TigrisClient +from tigrisdb.errors import TigrisServerError +from tigrisdb.types import ClientConfig +from tigrisdb.types.filters import Filter +from tigrisdb.types.search import DocStatus, IndexedDoc, Query, Result, VectorField +from tigrisdb.types.vector import Document, DocWithScore + + +class VectorStore: + def __init__(self, name: str, config: Optional[ClientConfig] = None): + self.client = TigrisClient(config).get_search() + self._index_name = name + self.index = self.client.get_index(name) + + @property + def name(self): + return self._index_name + + def create_index(self, dimension: int): + self.client.create_or_update_index( + name=self.name, + schema={ + "title": self.name, + "additionalProperties": False, + "type": "object", + "properties": { + "id": {"type": "string"}, + "text": {"type": "string"}, + "metadata": {"type": "object"}, + "embeddings": { + "type": "array", + "format": "vector", + "dimensions": dimension, + }, + }, + }, + ) + + def add_documents(self, docs: List[Document]) -> List[DocStatus]: + """Adds documents to index, if the index does not exist, create it. A `Document` + is a dictionary with following structure: + + ``` + { + "id": "optional id of a document", + "text": "Actual content to store", + "embeddings": "list of float values", + "metadata": "optional metadata as dict" + } + ``` + + - `id` is optional and automatically generated once documents are added to index + - If `id` is given, any existing documents with matching `id` are replaced + + :param docs: list of documents to add to index + :type docs: list[Document] + :raises TigrisServerError: thrown i + :return: List of `ids` for the added documents + :rtype: list[DocStatus] + """ + try: + return self.index.create_many(docs) + except TigrisServerError as e: + if ( + e.code == grpc.StatusCode.NOT_FOUND + and "search index not found" in e.details + ): + first_embedding = docs[0]["embeddings"] if docs else [] + inferred_dim = len(first_embedding) if first_embedding else 16 + self.create_index(inferred_dim) + return self.index.create_many(docs) + else: + raise e + + def delete_documents(self, ids: List[str]) -> List[DocStatus]: + """Delete documents from index. + + :param ids: list of document ids to delete + :type ids: list[str] + :return: `ids` of documents and deletion status for each + :rtype: list[DocStatus] + """ + return self.index.delete_many(ids) + + def get_documents(self, ids: List[str]) -> List[IndexedDoc]: + """Retrieve documents from index. It will only have document `ids` found in the + index. + + :param ids: list of document ids to retrieve + :type ids: list[str] + :return: list of documents and associated metadata + :rtype: list[IndexedDoc] + """ + return self.index.get_many(ids) + + def similarity_search( + self, vector: List[float], k: int = 10, filter_by: Optional[Filter] = None + ) -> List[DocWithScore]: + """Perform a similarity search and returns documents most similar to the given + vector with distance. + + :param vector: Search for documents closest to this vector + :type vector: list[float] + :param k: number of documents to return, defaults to 10 + :type k: int, optional + :param filter_by: apply the filter to metadata to only return a subset of + documents, defaults to None + :type filter_by: Filter, optional + :return: list of documents with similarity score (distance from given vector) + :rtype: list[DocWithScore] + """ + q = Query( + vector_query=VectorField("embeddings", vector), + filter_by=filter_by, + hits_per_page=k, + ) + r = self.search(q) + return [DocWithScore(_h=hit) for hit in r.hits] + + def search(self, query: Query) -> Result: + return self.index.search(query=query)