forked from VRSEN/agency-swarm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented assistant initialization
- Loading branch information
Showing
18 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .util.oai import set_openai_key |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
class Agency: | ||
def __init__(self): | ||
self.agents = [] | ||
self.agents_and_threads = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .base_agent import BaseAgent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import json | ||
import os | ||
import inspect | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, List, Literal, Union, Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
from agency_swarm.tools import BaseTool | ||
from agency_swarm.util.oai import get_openai_client | ||
|
||
from deepdiff import DeepDiff | ||
|
||
import sys | ||
|
||
|
||
class DefaultTools(BaseModel): | ||
type: Literal["code_interpreter", "retrieval"] | ||
|
||
|
||
class BaseAgent(ABC): | ||
id: str = None | ||
name: str = None | ||
description: str = None | ||
instructions: str = None # can be file path | ||
tools: List[Union[BaseTool, DefaultTools]] = [] | ||
files: Union[List[str], str] = [] # can be file or folder path | ||
metadata: Dict[str, str] = {} | ||
model: str = "gpt-4-1106-preview" | ||
|
||
def __init__(self): | ||
if os.path.isfile(self.instructions): | ||
self.instructions = self._read_instructions() | ||
|
||
if isinstance(self.files, str): | ||
if os.path.isdir(self.files): | ||
self.files = os.listdir(self.files) | ||
self.files = [os.path.join(self.files, file) for file in self.files] | ||
self.files = self._upload_files() | ||
elif os.path.isfile(self.files): | ||
self.files = [self.files] | ||
self.files = self._upload_files() | ||
|
||
for i, tool in enumerate(self.tools): | ||
if isinstance(tool, BaseTool): | ||
self.tools[i]['type'] = "function" | ||
self.tools[i]["function"] = tool.openai_schema() | ||
|
||
if not self.name: | ||
self.name = self.__class__.__name__ | ||
|
||
self.client = get_openai_client() | ||
self._init_assistant() | ||
|
||
def _init_assistant(self): | ||
# check if settings.json exists | ||
path = self.get_settings_path() | ||
|
||
# load assistant from id | ||
if self.id: | ||
self.assistant = self.client.beta.assistants.retrieve(self.id) | ||
# update assistant if parameters are different | ||
if not self._check_parameters(self.assistant.model_dump()): | ||
self._update_assistant() | ||
return | ||
|
||
# load assistant from settings | ||
if os.path.exists(path): | ||
with open(path, 'r') as f: | ||
settings = json.load(f) | ||
# iterate settings and find the assistant with the same name | ||
for assistant_settings in settings: | ||
if assistant_settings['name'] == self.name: | ||
self.assistant = self.client.beta.assistants.retrieve(assistant_settings['id']) | ||
self.id = assistant_settings['id'] | ||
# update assistant if parameters are different | ||
if not self._check_parameters(self.assistant.model_dump()): | ||
self._update_assistant() | ||
|
||
return | ||
# create assistant if settings.json does not exist or assistant with the same name does not exist | ||
self.assistant = self.client.beta.assistants.create( | ||
name=self.name, | ||
description=self.description, | ||
instructions=self.instructions, | ||
tools=self.tools, | ||
file_ids=self.files, | ||
metadata=self.metadata, | ||
model=self.model | ||
) | ||
|
||
self.id = self.assistant.id | ||
|
||
self._save_settings() | ||
|
||
def _update_assistant(self): | ||
params = self.get_params() | ||
params = {k: v for k, v in params.items() if v is not None} | ||
self.assistant = self.client.beta.assistants.update( | ||
self.id, | ||
**params, | ||
) | ||
self._update_settings() | ||
|
||
def _check_parameters(self, assistant_settings): | ||
if self.name != assistant_settings['name']: | ||
print("name") | ||
return False | ||
if self.description != assistant_settings['description']: | ||
print("description") | ||
return False | ||
if self.instructions != assistant_settings['instructions']: | ||
print("instructions") | ||
return False | ||
if DeepDiff(self.tools, assistant_settings['tools'], ignore_order=True) != {}: | ||
print("tools") | ||
return False | ||
if set(self.files) != set(assistant_settings['file_ids']): | ||
print("files") | ||
return False | ||
if DeepDiff(self.metadata, assistant_settings['metadata'], ignore_order=True) != {}: | ||
print("metadata") | ||
return False | ||
if self.model != assistant_settings['model']: | ||
print("model") | ||
return False | ||
return True | ||
|
||
def _save_settings(self): | ||
path = self.get_settings_path() | ||
# check if settings.json exists | ||
if not os.path.isfile(path): | ||
with open(path, 'w') as f: | ||
json.dump([self.assistant.model_dump()], f, indent=4) | ||
else: | ||
settings = [] | ||
with open(path, 'r') as f: | ||
settings = json.load(f) | ||
settings.append(self.assistant.model_dump()) | ||
with open(path, 'w') as f: | ||
json.dump(settings, f, indent=4) | ||
|
||
def _update_settings(self): | ||
path = os.path.join(self.get_class_folder_path(), 'settings.json') | ||
# check if settings.json exists | ||
if os.path.isfile(path): | ||
settings = [] | ||
with open(path, 'r') as f: | ||
settings = json.load(f) | ||
for i, assistant_settings in enumerate(settings): | ||
if assistant_settings['id'] == self.id: | ||
settings[i] = self.assistant.model_dump() | ||
break | ||
with open(path, 'w') as f: | ||
json.dump(settings, f, indent=4) | ||
|
||
def _read_instructions(self): | ||
with open(self.instructions, 'r') as f: | ||
return f.read() | ||
|
||
def _upload_files(self): | ||
file_ids = [] | ||
for file in self.files: | ||
file = self.client.files.create(file=open(file, 'rb'), purpose="assistants") | ||
file_ids.append(file.id) | ||
return file_ids | ||
|
||
def get_settings_path(self): | ||
return os.path.join(self.get_class_folder_path(), 'settings.json') | ||
|
||
def get_class_folder_path(self): | ||
return os.path.abspath(os.path.dirname(inspect.getfile(self.__class__))) | ||
|
||
def get_params(self): | ||
return { | ||
"name": self.name, | ||
"description": self.description, | ||
"instructions": self.instructions, | ||
"tools": self.tools, | ||
"file_ids": self.files, | ||
"metadata": self.metadata, | ||
"model": self.model | ||
} | ||
|
||
def delete_assistant(self): | ||
self.client.beta.assistants.delete(self.id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from typing import Dict | ||
|
||
from agency_swarm.tools import BaseTool | ||
from agency_swarm.agents import BaseAgent | ||
|
||
|
||
class Ceo(BaseAgent): | ||
pass |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .base_tool import BaseTool |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from abc import ABC, abstractmethod | ||
from instructor import OpenAISchema | ||
|
||
|
||
class BaseTool(OpenAISchema, ABC): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
@abstractmethod | ||
def run(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Literal | ||
|
||
from agency_swarm.tools import BaseTool | ||
from pydantic import Field | ||
|
||
|
||
class SendMessage(BaseTool): | ||
"""Send messages to other specialized agents in this group chat.""" | ||
recepient: Literal['code_assistant'] = Field(..., description="code_assistant is a world class programming AI capable of executing python code.") | ||
message: str = Field(..., | ||
description="Specify the task required for the recipient agent to complete. Focus instead on clarifying what the task entails, rather than providing detailed instructions.") | ||
|
||
def run(self): | ||
pass |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import openai | ||
import threading | ||
import os | ||
|
||
client_lock = threading.Lock() | ||
client = None | ||
|
||
|
||
def get_openai_client(): | ||
global client | ||
with client_lock: | ||
if client is None: | ||
# Check if the API key is set | ||
api_key = openai.api_key or os.getenv('OPENAI_API_KEY') | ||
if api_key is None: | ||
raise ValueError("OpenAI API key is not set. Please set it using set_openai_key.") | ||
client = openai.OpenAI(api_key=api_key) | ||
return client | ||
|
||
|
||
def set_openai_key(key): | ||
if not key: | ||
raise ValueError("Invalid API key. The API key cannot be empty.") | ||
openai.api_key = key |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
openai==1.3.0 | ||
instructor==0.3.4 | ||
deepdiff==6.7.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from setuptools import setup, find_packages | ||
|
||
# Read the contents of your requirements file | ||
with open('requirements.txt') as f: | ||
requirements = f.read().splitlines() | ||
|
||
setup( | ||
name='agency', # Replace with your package's name | ||
version='0.1.0', # Initial version of your package | ||
author='VRSEN', # Replace with your name | ||
author_email='[email protected]', # Replace with your email address | ||
description='This project allows anyone', # Provide a short description | ||
long_description=open('README.md').read(), # Long description read from the README.md | ||
long_description_content_type='text/markdown', # Content type of the long description | ||
url='https://github.com/yourusername/your_package_name', # Replace with the URL of your package's repository | ||
packages=find_packages(), # Automatically find all packages and subpackages | ||
install_requires=requirements, | ||
classifiers=[ | ||
# Classifiers help users find your project by categorizing it | ||
'Development Status :: 3 - Alpha', # Change as appropriate | ||
'Intended Audience :: Developers', | ||
'Topic :: Software Development :: Build Tools', | ||
'License :: OSI Approved :: MIT License', # Choose the appropriate license | ||
'Programming Language :: Python :: 3', # Specify which pyhton versions you support | ||
'Programming Language :: Python :: 3.7', | ||
'Programming Language :: Python :: 3.8', | ||
'Programming Language :: Python :: 3.9', | ||
], | ||
python_requires='>=3.7', # Specify the Python version requirements | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import unittest | ||
|
||
from agency_swarm import set_openai_key | ||
from .test_agent.test_agent import TestAgent | ||
import sys | ||
import os | ||
import json | ||
|
||
sys.path.insert(0, '../agency_swarm') | ||
|
||
|
||
class MyTestCase(unittest.TestCase): | ||
agent = None | ||
|
||
def setUp(self): | ||
set_openai_key("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") | ||
|
||
# it should create new settings file and init agent | ||
def test_init_agent(self): | ||
self.agent = TestAgent() | ||
self.assertTrue(self.agent.id) | ||
|
||
self.settings_path = self.agent.get_settings_path() | ||
self.assertTrue(os.path.exists(self.settings_path)) | ||
|
||
# find assistant in settings by id | ||
with open(self.settings_path, 'r') as f: | ||
settings = json.load(f) | ||
for assistant_settings in settings: | ||
if assistant_settings['id'] == self.agent.id: | ||
self.assertTrue(self.agent._check_parameters(assistant_settings)) | ||
|
||
# it should load assistant from settings | ||
def test_load_agent(self): | ||
self.agent = TestAgent() | ||
agent2 = TestAgent() | ||
self.assertEqual(self.agent.id, agent2.id) | ||
|
||
def tearDown(self): | ||
# delete assistant from openai | ||
self.agent.delete_assistant() | ||
|
||
os.remove(self.agent.get_settings_path()) | ||
|
||
pass | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
You are a test agent. Always respond with "test". |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from agency_swarm.agents import BaseAgent | ||
|
||
|
||
class TestAgent(BaseAgent): | ||
description = "Test Agent" | ||
instructions = "./instructions.md" |