Skip to content

Commit

Permalink
add all models as fallbacks (#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Mar 13, 2024
1 parent 6296d0a commit a07f8b5
Showing 1 changed file with 31 additions and 28 deletions.
59 changes: 31 additions & 28 deletions backend/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,39 +236,42 @@ def cohere_response_synthesizer(input: dict) -> RunnableSequence:
)


llm = ChatOpenAI(
model="gpt-3.5-turbo-0125",
gpt_3_5 = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0, streaming=True)
claude_3_sonnet = ChatAnthropic(
model="claude-3-sonnet-20240229",
temperature=0,
streaming=True,
).configurable_alternatives(
max_tokens=4096,
anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY", "not_provided"),
)
fireworks_mixtral = ChatFireworks(
model="accounts/fireworks/models/mixtral-8x7b-instruct",
temperature=0,
max_tokens=16384,
fireworks_api_key=os.environ.get("FIREWORKS_API_KEY", "not_provided"),
)
gemini_pro = ChatGoogleGenerativeAI(
model="gemini-pro",
temperature=0,
max_tokens=16384,
convert_system_message_to_human=True,
google_api_key=os.environ.get("GOOGLE_API_KEY", "not_provided"),
)
cohere_command = ChatCohere(
model="command",
temperature=0,
cohere_api_key=os.environ.get("COHERE_API_KEY", "not_provided"),
)
llm = gpt_3_5.configurable_alternatives(
# This gives this field an id
# When configuring the end runnable, we can then use this id to configure this field
ConfigurableField(id="llm"),
default_key="openai_gpt_3_5_turbo",
anthropic_claude_3_sonnet=ChatAnthropic(
model="claude-3-sonnet-20240229",
temperature=0,
max_tokens=4096,
anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY", "not_provided"),
),
fireworks_mixtral=ChatFireworks(
model="accounts/fireworks/models/mixtral-8x7b-instruct",
temperature=0,
max_tokens=16384,
fireworks_api_key=os.environ.get("FIREWORKS_API_KEY", "not_provided"),
),
google_gemini_pro=ChatGoogleGenerativeAI(
model="gemini-pro",
temperature=0,
max_tokens=16384,
convert_system_message_to_human=True,
google_api_key=os.environ.get("GOOGLE_API_KEY", "not_provided"),
),
cohere_command=ChatCohere(
model="command",
temperature=0,
cohere_api_key=os.environ.get("COHERE_API_KEY", "not_provided"),
),
anthropic_claude_3_sonnet=claude_3_sonnet,
fireworks_mixtral=fireworks_mixtral,
google_gemini_pro=gemini_pro,
cohere_command=cohere_command,
).with_fallbacks(
[gpt_3_5, claude_3_sonnet, fireworks_mixtral, gemini_pro, cohere_command]
)

retriever = get_retriever()
Expand Down

0 comments on commit a07f8b5

Please sign in to comment.