diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 29f4bbc6..80d54cb5 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -14,6 +14,7 @@ InMemorySemanticCache, InMemoryVectorStore, ) +from langchain_aws.rerank.rerank import BedrockRerank __all__ = [ "BedrockEmbeddings", @@ -29,4 +30,5 @@ "NeptuneGraph", "InMemoryVectorStore", "InMemorySemanticCache", + "BedrockRerank" ] diff --git a/libs/aws/langchain_aws/rerank/__init__.py b/libs/aws/langchain_aws/rerank/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/langchain_aws/rerank/rerank.py b/libs/aws/langchain_aws/rerank/rerank.py new file mode 100644 index 00000000..aabed315 --- /dev/null +++ b/libs/aws/langchain_aws/rerank/rerank.py @@ -0,0 +1,126 @@ +import json +from copy import deepcopy +from typing import Any, Dict, List, Optional, Sequence, Union + +import boto3 +from langchain_core.callbacks.manager import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from langchain_core.utils import from_env +from pydantic import ConfigDict, Field, model_validator +from typing_extensions import Self + + +class BedrockRerank(BaseDocumentCompressor): + """Document compressor that uses AWS Bedrock Rerank API.""" + + client: Any = None + """Bedrock client to use for compressing documents.""" + top_n: Optional[int] = 3 + """Number of documents to return.""" + model: Optional[str] = "amazon.rerank-v1:0" + """Model to use for reranking. Default is amazon.rerank-v1:0.""" + aws_region: str = Field( + default_factory=from_env("AWS_DEFAULT_REGION", default="us-west-2") + ) + """AWS region to initialize the Bedrock client.""" + aws_profile: Optional[str] = Field( + default_factory=from_env("AWS_PROFILE", default=None) + ) + """AWS profile for authentication, optional.""" + + model_config = ConfigDict( + extra="forbid", + arbitrary_types_allowed=True, + ) + + @model_validator(mode="after") + def initialize_client(self) -> Self: + """Initialize the AWS Bedrock client.""" + if not self.client: + session = ( + boto3.Session(profile_name=self.aws_profile) + if self.aws_profile + else boto3.Session() + ) + self.client = session.client("bedrock-runtime", region_name=self.aws_region) + return self + + def rerank( + self, + documents: Sequence[Union[str, Document, dict]], + query: str, + *, + top_n: Optional[int] = None, + model: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Returns an ordered list of documents based on their relevance to the query. + + Args: + query: The query to use for reranking. + documents: A sequence of documents to rerank. + top_n: The number of top-ranked results to return. Defaults to self.top_n. + model: The model to use for reranking. Defaults to self.model. + + Returns: + List[Dict[str, Any]]: A list of ranked documents with relevance scores. + """ + if len(documents) == 0: + return [] + + # Serialize documents for the Bedrock API + serialized_documents = [ + json.dumps(doc) + if isinstance(doc, dict) + else doc.page_content + if isinstance(doc, Document) + else doc + for doc in documents + ] + + body = json.dumps( + { + "query": query, + "documents": serialized_documents, + "top_n": top_n or self.top_n, + } + ) + + response = self.client.invoke_model( + modelId=model or self.model, + accept="application/json", + contentType="application/json", + body=body, + ) + + response_body = json.loads(response.get("body").read()) + results = [ + {"index": result["index"], "relevance_score": result["relevance_score"]} + for result in response_body["results"] + ] + + return results + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """ + Compress documents using Bedrock's rerank API. + + Args: + documents: A sequence of documents to compress. + query: The query to use for compressing the documents. + callbacks: Callbacks to run during the compression process. + + Returns: + A sequence of compressed documents. + """ + compressed = [] + for res in self.rerank(documents, query): + doc = documents[res["index"]] + doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) + doc_copy.metadata["relevance_score"] = res["relevance_score"] + compressed.append(doc_copy) + return compressed diff --git a/libs/aws/tests/unit_tests/rerank/__init__.py b/libs/aws/tests/unit_tests/rerank/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/tests/unit_tests/rerank/test_rerank.py b/libs/aws/tests/unit_tests/rerank/test_rerank.py new file mode 100644 index 00000000..ac99df93 --- /dev/null +++ b/libs/aws/tests/unit_tests/rerank/test_rerank.py @@ -0,0 +1,75 @@ +import json +from unittest.mock import MagicMock + +import pytest +from langchain_core.documents import Document + +from langchain_aws import BedrockRerank + + +# Mock setup +@pytest.fixture +def mock_bedrock_client(): + mock_client = MagicMock() + mock_client.invoke_model.return_value = { + "body": MagicMock( + read=MagicMock( + return_value=json.dumps( + { + "results": [ + {"index": 0, "relevance_score": 0.95}, + {"index": 1, "relevance_score": 0.90}, + ] + } + ) + ) + ) + } + return mock_client + + +@pytest.fixture +def bedrock_rerank(mock_bedrock_client): + return BedrockRerank(client=mock_bedrock_client) + + +# Test initialize_client +def test_initialize_client_with_profile(): + bedrock_rerank = BedrockRerank(aws_profile="default") + bedrock_rerank.initialize_client() + assert bedrock_rerank.client is not None + + +def test_initialize_client_without_profile(): + bedrock_rerank = BedrockRerank() + bedrock_rerank.initialize_client() + assert bedrock_rerank.client is not None + + +# Test rerank method +def test_rerank_success(bedrock_rerank): + documents = ["doc1", "doc2", "doc3"] + query = "Test query" + results = bedrock_rerank.rerank(documents, query) + assert len(results) == 2 + assert results[0]["index"] == 0 + assert results[0]["relevance_score"] == 0.95 + + +def test_rerank_empty_documents(bedrock_rerank): + results = bedrock_rerank.rerank([], "query") + assert results == [] + + +# Test compress_documents method +def test_compress_documents(bedrock_rerank): + documents = [ + Document(page_content="doc1"), + Document(page_content="doc2"), + Document(page_content="doc3"), + ] + query = "Test query" + compressed = bedrock_rerank.compress_documents(documents, query) + assert len(compressed) == 2 + assert compressed[0].metadata["relevance_score"] == 0.95 + assert compressed[1].metadata["relevance_score"] == 0.90