Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes a problem with tools not working with the new converse endpoint using anthropic models #76

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
27 changes: 21 additions & 6 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class Joke(BaseModel):
max_tokens: Optional[int] = None
"""Max tokens to generate."""

stop_sequences: Optional[List[str]] = Field(None, alias="stop")
stop_sequences: Optional[List[str]] = Field(default=None, alias="stop")
"""Stop generation if any of these substrings occurs."""

temperature: Optional[float] = None
Expand Down Expand Up @@ -308,12 +308,17 @@ class Joke(BaseModel):
have an ARN associated with them.
"""

endpoint_url: Optional[str] = Field(None, alias="base_url")
endpoint_url: Optional[str] = Field(default=None, alias="base_url")
"""Needed if you don't want to default to us-east-1 endpoint"""

config: Any = None
"""An optional botocore.config.Config instance to pass to the client."""

formatted_tools: List[
Dict[Literal["toolSpec"], Dict[str, Union[Dict[str, Any], str]]]
] = Field(default_factory=list, exclude=True)
""""Formatted tools to be stored and used in the toolConfig parameter."""

class Config:
"""Configuration for this pydantic object."""

Expand Down Expand Up @@ -413,7 +418,8 @@ def bind_tools(
) -> Runnable[LanguageModelInput, BaseMessage]:
if tool_choice:
kwargs["tool_choice"] = _format_tool_choice(tool_choice)
return self.bind(tools=_format_tools(tools), **kwargs)
self.formatted_tools = _format_tools(tools)
return self.bind(tools=self.formatted_tools, **kwargs)

def with_structured_output(
self,
Expand Down Expand Up @@ -467,8 +473,7 @@ def _converse_params(
}
if not toolConfig and tools:
toolChoice = _format_tool_choice(toolChoice) if toolChoice else None
toolConfig = {"tools": _format_tools(tools), "toolChoice": toolChoice}

toolConfig = {"tools": self.formatted_tools, "toolChoice": toolChoice}
return _drop_none(
{
"modelId": modelId or self.model_id,
Expand Down Expand Up @@ -648,7 +653,7 @@ def _anthropic_to_bedrock(
{
"toolUse": {
"toolUseId": block["id"],
"input": block["input"],
"input": _try_to_convert_to_dict(block["input"]),
"name": block["name"],
}
}
Expand Down Expand Up @@ -852,3 +857,13 @@ def _format_openai_image_url(image_url: str) -> Dict:
"format": match.group("media_type"),
"source": {"bytes": _b64str_to_bytes(match.group("data"))},
}


def _try_to_convert_to_dict(tool_use_input: Any) -> Any:
"""Attempt to convert the toolUse.input to a dictionary."""
if isinstance(tool_use_input, str):
try:
return json.loads(tool_use_input)
except json.JSONDecodeError:
return tool_use_input
return tool_use_input