diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 342c1e6d..30b97d91 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -133,6 +133,9 @@ def get_completion_stream(self, message: str, event_handler: type(AgencyEventHan if self.async_mode: raise Exception("Streaming is not supported in async mode.") + if not inspect.isclass(event_handler): + raise Exception("Event handler must not be an instance.") + gen = self.main_thread.get_completion_stream(message=message, event_handler=event_handler, message_files=message_files, recipient_agent=recipient_agent) diff --git a/tests/demos/streaming_demo.py b/tests/demos/streaming_demo.py index 2a190f79..d06c96fe 100644 --- a/tests/demos/streaming_demo.py +++ b/tests/demos/streaming_demo.py @@ -25,7 +25,7 @@ def run(self): ]) def test_demo(self): - self.agency.run_demo() + self.agency.demo_gradio() if __name__ == '__main__': diff --git a/tests/test_agency.py b/tests/test_agency.py index 80744058..60cb4de2 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -6,15 +6,20 @@ import time import unittest +from openai.types.beta.threads.runs import ToolCall + from agency_swarm.tools import CodeInterpreter, Retrieval sys.path.insert(0, '../agency-swarm') from agency_swarm.util import create_agent_template -from agency_swarm import set_openai_key, Agent, Agency +from agency_swarm import set_openai_key, Agent, Agency, AgencyEventHandler +from typing_extensions import override +from agency_swarm.tools import BaseTool class AgencyTest(unittest.TestCase): + TestTool = None agency = None agent2 = None agent1 = None @@ -94,11 +99,29 @@ def save_thread_callback(agents_and_thread_ids): shutil.copyfile("./data/schemas/" + file, "./test_agents/TestAgent2/schemas/" + file) cls.num_schemas += 1 + class TestTool(BaseTool): + """ + A simple test tool that returns "Test Successful" to demonstrate the functionality of a custom tool within the Agency Swarm framework. + """ + + # This tool does not require any input fields, but you can define them similarly for other tools. + + def run(self): + """ + Executes the test tool's main functionality. In this case, it simply returns a success message. + """ + self.shared_state.set("test_tool_used", True) + + return "Test Successful" + + cls.TestTool = TestTool + from test_agents import CEO, TestAgent1, TestAgent2 cls.ceo = CEO() cls.agent1 = TestAgent1() cls.agent1.add_tool(Retrieval) cls.agent2 = TestAgent2() + cls.agent2.add_tool(cls.TestTool) def test_1_init_agency(self): """it should initialize agency with agents""" @@ -162,7 +185,44 @@ def test_4_agent_communication(self): for agent in self.__class__.agency.agents: self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) - def test_5_load_from_db(self): + def test_5_agent_communication_stream(self): + """it should communicate between agents using streaming""" + print("TestAgent1 tools", self.__class__.agent1.tools) + + test_tool_used = False + test_agent2_used = False + + class EventHandler(AgencyEventHandler): + @override + def on_text_created(self, text) -> None: + # get the name of the agent that is sending the message + if self.recipient_agent_name == "TestAgent2": + nonlocal test_agent2_used + test_agent2_used = True + + def on_tool_call_done(self, tool_call: ToolCall) -> None: + if tool_call.function.name == "TestTool": + nonlocal test_tool_used + test_tool_used = True + + message = self.__class__.agency.get_completion_stream("Please tell TestAgent1 to tell TestAgent 2 to use test tool.", + event_handler=EventHandler) + + self.assertFalse('error' in message.lower()) + + self.assertTrue(test_tool_used) + self.assertTrue(test_agent2_used) + + self.assertTrue(self.__class__.TestTool.shared_state.get("test_tool_used")) + + for agent_name, threads in self.__class__.agency.agents_and_threads.items(): + for other_agent_name, thread in threads.items(): + self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) + + for agent in self.__class__.agency.agents: + self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) + + def test_6_load_from_db(self): """it should load agents from db""" # os.rename("settings.json", "settings2.json") @@ -173,6 +233,8 @@ def test_5_load_from_db(self): agent1 = TestAgent1() agent1.add_tool(Retrieval) agent2 = TestAgent2() + agent2.add_tool(self.__class__.TestTool) + ceo = CEO() # check that agents are loaded @@ -205,7 +267,7 @@ def test_5_load_from_db(self): self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) self.assertTrue(agent.id in [settings['id'] for settings in previous_loaded_agents_settings]) - def test_6_init_async_agency(self): + def test_7_init_async_agency(self): """it should initialize agency with agents""" # reset loaded thread ids self.__class__.loaded_thread_ids = {} @@ -222,7 +284,7 @@ def test_6_init_async_agency(self): self.check_all_agents_settings(True) - def test_7_async_agent_communication(self): + def test_8_async_agent_communication(self): """it should communicate between agents asynchronously""" print("TestAgent1 tools", self.__class__.agent1.tools) self.__class__.agency.get_completion("Please tell TestAgent1 to say test to TestAgent2.", @@ -278,7 +340,7 @@ def check_agent_settings(self, agent, async_mode=False): self.assertTrue(assistant.tools[3].type == "function") self.assertTrue(assistant.tools[3].function.name == "GetResponse") elif agent.name == "TestAgent2": - self.assertTrue(len(assistant.tools) == self.__class__.num_schemas) + self.assertTrue(len(assistant.tools) == self.__class__.num_schemas + 1) for tool in assistant.tools: self.assertTrue(tool.type == "function") self.assertTrue(tool.function.name in [tool.__name__ for tool in agent.tools]) @@ -287,7 +349,7 @@ def check_agent_settings(self, agent, async_mode=False): self.assertTrue(len(assistant.file_ids) == 0) self.assertTrue(len(assistant.tools) == num_tools) else: - raise Exception("Unknown agent name") + pass except Exception as e: print("Error checking agent settings ", agent.name) raise e