From 55f7e7dbf57040a78aa410235b46577f60064533 Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:54:45 -0600 Subject: [PATCH] feature: Add gemini backend (#52) * feat: Add gemini enums * feat: Add cli arguments for gemini * feat: Add first version gemini * fix: Add break on empty text * feat: Added tests and snapshots * feat: Update readme * fix: Commented out unused model attributes * fix: Config URL for Gemini cannot be different * docs: Removed gemini-url from docs * test: Removed unused vars in model from tests * refactor: Clean imports in test * docs: Fix config example, removed unused variables * fix: Add config set to gemini tests * fix: Added config set to gemini test completions * refactor: Change completion test body to raw string * chore: Run fmt lint --- README.md | 5 +- config.example.toml | 5 +- src/application/cli.rs | 8 + src/configuration/config.rs | 2 + src/domain/models/backend.rs | 1 + src/infrastructure/backends/gemini.rs | 249 ++++++++++++++++++ src/infrastructure/backends/gemini_test.rs | 195 ++++++++++++++ src/infrastructure/backends/mod.rs | 5 + ...g__tests__it_serializes_to_valid_toml.snap | 5 +- ...s__gemini__tests__it_gets_completions.snap | 5 + 10 files changed, 477 insertions(+), 3 deletions(-) create mode 100644 src/infrastructure/backends/gemini.rs create mode 100644 src/infrastructure/backends/gemini_test.rs create mode 100644 test/snapshots/oatmeal__infrastructure__backends__gemini__tests__it_gets_completions.snap diff --git a/README.md b/README.md index db3aa8f..072ab1b 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,7 @@ Commands: Options: -b, --backend - The initial backend hosting a model to connect to. [default: ollama] [env: OATMEAL_BACKEND=] [possible values: langchain, ollama, openai] + The initial backend hosting a model to connect to. [default: ollama] [env: OATMEAL_BACKEND=] [possible values: langchain, ollama, openai, gemini] --backend-health-check-timeout Time to wait in milliseconds before timing out when doing a healthcheck for a backend. [default: 1000] [env: OATMEAL_BACKEND_HEALTH_CHECK_TIMEOUT=] -m, --model @@ -194,6 +194,8 @@ Options: OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. [default: https://api.openai.com] [env: OATMEAL_OPENAI_URL=] --open-ai-token OpenAI API token when using the OpenAI backend. [env: OATMEAL_OPENAI_TOKEN=] + --gemini-token + Gemini API token when using the Gemini backend. [env: OATMEAL_GEMINI_TOKEN=] -h, --help Print help -V, --version @@ -261,6 +263,7 @@ The following model backends are supported: - [OpenAI](https://chat.openai.com) (Or any compatible proxy/API) - [Ollama](https://github.com/jmorganca/ollama) - [LangChain/LangServe](https://python.langchain.com/docs/langserve) (Experimental) +- [Gemini](https://gemini.google.com) (Experimental) ### Editors diff --git a/config.example.toml b/config.example.toml index 9939e60..f79a209 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,4 +1,4 @@ -# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai] +# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai, gemini] backend = "ollama" # Time to wait in milliseconds before timing out when doing a healthcheck for a backend. @@ -22,6 +22,9 @@ ollama-url = "http://localhost:11434" # OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. open-ai-url = "https://api.openai.com" +# Gemini API token when using the Gemini backend. +# gemini-token = "" + # Sets code syntax highlighting theme. [possible values: base16-github, base16-monokai, base16-one-light, base16-onedark, base16-seti] theme = "base16-onedark" diff --git a/src/application/cli.rs b/src/application/cli.rs index 9019e97..1414f2b 100644 --- a/src/application/cli.rs +++ b/src/application/cli.rs @@ -398,6 +398,14 @@ pub fn build() -> Command { .num_args(1) .help("OpenAI API token when using the OpenAI backend.") .global(true), + ) + .arg( + Arg::new(ConfigKey::GeminiToken.to_string()) + .long(ConfigKey::GeminiToken.to_string()) + .env("OATMEAL_GEMINI_TOKEN") + .num_args(1) + .help("Google Gemini API token when using the Gemini backend.") + .global(true), ); } diff --git a/src/configuration/config.rs b/src/configuration/config.rs index 6494722..68ba929 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -33,6 +33,7 @@ pub enum ConfigKey { OllamaURL, OpenAiToken, OpenAiURL, + GeminiToken, SessionID, Theme, ThemeFile, @@ -99,6 +100,7 @@ impl Config { ConfigKey::OllamaURL => "http://localhost:11434", ConfigKey::OpenAiToken => "", ConfigKey::OpenAiURL => "https://api.openai.com", + ConfigKey::GeminiToken => "", ConfigKey::Theme => "base16-onedark", ConfigKey::ThemeFile => "", diff --git a/src/domain/models/backend.rs b/src/domain/models/backend.rs index 1f7b0f6..6c4fc2f 100644 --- a/src/domain/models/backend.rs +++ b/src/domain/models/backend.rs @@ -19,6 +19,7 @@ pub enum BackendName { LangChain, Ollama, OpenAI, + Gemini, } impl BackendName { diff --git a/src/infrastructure/backends/gemini.rs b/src/infrastructure/backends/gemini.rs new file mode 100644 index 0000000..166bc25 --- /dev/null +++ b/src/infrastructure/backends/gemini.rs @@ -0,0 +1,249 @@ +#[cfg(test)] +#[path = "gemini_test.rs"] +mod tests; + +use std::time::Duration; + +use anyhow::bail; +use anyhow::Result; +use async_trait::async_trait; +use futures::stream::TryStreamExt; +use serde::Deserialize; +use serde::Serialize; +use tokio::io::AsyncBufReadExt; +use tokio::sync::mpsc; +use tokio_util::io::StreamReader; + +use crate::configuration::Config; +use crate::configuration::ConfigKey; +use crate::domain::models::Author; +use crate::domain::models::Backend; +use crate::domain::models::BackendName; +use crate::domain::models::BackendPrompt; +use crate::domain::models::BackendResponse; +use crate::domain::models::Event; + +fn convert_err(err: reqwest::Error) -> std::io::Error { + let err_msg = err.to_string(); + return std::io::Error::new(std::io::ErrorKind::Interrupted, err_msg); +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct Model { + name: String, + supported_generation_methods: Vec, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct ModelListResponse { + models: Vec, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ContentPartsBlob { + mime_type: String, + data: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +enum ContentParts { + Text(String), + InlineData(ContentPartsBlob), +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct Content { + role: String, + parts: Vec, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct CompletionRequest { + contents: Vec, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GenerateContentResponse { + text: String, +} + +pub struct Gemini { + url: String, + token: String, + timeout: String, +} + +impl Default for Gemini { + fn default() -> Gemini { + return Gemini { + url: "https://generativelanguage.googleapis.com".to_string(), + token: Config::get(ConfigKey::GeminiToken), + timeout: Config::get(ConfigKey::BackendHealthCheckTimeout), + }; + } +} + +#[async_trait] +impl Backend for Gemini { + fn name(&self) -> BackendName { + return BackendName::Gemini; + } + + #[allow(clippy::implicit_return)] + async fn health_check(&self) -> Result<()> { + if self.url.is_empty() { + bail!("Gemini URL is not defined"); + } + if self.token.is_empty() { + bail!("Gemini token is not defined"); + } + + let url = format!( + "{url}/v1beta/{model}?key={key}", + url = self.url, + model = Config::get(ConfigKey::Model), + key = self.token + ); + + let res = reqwest::Client::new() + .get(&url) + .timeout(Duration::from_millis(self.timeout.parse::()?)) + .send() + .await; + + if res.is_err() { + tracing::error!(error = ?res.unwrap_err(), "Gemini is not reachable"); + bail!("Gemini is not reachable"); + } + + let status = res.unwrap().status().as_u16(); + if status >= 400 { + tracing::error!(status = status, "Gemini health check failed"); + bail!("Gemini health check failed"); + } + + return Ok(()); + } + + #[allow(clippy::implicit_return)] + async fn list_models(&self) -> Result> { + let res = reqwest::Client::new() + .get(format!( + "{url}/v1beta/models?key={key}", + url = self.url, + key = self.token + )) + .send() + .await? + .json::() + .await?; + + let mut models: Vec = res + .models + .iter() + .filter(|model| { + model + .supported_generation_methods + .contains(&"generateContent".to_string()) + }) + .map(|model| { + return model.name.to_string(); + }) + .collect(); + + models.sort(); + + return Ok(models); + } + + #[allow(clippy::implicit_return)] + async fn get_completion<'a>( + &self, + prompt: BackendPrompt, + tx: &'a mpsc::UnboundedSender, + ) -> Result<()> { + let mut contents: Vec = vec![]; + if !prompt.backend_context.is_empty() { + contents = serde_json::from_str(&prompt.backend_context)?; + } + contents.push(Content { + role: "user".to_string(), + parts: vec![ContentParts::Text(prompt.text)], + }); + + let req = CompletionRequest { + contents: contents.clone(), + }; + + let res = reqwest::Client::new() + .post(format!( + "{url}/v1beta/{model}:streamGenerateContent?key={key}", + url = self.url, + model = Config::get(ConfigKey::Model), + key = self.token, + )) + .json(&req) + .send() + .await?; + + if !res.status().is_success() { + tracing::error!( + status = res.status().as_u16(), + "Failed to make completion request to Gemini" + ); + bail!(format!( + "Failed to make completion request to Gemini, {}", + res.status().as_u16() + )); + } + let stream = res.bytes_stream().map_err(convert_err); + let mut lines_reader = StreamReader::new(stream).lines(); + + let mut last_message = "".to_string(); + while let Ok(line) = lines_reader.next_line().await { + if line.is_none() { + break; + } + + let cleaned_line = line.unwrap().trim().to_string(); + if !cleaned_line.starts_with("\"text\":") { + continue; + } + + let ores: GenerateContentResponse = + serde_json::from_str(&format!("{{ {text} }}", text = cleaned_line)).unwrap(); + + if ores.text.is_empty() || ores.text.is_empty() || ores.text == "\n" { + break; + } + + last_message += &ores.text; + let msg = BackendResponse { + author: Author::Model, + text: ores.text, + done: false, + context: None, + }; + tx.send(Event::BackendPromptResponse(msg))?; + } + + contents.push(Content { + role: "model".to_string(), + parts: vec![ContentParts::Text(last_message.clone())], + }); + + let msg = BackendResponse { + author: Author::Model, + text: "".to_string(), + done: true, + context: Some(serde_json::to_string(&contents)?), + }; + tx.send(Event::BackendPromptResponse(msg))?; + + return Ok(()); + } +} diff --git a/src/infrastructure/backends/gemini_test.rs b/src/infrastructure/backends/gemini_test.rs new file mode 100644 index 0000000..afc6ee6 --- /dev/null +++ b/src/infrastructure/backends/gemini_test.rs @@ -0,0 +1,195 @@ +use anyhow::bail; +use anyhow::Result; +use test_utils::insta_snapshot; +use tokio::sync::mpsc; + +use super::Config; +use super::Content; +use super::ContentParts; +use super::Gemini; +use super::Model; +use super::ModelListResponse; +use crate::configuration::ConfigKey; +use crate::domain::models::Author; +use crate::domain::models::Backend; +use crate::domain::models::BackendPrompt; +use crate::domain::models::BackendResponse; +use crate::domain::models::Event; + +impl Gemini { + fn with_url(url: String) -> Gemini { + return Gemini { + url, + token: "abc".to_string(), + timeout: "200".to_string(), + }; + } +} + +fn to_res(action: Option) -> Result { + let act = match action.unwrap() { + Event::BackendPromptResponse(res) => res, + _ => bail!("Wrong type from recv"), + }; + + return Ok(act); +} + +#[tokio::test] +async fn it_successfully_health_checks() { + Config::set(ConfigKey::Model, "model-1"); + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/v1beta/model-1?key=abc") + .with_status(200) + .create(); + + let backend = Gemini::with_url(server.url()); + let res = backend.health_check().await; + + assert!(res.is_ok()); + mock.assert(); +} + +#[tokio::test] +async fn it_successfully_health_checks_with_official_api() { + Config::set(ConfigKey::Model, "models/gemini-pro"); + let token = match std::env::var("OATMEAL_GEMINI_TOKEN") { + Ok(token) => token, + Err(_) => { + println!("There is no token in environment defined, skipping test"); + return; + } + }; + let backend = Gemini { + url: "https://generativelanguage.googleapis.com".to_string(), + token, + timeout: "500".to_string(), + }; + + let res = backend.health_check().await; + assert!(res.is_ok()); +} + +#[tokio::test] +async fn it_fails_health_checks() { + Config::set(ConfigKey::Model, "model-1"); + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/v1beta/model-1?key=abc") + .with_status(500) + .create(); + + let backend = Gemini::with_url(server.url()); + let res = backend.health_check().await; + + assert!(res.is_err()); + mock.assert(); +} + +#[tokio::test] +async fn it_lists_models() -> Result<()> { + let body = serde_json::to_string(&ModelListResponse { + models: vec![ + Model { + name: "first".to_string(), + supported_generation_methods: vec!["generateContent".to_string()], + }, + Model { + name: "second".to_string(), + supported_generation_methods: vec!["generateContent".to_string()], + }, + ], + })?; + + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/v1beta/models?key=abc") + .with_status(200) + .with_body(body) + .create(); + + let backend = Gemini::with_url(server.url()); + let res = backend.list_models().await?; + mock.assert(); + + assert_eq!(res, vec!["first".to_string(), "second".to_string()]); + + return Ok(()); +} + +#[tokio::test] +async fn it_gets_completions() -> Result<()> { + Config::set(ConfigKey::Model, "model-1"); + let body = r#" +{ + "contents": [ + { + "parts": [ + { + "text": "Hello " + } + ] + }, + { + "parts": [ + { + "text": "World" + } + ] + }, + { + "parts": [ + { + "text": "" + } + ] + } + ] +} + "#; + let prompt = BackendPrompt { + text: "Say hi to the world".to_string(), + backend_context: serde_json::to_string(&vec![Content { + role: "model".to_string(), + parts: vec![ContentParts::Text("Hello".to_string())], + }])?, + }; + + let mut server = mockito::Server::new(); + let mock = server + .mock("POST", "/v1beta/model-1:streamGenerateContent?key=abc") + .with_status(200) + .with_body(body) + .create(); + + let (tx, mut rx) = mpsc::unbounded_channel::(); + + let backend = Gemini::with_url(server.url()); + backend.get_completion(prompt, &tx).await?; + + mock.assert(); + + let first_recv = to_res(rx.recv().await)?; + let second_recv = to_res(rx.recv().await)?; + let third_recv = to_res(rx.recv().await)?; + + assert_eq!(first_recv.author, Author::Model); + assert_eq!(first_recv.text, "Hello ".to_string()); + assert!(!first_recv.done); + assert_eq!(first_recv.context, None); + + assert_eq!(second_recv.author, Author::Model); + assert_eq!(second_recv.text, "World".to_string()); + assert!(!second_recv.done); + assert_eq!(second_recv.context, None); + + assert_eq!(third_recv.author, Author::Model); + assert_eq!(third_recv.text, "".to_string()); + assert!(third_recv.done); + insta_snapshot(|| { + insta::assert_toml_snapshot!(third_recv.context); + }); + + return Ok(()); +} diff --git a/src/infrastructure/backends/mod.rs b/src/infrastructure/backends/mod.rs index 8f992e8..ee2fb68 100644 --- a/src/infrastructure/backends/mod.rs +++ b/src/infrastructure/backends/mod.rs @@ -1,3 +1,4 @@ +pub mod gemini; pub mod langchain; pub mod ollama; pub mod openai; @@ -23,6 +24,10 @@ impl BackendManager { return Ok(Box::::default()); } + if name == BackendName::Gemini { + return Ok(Box::::default()); + } + bail!(format!("No backend implemented for {name}")) } } diff --git a/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap b/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap index 37fa6b8..8746b7f 100644 --- a/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap +++ b/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap @@ -3,7 +3,7 @@ source: src/configuration/config_test.rs expression: res --- ''' -# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai] +# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai, gemini] backend = "ollama" # Time to wait in milliseconds before timing out when doing a healthcheck for a backend. @@ -27,6 +27,9 @@ ollama-url = "http://localhost:11434" # OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. open-ai-url = "https://api.openai.com" +# Google Gemini API token when using the Gemini backend. +# gemini-token = "" + # Sets code syntax highlighting theme. [possible values: base16-github, base16-monokai, base16-one-light, base16-onedark, base16-seti] theme = "base16-onedark" diff --git a/test/snapshots/oatmeal__infrastructure__backends__gemini__tests__it_gets_completions.snap b/test/snapshots/oatmeal__infrastructure__backends__gemini__tests__it_gets_completions.snap new file mode 100644 index 0000000..78aa7ac --- /dev/null +++ b/test/snapshots/oatmeal__infrastructure__backends__gemini__tests__it_gets_completions.snap @@ -0,0 +1,5 @@ +--- +source: src/infrastructure/backends/gemini_test.rs +expression: third_recv.context +--- +'[{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"Say hi to the world"}]},{"role":"model","parts":[{"text":"Hello World"}]}]'