Skip to content

Commit

Permalink
cleanup initializing tools
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Mar 4, 2025
1 parent 828684e commit aedb9ca
Showing 1 changed file with 69 additions and 51 deletions.
120 changes: 69 additions & 51 deletions crates/torii/server/src/handlers/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde_json::{json, Number, Value};
use sqlx::{Row, SqlitePool};
use tokio::sync::{broadcast, RwLock};
use tokio_tungstenite::tungstenite::Message;
use tracing::warn;
use uuid::Uuid;

use super::sql::map_row_to_json;
Expand Down Expand Up @@ -84,21 +85,68 @@ struct ResourceCapabilities {
#[derive(Clone, Debug)]
struct SseSession {
tx: broadcast::Sender<JsonRpcResponse>,
session_id: String,
_session_id: String,
}

#[derive(Clone, Debug)]
struct Tool {
name: &'static str,
description: &'static str,
input_schema: Value,
}

#[derive(Clone, Debug)]
struct Resource {
name: &'static str,
}

#[derive(Clone, Debug)]
pub struct McpHandler {
pool: Arc<SqlitePool>,
// Map of session IDs to SSE sessions
sse_sessions: Arc<RwLock<std::collections::HashMap<String, SseSession>>>,
tools: Vec<Tool>,
resources: Vec<Resource>,
}

impl McpHandler {
pub fn new(pool: Arc<SqlitePool>) -> Self {
let tools = vec![
Tool {
name: "query",
description: "Execute a SQL query on the database",
input_schema: json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL query to execute"
}
},
"required": ["query"]
}),
},
Tool {
name: "schema",
description: "Retrieve the database schema including tables, columns, and their types",
input_schema: json!({
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "Optional table name to get schema for. If omitted, returns schema for all tables."
}
}
}),
},
];

let resources = vec![]; // Add resources as needed

Self {
pool,
sse_sessions: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
tools,
resources,
}
}

Expand Down Expand Up @@ -141,39 +189,17 @@ impl McpHandler {
}

fn handle_tools_list(&self, id: Value) -> JsonRpcResponse {
let tools_json: Vec<Value> = self.tools.iter().map(|tool| {
json!({
"name": tool.name,
"description": tool.description,
"inputSchema": tool.input_schema
})
}).collect();

JsonRpcResponse::ok(
id,
json!({
"tools": [
{
"name": "query",
// "description": "Execute a SQL query on the database",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL query to execute"
}
},
"required": ["query"]
}
},
{
"name": "schema",
// "description": "Retrieve the database schema including tables, columns, and their types",
"inputSchema": {
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "Optional table name to get schema for. If omitted, returns schema for all tables."
}
}
}
}
]
}),
json!({ "tools": tools_json }),
)
}

Expand Down Expand Up @@ -213,7 +239,7 @@ impl McpHandler {
if let Err(e) =
write.send(Message::Text(serde_json::to_string(&response).unwrap())).await
{
eprintln!("Error sending message: {}", e);
warn!("Error sending message: {}", e);
break;
}
}
Expand All @@ -233,7 +259,7 @@ impl McpHandler {
let mut sessions = self.sse_sessions.write().await;
sessions.insert(
session_id.clone(),
SseSession { tx: tx.clone(), session_id: session_id.clone() },
SseSession { tx: tx.clone(), _session_id: session_id.clone() },
);
}

Expand Down Expand Up @@ -261,7 +287,7 @@ impl McpHandler {
))
}
Err(e) => {
eprintln!("Error serializing message: {}", e);
warn!("Error serializing message: {}", e);
// Format error event with proper SSE format
Some((
Ok::<_, hyper::Error>(hyper::body::Bytes::from(format!(
Expand Down Expand Up @@ -346,12 +372,7 @@ impl McpHandler {

// Forward the response to the SSE channel
if let Err(e) = tx.send(response.clone()) {
eprintln!("Error forwarding response to SSE: {}", e);
} else {
eprintln!(
"Successfully sent response to SSE channel: {:?}",
response.id
);
warn!("Error forwarding response to SSE: {}", e);
}

Response::builder()
Expand Down Expand Up @@ -387,12 +408,7 @@ impl McpHandler {

// Forward the response to the SSE channel
if let Err(e) = tx.send(response.clone()) {
eprintln!("Error forwarding response to SSE: {}", e);
} else {
eprintln!(
"Successfully sent response to SSE channel: {:?}",
response.id
);
warn!("Error forwarding response to SSE: {}", e);
}

Response::builder()
Expand Down Expand Up @@ -578,11 +594,13 @@ impl McpHandler {

// New method to handle resources/list
fn handle_resources_list(&self, id: Value) -> JsonRpcResponse {
let resources_json: Vec<Value> = self.resources.iter().map(|resource| {
json!({ "name": resource.name })
}).collect();

JsonRpcResponse::ok(
id,
json!({
"resources": []
}),
json!({ "resources": resources_json }),
)
}

Expand Down

0 comments on commit aedb9ca

Please sign in to comment.