Skip to content

Commit

Permalink
fix: mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jpfcabral committed Feb 6, 2025
1 parent d907e05 commit d69a71f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
28 changes: 13 additions & 15 deletions libs/aws/langchain_aws/document_compressors/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,21 @@ class BedrockRerank(BaseDocumentCompressor):
arbitrary_types_allowed=True,
)

@model_validator(mode="after")
def initialize_client(self) -> Self:
@model_validator(mode="before")
@classmethod
def initialize_client(cls, values: Dict[str, Any]) -> Any:
"""Initialize the AWS Bedrock client."""
if not self.client:
session = self._get_session()
self.client = session.client(
if not values.get("client"):
session = (
boto3.Session(profile_name=values.get("credentials_profile_name"))
if values.get("credentials_profile_name", None)
else boto3.Session()
)
values["client"] = session.client(
"bedrock-agent-runtime",
region_name=self.region_name
)
return self

def _get_session(self):
return (
boto3.Session(profile_name=self.credentials_profile_name)
if self.credentials_profile_name
else boto3.Session()
)
region_name=values.get("region_name"),
)
return values

def rerank(
self,
Expand Down
9 changes: 4 additions & 5 deletions libs/aws/tests/unit_tests/document_compressors/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@pytest.fixture
def reranker():
def reranker() -> BedrockRerank:
reranker = BedrockRerank(
model_arn="arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
region_name="us-east-1",
Expand All @@ -16,15 +16,14 @@ def reranker():
return reranker

@patch("boto3.Session")
def test_initialize_client(mock_boto_session, reranker):
def test_initialize_client(mock_boto_session: MagicMock, reranker: BedrockRerank) -> None:
session_instance = MagicMock()
mock_boto_session.return_value = session_instance
session_instance.client.return_value = MagicMock()
reranker.initialize_client()
assert reranker.client is not None

@patch("langchain_aws.document_compressors.rerank.BedrockRerank.rerank")
def test_rerank(mock_rerank, reranker):
def test_rerank(mock_rerank: MagicMock, reranker: BedrockRerank) -> None:
mock_rerank.return_value = [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.8},
Expand All @@ -41,7 +40,7 @@ def test_rerank(mock_rerank, reranker):
assert results[1]["relevance_score"] == 0.8

@patch("langchain_aws.document_compressors.rerank.BedrockRerank.rerank")
def test_compress_documents(mock_rerank, reranker):
def test_compress_documents(mock_rerank: MagicMock, reranker: BedrockRerank) -> None:
mock_rerank.return_value = [
{"index": 0, "relevance_score": 0.95},
{"index": 1, "relevance_score": 0.85},
Expand Down

0 comments on commit d69a71f

Please sign in to comment.