Skip to content

Commit

Permalink
feature: Add gemini backend (#52)
Browse files Browse the repository at this point in the history
* feat: Add gemini enums

* feat: Add cli arguments for gemini

* feat: Add first version gemini

* fix: Add break on empty text

* feat: Added tests and snapshots

* feat: Update readme

* fix: Commented out unused model attributes

* fix: Config URL for Gemini cannot be different

* docs: Removed gemini-url from docs

* test: Removed unused vars in model from tests

* refactor: Clean imports in test

* docs: Fix config example, removed unused variables

* fix: Add config set to gemini tests

* fix: Added config set to gemini test completions

* refactor: Change completion test body to raw string

* chore: Run fmt lint
  • Loading branch information
aislasq authored Mar 12, 2024
1 parent d363b35 commit 55f7e7d
Show file tree
Hide file tree
Showing 10 changed files with 477 additions and 3 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ Commands:
Options:
-b, --backend <backend>
The initial backend hosting a model to connect to. [default: ollama] [env: OATMEAL_BACKEND=] [possible values: langchain, ollama, openai]
The initial backend hosting a model to connect to. [default: ollama] [env: OATMEAL_BACKEND=] [possible values: langchain, ollama, openai, gemini]
--backend-health-check-timeout <backend-health-check-timeout>
Time to wait in milliseconds before timing out when doing a healthcheck for a backend. [default: 1000] [env: OATMEAL_BACKEND_HEALTH_CHECK_TIMEOUT=]
-m, --model <model>
Expand All @@ -194,6 +194,8 @@ Options:
OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. [default: https://api.openai.com] [env: OATMEAL_OPENAI_URL=]
--open-ai-token <open-ai-token>
OpenAI API token when using the OpenAI backend. [env: OATMEAL_OPENAI_TOKEN=]
--gemini-token <gemini-token>
Gemini API token when using the Gemini backend. [env: OATMEAL_GEMINI_TOKEN=]
-h, --help
Print help
-V, --version
Expand Down Expand Up @@ -261,6 +263,7 @@ The following model backends are supported:
- [OpenAI](https://chat.openai.com) (Or any compatible proxy/API)
- [Ollama](https://github.com/jmorganca/ollama)
- [LangChain/LangServe](https://python.langchain.com/docs/langserve) (Experimental)
- [Gemini](https://gemini.google.com) (Experimental)

### Editors

Expand Down
5 changes: 4 additions & 1 deletion config.example.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai]
# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai, gemini]
backend = "ollama"

# Time to wait in milliseconds before timing out when doing a healthcheck for a backend.
Expand All @@ -22,6 +22,9 @@ ollama-url = "http://localhost:11434"
# OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy.
open-ai-url = "https://api.openai.com"

# Gemini API token when using the Gemini backend.
# gemini-token = ""

# Sets code syntax highlighting theme. [possible values: base16-github, base16-monokai, base16-one-light, base16-onedark, base16-seti]
theme = "base16-onedark"

Expand Down
8 changes: 8 additions & 0 deletions src/application/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,14 @@ pub fn build() -> Command {
.num_args(1)
.help("OpenAI API token when using the OpenAI backend.")
.global(true),
)
.arg(
Arg::new(ConfigKey::GeminiToken.to_string())
.long(ConfigKey::GeminiToken.to_string())
.env("OATMEAL_GEMINI_TOKEN")
.num_args(1)
.help("Google Gemini API token when using the Gemini backend.")
.global(true),
);
}

Expand Down
2 changes: 2 additions & 0 deletions src/configuration/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub enum ConfigKey {
OllamaURL,
OpenAiToken,
OpenAiURL,
GeminiToken,
SessionID,
Theme,
ThemeFile,
Expand Down Expand Up @@ -99,6 +100,7 @@ impl Config {
ConfigKey::OllamaURL => "http://localhost:11434",
ConfigKey::OpenAiToken => "",
ConfigKey::OpenAiURL => "https://api.openai.com",
ConfigKey::GeminiToken => "",
ConfigKey::Theme => "base16-onedark",
ConfigKey::ThemeFile => "",

Expand Down
1 change: 1 addition & 0 deletions src/domain/models/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub enum BackendName {
LangChain,
Ollama,
OpenAI,
Gemini,
}

impl BackendName {
Expand Down
249 changes: 249 additions & 0 deletions src/infrastructure/backends/gemini.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
#[cfg(test)]
#[path = "gemini_test.rs"]
mod tests;

use std::time::Duration;

use anyhow::bail;
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::TryStreamExt;
use serde::Deserialize;
use serde::Serialize;
use tokio::io::AsyncBufReadExt;
use tokio::sync::mpsc;
use tokio_util::io::StreamReader;

use crate::configuration::Config;
use crate::configuration::ConfigKey;
use crate::domain::models::Author;
use crate::domain::models::Backend;
use crate::domain::models::BackendName;
use crate::domain::models::BackendPrompt;
use crate::domain::models::BackendResponse;
use crate::domain::models::Event;

fn convert_err(err: reqwest::Error) -> std::io::Error {
let err_msg = err.to_string();
return std::io::Error::new(std::io::ErrorKind::Interrupted, err_msg);
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Model {
name: String,
supported_generation_methods: Vec<String>,
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct ModelListResponse {
models: Vec<Model>,
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ContentPartsBlob {
mime_type: String,
data: String,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
enum ContentParts {
Text(String),
InlineData(ContentPartsBlob),
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct Content {
role: String,
parts: Vec<ContentParts>,
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct CompletionRequest {
contents: Vec<Content>,
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GenerateContentResponse {
text: String,
}

pub struct Gemini {
url: String,
token: String,
timeout: String,
}

impl Default for Gemini {
fn default() -> Gemini {
return Gemini {
url: "https://generativelanguage.googleapis.com".to_string(),
token: Config::get(ConfigKey::GeminiToken),
timeout: Config::get(ConfigKey::BackendHealthCheckTimeout),
};
}
}

#[async_trait]
impl Backend for Gemini {
fn name(&self) -> BackendName {
return BackendName::Gemini;
}

#[allow(clippy::implicit_return)]
async fn health_check(&self) -> Result<()> {
if self.url.is_empty() {
bail!("Gemini URL is not defined");
}
if self.token.is_empty() {
bail!("Gemini token is not defined");
}

let url = format!(
"{url}/v1beta/{model}?key={key}",
url = self.url,
model = Config::get(ConfigKey::Model),
key = self.token
);

let res = reqwest::Client::new()
.get(&url)
.timeout(Duration::from_millis(self.timeout.parse::<u64>()?))
.send()
.await;

if res.is_err() {
tracing::error!(error = ?res.unwrap_err(), "Gemini is not reachable");
bail!("Gemini is not reachable");
}

let status = res.unwrap().status().as_u16();
if status >= 400 {
tracing::error!(status = status, "Gemini health check failed");
bail!("Gemini health check failed");
}

return Ok(());
}

#[allow(clippy::implicit_return)]
async fn list_models(&self) -> Result<Vec<String>> {
let res = reqwest::Client::new()
.get(format!(
"{url}/v1beta/models?key={key}",
url = self.url,
key = self.token
))
.send()
.await?
.json::<ModelListResponse>()
.await?;

let mut models: Vec<String> = res
.models
.iter()
.filter(|model| {
model
.supported_generation_methods
.contains(&"generateContent".to_string())
})
.map(|model| {
return model.name.to_string();
})
.collect();

models.sort();

return Ok(models);
}

#[allow(clippy::implicit_return)]
async fn get_completion<'a>(
&self,
prompt: BackendPrompt,
tx: &'a mpsc::UnboundedSender<Event>,
) -> Result<()> {
let mut contents: Vec<Content> = vec![];
if !prompt.backend_context.is_empty() {
contents = serde_json::from_str(&prompt.backend_context)?;
}
contents.push(Content {
role: "user".to_string(),
parts: vec![ContentParts::Text(prompt.text)],
});

let req = CompletionRequest {
contents: contents.clone(),
};

let res = reqwest::Client::new()
.post(format!(
"{url}/v1beta/{model}:streamGenerateContent?key={key}",
url = self.url,
model = Config::get(ConfigKey::Model),
key = self.token,
))
.json(&req)
.send()
.await?;

if !res.status().is_success() {
tracing::error!(
status = res.status().as_u16(),
"Failed to make completion request to Gemini"
);
bail!(format!(
"Failed to make completion request to Gemini, {}",
res.status().as_u16()
));
}
let stream = res.bytes_stream().map_err(convert_err);
let mut lines_reader = StreamReader::new(stream).lines();

let mut last_message = "".to_string();
while let Ok(line) = lines_reader.next_line().await {
if line.is_none() {
break;
}

let cleaned_line = line.unwrap().trim().to_string();
if !cleaned_line.starts_with("\"text\":") {
continue;
}

let ores: GenerateContentResponse =
serde_json::from_str(&format!("{{ {text} }}", text = cleaned_line)).unwrap();

if ores.text.is_empty() || ores.text.is_empty() || ores.text == "\n" {
break;
}

last_message += &ores.text;
let msg = BackendResponse {
author: Author::Model,
text: ores.text,
done: false,
context: None,
};
tx.send(Event::BackendPromptResponse(msg))?;
}

contents.push(Content {
role: "model".to_string(),
parts: vec![ContentParts::Text(last_message.clone())],
});

let msg = BackendResponse {
author: Author::Model,
text: "".to_string(),
done: true,
context: Some(serde_json::to_string(&contents)?),
};
tx.send(Event::BackendPromptResponse(msg))?;

return Ok(());
}
}
Loading

0 comments on commit 55f7e7d

Please sign in to comment.