diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 94506908..16907691 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -268,7 +268,7 @@ class Joke(BaseModel): max_tokens: Optional[int] = None """Max tokens to generate.""" - stop_sequences: Optional[List[str]] = Field(None, alias="stop") + stop_sequences: Optional[List[str]] = Field(default=None, alias="stop") """Stop generation if any of these substrings occurs.""" temperature: Optional[float] = None @@ -308,12 +308,17 @@ class Joke(BaseModel): have an ARN associated with them. """ - endpoint_url: Optional[str] = Field(None, alias="base_url") + endpoint_url: Optional[str] = Field(default=None, alias="base_url") """Needed if you don't want to default to us-east-1 endpoint""" config: Any = None """An optional botocore.config.Config instance to pass to the client.""" + formatted_tools: List[ + Dict[Literal["toolSpec"], Dict[str, Union[Dict[str, Any], str]]] + ] = Field(default_factory=list, exclude=True) + """"Formatted tools to be stored and used in the toolConfig parameter.""" + class Config: """Configuration for this pydantic object.""" @@ -413,7 +418,8 @@ def bind_tools( ) -> Runnable[LanguageModelInput, BaseMessage]: if tool_choice: kwargs["tool_choice"] = _format_tool_choice(tool_choice) - return self.bind(tools=_format_tools(tools), **kwargs) + self.formatted_tools = _format_tools(tools) + return self.bind(tools=self.formatted_tools, **kwargs) def with_structured_output( self, @@ -467,8 +473,7 @@ def _converse_params( } if not toolConfig and tools: toolChoice = _format_tool_choice(toolChoice) if toolChoice else None - toolConfig = {"tools": _format_tools(tools), "toolChoice": toolChoice} - + toolConfig = {"tools": self.formatted_tools, "toolChoice": toolChoice} return _drop_none( { "modelId": modelId or self.model_id, @@ -648,7 +653,7 @@ def _anthropic_to_bedrock( { "toolUse": { "toolUseId": block["id"], - "input": block["input"], + "input": _try_to_convert_to_dict(block["input"]), "name": block["name"], } } @@ -852,3 +857,13 @@ def _format_openai_image_url(image_url: str) -> Dict: "format": match.group("media_type"), "source": {"bytes": _b64str_to_bytes(match.group("data"))}, } + + +def _try_to_convert_to_dict(tool_use_input: Any) -> Any: + """Attempt to convert the toolUse.input to a dictionary.""" + if isinstance(tool_use_input, str): + try: + return json.loads(tool_use_input) + except json.JSONDecodeError: + return tool_use_input + return tool_use_input