Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rerank document compressor #331

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

jpfcabral
Copy link

@jpfcabral jpfcabral commented Jan 16, 2025

This PR resolves #298

Added:

Some snippets:

  • Example 1 (from documents):
from langchain_core.documents import Document
from langchain_aws import BedrockRerank

# Initialize the class
reranker = BedrockRerank(model_arn=model_arn)

# List of documents to rerank
documents = [
    Document(page_content="LangChain is a powerful library for LLMs."),
    Document(page_content="AWS Bedrock enables access to AI models."),
    Document(page_content="Artificial intelligence is transforming the world."),
]

# Query for reranking
query = "What is AWS Bedrock?"

# Call the rerank method
results = reranker.compress_documents(documents, query)

# Display the most relevant documents
for doc in results:
    print(f"Content: {doc.page_content}")
    print(f"Score: {doc.metadata['relevance_score']}")
  • Example 2 (with contextual compression retriever):
from langchain_aws import BedrockEmbeddings
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_aws import BedrockRerank

# Create a vector store using FAISS with Bedrock embeddings
documents = [
    Document(page_content="LangChain integrates LLM models."),
    Document(page_content="AWS Bedrock provides cloud-based AI models."),
    Document(page_content="Machine learning can be used for predictions."),
]
embeddings = BedrockEmbeddings()
vectorstore = FAISS.from_documents(documents, embeddings)

# Create the document compressor using BedrockRerank
reranker = BedrockRerank(model_arn=model_arn)

# Create the retriever with contextual compression
retriever = ContextualCompressionRetriever(
    base_compressor=reranker,
    base_retriever=vectorstore.as_retriever(),
)

# Execute a query
query = "How does AWS Bedrock work?"
retrieved_docs = retriever.invoke(query)

# Display the most relevant documents
for doc in retrieved_docs:
    print(f"Content: {doc.page_content}")
    print(f"Score: {doc.metadata.get('relevance_score', 'N/A')}")
  • Example 3 (from list):
from langchain_aws import BedrockRerank

# Initialize BedrockRerank
reranker = BedrockRerank(model_arn=model_arn)

# Unstructured documents
documents = [
    "LangChain is used to integrate LLM models.",
    "AWS Bedrock provides access to cloud-based models.",
    "Machine learning is revolutionizing the world.",
]

# Query
query = "What is the role of AWS Bedrock?"

# Rerank the documents
results = reranker.rerank(query=query, documents=documents)

# Display the results
for res in results:
    print(f"Index: {res['index']}, Score: {res['relevance_score']}")
    print(f"Document: {documents[res['index']]}")

@jpfcabral jpfcabral changed the title Adding rerank on langchain format Fixes #298 Adding rerank on langchain format Jan 16, 2025
@jpfcabral jpfcabral changed the title Fixes #298 Adding rerank on langchain format Closes #298 Adding rerank on langchain format Jan 16, 2025
@jpfcabral jpfcabral changed the title Closes #298 Adding rerank on langchain format Adding rerank #298 Jan 16, 2025
@jpfcabral jpfcabral changed the title Adding rerank #298 Adding rerank Jan 16, 2025
@jpfcabral jpfcabral changed the title Adding rerank Adding rerank as a retriever Jan 16, 2025
@jpfcabral jpfcabral force-pushed the main branch 2 times, most recently from df8c35a to 70f4e2d Compare January 16, 2025 19:52
@mgvalverde
Copy link

Hi @jpfcabral, interesting contribution!

I noticed that the default region is set to aws_region="us-west-2". From what I understand, if you have a region configured using a profile, not specifying the region should use the one from the profile. However, it defaults always to "us-west-2" instead.
Would it make sense for you to replace that value with aws_region=None?

@jpfcabral
Copy link
Author

Fair point, @mgvalverde , I just changed on bbc0243

@3coins
Copy link
Collaborator

3coins commented Jan 31, 2025

@jpfcabral
It seems like the model is called directly via invoke here, but this rerank API seems to have more options for controlling the re-ranking. Is this the right API to implement?
https://docs.aws.amazon.com/bedrock/latest/userguide/rerank-use.html

@jpfcabral
Copy link
Author

jpfcabral commented Jan 31, 2025

@3coins
I was using bedrock-runtime client to hit the model so we couldn't use more options for controlling the re-ranking.

Following on the documentation you sent me, I need to call it by bedrock-agent-runtime so we can configure the request with rerankingConfiguration attribute. Based on that, the user can add additional information, if needed, in additional_model_request_fields parameter.

Note: bedrock-agent-runtime rerank only support requests with model arn but does not have a method to get the arn from a model id. To surpass this, I created a funcion that instantiate a simple bedrock client to get session information + model id and return model arn.

I just added a commit 13ad88f with the changes above, let me know if need some fixes on that.

Copy link
Collaborator

@3coins 3coins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jpfcabral
Thanks for submitting this PR and a quick turnaround on the updates.
Great job on adding the examples in the PR description, it might be more useful to add a notebook with those samples in the samples/document_compressors directory.

Also, to keep the module organization consistent with community, does it sound better to put this under document_compressors rather than rerank?

Comment on lines 50 to 55
def _get_model_arn(self) -> str:
"""Fetch the ARN of the reranker model using the model ID."""
session = self._get_session()
client = session.client("bedrock", self.aws_region)
response = client.get_foundation_model(modelIdentifier=self.model_id)
return response["modelDetails"]["modelArn"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than relying on the API to fetch the model arn, I would suggest, we keep this consistent with the rerank API and take model_arn as input rather than the model_id.

Suggested change
def _get_model_arn(self) -> str:
"""Fetch the ARN of the reranker model using the model ID."""
session = self._get_session()
client = session.client("bedrock", self.aws_region)
response = client.get_foundation_model(modelIdentifier=self.model_id)
return response["modelDetails"]["modelArn"]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved on f45340b

"""Bedrock client to use for compressing documents."""
top_n: Optional[int] = 3
"""Number of documents to return."""
model_id: Optional[str] = "amazon.rerank-v1:0"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update this to model_arn.

Suggested change
model_id: Optional[str] = "amazon.rerank-v1:0"
model_arn: str

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved on f45340b

Comment on lines 21 to 23
aws_region: str = Field(
default_factory=from_env("AWS_DEFAULT_REGION", default=None)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I like the short name here, we have been using region_name in other places, and would prefer we keep it consistent here as well.

Suggested change
aws_region: str = Field(
default_factory=from_env("AWS_DEFAULT_REGION", default=None)
)
region_name: str = Field(
default_factory=from_env("AWS_DEFAULT_REGION", default=None)
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved on 0489125

Comment on lines 25 to 27
aws_profile: Optional[str] = Field(
default_factory=from_env("AWS_PROFILE", default=None)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the region, would like this to be consistent with other implementations in langchain_aws.

Suggested change
aws_profile: Optional[str] = Field(
default_factory=from_env("AWS_PROFILE", default=None)
)
credentials_profile_name: Optional[str] = Field(
default_factory=from_env("AWS_PROFILE", default=None)
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved on 0489125

@jpfcabral
Copy link
Author

Based on the recommendations:

  • Samples added (5c8aca3)
  • Rename module from rerank to document_compressors (66f0d64)

@jpfcabral jpfcabral requested a review from 3coins February 1, 2025 02:10
@jpfcabral jpfcabral changed the title Adding rerank as a retriever Add rerank document compressor Feb 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support bedrock rerank API
3 participants