diff --git a/libs/redis/langchain_redis/vectorstores.py b/libs/redis/langchain_redis/vectorstores.py index e211146..348c3bc 100644 --- a/libs/redis/langchain_redis/vectorstores.py +++ b/libs/redis/langchain_redis/vectorstores.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Iterable, List, Optional, Tuple, Union, cast,\ + Sequence, Dict import numpy as np from langchain_core.documents import Document @@ -739,7 +740,74 @@ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[boo return self._index.drop_keys(keys) == len(ids) else: return False + + def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: + """Get documents by their IDs. + The returned documents are expected to have the ID field set to the ID of the + document in the vector store. + + Fewer documents may be returned than requested if some IDs are not found or + if there are duplicated IDs. + + Users should not assume that the order of the returned documents matches + the order of the input IDs. Instead, users should rely on the ID field of the + returned documents. + + This method should **NOT** raise exceptions if no documents are found for + some IDs. + + Args: + ids: List of ids to retrieve. + + Returns: + List of Documents. + """ + + if self.config.storage_type == StorageType.HASH.value: + # Fetch full hash data for each document + if not ids: + full_docs = [] + else: + with self._index.client.pipeline(transaction=False) as pipe: + for doc_id in ids: + pipe.hgetall(doc_id) + full_docs = convert_bytes(pipe.execute()) + + return [ + Document( + page_content=doc[self.config.content_field], + metadata={ + k: v + for k, v in doc.items() + if k != self.config.content_field + }, + ) + for doc in full_docs + ] + else: + # Fetch full JSON data for each document + if not ids: + full_docs = [] + else: + with self._index.client.json().pipeline(transaction=False) as pipe: + for doc_id in ids: + pipe.get(doc_id, ".") + full_docs = pipe.execute() + + return [ + Document( + page_content=doc[self.config.content_field], + metadata={ + k: v + for k, v in doc.items() + if k != self.config.content_field + }, + ) + for doc in full_docs + if doc is not None # Handle potential missing documents + ] + def similarity_search_by_vector( self, embedding: List[float], @@ -822,49 +890,8 @@ def similarity_search_by_vector( for doc in results ] else: - if self.config.storage_type == StorageType.HASH.value: - # Fetch full hash data for each document - if not results: - full_docs = [] - else: - with self._index.client.pipeline(transaction=False) as pipe: - for doc in results: - pipe.hgetall(doc["id"]) - full_docs = convert_bytes(pipe.execute()) - - return [ - Document( - page_content=doc[self.config.content_field], - metadata={ - k: v - for k, v in doc.items() - if k != self.config.content_field - }, - ) - for doc in full_docs - ] - else: - # Fetch full JSON data for each document - if not results: - full_docs = [] - else: - with self._index.client.json().pipeline(transaction=False) as pipe: - for doc in results: - pipe.get(doc["id"], ".") - full_docs = pipe.execute() - - return [ - Document( - page_content=doc[self.config.content_field], - metadata={ - k: v - for k, v in doc.items() - if k != self.config.content_field - }, - ) - for doc in full_docs - if doc is not None # Handle potential missing documents - ] + ids = [doc["id"] for doc in results] + return self.get_by_ids(ids=ids) def similarity_search( self, diff --git a/libs/redis/tests/integration_tests/test_vectorstores_hash.py b/libs/redis/tests/integration_tests/test_vectorstores_hash.py index 8063fb7..b7fbdea 100644 --- a/libs/redis/tests/integration_tests/test_vectorstores_hash.py +++ b/libs/redis/tests/integration_tests/test_vectorstores_hash.py @@ -490,6 +490,39 @@ def test_similarity_search(redis_url: str) -> None: vector_store.index.delete(drop=True) +def test_get_by_ids(redis_url: str) -> None: + """Test end to end construction and search.""" + # Create embeddings + embeddings = OpenAIEmbeddings() + + # Create a unique index name for testing + index_name = f"test_index_{str(ULID())}" + + texts = ["foo", "bar", "baz"] + + keys = ["a", "b", "c"] + + # Create the RedisVectorStore + vector_store = RedisVectorStore.from_texts( + texts, + embeddings, + index_name=index_name, + key_prefix="tst11", + redis_url=redis_url, + ) + + ids = [f"tst11:{k}" for k in ["a","c"]] + + docs = vector_store.get_by_ids(ids) + + result_texts = [doc.page_content for doc in docs] + + assert all(txt in result_texts for txt in texts) + + # Clean up + vector_store.index.delete(drop=True) + + def test_similarity_search_with_scores(redis_url: str) -> None: """Test end to end construction and search.""" # Create embeddings diff --git a/libs/redis/tests/integration_tests/test_vectorstores_json.py b/libs/redis/tests/integration_tests/test_vectorstores_json.py index 099d1d1..e511493 100644 --- a/libs/redis/tests/integration_tests/test_vectorstores_json.py +++ b/libs/redis/tests/integration_tests/test_vectorstores_json.py @@ -501,6 +501,40 @@ def test_similarity_search(redis_url: str) -> None: vector_store.index.delete(drop=True) +def test_get_by_ids(redis_url: str) -> None: + """Test end to end construction and search.""" + # Create embeddings + embeddings = OpenAIEmbeddings() + + # Create a unique index name for testing + index_name = f"test_index_{str(ULID())}" + + texts = ["foo", "bar", "baz"] + + keys = ["a", "b", "c"] + + # Create the RedisVectorStore + vector_store = RedisVectorStore.from_texts( + texts, + embeddings, + index_name=index_name, + key_prefix="tst11", + redis_url=redis_url, + storage_type="json", + ) + + ids = [f"tst11:{k}" for k in ["a","c"]] + + docs = vector_store.get_by_ids(ids) + + result_texts = [doc.page_content for doc in docs] + + assert all(txt in result_texts for txt in texts) + + # Clean up + vector_store.index.delete(drop=True) + + def test_similarity_search_with_scores(redis_url: str) -> None: """Test end to end construction and search.""" # Create embeddings diff --git a/libs/redis/tests/unit_tests/test_vectorstores.py b/libs/redis/tests/unit_tests/test_vectorstores.py index ac7b8b0..1e35fe9 100644 --- a/libs/redis/tests/unit_tests/test_vectorstores.py +++ b/libs/redis/tests/unit_tests/test_vectorstores.py @@ -42,6 +42,9 @@ def __init__(self) -> None: def get(self, client: Any, doc_ids: List[str]) -> List[Dict[str, Any]]: return [self.data.get(doc_id, {}) for doc_id in doc_ids] + + def hgetall(self, name: str) -> Dict[str, Any]: + return self.data.get(name,{}) class MockSearchIndex: @@ -63,8 +66,9 @@ def __init__( self.schema = MockSchema( schema["fields"] if schema and "fields" in schema else default_schema ) - self.client = client or Mock() + self._storage = MockStorage() + self.client = self._storage def create(self, overwrite: bool = False) -> None: pass @@ -148,6 +152,16 @@ def test_add_texts(self, vector_store: RedisVectorStore) -> None: assert len(keys) == 2 assert all(key.startswith("key_") for key in keys) + def test_get_by_ids(self, vector_store: RedisVectorStore) -> None: + texts = ["Hello, world!", "Test document"] + metadatas = [{"source": "greeting"}, {"source": "test"}] + keys = vector_store.add_texts(texts, metadatas) + docs = vector_store.get_by_ids(keys) + assert len(keys) == len(docs) + assert all(isinstance(doc, Document) for doc in docs) + assert all(isinstance(doc.page_content, str) for doc in docs) + assert all(isinstance(doc.metadata, dict) for doc in docs) + def test_similarity_search(self, vector_store: RedisVectorStore) -> None: vector_store.add_texts(["Hello, world!", "Test document"]) results = vector_store.similarity_search("Hello", k=1)