From 227619102e51d9e750e5b7c6b37937a1c5c8d12a Mon Sep 17 00:00:00 2001 From: "P. Taylor Goetz" Date: Wed, 10 Apr 2024 15:21:41 -0400 Subject: [PATCH 1/2] Consolidate LLMType into AgentType --- backend/app/agent.py | 61 ++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/backend/app/agent.py b/backend/app/agent.py index a633c071..10521372 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -64,9 +64,9 @@ class AgentType(str, Enum): CLAUDE2 = "Claude 2" BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" GEMINI = "GEMINI" + MIXTRAL = "Mixtral" OLLAMA = "Ollama" - DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." CHECKPOINTER = PostgresCheckpoint(at=CheckpointAt.END_OF_STEP) @@ -175,36 +175,25 @@ def __init__( ) -class LLMType(str, Enum): - GPT_35_TURBO = "GPT 3.5 Turbo" - GPT_4 = "GPT 4 Turbo" - AZURE_OPENAI = "GPT 4 (Azure OpenAI)" - CLAUDE2 = "Claude 2" - BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" - GEMINI = "GEMINI" - MIXTRAL = "Mixtral" - OLLAMA = "Ollama" - - def get_chatbot( - llm_type: LLMType, + llm_type: AgentType, system_message: str, ): - if llm_type == LLMType.GPT_35_TURBO: + if llm_type == AgentType.GPT_35_TURBO: llm = get_openai_llm() - elif llm_type == LLMType.GPT_4: + elif llm_type == AgentType.GPT_4: llm = get_openai_llm(gpt_4=True) - elif llm_type == LLMType.AZURE_OPENAI: + elif llm_type == AgentType.AZURE_OPENAI: llm = get_openai_llm(azure=True) - elif llm_type == LLMType.CLAUDE2: + elif llm_type == AgentType.CLAUDE2: llm = get_anthropic_llm() - elif llm_type == LLMType.BEDROCK_CLAUDE2: + elif llm_type == AgentType.BEDROCK_CLAUDE2: llm = get_anthropic_llm(bedrock=True) - elif llm_type == LLMType.GEMINI: + elif llm_type == AgentType.GEMINI: llm = get_google_llm() - elif llm_type == LLMType.MIXTRAL: + elif llm_type == AgentType.MIXTRAL: llm = get_mixtral_fireworks() - elif llm_type == LLMType.OLLAMA: + elif llm_type == AgentType.OLLAMA: llm = get_ollama_llm() else: raise ValueError("Unexpected llm type") @@ -212,14 +201,14 @@ def get_chatbot( class ConfigurableChatBot(RunnableBinding): - llm: LLMType + llm: AgentType system_message: str = DEFAULT_SYSTEM_MESSAGE user_id: Optional[str] = None def __init__( self, *, - llm: LLMType = LLMType.GPT_35_TURBO, + llm: AgentType = AgentType.GPT_35_TURBO, system_message: str = DEFAULT_SYSTEM_MESSAGE, kwargs: Optional[Mapping[str, Any]] = None, config: Optional[Mapping[str, Any]] = None, @@ -238,7 +227,7 @@ def __init__( chatbot = ( - ConfigurableChatBot(llm=LLMType.GPT_35_TURBO, checkpoint=CHECKPOINTER) + ConfigurableChatBot(llm=AgentType.GPT_35_TURBO, checkpoint=CHECKPOINTER) .configurable_fields( llm=ConfigurableField(id="llm_type", name="LLM Type"), system_message=ConfigurableField(id="system_message", name="Instructions"), @@ -248,7 +237,7 @@ def __init__( class ConfigurableRetrieval(RunnableBinding): - llm_type: LLMType + llm_type: AgentType system_message: str = DEFAULT_SYSTEM_MESSAGE assistant_id: Optional[str] = None thread_id: Optional[str] = None @@ -257,7 +246,7 @@ class ConfigurableRetrieval(RunnableBinding): def __init__( self, *, - llm_type: LLMType = LLMType.GPT_35_TURBO, + llm_type: AgentType = AgentType.GPT_35_TURBO, system_message: str = DEFAULT_SYSTEM_MESSAGE, assistant_id: Optional[str] = None, thread_id: Optional[str] = None, @@ -267,21 +256,21 @@ def __init__( ) -> None: others.pop("bound", None) retriever = get_retriever(assistant_id, thread_id) - if llm_type == LLMType.GPT_35_TURBO: + if llm_type == AgentType.GPT_35_TURBO: llm = get_openai_llm() - elif llm_type == LLMType.GPT_4: + elif llm_type == AgentType.GPT_4: llm = get_openai_llm(gpt_4=True) - elif llm_type == LLMType.AZURE_OPENAI: + elif llm_type == AgentType.AZURE_OPENAI: llm = get_openai_llm(azure=True) - elif llm_type == LLMType.CLAUDE2: + elif llm_type == AgentType.CLAUDE2: llm = get_anthropic_llm() - elif llm_type == LLMType.BEDROCK_CLAUDE2: + elif llm_type == AgentType.BEDROCK_CLAUDE2: llm = get_anthropic_llm(bedrock=True) - elif llm_type == LLMType.GEMINI: + elif llm_type == AgentType.GEMINI: llm = get_google_llm() - elif llm_type == LLMType.MIXTRAL: + elif llm_type == AgentType.MIXTRAL: llm = get_mixtral_fireworks() - elif llm_type == LLMType.OLLAMA: + elif llm_type == AgentType.OLLAMA: llm = get_ollama_llm() else: raise ValueError("Unexpected llm type") @@ -296,7 +285,7 @@ def __init__( chat_retrieval = ( - ConfigurableRetrieval(llm_type=LLMType.GPT_35_TURBO, checkpoint=CHECKPOINTER) + ConfigurableRetrieval(llm_type=AgentType.GPT_35_TURBO, checkpoint=CHECKPOINTER) .configurable_fields( llm_type=ConfigurableField(id="llm_type", name="LLM Type"), system_message=ConfigurableField(id="system_message", name="Instructions"), @@ -319,7 +308,7 @@ def __init__( thread_id=None, ) .configurable_fields( - agent=ConfigurableField(id="agent_type", name="Agent Type"), + agent=ConfigurableField(id="agent_type", name="LLM Type"), system_message=ConfigurableField(id="system_message", name="Instructions"), interrupt_before_action=ConfigurableField( id="interrupt_before_action", From dd1245a3fcfc6e5f439a49e52695e07254589156 Mon Sep 17 00:00:00 2001 From: "P. Taylor Goetz" Date: Wed, 10 Apr 2024 15:29:08 -0400 Subject: [PATCH 2/2] linting --- backend/app/agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/app/agent.py b/backend/app/agent.py index 10521372..c5acffec 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -67,6 +67,7 @@ class AgentType(str, Enum): MIXTRAL = "Mixtral" OLLAMA = "Ollama" + DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." CHECKPOINTER = PostgresCheckpoint(at=CheckpointAt.END_OF_STEP)