Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/amazon nova #292

Merged
merged 8 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator

from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
from langchain_aws.function_calling import (
Expand Down Expand Up @@ -407,6 +407,15 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "bedrock"]

@model_validator(mode="before")
@classmethod
def set_beta_use_converse_api(cls, values: Dict) -> Any:
model_id = values.get("model_id", values.get("model"))

if "beta_use_converse_api" not in values:
values["beta_use_converse_api"] = "nova" in model_id
return values

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
Expand Down Expand Up @@ -826,7 +835,16 @@ def _as_converse(self) -> ChatBedrockConverse:
kwargs = {
k: v
for k, v in (self.model_kwargs or {}).items()
if k in ("stop", "stop_sequences", "max_tokens", "temperature", "top_p")
if k
in (
"stop",
"stop_sequences",
"max_tokens",
"temperature",
"top_p",
"additional_model_request_fields",
"additional_model_response_field_paths",
)
}
if self.max_tokens:
kwargs["max_tokens"] = self.max_tokens
Expand Down
139 changes: 111 additions & 28 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ class Joke(BaseModel):
"""Which types of tool_choice values the model supports.

Inferred if not specified. Inferred as ('auto', 'any', 'tool') if a 'claude-3'
model is used, ('auto', 'any') if a 'mistral-large' model is used, empty otherwise.
model is used, ('auto', 'any') if a 'mistral-large' model is used,
('auto') if a 'nova' model is used, empty otherwise.
"""

model_config = ConfigDict(
Expand All @@ -406,11 +407,13 @@ def set_disable_streaming(cls, values: Dict) -> Any:
model_parts[-2] if len(model_parts) > 1 else model_parts[0]
)

# As of 09/15/24 Anthropic and Cohere models support streamed tool calling
# As of 12/03/24:
# Anthropic, Cohere and Amazon Nova models support streamed tool calling
if "disable_streaming" not in values:
values["disable_streaming"] = (
False
if values["provider"] in ["anthropic", "cohere"]
or (values["provider"] == "amazon" and "nova" in model_id)
else "tool_calling"
)
return values
Expand All @@ -419,13 +422,16 @@ def set_disable_streaming(cls, values: Dict) -> Any:
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""

# As of 08/05/24 only claude-3 and mistral-large models support tool choice:
# As of 12/03/24:
# only claude-3, mistral-large, and nova models support tool choice:
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
if self.supports_tool_choice_values is None:
if "claude-3" in self.model_id:
self.supports_tool_choice_values = ("auto", "any", "tool")
elif "mistral-large" in self.model_id:
self.supports_tool_choice_values = ("auto", "any")
elif "nova" in self.model_id:
self.supports_tool_choice_values = ["auto"]
else:
self.supports_tool_choice_values = ()

Expand Down Expand Up @@ -552,6 +558,7 @@ def bind_tools(
f"for the latest documentation on models that support tool choice."
)
kwargs["tool_choice"] = _format_tool_choice(tool_choice)

return self.bind(tools=formatted_tools, **kwargs)

def with_structured_output(
Expand Down Expand Up @@ -690,7 +697,7 @@ def _messages_to_bedrock(
# system message then alternating human/ai messages.
messages = merge_message_runs(messages)
for msg in messages:
content = _anthropic_to_bedrock(msg.content)
content = _lc_content_to_bedrock(msg.content)
if isinstance(msg, HumanMessage):
# If there's a human, tool, human message sequence, the
# tool message will be merged with the first human message, so the second
Expand Down Expand Up @@ -736,13 +743,11 @@ def _extract_response_metadata(response: Dict[str, Any]) -> Dict[str, Any]:


def _parse_response(response: Dict[str, Any]) -> AIMessage:
anthropic_content = _bedrock_to_anthropic(
response.pop("output")["message"]["content"]
)
tool_calls = _extract_tool_calls(anthropic_content)
lc_content = _bedrock_to_lc(response.pop("output")["message"]["content"])
tool_calls = _extract_tool_calls(lc_content)
usage = UsageMetadata(_camel_to_snake_keys(response.pop("usage"))) # type: ignore[misc]
return AIMessage(
content=_str_if_single_text_block(anthropic_content), # type: ignore[arg-type]
content=_str_if_single_text_block(lc_content), # type: ignore[arg-type]
usage_metadata=usage,
response_metadata=_extract_response_metadata(response),
tool_calls=tool_calls,
Expand All @@ -759,7 +764,7 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]:
)
elif "contentBlockStart" in event:
block = {
**_bedrock_to_anthropic([event["contentBlockStart"]["start"]])[0],
**_bedrock_to_lc([event["contentBlockStart"]["start"]])[0],
"index": event["contentBlockStart"]["contentBlockIndex"],
}
tool_call_chunks = []
Expand All @@ -775,7 +780,7 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]:
return AIMessageChunk(content=[block], tool_call_chunks=tool_call_chunks)
elif "contentBlockDelta" in event:
block = {
**_bedrock_to_anthropic([event["contentBlockDelta"]["delta"]])[0],
**_bedrock_to_lc([event["contentBlockDelta"]["delta"]])[0],
"index": event["contentBlockDelta"]["contentBlockIndex"],
}
tool_call_chunks = []
Expand Down Expand Up @@ -811,7 +816,7 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]:
raise ValueError(f"Received unsupported stream event:\n\n{event}")


def _anthropic_to_bedrock(
def _lc_content_to_bedrock(
content: Union[str, List[Union[str, Dict[str, Any]]]],
) -> List[Dict[str, Any]]:
if isinstance(content, str):
Expand Down Expand Up @@ -845,6 +850,36 @@ def _anthropic_to_bedrock(
bedrock_content.append(
{"image": _format_openai_image_url(block["imageUrl"]["url"])}
)
elif block["type"] == "video":
# Assume block is already in bedrock format.
if "video" in block:
bedrock_content.append({"video": block["video"]})
else:
if block["source"]["type"] == "base64":
bedrock_content.append(
{
"video": {
"format": block["source"]["mediaType"].split("/")[1],
"source": {
"bytes": _b64str_to_bytes(block["source"]["data"])
},
}
}
)
elif block["source"]["type"] == "s3Location":
bedrock_content.append(
{
"video": {
"format": block["source"]["mediaType"].split("/")[1],
"source": {"s3Location": block["source"]["data"]},
}
}
)
elif block["type"] == "video_url":
# Support OpenAI image format as well.
bedrock_content.append(
{"video": _format_openai_video_url(block["videoUrl"]["url"])}
)
elif block["type"] == "document":
# Assume block in bedrock document format
bedrock_content.append({"document": block["document"]})
Expand All @@ -863,7 +898,7 @@ def _anthropic_to_bedrock(
{
"toolResult": {
"toolUseId": block["toolUseId"],
"content": _anthropic_to_bedrock(block["content"]),
"content": _lc_content_to_bedrock(block["content"]),
"status": "error" if block.get("isError") else "success",
}
}
Expand All @@ -879,16 +914,16 @@ def _anthropic_to_bedrock(
return [block for block in bedrock_content if block.get("text", True)]


def _bedrock_to_anthropic(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
anthropic_content = []
def _bedrock_to_lc(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
lc_content = []
for block in _camel_to_snake_keys(content):
if "text" in block:
anthropic_content.append({"type": "text", "text": block["text"]})
lc_content.append({"type": "text", "text": block["text"]})
elif "tool_use" in block:
block["tool_use"]["id"] = block["tool_use"].pop("tool_use_id", None)
anthropic_content.append({"type": "tool_use", **block["tool_use"]})
lc_content.append({"type": "tool_use", **block["tool_use"]})
elif "image" in block:
anthropic_content.append(
lc_content.append(
{
"type": "image",
"source": {
Expand All @@ -898,20 +933,48 @@ def _bedrock_to_anthropic(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]
},
}
)
elif "video" in block:
if "bytes" in block["video"]["source"]:
lc_content.append(
{
"type": "video",
"source": {
"media_type": f"video/{block['video']['format']}",
"type": "base64",
"data": _bytes_to_b64_str(
block["video"]["source"]["bytes"]
),
},
}
)
if "s3location" in block["video"]["source"]:
lc_content.append(
{
"type": "video",
"source": {
"media_type": f"video/{block['video']['format']}",
"type": "s3Location",
"data": block["video"]["source"]["s3location"],
},
}
)
elif "document" in block:
# Request syntax assumes bedrock format; returning in same bedrock format
lc_content.append({"type": "document", **block})
elif "tool_result" in block:
anthropic_content.append(
lc_content.append(
{
"type": "tool_result",
"tool_use_id": block["tool_result"]["tool_use_id"],
"is_error": block["tool_result"].get("status") == "error",
"content": _bedrock_to_anthropic(block["tool_result"]["content"]),
"content": _bedrock_to_lc(block["tool_result"]["content"]),
}
)
# Only occurs in content blocks of a tool_result:
elif "json" in block:
anthropic_content.append({"type": "json", **block})
lc_content.append({"type": "json", **block})
elif "guard_content" in block:
anthropic_content.append(
lc_content.append(
{
"type": "guard_content",
"guard_content": {
Expand All @@ -926,7 +989,7 @@ def _bedrock_to_anthropic(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]
"'text', 'tool_use', 'image', or 'tool_result' keys. Received:\n\n"
f"{block}"
)
return anthropic_content
return lc_content


def _format_tools(
Expand Down Expand Up @@ -1026,11 +1089,11 @@ def _bytes_to_b64_str(bytes_: bytes) -> str:


def _str_if_single_text_block(
anthropic_content: List[Dict[str, Any]],
content: List[Dict[str, Any]],
) -> Union[str, List[Dict[str, Any]]]:
if len(anthropic_content) == 1 and anthropic_content[0]["type"] == "text":
return anthropic_content[0]["text"]
return anthropic_content
if len(content) == 1 and content[0]["type"] == "text":
return content[0]["text"]
return content


def _upsert_tool_calls_to_bedrock_content(
Expand Down Expand Up @@ -1072,10 +1135,30 @@ def _format_openai_image_url(image_url: str) -> Dict:
match = re.match(regex, image_url)
if match is None:
raise ValueError(
"Bedrock does not currently support OpenAI-format image URLs, only "
"The image URL provided is not supported. Expected image URL format is "
"base64-encoded images. Example: data:image/png;base64,'/9j/4AAQSk'..."
)
return {
"format": match.group("media_type"),
"source": {"bytes": _b64str_to_bytes(match.group("data"))},
}


def _format_openai_video_url(video_url: str) -> Dict:
"""
Formats a video of format data:video/mp4;base64,{b64_string}
to a dict for bedrock api.

And throws an error if url is not a b64 video.
"""
regex = r"^data:video/(?P<media_type>.+);base64,(?P<data>.+)$"
match = re.match(regex, video_url)
if match is None:
raise ValueError(
"The video URL provided is not supported. Expected video URL format is "
"base64-encoded video. Example: data:video/mp4;base64,'/9j/4AAQSk'..."
)
return {
"format": match.group("media_type"),
"source": {"bytes": _b64str_to_bytes(match.group("data"))},
}
Loading
Loading