Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for a new runnable for InlineAgents #340

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libs/aws/langchain_aws/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
BedrockAgentAction,
BedrockAgentFinish,
BedrockAgentsRunnable,
BedrockInlineAgentsRunnable,
)

__all__ = ["BedrockAgentAction", "BedrockAgentFinish", "BedrockAgentsRunnable"]
__all__ = ["BedrockAgentAction", "BedrockAgentFinish", "BedrockAgentsRunnable", "BedrockInlineAgentsRunnable"]
291 changes: 291 additions & 0 deletions libs/aws/langchain_aws/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,294 @@ def _parse_intermediate_steps(
return session_id, session_state

return None, None

class KnowledgebaseConfiguration(TypedDict, total=False):
description: str
knowledgeBaseId: str
retrievalConfiguration: Dict

class InlineAgentConfiguration(TypedDict, total=False):
"""Configurations for an Inline Agent."""
foundation_model: str
instruction: str
enable_trace: Optional[bool]
tools: List[BaseTool]
enable_human_input: Optional[bool]
enable_code_interpreter: Optional[bool]
customer_encryption_key_arn: Optional[str]
idle_session_ttl_in_seconds: Optional[int]
guardrail_configuration: Optional[GuardrailConfiguration]
knowledge_bases: Optional[KnowledgebaseConfiguration]
prompt_override_configuration: Optional[Dict]
inline_session_state: Optional[Dict]

class BedrockInlineAgentsRunnable(RunnableSerializable[Dict, OutputType]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The contract seems to deviate from how other Runnables are implemented. Notice here in the Anthropic Chat Model that they use message as input and output: https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/language_models/chat_models.py#L127

If we support message in and message out then this will easily plugin to LCEL, LG and other constructs they have.

"""
Invoke a Bedrock Inline Agent
"""

client: Any
"""Boto3 client"""
region_name: Optional[str] = None
"""Region"""
credentials_profile_name: Optional[str] = None
"""Credentials to use to invoke the agent"""
endpoint_url: Optional[str] = None
"""Endpoint URL"""
inline_agent_config: Optional[InlineAgentConfiguration] = None

@model_validator(mode="before")
@classmethod
def validate_client(cls, values: dict) -> Any:
if values.get("client") is not None:
return values

try:
client_params, session = get_boto_session(
credentials_profile_name=values["credentials_profile_name"],
region_name=values["region_name"],
endpoint_url=values["endpoint_url"],
)

values["client"] = session.client("bedrock-agent-runtime", **client_params)
return values
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except UnknownServiceError as e:
raise ModuleNotFoundError(
"Ensure that you have installed the latest boto3 package "
"that contains the API for `bedrock-runtime-agent`."
) from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e

@classmethod
def create(
cls,
*,
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,
endpoint_url: Optional[str] = None,
inline_agent_config: Optional[InlineAgentConfiguration] = None,
**kwargs: Any,
) -> BedrockInlineAgentsRunnable:
"""
Creates a Bedrock Inline Agent Runnable that can be used with an AgentExecutor
or with LangGraph.

Args:
credentials_profile_name: The profile name to use if different from default
region_name: Region for the Bedrock agent
endpoint_url: Endpoint URL for bedrock agent runtime
enable_trace: Boolean flag to specify whether trace should be enabled when
invoking the agent
**kwargs: Additional arguments
Returns:
BedrockInlineAgentsRunnable configured to invoke the Bedrock inline agent
"""
try:
client_params, session = get_boto_session(
credentials_profile_name=credentials_profile_name,
region_name=region_name,
endpoint_url=endpoint_url,
)
client = session.client("bedrock-agent-runtime", **client_params)

return cls(
client=client,
region_name=region_name,
credentials_profile_name=credentials_profile_name,
endpoint_url=endpoint_url,
inline_agent_config=inline_agent_config,
**kwargs,
)
except Exception as e:
raise ValueError(
f"Error creating BedrockInlineAgentsRunnable: {str(e)}"
) from e

# Check: can the input be of TypedDict:
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> OutputType:
"""
Invoke the Bedrock Inline agent.

Args:
input: The input dictionary containing:
input_text: The input text to the agent
session_id: The session id to use. If not provided, a new session will
be started
end_session: Boolean indicating whether to end a session or not
inline_agent_config: Optional configuration to override default config
config: Optional RunnableConfig
**kwargs: Additional arguments

Returns:
Union[List[BedrockAgentAction], BedrockAgentFinish]
"""
config = ensure_config(config)
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)

try:

# Merge configurations, giving priority to invoke-time config
runtime_config = input.get("inline_agent_config", {})
self.inline_agent_config = self.inline_agent_config or {} # Ensure it's never None to unpack
effective_config = {**self.inline_agent_config, **runtime_config}

# Convert tools to action groups format
action_groups = self._get_action_groups(
tools=effective_config.get("tools", []),
enableHumanInput=effective_config.get("enable_human_input", False),
enableCodeInterpreter=effective_config.get("enable_code_interpreter", False),
)

# Prepare the invoke_inline_agent request
agent_input: Dict[str, Any] = {
"foundationModel": effective_config["foundation_model"],
"instruction": effective_config["instruction"],
"actionGroups": action_groups,
"enableTrace": effective_config["enable_trace"],
"endSession": bool(input.get("end_session", False)),
}

# Add optional configurations
optional_params_name_map = {
"customerEncryptionKeyArn": "customer_encryption_key_arn",
"idleSessionTTLInSeconds": "idle_session_ttl_in_seconds",
"guardrailConfiguration": "guardrail_configuration",
"knowledgeBases": "knowledge_bases",
"promptOverrideConfiguration": "prompt_override_configuration",
"inlineSessionState": "inline_session_state",
}

for param_name, input_key in optional_params_name_map.items():
if effective_config.get(input_key):
agent_input[param_name] = effective_config[input_key]

session_id = None
if input.get("intermediate_steps"):
session_id, session_state = self._parse_intermediate_steps(
input.get("intermediate_steps")
)
if session_state:
agent_input["inlineSessionState"] = session_state
else:
agent_input["inputText"] = input.get("input_text", "")

# Use existing session_id from input, or from intermediate steps, or generate new one
session_id = input.get("session_id") or session_id or str(uuid.uuid4())

# Make the InvokeInlineAgent request to bedrock
output = self.client.invoke_inline_agent(
sessionId=session_id, **agent_input
)

except Exception as e:
run_manager.on_chain_error(e)
raise e

try:
response = parse_agent_response(output)
except Exception as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(response)
return response

def _get_action_groups(self, tools: List[BaseTool], enableHumanInput: bool, enableCodeInterpreter: bool) -> List:
action_groups = []
tools_by_action_group = defaultdict(list)

for tool in tools:
action_group_name, _ = _get_action_group_and_function_names(tool)
tools_by_action_group[action_group_name].append(tool)

for action_group_name, functions in tools_by_action_group.items():
action_groups.append(
{
"actionGroupName": action_group_name,
"actionGroupExecutor": {"customControl": "RETURN_CONTROL"},
"functionSchema": {
"functions": [
_tool_to_function(function) for function in functions
]
},
}
)

if enableHumanInput:
action_groups.append(
{
"actionGroupName": "UserInputAction",
"parentActionGroupSignature": "AMAZON.UserInput",
}
)

if enableCodeInterpreter:
action_groups.append(
{
"actionGroupName": "CodeInterpreterAction",
"parentActionGroupSignature": "AMAZON.CodeInterpreter",
}
)
return action_groups

# ToDo: move to common.
def _parse_intermediate_steps(
self, intermediate_steps: List[Tuple[BedrockAgentAction, str]]
) -> Tuple[Union[str, None], Union[Dict[str, Any], None]]:
"""Parse intermediate steps for inline agent invocation"""
last_step = max(0, len(intermediate_steps) - 1)
action = intermediate_steps[last_step][0]
tool_invoked = action.tool
messages = action.messages
session_id = action.session_id

if tool_invoked:
action_group_name = _DEFAULT_ACTION_GROUP_NAME
function_name = tool_invoked
tool_name_split = tool_invoked.split("::")
if len(tool_name_split) > 1:
action_group_name = tool_name_split[0]
function_name = tool_name_split[1]

if messages:
last_message = max(0, len(messages) - 1)
message = messages[last_message]
if type(message) is AIMessage:
response = intermediate_steps[last_step][1]
session_state = {
"invocationId": json.loads(message.content)
.get("returnControl", {})
.get("invocationId", ""),
"returnControlInvocationResults": [
{
"functionResult": {
"actionGroup": action_group_name,
"function": function_name,
"responseBody": {"TEXT": {"body": response}},
}
}
],
}

return session_id, session_state

return None, None
78 changes: 78 additions & 0 deletions libs/aws/tests/integration_tests/agents/test_bedrock_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BedrockAgentAction,
BedrockAgentFinish,
BedrockAgentsRunnable,
BedrockInlineAgentsRunnable,
)


Expand Down Expand Up @@ -198,6 +199,83 @@ def get_mortgage_rate(asset_holder_id: str, asset_value: str) -> str:
if agent:
_delete_agent(agent.agent_id)

# @pytest.mark.skip
def test_inline_agent_with_executor():
@tool
def get_weather(location: str = "") -> str:
"""
Get the weather of a location

Args:
location: location of the place
"""
if location.lower() == "seattle":
return f"It is raining in {location}"
return f"It is hot and humid in {location}"

foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0"
tools = [get_weather]
try:
runnable = BedrockInlineAgentsRunnable.create(
region_name="us-west-2",
inline_agent_config={
"foundation_model": foundation_model,
"instruction": "You are an agent who helps with getting weather for a given location",
"tools": tools,
"enable_trace": True,
}
)

agent_executor = AgentExecutor(agent=runnable, tools=tools)
output = agent_executor.invoke({
"input_text": "what is the weather in Seattle?"
})

assert output["output"] == "It is raining in Seattle"

except Exception as ex:
raise ex

# @pytest.mark.skip
def test_inline_agent():
@tool
def get_weather(location: str = "") -> str:
"""
Get the weather of a location

Args:
location: location of the place
"""
if location.lower() == "seattle":
return f"It is raining in {location}"
return f"It is hot and humid in {location}"

foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0"
tools = [get_weather]
try:
runnable = BedrockInlineAgentsRunnable.create(
region_name="us-west-2"
)

output = runnable.invoke(
input={
"input_text": "what is the weather in Seattle?",
"inline_agent_config":{
"foundation_model": foundation_model,
"instruction": "You are an agent who helps with getting weather for a given location",
"tools": tools,
"enable_trace": True,
}
}
)

# Check if the agent called for tool invocation
assert isinstance(output, list)
assert output[0].tool == "get_weather"
assert "Seattle" in output[0].tool_input["location"]

except Exception as ex:
raise ex

@pytest.mark.skip
def test_weather_agent():
Expand Down
Loading