forked from VRSEN/agency-swarm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_agency.py
372 lines (291 loc) · 15.3 KB
/
test_agency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import inspect
import json
import os
import shutil
import sys
import time
import unittest
from openai.types.beta.threads.runs import ToolCall
from agency_swarm.tools import CodeInterpreter, Retrieval
sys.path.insert(0, '../agency-swarm')
from agency_swarm.util import create_agent_template
from agency_swarm import set_openai_key, Agent, Agency, AgencyEventHandler
from typing_extensions import override
from agency_swarm.tools import BaseTool
class AgencyTest(unittest.TestCase):
TestTool = None
agency = None
agent2 = None
agent1 = None
ceo = None
num_schemas = None
num_files = None
# testing loading agents from db
loaded_thread_ids = None
loaded_agents_settings = None
settings_callbacks = None
threads_callbacks = None
@classmethod
def setUpClass(cls):
cls.num_files = 0
cls.num_schemas = 0
cls.ceo = None
cls.agent1 = None
cls.agent2 = None
cls.agency = None
# testing loading agents from db
cls.loaded_thread_ids = {}
cls.loaded_agents_settings = []
def save_settings_callback(settings):
cls.loaded_agents_settings = settings
cls.settings_callbacks = {
"load": lambda: cls.loaded_agents_settings,
"save": save_settings_callback,
}
def save_thread_callback(agents_and_thread_ids):
cls.loaded_thread_ids = agents_and_thread_ids
cls.threads_callbacks = {
"load": lambda: cls.loaded_thread_ids,
"save": save_thread_callback,
}
if not os.path.exists("./test_agents"):
os.mkdir("./test_agents")
else:
shutil.rmtree("./test_agents")
os.mkdir("./test_agents")
# create init file
with open("./test_agents/__init__.py", "w") as f:
f.write("")
# create agent templates in test_agents
create_agent_template("CEO", "CEO Test Agent", path="./test_agents",
instructions="Your task is to tell TestAgent1 to say test to another test agent. If the "
"agent, does not respond or something goes wrong please say 'error' and "
"nothing else. Otherwise say 'success' and nothing else.")
create_agent_template("TestAgent1", "Test Agent 1", path="./test_agents",
instructions="Your task is to say test to another test agent using SendMessage tool. "
"If the agent, does not "
"respond or something goes wrong please say 'error' and nothing else. "
"Otherwise say 'success' and nothing else.", code_interpreter=True)
create_agent_template("TestAgent2", "Test Agent 2", path="./test_agents",
instructions="Please respond to the user that test was a success.")
sys.path.insert(0, './test_agents')
# copy files from data/files to test_agents/TestAgent1/files
for file in os.listdir("./data/files"):
shutil.copyfile("./data/files/" + file, "./test_agents/TestAgent1/files/" + file)
cls.num_files += 1
# copy schemas from data/schemas to test_agents/TestAgent2/schemas
for file in os.listdir("./data/schemas"):
shutil.copyfile("./data/schemas/" + file, "./test_agents/TestAgent2/schemas/" + file)
cls.num_schemas += 1
class TestTool(BaseTool):
"""
A simple test tool that returns "Test Successful" to demonstrate the functionality of a custom tool within the Agency Swarm framework.
"""
# This tool does not require any input fields, but you can define them similarly for other tools.
def run(self):
"""
Executes the test tool's main functionality. In this case, it simply returns a success message.
"""
self.shared_state.set("test_tool_used", True)
return "Test Successful"
cls.TestTool = TestTool
from test_agents import CEO, TestAgent1, TestAgent2
cls.ceo = CEO()
cls.agent1 = TestAgent1()
cls.agent1.add_tool(Retrieval)
cls.agent2 = TestAgent2()
cls.agent2.add_tool(cls.TestTool)
def test_1_init_agency(self):
"""it should initialize agency with agents"""
self.__class__.agency = Agency([
self.__class__.ceo,
[self.__class__.ceo, self.__class__.agent1],
[self.__class__.agent1, self.__class__.agent2]],
shared_instructions="This is a shared instruction",
settings_callbacks=self.__class__.settings_callbacks,
threads_callbacks=self.__class__.threads_callbacks,
)
self.check_all_agents_settings()
def test_2_load_agent(self):
"""it should load existing assistant from settings"""
from test_agents import TestAgent1
agent3 = TestAgent1()
agent3.add_shared_instructions(self.__class__.agency.shared_instructions)
agent3.tools = self.__class__.agent1.tools
agent3 = agent3.init_oai()
print("agent3", agent3.assistant.model_dump())
print("agent1", self.__class__.agent1.assistant.model_dump())
self.assertTrue(self.__class__.agent1.id == agent3.id)
# check that assistant settings match
self.assertTrue(agent3._check_parameters(self.__class__.agent1.assistant.model_dump()))
self.check_agent_settings(agent3)
def test_3_load_agent_id(self):
"""it should load existing assistant from id"""
from test_agents import TestAgent1
agent3 = Agent(id=self.__class__.agent1.id)
agent3.tools = self.__class__.agent1.tools
agent3 = agent3.init_oai()
print("agent3", agent3.assistant.model_dump())
print("agent1", self.__class__.agent1.assistant.model_dump())
self.assertTrue(self.__class__.agent1.id == agent3.id)
# check that assistant settings match
self.assertTrue(agent3._check_parameters(self.__class__.agent1.assistant.model_dump()))
self.check_agent_settings(agent3)
def test_4_agent_communication(self):
"""it should communicate between agents"""
print("TestAgent1 tools", self.__class__.agent1.tools)
message = self.__class__.agency.get_completion("Please tell TestAgent1 to say test to TestAgent2.", yield_messages=False)
self.assertFalse('error' in message.lower())
for agent_name, threads in self.__class__.agency.agents_and_threads.items():
for other_agent_name, thread in threads.items():
self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name])
for agent in self.__class__.agency.agents:
self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings])
def test_5_agent_communication_stream(self):
"""it should communicate between agents using streaming"""
print("TestAgent1 tools", self.__class__.agent1.tools)
test_tool_used = False
test_agent2_used = False
class EventHandler(AgencyEventHandler):
@override
def on_text_created(self, text) -> None:
# get the name of the agent that is sending the message
if self.recipient_agent_name == "TestAgent2":
nonlocal test_agent2_used
test_agent2_used = True
def on_tool_call_done(self, tool_call: ToolCall) -> None:
if tool_call.function.name == "TestTool":
nonlocal test_tool_used
test_tool_used = True
message = self.__class__.agency.get_completion_stream("Please tell TestAgent1 to tell TestAgent 2 to use test tool.",
event_handler=EventHandler)
# self.assertFalse('error' in message.lower())
self.assertTrue(test_tool_used)
self.assertTrue(test_agent2_used)
self.assertTrue(self.__class__.TestTool.shared_state.get("test_tool_used"))
for agent_name, threads in self.__class__.agency.agents_and_threads.items():
for other_agent_name, thread in threads.items():
self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name])
for agent in self.__class__.agency.agents:
self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings])
def test_6_load_from_db(self):
"""it should load agents from db"""
# os.rename("settings.json", "settings2.json")
previous_loaded_thread_ids = self.__class__.loaded_thread_ids
previous_loaded_agents_settings = self.__class__.loaded_agents_settings
from test_agents import CEO, TestAgent1, TestAgent2
agent1 = TestAgent1()
agent1.add_tool(Retrieval)
agent2 = TestAgent2()
agent2.add_tool(self.__class__.TestTool)
ceo = CEO()
# check that agents are loaded
agency = Agency([
ceo,
[ceo, agent1],
[agent1, agent2]],
shared_instructions="This is a shared instruction",
settings_path="./settings2.json",
settings_callbacks=self.__class__.settings_callbacks,
threads_callbacks=self.__class__.threads_callbacks,
)
# check that settings are the same
self.assertTrue(len(agency.agents) == len(self.__class__.agency.agents))
os.remove("settings.json")
os.rename("settings2.json", "settings.json")
self.check_all_agents_settings()
# check that threads are the same
for agent_name, threads in agency.agents_and_threads.items():
for other_agent_name, thread in threads.items():
self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name])
self.assertTrue(thread.id in previous_loaded_thread_ids[agent_name][other_agent_name])
# check that agents are the same
for agent in agency.agents:
self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings])
self.assertTrue(agent.id in [settings['id'] for settings in previous_loaded_agents_settings])
def test_7_init_async_agency(self):
"""it should initialize agency with agents"""
# reset loaded thread ids
self.__class__.loaded_thread_ids = {}
self.__class__.agency = Agency([
self.__class__.ceo,
[self.__class__.ceo, self.__class__.agent1],
[self.__class__.agent1, self.__class__.agent2]],
shared_instructions="This is a shared instruction",
settings_callbacks=self.__class__.settings_callbacks,
threads_callbacks=self.__class__.threads_callbacks,
async_mode='threading',
)
self.check_all_agents_settings(True)
def test_8_async_agent_communication(self):
"""it should communicate between agents asynchronously"""
print("TestAgent1 tools", self.__class__.agent1.tools)
self.__class__.agency.get_completion("Please tell TestAgent1 to say test to TestAgent2.",
yield_messages=False)
time.sleep(10)
message = self.__class__.agency.get_completion("Please check response. If the GetResponse function output includes `TestAgent1's Response` (for example, that the message was sent to Test Agent 2, the process or the task has started, initiated, etc.), say 'success'. If the function output does not include `TestAgent1's Response`, or if you get a System Notification, or an error instead, say 'error'.",
yield_messages=False)
self.assertFalse('error' in message.lower())
for agent_name, threads in self.__class__.agency.agents_and_threads.items():
for other_agent_name, thread in threads.items():
self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name])
for agent in self.__class__.agency.agents:
self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings])
# --- Helper methods ---
def get_class_folder_path(self):
return os.path.abspath(os.path.dirname(inspect.getfile(self.__class__)))
def check_agent_settings(self, agent, async_mode=False):
try:
settings_path = agent.get_settings_path()
self.assertTrue(os.path.exists(settings_path))
with open(settings_path, 'r') as f:
settings = json.load(f)
for assistant_settings in settings:
if assistant_settings['id'] == agent.id:
self.assertTrue(agent._check_parameters(assistant_settings))
assistant = agent.assistant
self.assertTrue(assistant)
self.assertTrue(agent._check_parameters(assistant.model_dump()))
if agent.name == "TestAgent1":
num_tools = 3 if not async_mode else 4
self.assertTrue(len(assistant.file_ids) == self.__class__.num_files)
for file_id in assistant.file_ids:
self.assertTrue(file_id in agent.file_ids)
# check retrieval tools is there
print("assistant tools", assistant.tools)
self.assertTrue(len(assistant.tools) == num_tools)
self.assertTrue(len(agent.tools) == num_tools)
self.assertTrue(assistant.tools[0].type == "code_interpreter")
self.assertTrue(assistant.tools[1].type == "retrieval")
self.assertTrue(assistant.tools[2].type == "function")
self.assertTrue(assistant.tools[2].function.name == "SendMessage")
if async_mode:
self.assertTrue(assistant.tools[3].type == "function")
self.assertTrue(assistant.tools[3].function.name == "GetResponse")
elif agent.name == "TestAgent2":
self.assertTrue(len(assistant.tools) == self.__class__.num_schemas + 1)
for tool in assistant.tools:
self.assertTrue(tool.type == "function")
self.assertTrue(tool.function.name in [tool.__name__ for tool in agent.tools])
elif agent.name == "CEO":
num_tools = 1 if not async_mode else 2
self.assertTrue(len(assistant.file_ids) == 0)
self.assertTrue(len(assistant.tools) == num_tools)
else:
pass
except Exception as e:
print("Error checking agent settings ", agent.name)
raise e
def check_all_agents_settings(self, async_mode=False):
self.check_agent_settings(self.__class__.ceo, async_mode=async_mode)
self.check_agent_settings(self.__class__.agent1, async_mode=async_mode)
self.check_agent_settings(self.__class__.agent2, async_mode=async_mode)
@classmethod
def tearDownClass(cls):
shutil.rmtree("./test_agents")
os.remove("./settings.json")
cls.ceo.delete()
cls.agent1.delete()
cls.agent2.delete()
if __name__ == '__main__':
unittest.main()