Skip to content

Commit

Permalink
Add rerank document compressor (#331)
Browse files Browse the repository at this point in the history
Fixes #298 

Added:
- BedrockRerank based on
[BaseDocumentCompressor](https://python.langchain.com/api_reference/core/documents/langchain_core.documents.compressor.BaseDocumentCompressor.html)
so we can use
[ContextualCompressionRetriever](https://python.langchain.com/api_reference/langchain/retrievers/langchain.retrievers.contextual_compression.ContextualCompressionRetriever.html)
- Import from root import i.e: `from langchain_aws import BedrockRerank`
- Unit tests


Some snippets:

- Example 1 (from documents):
```python3
from langchain_core.documents import Document
from langchain_aws import BedrockRerank

# Initialize the class
reranker = BedrockRerank(model_arn=model_arn)

# List of documents to rerank
documents = [
    Document(page_content="LangChain is a powerful library for LLMs."),
    Document(page_content="AWS Bedrock enables access to AI models."),
    Document(page_content="Artificial intelligence is transforming the world."),
]

# Query for reranking
query = "What is AWS Bedrock?"

# Call the rerank method
results = reranker.compress_documents(documents, query)

# Display the most relevant documents
for doc in results:
    print(f"Content: {doc.page_content}")
    print(f"Score: {doc.metadata['relevance_score']}")
```

- Example 2 (with contextual compression retriever):

```python3
from langchain_aws import BedrockEmbeddings
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_aws import BedrockRerank

# Create a vector store using FAISS with Bedrock embeddings
documents = [
    Document(page_content="LangChain integrates LLM models."),
    Document(page_content="AWS Bedrock provides cloud-based AI models."),
    Document(page_content="Machine learning can be used for predictions."),
]
embeddings = BedrockEmbeddings()
vectorstore = FAISS.from_documents(documents, embeddings)

# Create the document compressor using BedrockRerank
reranker = BedrockRerank(model_arn=model_arn)

# Create the retriever with contextual compression
retriever = ContextualCompressionRetriever(
    base_compressor=reranker,
    base_retriever=vectorstore.as_retriever(),
)

# Execute a query
query = "How does AWS Bedrock work?"
retrieved_docs = retriever.invoke(query)

# Display the most relevant documents
for doc in retrieved_docs:
    print(f"Content: {doc.page_content}")
    print(f"Score: {doc.metadata.get('relevance_score', 'N/A')}")
```

- Example 3 (from list):

```python3
from langchain_aws import BedrockRerank

# Initialize BedrockRerank
reranker = BedrockRerank(model_arn=model_arn)

# Unstructured documents
documents = [
    "LangChain is used to integrate LLM models.",
    "AWS Bedrock provides access to cloud-based models.",
    "Machine learning is revolutionizing the world.",
]

# Query
query = "What is the role of AWS Bedrock?"

# Rerank the documents
results = reranker.rerank(query=query, documents=documents)

# Display the results
for res in results:
    print(f"Index: {res['index']}, Score: {res['relevance_score']}")
    print(f"Document: {documents[res['index']]}")
```
  • Loading branch information
jpfcabral authored Feb 6, 2025
1 parent cd1d4c6 commit c5ec714
Show file tree
Hide file tree
Showing 6 changed files with 411 additions and 0 deletions.
2 changes: 2 additions & 0 deletions libs/aws/langchain_aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
create_neptune_sparql_qa_chain,
)
from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse
from langchain_aws.document_compressors.rerank import BedrockRerank
from langchain_aws.embeddings import BedrockEmbeddings
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from langchain_aws.llms import BedrockLLM, SagemakerEndpoint
Expand Down Expand Up @@ -48,4 +49,5 @@ def setup_logging():
"NeptuneGraph",
"InMemoryVectorStore",
"InMemorySemanticCache",
"BedrockRerank"
]
Empty file.
134 changes: 134 additions & 0 deletions libs/aws/langchain_aws/document_compressors/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union

import boto3
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils import from_env
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self


class BedrockRerank(BaseDocumentCompressor):
"""Document compressor that uses AWS Bedrock Rerank API."""

model_arn: str
"""The ARN of the reranker model."""
client: Any = None
"""Bedrock client to use for compressing documents."""
top_n: Optional[int] = 3
"""Number of documents to return."""
region_name: str = Field(
default_factory=from_env("AWS_DEFAULT_REGION", default=None)
)
"""AWS region to initialize the Bedrock client."""
credentials_profile_name: Optional[str] = Field(
default_factory=from_env("AWS_PROFILE", default=None)
)
"""AWS profile for authentication, optional."""

model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True,
)

@model_validator(mode="before")
@classmethod
def initialize_client(cls, values: Dict[str, Any]) -> Any:
"""Initialize the AWS Bedrock client."""
if not values.get("client"):
session = (
boto3.Session(profile_name=values.get("credentials_profile_name"))
if values.get("credentials_profile_name", None)
else boto3.Session()
)
values["client"] = session.client(
"bedrock-agent-runtime",
region_name=values.get("region_name"),
)
return values

def rerank(
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
top_n: Optional[int] = None,
additional_model_request_fields: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Returns an ordered list of documents based on their relevance to the query.
Args:
query: The query to use for reranking.
documents: A sequence of documents to rerank.
top_n: The number of top-ranked results to return. Defaults to self.top_n.
additional_model_request_fields: A dictionary of additional fields to pass to the model.
Returns:
List[Dict[str, Any]]: A list of ranked documents with relevance scores.
"""
if len(documents) == 0:
return []

# Serialize documents for the Bedrock API
serialized_documents = [
{"textDocument": {"text": doc.page_content}, "type": "TEXT"}
if isinstance(doc, Document)
else {"textDocument": {"text": doc}, "type": "TEXT"}
if isinstance(doc, str)
else {"jsonDocument": doc, "type": "JSON"}
for doc in documents
]

request_body = {
"queries": [{"textQuery": {"text": query}, "type": "TEXT"}],
"rerankingConfiguration": {
"bedrockRerankingConfiguration": {
"modelConfiguration": {
"modelArn": self.model_arn,
"additionalModelRequestFields": additional_model_request_fields
or {},
},
"numberOfResults": top_n or self.top_n,
},
"type": "BEDROCK_RERANKING_MODEL",
},
"sources": [
{"inlineDocumentSource": doc, "type": "INLINE"}
for doc in serialized_documents
],
}

response = self.client.rerank(**request_body)
response_body = response.get("results", [])

results = [
{"index": result["index"], "relevance_score": result["relevanceScore"]}
for result in response_body
]

return results

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using Bedrock's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
compressed = []
for res in self.rerank(documents, query):
doc = documents[res["index"]]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res["relevance_score"]
compressed.append(doc_copy)
return compressed
Empty file.
55 changes: 55 additions & 0 deletions libs/aws/tests/unit_tests/document_compressors/test_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from unittest.mock import MagicMock, patch

import pytest
from langchain_core.documents import Document

from langchain_aws.document_compressors.rerank import BedrockRerank


@pytest.fixture
def reranker() -> BedrockRerank:
reranker = BedrockRerank(
model_arn="arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
region_name="us-east-1",
)
reranker.client = MagicMock()
return reranker

@patch("boto3.Session")
def test_initialize_client(mock_boto_session: MagicMock, reranker: BedrockRerank) -> None:
session_instance = MagicMock()
mock_boto_session.return_value = session_instance
session_instance.client.return_value = MagicMock()
assert reranker.client is not None

@patch("langchain_aws.document_compressors.rerank.BedrockRerank.rerank")
def test_rerank(mock_rerank: MagicMock, reranker: BedrockRerank) -> None:
mock_rerank.return_value = [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.8},
]

documents = [Document(page_content="Doc 1"), Document(page_content="Doc 2")]
query = "Example Query"
results = reranker.rerank(documents, query)

assert len(results) == 2
assert results[0]["index"] == 0
assert results[0]["relevance_score"] == 0.9
assert results[1]["index"] == 1
assert results[1]["relevance_score"] == 0.8

@patch("langchain_aws.document_compressors.rerank.BedrockRerank.rerank")
def test_compress_documents(mock_rerank: MagicMock, reranker: BedrockRerank) -> None:
mock_rerank.return_value = [
{"index": 0, "relevance_score": 0.95},
{"index": 1, "relevance_score": 0.85},
]

documents = [Document(page_content="Content 1"), Document(page_content="Content 2")]
query = "Relevant query"
compressed_docs = reranker.compress_documents(documents, query)

assert len(compressed_docs) == 2
assert compressed_docs[0].metadata["relevance_score"] == 0.95
assert compressed_docs[1].metadata["relevance_score"] == 0.85
Loading

0 comments on commit c5ec714

Please sign in to comment.