From d14141aafc2059274e054625c88ee9765b858548 Mon Sep 17 00:00:00 2001 From: Amit Ghadge Date: Sat, 1 Feb 2025 12:14:46 -0800 Subject: [PATCH 1/3] [Fix] added missed client parameter in ChatBedrock Update the ChatBedrock class to accept the user initiated Bedrock client. Signed-off-by: Amit Ghadge --- libs/aws/langchain_aws/chat_models/bedrock.py | 1 + .../chat_models/test_bedrock_converse.py | 29 ++++++++++++++++++- .../chat_models/test_bedrock_converse.py | 1 + 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index b05ebea5..23933506 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -855,6 +855,7 @@ def _as_converse(self) -> ChatBedrockConverse: if self.temperature is not None: kwargs["temperature"] = self.temperature return ChatBedrockConverse( + client=self.client, model=self.model_id, region_name=self.region_name, credentials_profile_name=self.credentials_profile_name, diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index 4841b775..34933a72 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -1,7 +1,7 @@ """Standard LangChain interface tests""" from typing import Literal, Type - +from unittest.mock import Mock import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage @@ -74,6 +74,33 @@ def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None super().test_tool_message_histories_list_content(model) +class TestBedrockNovaStandardWithClient(ChatModelIntegrationTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatBedrockConverse + + @property + def chat_model_params(self) -> dict: + client = Mock() + return {"client": client, "model": "us.amazon.nova-pro-v1:0"} + + @property + def standard_chat_model_params(self) -> dict: + return {"max_tokens": 300, "stop": []} + + @property + def tool_choice_value(self) -> str: + return "auto" + + @pytest.mark.xfail(reason="Tool choice 'Any' not supported.") + def test_structured_few_shot_examples(self, model: BaseChatModel) -> None: + super().test_structured_few_shot_examples(model) + + @pytest.mark.xfail(reason="Human messages following AI messages not supported.") + def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None: + super().test_tool_message_histories_list_content(model) + + class TestBedrockCohereStandard(ChatModelIntegrationTests): @property def chat_model_class(self) -> Type[BaseChatModel]: diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index d61830dc..d21b78d1 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -36,6 +36,7 @@ def chat_model_class(self) -> Type[BaseChatModel]: @property def chat_model_params(self) -> dict: return { + "client": None, "model": "anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "us-west-1", } From 3c9446777300ee74ee4c1ed3f86f6c228c27a7f8 Mon Sep 17 00:00:00 2001 From: Amit Ghadge Date: Wed, 5 Feb 2025 09:14:58 -0800 Subject: [PATCH 2/3] revert the test_bedrock_converse.py changes --- .../chat_models/test_bedrock_converse.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index 34933a72..cdc41b34 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -1,7 +1,6 @@ """Standard LangChain interface tests""" from typing import Literal, Type -from unittest.mock import Mock import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage @@ -74,33 +73,6 @@ def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None super().test_tool_message_histories_list_content(model) -class TestBedrockNovaStandardWithClient(ChatModelIntegrationTests): - @property - def chat_model_class(self) -> Type[BaseChatModel]: - return ChatBedrockConverse - - @property - def chat_model_params(self) -> dict: - client = Mock() - return {"client": client, "model": "us.amazon.nova-pro-v1:0"} - - @property - def standard_chat_model_params(self) -> dict: - return {"max_tokens": 300, "stop": []} - - @property - def tool_choice_value(self) -> str: - return "auto" - - @pytest.mark.xfail(reason="Tool choice 'Any' not supported.") - def test_structured_few_shot_examples(self, model: BaseChatModel) -> None: - super().test_structured_few_shot_examples(model) - - @pytest.mark.xfail(reason="Human messages following AI messages not supported.") - def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None: - super().test_tool_message_histories_list_content(model) - - class TestBedrockCohereStandard(ChatModelIntegrationTests): @property def chat_model_class(self) -> Type[BaseChatModel]: From 5e1e8822a300c8212ede02c8311a33bae1f11a4f Mon Sep 17 00:00:00 2001 From: Amit Ghadge Date: Wed, 5 Feb 2025 09:15:32 -0800 Subject: [PATCH 3/3] revert test_bedrock_converse.py --- .../tests/integration_tests/chat_models/test_bedrock_converse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index cdc41b34..4841b775 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -1,6 +1,7 @@ """Standard LangChain interface tests""" from typing import Literal, Type + import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage