Skip to content

Commit

Permalink
Adding support for a new runnable for InlineAgents
Browse files Browse the repository at this point in the history
  • Loading branch information
divekarshubham committed Jan 28, 2025
1 parent 784445c commit 5f9ad6c
Show file tree
Hide file tree
Showing 6 changed files with 1,508 additions and 1 deletion.
3 changes: 2 additions & 1 deletion libs/aws/langchain_aws/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
BedrockAgentAction,
BedrockAgentFinish,
BedrockAgentsRunnable,
BedrockInlineAgentsRunnable,
)

__all__ = ["BedrockAgentAction", "BedrockAgentFinish", "BedrockAgentsRunnable"]
__all__ = ["BedrockAgentAction", "BedrockAgentFinish", "BedrockAgentsRunnable", "BedrockInlineAgentsRunnable"]
291 changes: 291 additions & 0 deletions libs/aws/langchain_aws/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,294 @@ def _parse_intermediate_steps(
return session_id, session_state

return None, None

class KnowledgebaseConfiguration(TypedDict, total=False):
description: str
knowledgeBaseId: str
retrievalConfiguration: Dict

class InlineAgentConfiguration(TypedDict, total=False):
"""Configurations for an Inline Agent."""
foundation_model: str
instruction: str
enable_trace: Optional[bool]
tools: List[BaseTool]
enable_human_input: Optional[bool]
enable_code_interpreter: Optional[bool]
customer_encryption_key_arn: Optional[str]
idle_session_ttl_in_seconds: Optional[int]
guardrail_configuration: Optional[GuardrailConfiguration]
knowledge_bases: Optional[KnowledgebaseConfiguration]
prompt_override_configuration: Optional[Dict]
inline_session_state: Optional[Dict]

class BedrockInlineAgentsRunnable(RunnableSerializable[Dict, OutputType]):
"""
Invoke a Bedrock Inline Agent
"""

client: Any
"""Boto3 client"""
region_name: Optional[str] = None
"""Region"""
credentials_profile_name: Optional[str] = None
"""Credentials to use to invoke the agent"""
endpoint_url: Optional[str] = None
"""Endpoint URL"""
inline_agent_config: Optional[InlineAgentConfiguration] = None

@model_validator(mode="before")
@classmethod
def validate_client(cls, values: dict) -> Any:
if values.get("client") is not None:
return values

try:
client_params, session = get_boto_session(
credentials_profile_name=values["credentials_profile_name"],
region_name=values["region_name"],
endpoint_url=values["endpoint_url"],
)

values["client"] = session.client("bedrock-agent-runtime", **client_params)
return values
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except UnknownServiceError as e:
raise ModuleNotFoundError(
"Ensure that you have installed the latest boto3 package "
"that contains the API for `bedrock-runtime-agent`."
) from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e

@classmethod
def create(
cls,
*,
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,
endpoint_url: Optional[str] = None,
inline_agent_config: Optional[InlineAgentConfiguration] = None,
**kwargs: Any,
) -> BedrockInlineAgentsRunnable:
"""
Creates a Bedrock Inline Agent Runnable that can be used with an AgentExecutor
or with LangGraph.
Args:
credentials_profile_name: The profile name to use if different from default
region_name: Region for the Bedrock agent
endpoint_url: Endpoint URL for bedrock agent runtime
enable_trace: Boolean flag to specify whether trace should be enabled when
invoking the agent
**kwargs: Additional arguments
Returns:
BedrockInlineAgentsRunnable configured to invoke the Bedrock inline agent
"""
try:
client_params, session = get_boto_session(
credentials_profile_name=credentials_profile_name,
region_name=region_name,
endpoint_url=endpoint_url,
)
client = session.client("bedrock-agent-runtime", **client_params)

return cls(
client=client,
region_name=region_name,
credentials_profile_name=credentials_profile_name,
endpoint_url=endpoint_url,
inline_agent_config=inline_agent_config,
**kwargs,
)
except Exception as e:
raise ValueError(
f"Error creating BedrockInlineAgentsRunnable: {str(e)}"
) from e

# Check: can the input be of TypedDict:
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> OutputType:
"""
Invoke the Bedrock Inline agent.
Args:
input: The input dictionary containing:
input_text: The input text to the agent
session_id: The session id to use. If not provided, a new session will
be started
end_session: Boolean indicating whether to end a session or not
inline_agent_config: Optional configuration to override default config
config: Optional RunnableConfig
**kwargs: Additional arguments
Returns:
Union[List[BedrockAgentAction], BedrockAgentFinish]
"""
config = ensure_config(config)
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)

try:

# Merge configurations, giving priority to invoke-time config
runtime_config = input.get("inline_agent_config", {})
self.inline_agent_config = self.inline_agent_config or {} # Ensure it's never None to unpack
effective_config = {**self.inline_agent_config, **runtime_config}

# Convert tools to action groups format
action_groups = self._get_action_groups(
tools=effective_config.get("tools", []),
enableHumanInput=effective_config.get("enable_human_input", False),
enableCodeInterpreter=effective_config.get("enable_code_interpreter", False),
)

# Prepare the invoke_inline_agent request
agent_input: Dict[str, Any] = {
"foundationModel": effective_config["foundation_model"],
"instruction": effective_config["instruction"],
"actionGroups": action_groups,
"enableTrace": effective_config["enable_trace"],
"endSession": bool(input.get("end_session", False)),
}

# Add optional configurations
optional_params_name_map = {
"customerEncryptionKeyArn": "customer_encryption_key_arn",
"idleSessionTTLInSeconds": "idle_session_ttl_in_seconds",
"guardrailConfiguration": "guardrail_configuration",
"knowledgeBases": "knowledge_bases",
"promptOverrideConfiguration": "prompt_override_configuration",
"inlineSessionState": "inline_session_state",
}

for param_name, input_key in optional_params_name_map.items():
if effective_config.get(input_key):
agent_input[param_name] = effective_config[input_key]

session_id = None
if input.get("intermediate_steps"):
session_id, session_state = self._parse_intermediate_steps(
input.get("intermediate_steps")
)
if session_state:
agent_input["inlineSessionState"] = session_state
else:
agent_input["inputText"] = input.get("input_text", "")

# Use existing session_id from input, or from intermediate steps, or generate new one
session_id = input.get("session_id") or session_id or str(uuid.uuid4())

# Make the InvokeInlineAgent request to bedrock
output = self.client.invoke_inline_agent(
sessionId=session_id, **agent_input
)

except Exception as e:
run_manager.on_chain_error(e)
raise e

try:
response = parse_agent_response(output)
except Exception as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(response)
return response

def _get_action_groups(self, tools: List[BaseTool], enableHumanInput: bool, enableCodeInterpreter: bool) -> List:
action_groups = []
tools_by_action_group = defaultdict(list)

for tool in tools:
action_group_name, _ = _get_action_group_and_function_names(tool)
tools_by_action_group[action_group_name].append(tool)

for action_group_name, functions in tools_by_action_group.items():
action_groups.append(
{
"actionGroupName": action_group_name,
"actionGroupExecutor": {"customControl": "RETURN_CONTROL"},
"functionSchema": {
"functions": [
_tool_to_function(function) for function in functions
]
},
}
)

if enableHumanInput:
action_groups.append(
{
"actionGroupName": "UserInputAction",
"parentActionGroupSignature": "AMAZON.UserInput",
}
)

if enableCodeInterpreter:
action_groups.append(
{
"actionGroupName": "CodeInterpreterAction",
"parentActionGroupSignature": "AMAZON.CodeInterpreter",
}
)
return action_groups

# ToDo: move to common.
def _parse_intermediate_steps(
self, intermediate_steps: List[Tuple[BedrockAgentAction, str]]
) -> Tuple[Union[str, None], Union[Dict[str, Any], None]]:
"""Parse intermediate steps for inline agent invocation"""
last_step = max(0, len(intermediate_steps) - 1)
action = intermediate_steps[last_step][0]
tool_invoked = action.tool
messages = action.messages
session_id = action.session_id

if tool_invoked:
action_group_name = _DEFAULT_ACTION_GROUP_NAME
function_name = tool_invoked
tool_name_split = tool_invoked.split("::")
if len(tool_name_split) > 1:
action_group_name = tool_name_split[0]
function_name = tool_name_split[1]

if messages:
last_message = max(0, len(messages) - 1)
message = messages[last_message]
if type(message) is AIMessage:
response = intermediate_steps[last_step][1]
session_state = {
"invocationId": json.loads(message.content)
.get("returnControl", {})
.get("invocationId", ""),
"returnControlInvocationResults": [
{
"functionResult": {
"actionGroup": action_group_name,
"function": function_name,
"responseBody": {"TEXT": {"body": response}},
}
}
],
}

return session_id, session_state

return None, None
78 changes: 78 additions & 0 deletions libs/aws/tests/integration_tests/agents/test_bedrock_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BedrockAgentAction,
BedrockAgentFinish,
BedrockAgentsRunnable,
BedrockInlineAgentsRunnable,
)


Expand Down Expand Up @@ -198,6 +199,83 @@ def get_mortgage_rate(asset_holder_id: str, asset_value: str) -> str:
if agent:
_delete_agent(agent.agent_id)

# @pytest.mark.skip
def test_inline_agent_with_executor():
@tool
def get_weather(location: str = "") -> str:
"""
Get the weather of a location
Args:
location: location of the place
"""
if location.lower() == "seattle":
return f"It is raining in {location}"
return f"It is hot and humid in {location}"

foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0"
tools = [get_weather]
try:
runnable = BedrockInlineAgentsRunnable.create(
region_name="us-west-2",
inline_agent_config={
"foundation_model": foundation_model,
"instruction": "You are an agent who helps with getting weather for a given location",
"tools": tools,
"enable_trace": True,
}
)

agent_executor = AgentExecutor(agent=runnable, tools=tools)
output = agent_executor.invoke({
"input_text": "what is the weather in Seattle?"
})

assert output["output"] == "It is raining in Seattle"

except Exception as ex:
raise ex

# @pytest.mark.skip
def test_inline_agent():
@tool
def get_weather(location: str = "") -> str:
"""
Get the weather of a location
Args:
location: location of the place
"""
if location.lower() == "seattle":
return f"It is raining in {location}"
return f"It is hot and humid in {location}"

foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0"
tools = [get_weather]
try:
runnable = BedrockInlineAgentsRunnable.create(
region_name="us-west-2"
)

output = runnable.invoke(
input={
"input_text": "what is the weather in Seattle?",
"inline_agent_config":{
"foundation_model": foundation_model,
"instruction": "You are an agent who helps with getting weather for a given location",
"tools": tools,
"enable_trace": True,
}
}
)

# Check if the agent called for tool invocation
assert isinstance(output, list)
assert output[0].tool == "get_weather"
assert "Seattle" in output[0].tool_input["location"]

except Exception as ex:
raise ex

@pytest.mark.skip
def test_weather_agent():
Expand Down
Loading

0 comments on commit 5f9ad6c

Please sign in to comment.