diff --git a/backend/app/agent.py b/backend/app/agent.py index a633c071..c5acffec 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -64,6 +64,7 @@ class AgentType(str, Enum): CLAUDE2 = "Claude 2" BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" GEMINI = "GEMINI" + MIXTRAL = "Mixtral" OLLAMA = "Ollama" @@ -175,36 +176,25 @@ def __init__( ) -class LLMType(str, Enum): - GPT_35_TURBO = "GPT 3.5 Turbo" - GPT_4 = "GPT 4 Turbo" - AZURE_OPENAI = "GPT 4 (Azure OpenAI)" - CLAUDE2 = "Claude 2" - BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" - GEMINI = "GEMINI" - MIXTRAL = "Mixtral" - OLLAMA = "Ollama" - - def get_chatbot( - llm_type: LLMType, + llm_type: AgentType, system_message: str, ): - if llm_type == LLMType.GPT_35_TURBO: + if llm_type == AgentType.GPT_35_TURBO: llm = get_openai_llm() - elif llm_type == LLMType.GPT_4: + elif llm_type == AgentType.GPT_4: llm = get_openai_llm(gpt_4=True) - elif llm_type == LLMType.AZURE_OPENAI: + elif llm_type == AgentType.AZURE_OPENAI: llm = get_openai_llm(azure=True) - elif llm_type == LLMType.CLAUDE2: + elif llm_type == AgentType.CLAUDE2: llm = get_anthropic_llm() - elif llm_type == LLMType.BEDROCK_CLAUDE2: + elif llm_type == AgentType.BEDROCK_CLAUDE2: llm = get_anthropic_llm(bedrock=True) - elif llm_type == LLMType.GEMINI: + elif llm_type == AgentType.GEMINI: llm = get_google_llm() - elif llm_type == LLMType.MIXTRAL: + elif llm_type == AgentType.MIXTRAL: llm = get_mixtral_fireworks() - elif llm_type == LLMType.OLLAMA: + elif llm_type == AgentType.OLLAMA: llm = get_ollama_llm() else: raise ValueError("Unexpected llm type") @@ -212,14 +202,14 @@ def get_chatbot( class ConfigurableChatBot(RunnableBinding): - llm: LLMType + llm: AgentType system_message: str = DEFAULT_SYSTEM_MESSAGE user_id: Optional[str] = None def __init__( self, *, - llm: LLMType = LLMType.GPT_35_TURBO, + llm: AgentType = AgentType.GPT_35_TURBO, system_message: str = DEFAULT_SYSTEM_MESSAGE, kwargs: Optional[Mapping[str, Any]] = None, config: Optional[Mapping[str, Any]] = None, @@ -238,7 +228,7 @@ def __init__( chatbot = ( - ConfigurableChatBot(llm=LLMType.GPT_35_TURBO, checkpoint=CHECKPOINTER) + ConfigurableChatBot(llm=AgentType.GPT_35_TURBO, checkpoint=CHECKPOINTER) .configurable_fields( llm=ConfigurableField(id="llm_type", name="LLM Type"), system_message=ConfigurableField(id="system_message", name="Instructions"), @@ -248,7 +238,7 @@ def __init__( class ConfigurableRetrieval(RunnableBinding): - llm_type: LLMType + llm_type: AgentType system_message: str = DEFAULT_SYSTEM_MESSAGE assistant_id: Optional[str] = None thread_id: Optional[str] = None @@ -257,7 +247,7 @@ class ConfigurableRetrieval(RunnableBinding): def __init__( self, *, - llm_type: LLMType = LLMType.GPT_35_TURBO, + llm_type: AgentType = AgentType.GPT_35_TURBO, system_message: str = DEFAULT_SYSTEM_MESSAGE, assistant_id: Optional[str] = None, thread_id: Optional[str] = None, @@ -267,21 +257,21 @@ def __init__( ) -> None: others.pop("bound", None) retriever = get_retriever(assistant_id, thread_id) - if llm_type == LLMType.GPT_35_TURBO: + if llm_type == AgentType.GPT_35_TURBO: llm = get_openai_llm() - elif llm_type == LLMType.GPT_4: + elif llm_type == AgentType.GPT_4: llm = get_openai_llm(gpt_4=True) - elif llm_type == LLMType.AZURE_OPENAI: + elif llm_type == AgentType.AZURE_OPENAI: llm = get_openai_llm(azure=True) - elif llm_type == LLMType.CLAUDE2: + elif llm_type == AgentType.CLAUDE2: llm = get_anthropic_llm() - elif llm_type == LLMType.BEDROCK_CLAUDE2: + elif llm_type == AgentType.BEDROCK_CLAUDE2: llm = get_anthropic_llm(bedrock=True) - elif llm_type == LLMType.GEMINI: + elif llm_type == AgentType.GEMINI: llm = get_google_llm() - elif llm_type == LLMType.MIXTRAL: + elif llm_type == AgentType.MIXTRAL: llm = get_mixtral_fireworks() - elif llm_type == LLMType.OLLAMA: + elif llm_type == AgentType.OLLAMA: llm = get_ollama_llm() else: raise ValueError("Unexpected llm type") @@ -296,7 +286,7 @@ def __init__( chat_retrieval = ( - ConfigurableRetrieval(llm_type=LLMType.GPT_35_TURBO, checkpoint=CHECKPOINTER) + ConfigurableRetrieval(llm_type=AgentType.GPT_35_TURBO, checkpoint=CHECKPOINTER) .configurable_fields( llm_type=ConfigurableField(id="llm_type", name="LLM Type"), system_message=ConfigurableField(id="system_message", name="Instructions"), @@ -319,7 +309,7 @@ def __init__( thread_id=None, ) .configurable_fields( - agent=ConfigurableField(id="agent_type", name="Agent Type"), + agent=ConfigurableField(id="agent_type", name="LLM Type"), system_message=ConfigurableField(id="system_message", name="Instructions"), interrupt_before_action=ConfigurableField( id="interrupt_before_action",