From a07f8b55b0e8835915feba9f9e2ca588d96e9769 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 12 Mar 2024 22:50:38 -0700 Subject: [PATCH] add all models as fallbacks (#290) --- backend/chain.py | 59 +++++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/backend/chain.py b/backend/chain.py index 982b923b7..3e7e428ec 100644 --- a/backend/chain.py +++ b/backend/chain.py @@ -236,39 +236,42 @@ def cohere_response_synthesizer(input: dict) -> RunnableSequence: ) -llm = ChatOpenAI( - model="gpt-3.5-turbo-0125", +gpt_3_5 = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, streaming=True) +claude_3_sonnet = ChatAnthropic( + model="claude-3-sonnet-20240229", temperature=0, - streaming=True, -).configurable_alternatives( + max_tokens=4096, + anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY", "not_provided"), +) +fireworks_mixtral = ChatFireworks( + model="accounts/fireworks/models/mixtral-8x7b-instruct", + temperature=0, + max_tokens=16384, + fireworks_api_key=os.environ.get("FIREWORKS_API_KEY", "not_provided"), +) +gemini_pro = ChatGoogleGenerativeAI( + model="gemini-pro", + temperature=0, + max_tokens=16384, + convert_system_message_to_human=True, + google_api_key=os.environ.get("GOOGLE_API_KEY", "not_provided"), +) +cohere_command = ChatCohere( + model="command", + temperature=0, + cohere_api_key=os.environ.get("COHERE_API_KEY", "not_provided"), +) +llm = gpt_3_5.configurable_alternatives( # This gives this field an id # When configuring the end runnable, we can then use this id to configure this field ConfigurableField(id="llm"), default_key="openai_gpt_3_5_turbo", - anthropic_claude_3_sonnet=ChatAnthropic( - model="claude-3-sonnet-20240229", - temperature=0, - max_tokens=4096, - anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY", "not_provided"), - ), - fireworks_mixtral=ChatFireworks( - model="accounts/fireworks/models/mixtral-8x7b-instruct", - temperature=0, - max_tokens=16384, - fireworks_api_key=os.environ.get("FIREWORKS_API_KEY", "not_provided"), - ), - google_gemini_pro=ChatGoogleGenerativeAI( - model="gemini-pro", - temperature=0, - max_tokens=16384, - convert_system_message_to_human=True, - google_api_key=os.environ.get("GOOGLE_API_KEY", "not_provided"), - ), - cohere_command=ChatCohere( - model="command", - temperature=0, - cohere_api_key=os.environ.get("COHERE_API_KEY", "not_provided"), - ), + anthropic_claude_3_sonnet=claude_3_sonnet, + fireworks_mixtral=fireworks_mixtral, + google_gemini_pro=gemini_pro, + cohere_command=cohere_command, +).with_fallbacks( + [gpt_3_5, claude_3_sonnet, fireworks_mixtral, gemini_pro, cohere_command] ) retriever = get_retriever()