From add851beab88b1948e99369b9975107b792c7403 Mon Sep 17 00:00:00 2001 From: jerrylien Date: Thu, 29 Feb 2024 12:26:08 +0800 Subject: [PATCH] Add GEMINI Pro langchain genai support --- backend/app/agent.py | 15 +++- backend/app/llms.py | 7 +- backend/poetry.lock | 57 +++++++++++++- backend/pyproject.toml | 165 +++++++++++++++++++++-------------------- 4 files changed, 158 insertions(+), 86 deletions(-) diff --git a/backend/app/agent.py b/backend/app/agent.py index d9e885ce..9ad69fe7 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -37,7 +37,8 @@ class AgentType(str, Enum): AZURE_OPENAI = "GPT 4 (Azure OpenAI)" CLAUDE2 = "Claude 2" BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" - GEMINI = "GEMINI" + GEMINI = "GEMINI (VertexAI)" + GEMINI_GENAI = "GEMINI (GenAI)" DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." @@ -81,6 +82,11 @@ def get_agent_executor( return get_google_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) + elif agent == AgentType.GEMINI_GENAI: + llm = get_google_llm(genai=True) + return get_google_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) else: raise ValueError("Unexpected agent type") @@ -143,7 +149,8 @@ class LLMType(str, Enum): AZURE_OPENAI = "GPT 4 (Azure OpenAI)" CLAUDE2 = "Claude 2" BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" - GEMINI = "GEMINI" + GEMINI = "GEMINI (VertexAI)" + GEMINI_GENAI = "GEMINI (GenAI)" MIXTRAL = "Mixtral" @@ -163,6 +170,8 @@ def get_chatbot( llm = get_anthropic_llm(bedrock=True) elif llm_type == LLMType.GEMINI: llm = get_google_llm() + elif llm_type == LLMType.GEMINI_GENAI: + llm = get_google_llm(genai=True) elif llm_type == LLMType.MIXTRAL: llm = get_mixtral_fireworks() else: @@ -236,6 +245,8 @@ def __init__( llm = get_anthropic_llm(bedrock=True) elif llm_type == LLMType.GEMINI: llm = get_google_llm() + elif llm_type == LLMType.GEMINI_GENAI: + llm = get_google_llm(genai=True) elif llm_type == LLMType.MIXTRAL: llm = get_mixtral_fireworks() else: diff --git a/backend/app/llms.py b/backend/app/llms.py index 40f1d9a2..0b6735a3 100644 --- a/backend/app/llms.py +++ b/backend/app/llms.py @@ -6,6 +6,7 @@ import boto3 from langchain_community.chat_models import BedrockChat, ChatAnthropic, ChatFireworks from langchain_google_vertexai import ChatVertexAI +from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import AzureChatOpenAI, ChatOpenAI logger = logging.getLogger(__name__) @@ -66,7 +67,11 @@ def get_anthropic_llm(bedrock: bool = False): @lru_cache(maxsize=1) -def get_google_llm(): +def get_google_llm(genai: bool = False): + if genai: + return ChatGoogleGenerativeAI( + model="gemini-pro", convert_system_message_to_human=True + ) return ChatVertexAI( model_name="gemini-pro", convert_system_message_to_human=True, streaming=True ) diff --git a/backend/poetry.lock b/backend/poetry.lock index d368bb04..ecd1cd86 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -880,6 +880,22 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "google-ai-generativelanguage" +version = "0.4.0" +description = "Google Ai Generativelanguage API client library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"}, + {file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" + [[package]] name = "google-api-core" version = "2.15.0" @@ -1140,6 +1156,27 @@ files = [ [package.extras] testing = ["pytest"] +[[package]] +name = "google-generativeai" +version = "0.3.2" +description = "Google Generative AI High level API client library and tools." +optional = false +python-versions = ">=3.9" +files = [ + {file = "google_generativeai-0.3.2-py3-none-any.whl", hash = "sha256:8761147e6e167141932dc14a7b7af08f2310dd56668a78d206c19bb8bd85bcd7"}, +] + +[package.dependencies] +google-ai-generativelanguage = "0.4.0" +google-api-core = "*" +google-auth = "*" +protobuf = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] + [[package]] name = "google-resumable-media" version = "2.7.0" @@ -1605,6 +1642,24 @@ tenacity = ">=8.1.0,<9.0.0" [package.extras] extended-testing = ["jinja2 (>=3,<4)"] +[[package]] +name = "langchain-google-genai" +version = "0.0.9" +description = "An integration package connecting Google's genai package and LangChain" +optional = false +python-versions = ">=3.9,<4.0" +files = [ + {file = "langchain_google_genai-0.0.9-py3-none-any.whl", hash = "sha256:82c0ca9540132a59b09fc38ff249a2dd06f8a587ed37c291a4fe7678d5566d15"}, + {file = "langchain_google_genai-0.0.9.tar.gz", hash = "sha256:466a228032bb06b0c1def822e57cbf2dfe9e4d1cc91dffa473a3025eb760f0ef"}, +] + +[package.dependencies] +google-generativeai = ">=0.3.1,<0.4.0" +langchain-core = ">=0.1,<0.2" + +[package.extras] +images = ["pillow (>=10.1.0,<11.0.0)"] + [[package]] name = "langchain-google-vertexai" version = "0.0.3" @@ -3628,4 +3683,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9.0" -content-hash = "08d25c6ac983f211b011a85e0e941c7afe5d1877c41b252e151f4b667a50ee02" +content-hash = "e19fc6cede90f552da97aa99c263214cf519db3727ca2cd248763abc798744ef" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index daa06032..d93d030e 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,82 +1,83 @@ -[tool.poetry] -name = "opengpts" -version = "0.1.0" -description = "" -authors = ["Your Name "] -readme = "README.md" -packages = [{include = "app"}] - -[tool.poetry.dependencies] -python = "^3.9.0" -sse-starlette = "^1.6.5" -tomli-w = "^1.0.0" -uvicorn = "^0.23.2" -fastapi = "^0.103.2" -langserve = "0.0.32" -# Uncomment if you need to work from a development branch -# This will only work for local development though! -# langchain = { git = "git@github.com:langchain-ai/langchain.git/", branch = "nc/subclass-runnable-binding" , subdirectory = "libs/langchain"} -orjson = "^3.9.10" -redis = "^5.0.1" -python-multipart = "^0.0.6" -tiktoken = "^0.5.1" -langchain = ">=0.0.338" -langgraph = "^0.0.23" -pydantic = "<2.0" -python-magic = "^0.4.27" -langchain-openai = "^0.0.5" -beautifulsoup4 = "^4.12.3" -boto3 = "^1.34.28" -duckduckgo-search = "^4.2" -arxiv = "^2.1.0" -kay = "^0.1.2" -xmltodict = "^0.13.0" -wikipedia = "^1.4.0" -langchain-google-vertexai = "^0.0.3" -setuptools = "^69.0.3" -pdfminer-six = "^20231228" -langchain-robocorp = "^0.0.3" -fireworks-ai = "^0.11.2" -anthropic = "^0.13.0" -httpx = { version = "0.25.2", extras = ["socks"] } - -[tool.poetry.group.dev.dependencies] -uvicorn = "^0.23.2" -pygithub = "^2.1.1" - -[tool.poetry.group.lint.dependencies] -ruff = "^0.1.4" -codespell = "^2.2.0" - -[tool.poetry.group.test.dependencies] -pytest = "^7.2.1" -pytest-cov = "^4.0.0" -pytest-asyncio = "^0.21.1" -pytest-mock = "^3.11.1" -pytest-socket = "^0.6.0" -pytest-watch = "^4.2.0" -pytest-timeout = "^2.2.0" - -[tool.coverage.run] -omit = [ - "tests/*", -] - -[tool.pytest.ini_options] -# --strict-markers will raise errors on unknown marks. -# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks -# -# https://docs.pytest.org/en/7.1.x/reference/reference.html -# --strict-config any warnings encountered while parsing the `pytest` -# section of the configuration file raise errors. -addopts = "--strict-markers --strict-config --durations=5 -vv" -# Use global timeout of 30 seconds for now. -# Most tests should be closer to ~100 ms, but some of the tests involve -# parsing files. We can adjust on a per test basis later on. -timeout = 30 -asyncio_mode = "auto" - - -[build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +[tool.poetry] +name = "opengpts" +version = "0.1.0" +description = "" +authors = ["Your Name "] +readme = "README.md" +packages = [{include = "app"}] + +[tool.poetry.dependencies] +python = "^3.9.0" +sse-starlette = "^1.6.5" +tomli-w = "^1.0.0" +uvicorn = "^0.23.2" +fastapi = "^0.103.2" +langserve = "0.0.32" +# Uncomment if you need to work from a development branch +# This will only work for local development though! +# langchain = { git = "git@github.com:langchain-ai/langchain.git/", branch = "nc/subclass-runnable-binding" , subdirectory = "libs/langchain"} +orjson = "^3.9.10" +redis = "^5.0.1" +python-multipart = "^0.0.6" +tiktoken = "^0.5.1" +langchain = ">=0.0.338" +langgraph = "^0.0.23" +pydantic = "<2.0" +python-magic = "^0.4.27" +langchain-openai = "^0.0.5" +beautifulsoup4 = "^4.12.3" +boto3 = "^1.34.28" +duckduckgo-search = "^4.2" +arxiv = "^2.1.0" +kay = "^0.1.2" +xmltodict = "^0.13.0" +wikipedia = "^1.4.0" +langchain-google-vertexai = "^0.0.3" +setuptools = "^69.0.3" +pdfminer-six = "^20231228" +langchain-robocorp = "^0.0.3" +fireworks-ai = "^0.11.2" +anthropic = "^0.13.0" +httpx = { version = "0.25.2", extras = ["socks"] } +langchain-google-genai = "0.0.9" + +[tool.poetry.group.dev.dependencies] +uvicorn = "^0.23.2" +pygithub = "^2.1.1" + +[tool.poetry.group.lint.dependencies] +ruff = "^0.1.4" +codespell = "^2.2.0" + +[tool.poetry.group.test.dependencies] +pytest = "^7.2.1" +pytest-cov = "^4.0.0" +pytest-asyncio = "^0.21.1" +pytest-mock = "^3.11.1" +pytest-socket = "^0.6.0" +pytest-watch = "^4.2.0" +pytest-timeout = "^2.2.0" + +[tool.coverage.run] +omit = [ + "tests/*", +] + +[tool.pytest.ini_options] +# --strict-markers will raise errors on unknown marks. +# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks +# +# https://docs.pytest.org/en/7.1.x/reference/reference.html +# --strict-config any warnings encountered while parsing the `pytest` +# section of the configuration file raise errors. +addopts = "--strict-markers --strict-config --durations=5 -vv" +# Use global timeout of 30 seconds for now. +# Most tests should be closer to ~100 ms, but some of the tests involve +# parsing files. We can adjust on a per test basis later on. +timeout = 30 +asyncio_mode = "auto" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api"