Skip to content

Commit

Permalink
creating the new DemoChatBedrock POC.
Browse files Browse the repository at this point in the history
  • Loading branch information
Vishal Patil committed Jan 23, 2025
1 parent ac7ec07 commit 7d8c9c8
Show file tree
Hide file tree
Showing 5 changed files with 866 additions and 2 deletions.
6 changes: 5 additions & 1 deletion libs/aws/langchain_aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from langchain_aws.chat_model_adapter import BedrockClaudeAdapter, ModelAdapter
from langchain_aws.chains import (
create_neptune_opencypher_qa_chain,
create_neptune_sparql_qa_chain,
)
from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse
from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse, DemoChatBedrock
from langchain_aws.embeddings import BedrockEmbeddings
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from langchain_aws.llms import BedrockLLM, SagemakerEndpoint
Expand All @@ -20,6 +21,9 @@
"BedrockLLM",
"ChatBedrock",
"ChatBedrockConverse",
"DemoChatBedrock",
"ModelAdapter",
"BedrockClaudeAdapter",
"SagemakerEndpoint",
"AmazonKendraRetriever",
"AmazonKnowledgeBasesRetriever",
Expand Down
6 changes: 6 additions & 0 deletions libs/aws/langchain_aws/chat_model_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from langchain_aws.chat_model_adapter.demo_chat_adapter import (
BedrockClaudeAdapter,
ModelAdapter,
)

__all__ = ["ModelAdapter", "BedrockClaudeAdapter"]
325 changes: 325 additions & 0 deletions libs/aws/langchain_aws/chat_model_adapter/demo_chat_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
from typing import (
Any,
Iterator,
List,
Optional,
Sequence,
Union,
Dict,
Callable,
Literal,
Type,
TypeVar,
Tuple,
cast,
)

from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import (
BaseMessage,
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
ChatMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.pydantic import TypeBaseModel
from pydantic import BaseModel

from abc import ABC, abstractmethod
import re

# ModelAdapter might also need access to the data that the wrapper ChatModel class has
# for example, the provider or custom inputs passed in by the user


class ModelAdapter(ABC):
"""Abstract base class for model-specific adaptation strategies"""

@abstractmethod
def convert_messages_to_payload(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
"""Convert LangChain messages to model-specific payload"""
pass

@abstractmethod
def convert_response_to_chat_result(self, response: Any) -> ChatResult:
"""Convert model-specific response to LangChain ChatResult"""
pass

@abstractmethod
def convert_stream_response_to_chunks(
self, response: Any
) -> Iterator[ChatGenerationChunk]:
"""Convert model-specific stream response to LangChain chunks"""
pass

@abstractmethod
def format_tools(
self, tools: Sequence[Union[Dict[str, Any], TypeBaseModel, Callable, BaseTool]]
) -> Any:
"""Format tools for the specific model"""
pass


# Example concrete implementation for a specific model
class BedrockClaudeAdapter(ModelAdapter):
message_type_lookups = {
"human": "user",
"ai": "assistant",
"AIMessageChunk": "assistant",
"HumanMessageChunk": "user",
}

def convert_messages_to_payload(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
# Specific implementation for converting LC messages to Claude payload
response_msg_with_provider = {
"messages": [self._convert_message(msg) for msg in messages],
"max_tokens": kwargs.get("max_tokens", 1000),
"stop_sequences": stop or [],
}
return self.convert_messages_to_prompt_anthropic(messages=messages)

def _convert_message(self, msg: BaseMessage) -> Dict[str, str]:
# Convert LangChain message to Claude-specific format
role_map = {"human": "user", "ai": "assistant", "system": "system"}
return {
"role": role_map.get(msg.type, "user"),
# This is just a string. A dict is expected with "type" and "text" fields
"content": msg.content,
}

def convert_response_to_chat_result(self, response: Any) -> ChatResult:
pass

def convert_stream_response_to_chunks(
self, response: Any
) -> Iterator[ChatGenerationChunk]:
"""Convert model-specific stream response to LangChain chunks"""
pass

def format_tools(
self, tools: Sequence[Union[Dict[str, Any], TypeBaseModel, Callable, BaseTool]]
) -> Any:
"""Format tools for the specific model"""
pass

def _format_image(self, image_url: str) -> Dict:
"""
Formats an image of format data:image/jpeg;base64,{b64_string}
to a dict for anthropic api
{
"type": "base64",
"media_type": "image/jpeg",
"data": "/9j/4AAQSkZJRg...",
}
And throws an error if it's not a b64 image
"""
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
match = re.match(regex, image_url)
if match is None:
raise ValueError(
"Anthropic only supports base64-encoded images currently."
" Example: data:image/png;base64,'/9j/4AAQSk'..."
)
return {
"type": "base64",
"media_type": match.group("media_type"),
"data": match.group("data"),
}

def _merge_messages(
self,
messages: Sequence[BaseMessage],
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = []
for curr in messages:
curr = curr.model_copy(deep=True)
if isinstance(curr, ToolMessage):
if isinstance(curr.content, list) and all(
isinstance(block, dict) and block.get("type") == "tool_result"
for block in curr.content
):
curr = HumanMessage(curr.content) # type: ignore[misc]
else:
curr = HumanMessage( # type: ignore[misc]
[
{
"type": "tool_result",
"content": curr.content,
"tool_use_id": curr.tool_call_id,
}
]
)
last = merged[-1] if merged else None
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
if isinstance(last.content, str):
new_content: List = [{"type": "text", "text": last.content}]
else:
new_content = last.content
if isinstance(curr.content, str):
new_content.append({"type": "text", "text": curr.content})
else:
new_content.extend(curr.content)
last.content = new_content
else:
merged.append(curr)
return merged

def format_anthropic_messages(
self,
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict]]:
"""Format messages for anthropic."""
system: Optional[str] = None
formatted_messages: List[Dict] = []

merged_messages = self._merge_messages(messages)
for i, message in enumerate(merged_messages):
if message.type == "system":
if i != 0:
raise ValueError(
"System message must be at beginning of message list."
)
if not isinstance(message.content, str):
raise ValueError(
"System message must be a string, "
f"instead was: {type(message.content)}"
)
system = message.content
continue

role = self.message_type_lookups[message.type]
content: Union[str, List]

if not isinstance(message.content, str):
# parse as dict
assert isinstance(
message.content, list
), "Anthropic message content must be str or list of dicts"

# populate content
content = []
for item in message.content:
if isinstance(item, str):
content.append({"type": "text", "text": item})
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
elif item["type"] == "image_url":
# convert format
source = self._format_image(item["image_url"]["url"])
content.append({"type": "image", "source": source})
elif item["type"] == "tool_use":
# If a tool_call with the same id as a tool_use content
# block exists, the tool_call is preferred.
if isinstance(message, AIMessage) and item["id"] in [
tc["id"] for tc in message.tool_calls
]:
overlapping = [
tc
for tc in message.tool_calls
if tc["id"] == item["id"]
]
# content.extend(
# _lc_tool_calls_to_anthropic_tool_use_blocks(overlapping)
# )
else:
item.pop("text", None)
content.append(item)
elif item["type"] == "text":
text = item.get("text", "")
# Only add non-empty strings for now as empty ones are not
# accepted.
# https://github.com/anthropics/anthropic-sdk-python/issues/461
if text.strip():
content.append({"type": "text", "text": text})
else:
content.append(item)
else:
raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}"
)
elif isinstance(message, AIMessage) and message.tool_calls:
content = (
[]
if not message.content
else [{"type": "text", "text": message.content}]
)
# Note: Anthropic can't have invalid tool calls as presently defined,
# since the model already returns dicts args not JSON strings, and invalid
# tool calls are those with invalid JSON for args.
# content += _lc_tool_calls_to_anthropic_tool_use_blocks(message.tool_calls)
else:
content = message.content

formatted_messages.append({"role": role, "content": content})
return system, formatted_messages

def _convert_one_message_to_text_anthropic(
self,
message: BaseMessage,
human_prompt: str,
ai_prompt: str,
) -> str:
content = cast(str, message.content)
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {content}"
elif isinstance(message, HumanMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AIMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemMessage):
message_text = content
else:
raise ValueError(f"Got unknown type {message}")
return message_text

def convert_messages_to_prompt_anthropic(
self,
messages: List[BaseMessage],
*,
human_prompt: str = "\n\nHuman:",
ai_prompt: str = "\n\nAssistant:",
) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
Returns:
str: Combined string with necessary human_prompt and ai_prompt tags.
"""

messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))

text = "".join(
self._convert_one_message_to_text_anthropic(
message, human_prompt, ai_prompt
)
for message in messages
)

# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()

# Implement other abstract methods similarly...
3 changes: 2 additions & 1 deletion libs/aws/langchain_aws/chat_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain_aws.chat_models.bedrock import ChatBedrock
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
from langchain_aws.chat_models.demo_chat import DemoChatBedrock

__all__ = ["ChatBedrock", "ChatBedrockConverse"]
__all__ = ["ChatBedrock", "ChatBedrockConverse", "DemoChatBedrock"]
Loading

0 comments on commit 7d8c9c8

Please sign in to comment.