Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
christos-h committed Feb 26, 2025
1 parent a332eeb commit 6c14d60
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 59 deletions.
1 change: 1 addition & 0 deletions examples/agent/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions examples/agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ reqwest = { version = "0.11.24", default-features = false, features = [
"rustls-tls",
] }
rig-core = "0.8.0"
schemars = "0.8.21"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
tokio = { version = "1.25.0", features = ["macros", "rt-multi-thread"] }
Expand Down
172 changes: 113 additions & 59 deletions examples/agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use rig::{
providers::openai,
tool::Tool,
};
use schemars::{schema_for, JsonSchema};
use serde::{Deserialize, Serialize};
use serde_json::json;
use url::Url;
Expand Down Expand Up @@ -148,11 +149,8 @@ async fn main() {

let openai = openai::Client::from_env();
let model = openai.completion_model(&opt.model);
let node_service = LineraNodeService::<openai::CompletionModel>::new(
opt.node_service_url.parse().unwrap(),
model.clone(),
)
.unwrap();
let node_service =
LineraNodeService::new(opt.node_service_url.parse().unwrap(), opt.model.clone()).unwrap();

let system_graphql_def = node_service.system_graphql_definition().await.unwrap();
let graphql_context = format!(
Expand All @@ -171,16 +169,18 @@ async fn main() {
chat(agent).await;
}

#[derive(Debug, Deserialize)]
#[serde(tag = "tag ")]
enum LineraNodeServiceArgs {
#[derive(Debug, Deserialize, JsonSchema)]
struct LineraNodeServiceArgs {
input: QueryType,
}

#[derive(Debug, Deserialize, JsonSchema)]
enum QueryType {
QuerySystem {
query: String,
},
AddApplication {
application_id: String,
},
QueryApplication {
chain_id: String,
application_id: String,
query: String,
},
Expand All @@ -192,18 +192,22 @@ struct LineraNodeServiceOutput {
errors: Option<Vec<serde_json::Value>>,
}

struct LineraNodeService<M: CompletionModel> {
struct LineraNodeService {
url: Url,
client: Client,
model: M,
model_name: String,
ensemble_tx: tokio::sync::mpsc::Sender<EnsembleQuery>,
}

impl<M: CompletionModel> LineraNodeService<M> {
fn new(url: Url, model: M) -> Result<Self> {
impl LineraNodeService {
fn new(url: Url, model_name: String) -> Result<Self> {
let (tx, rx) = tokio::sync::mpsc::channel(100);
tokio::spawn(run_ensemble(rx));
Ok(Self {
url,
client: Client::new(),
model,
model_name,
ensemble_tx: tx,
})
}

Expand All @@ -225,7 +229,7 @@ impl<M: CompletionModel> LineraNodeService<M> {
}
}

impl<M: CompletionModel> Tool for LineraNodeService<M> {
impl Tool for LineraNodeService {
const NAME: &'static str = "Linera";
type Error = reqwest::Error;
type Args = LineraNodeServiceArgs;
Expand All @@ -235,22 +239,14 @@ impl<M: CompletionModel> Tool for LineraNodeService<M> {
ToolDefinition {
name: "Linera".to_string(),
description: "Interact with a Linera wallet via GraphQL".to_string(),
parameters: json!({

"tag": { "type": "object", "oneOf": [
{ "properties": { "query": { "type": "string" } }, "required": ["query"] },
{ "properties": { "application_id": { "type": "string" } }, "required": ["application_id"] },
{ "properties": { "application_id": { "type": "string" }, "query": { "type": "string" } }, "required": ["application_id", "query"] }
]}

}),
parameters: serde_json::to_value(schema_for!(LineraNodeServiceArgs)).unwrap(),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
eprintln!("Args: {:?}", args);
match args {
LineraNodeServiceArgs::QuerySystem { query } => {
match args.input {
QueryType::QuerySystem { query } => {
let response = self
.client
.post(self.url.clone())
Expand All @@ -259,26 +255,23 @@ impl<M: CompletionModel> Tool for LineraNodeService<M> {
.await?;
response.json().await
}
LineraNodeServiceArgs::AddApplication { application_id } => {
unimplemented!();
}
LineraNodeServiceArgs::QueryApplication {
QueryType::QueryApplication {
chain_id,
application_id,
query,
} => {
let url = self.url.join(&application_id).unwrap();
let graphql_def = self.get_graphql_definition(url).await.unwrap();
let linera_app_service = LineraApplicationService::new(self.url.clone(), application_id, self.client.clone());
println!("Got here.");
let agent = AgentBuilder::new(self.model.clone())
.preamble(PREAMBLE)
.context(LINERA_CONTEXT)
.context(&graphql_def)
.tool(linera_app_service)
.build();
let runtime = tokio::runtime::Runtime::new().unwrap();
let response =
runtime.block_on(async { agent.chat(query.as_str(), vec![]).await.unwrap() });
let (tx, rx) = tokio::sync::oneshot::channel::<String>();
let ensemble_query = EnsembleQuery {
url: self.url.clone(),
chain_id,
application_id,
query,
model_name: self.model_name.clone(),
graphql_def: self.get_graphql_definition(self.url.clone()).await.unwrap(),
sender: tx,
};
self.ensemble_tx.send(ensemble_query).await.unwrap();
let response = rx.await.unwrap();
Ok(LineraNodeServiceOutput {
data: json!(response),
errors: None,
Expand All @@ -288,27 +281,88 @@ impl<M: CompletionModel> Tool for LineraNodeService<M> {
}
}

struct EnsembleQuery {
url: Url,
chain_id: String,
application_id: String,
query: String,
model_name: String,
graphql_def: String,
sender: tokio::sync::oneshot::Sender<String>,
}

async fn run_ensemble(mut rx: tokio::sync::mpsc::Receiver<EnsembleQuery>) {
while let Some(ensemble_query) = rx.recv().await {
let linera_app_service = LineraApplicationService::new(
ensemble_query.url,
ensemble_query.chain_id,
ensemble_query.application_id,
Client::new(),
);
let openai = openai::Client::from_env();
let model = openai.completion_model(&ensemble_query.model_name);
let agent = AgentBuilder::new(model)
.preamble(PREAMBLE)
.context(LINERA_CONTEXT)
.context(&ensemble_query.graphql_def)
.tool(linera_app_service)
.build();
let mut backoff = tokio::time::Duration::from_secs(2);
let mut attempts = 0;
let response = loop {
match agent.chat(ensemble_query.query.as_str(), vec![]).await {
Ok(response) => break response,
Err(e) => {
attempts += 1;
eprintln!(
"Error occurred: {}. Retrying in {} seconds...",
e,
backoff.as_secs()
);
if attempts >= 5 {
panic!("Failed after 5 attempts");
}
tokio::time::sleep(backoff).await;
backoff *= 2;
}
}
};

if let Err(e) = ensemble_query.sender.send(response) {
eprintln!("Error sending response from ensemble: {}", e);
}
}
}

#[derive(Debug, Deserialize)]
struct LineraApplicationServiceArgs {
query: String,
}

struct LineraApplicationService {
name: String,
url: Url,
client: Client,
}

impl LineraApplicationService {
fn new(node_url: Url, application_id: String, client: Client) -> Self {
Self {
url: node_url.join(&application_id).unwrap(),
client,
}
fn new(node_url: Url, chain_id: String, application_id: String, client: Client) -> Self {
let name = format!("Application {} Tool", application_id);
let url = node_url
.join("chains")
.unwrap()
.join(&chain_id)
.unwrap()
.join("applications")
.unwrap()
.join(&application_id)
.unwrap();
Self { name, url, client }
}
}

impl Tool for LineraApplicationService {
const NAME: &'static str = ""; // Don't need this.
const NAME: &'static str = "Linera Application Service";
type Error = reqwest::Error;
type Args = LineraApplicationServiceArgs; // GraphQL
type Output = LineraNodeServiceOutput;
Expand All @@ -335,14 +389,14 @@ impl Tool for LineraApplicationService {
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let query = args.query;
let response = self
.client
.post(self.url.clone())
.json(&json!({ "query": query }))
.send()
.await?;
response.json().await
let query = args.query;
let response = self
.client
.post(self.url.clone())
.json(&json!({ "query": query }))
.send()
.await?;
response.json().await
}
}

Expand Down

0 comments on commit 6c14d60

Please sign in to comment.