Skip to content

Commit

Permalink
fix: pydantic validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jpfcabral committed Feb 6, 2025
1 parent e3de880 commit 7194759
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
28 changes: 13 additions & 15 deletions libs/aws/langchain_aws/document_compressors/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,21 @@ class BedrockRerank(BaseDocumentCompressor):
arbitrary_types_allowed=True,
)

@model_validator(mode="after")
def initialize_client(self) -> Self:
@model_validator(mode="before")
@classmethod
def initialize_client(cls, values: Dict[str, Any]) -> Any:
"""Initialize the AWS Bedrock client."""
if not self.client:
session = self._get_session()
self.client = session.client(
if not values.get("client"):
session = (
boto3.Session(profile_name=values.get("credentials_profile_name"))
if values.get("credentials_profile_name", None)
else boto3.Session()
)
values["client"] = session.client(
"bedrock-agent-runtime",
region_name=self.region_name
)
return self

def _get_session(self):
return (
boto3.Session(profile_name=self.credentials_profile_name)
if self.credentials_profile_name
else boto3.Session()
)
region_name=values.get("region_name"),
)
return values

def rerank(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_initialize_client(mock_boto_session: MagicMock, reranker: BedrockRerank
session_instance = MagicMock()
mock_boto_session.return_value = session_instance
session_instance.client.return_value = MagicMock()
reranker.initialize_client()
assert reranker.client is not None

@patch("langchain_aws.document_compressors.rerank.BedrockRerank.rerank")
Expand Down

0 comments on commit 7194759

Please sign in to comment.