diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index b05ebea5..23933506 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -855,6 +855,7 @@ def _as_converse(self) -> ChatBedrockConverse: if self.temperature is not None: kwargs["temperature"] = self.temperature return ChatBedrockConverse( + client=self.client, model=self.model_id, region_name=self.region_name, credentials_profile_name=self.credentials_profile_name, diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index 4841b775..34933a72 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -1,7 +1,7 @@ """Standard LangChain interface tests""" from typing import Literal, Type - +from unittest.mock import Mock import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage @@ -74,6 +74,33 @@ def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None super().test_tool_message_histories_list_content(model) +class TestBedrockNovaStandardWithClient(ChatModelIntegrationTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatBedrockConverse + + @property + def chat_model_params(self) -> dict: + client = Mock() + return {"client": client, "model": "us.amazon.nova-pro-v1:0"} + + @property + def standard_chat_model_params(self) -> dict: + return {"max_tokens": 300, "stop": []} + + @property + def tool_choice_value(self) -> str: + return "auto" + + @pytest.mark.xfail(reason="Tool choice 'Any' not supported.") + def test_structured_few_shot_examples(self, model: BaseChatModel) -> None: + super().test_structured_few_shot_examples(model) + + @pytest.mark.xfail(reason="Human messages following AI messages not supported.") + def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None: + super().test_tool_message_histories_list_content(model) + + class TestBedrockCohereStandard(ChatModelIntegrationTests): @property def chat_model_class(self) -> Type[BaseChatModel]: diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index d61830dc..d21b78d1 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -36,6 +36,7 @@ def chat_model_class(self) -> Type[BaseChatModel]: @property def chat_model_params(self) -> dict: return { + "client": None, "model": "anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "us-west-1", }