From 296162d14168de79d039420e9836c02f04280da4 Mon Sep 17 00:00:00 2001 From: Larko <59736843+Larkooo@users.noreply.github.com> Date: Tue, 4 Mar 2025 13:11:39 +0800 Subject: [PATCH] opt(torii-server): initializing handlers (#3078) * refactorC(torii-server): initializing handlers * keep handlers in place & write graphql * fmt * c --- Cargo.lock | 2 +- crates/torii/server/src/handlers/graphql.rs | 13 ++--- crates/torii/server/src/handlers/grpc.rs | 10 ++-- crates/torii/server/src/handlers/mcp.rs | 9 +-- crates/torii/server/src/handlers/mod.rs | 6 +- crates/torii/server/src/handlers/sql.rs | 4 +- .../torii/server/src/handlers/static_files.rs | 13 ++--- crates/torii/server/src/proxy.rs | 58 +++++++------------ 8 files changed, 49 insertions(+), 66 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 795bb501e5..d0e59341b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8153,7 +8153,7 @@ dependencies = [ [[package]] name = "katana-explorer" -version = "1.2.1" +version = "1.2.2" dependencies = [ "anyhow", "rust-embed", diff --git a/crates/torii/server/src/handlers/graphql.rs b/crates/torii/server/src/handlers/graphql.rs index c51e7cccd7..0f11d6964e 100644 --- a/crates/torii/server/src/handlers/graphql.rs +++ b/crates/torii/server/src/handlers/graphql.rs @@ -8,14 +8,14 @@ use super::Handler; pub(crate) const LOG_TARGET: &str = "torii::server::handlers::graphql"; +#[derive(Debug)] pub struct GraphQLHandler { - client_ip: IpAddr, - graphql_addr: Option, + pub(crate) graphql_addr: Option, } impl GraphQLHandler { - pub fn new(client_ip: IpAddr, graphql_addr: Option) -> Self { - Self { client_ip, graphql_addr } + pub fn new(graphql_addr: Option) -> Self { + Self { graphql_addr } } } @@ -25,11 +25,10 @@ impl Handler for GraphQLHandler { req.uri().path().starts_with("/graphql") } - async fn handle(&self, req: Request) -> Response { + async fn handle(&self, req: Request, client_addr: IpAddr) -> Response { if let Some(addr) = self.graphql_addr { let graphql_addr = format!("http://{}", addr); - match crate::proxy::GRAPHQL_PROXY_CLIENT.call(self.client_ip, &graphql_addr, req).await - { + match crate::proxy::GRAPHQL_PROXY_CLIENT.call(client_addr, &graphql_addr, req).await { Ok(response) => response, Err(_error) => { error!(target: LOG_TARGET, "GraphQL proxy error: {:?}", _error); diff --git a/crates/torii/server/src/handlers/grpc.rs b/crates/torii/server/src/handlers/grpc.rs index befa3e56a4..2c8a914bdb 100644 --- a/crates/torii/server/src/handlers/grpc.rs +++ b/crates/torii/server/src/handlers/grpc.rs @@ -8,14 +8,14 @@ use super::Handler; pub(crate) const LOG_TARGET: &str = "torii::server::handlers::grpc"; +#[derive(Debug)] pub struct GrpcHandler { - client_ip: IpAddr, grpc_addr: Option, } impl GrpcHandler { - pub fn new(client_ip: IpAddr, grpc_addr: Option) -> Self { - Self { client_ip, grpc_addr } + pub fn new(grpc_addr: Option) -> Self { + Self { grpc_addr } } } @@ -29,10 +29,10 @@ impl Handler for GrpcHandler { .unwrap_or(false) } - async fn handle(&self, req: Request) -> Response { + async fn handle(&self, req: Request, client_addr: IpAddr) -> Response { if let Some(grpc_addr) = self.grpc_addr { let grpc_addr = format!("http://{}", grpc_addr); - match crate::proxy::GRPC_PROXY_CLIENT.call(self.client_ip, &grpc_addr, req).await { + match crate::proxy::GRPC_PROXY_CLIENT.call(client_addr, &grpc_addr, req).await { Ok(response) => response, Err(_error) => { error!(target: LOG_TARGET, "{:?}", _error); diff --git a/crates/torii/server/src/handlers/mcp.rs b/crates/torii/server/src/handlers/mcp.rs index 22c9e4a64b..c6b64d7e74 100644 --- a/crates/torii/server/src/handlers/mcp.rs +++ b/crates/torii/server/src/handlers/mcp.rs @@ -1,3 +1,4 @@ +use std::net::IpAddr; use std::sync::Arc; use futures_util::{SinkExt, StreamExt}; @@ -75,7 +76,7 @@ struct ResourceCapabilities { list_changed: bool, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct McpHandler { pool: Arc, } @@ -219,7 +220,7 @@ impl McpHandler { AND m.name = ? ORDER BY m.name, p.cid" .to_string(), - None => "SELECT + _ => "SELECT m.name as table_name, p.* FROM sqlite_master m @@ -231,7 +232,7 @@ impl McpHandler { let rows = match table_filter { Some(table) => sqlx::query(&schema_query).bind(table).fetch_all(&*self.pool).await, - None => sqlx::query(&schema_query).fetch_all(&*self.pool).await, + _ => sqlx::query(&schema_query).fetch_all(&*self.pool).await, }; match rows { @@ -393,7 +394,7 @@ impl Handler for McpHandler { .unwrap_or(false) } - async fn handle(&self, req: Request) -> Response { + async fn handle(&self, req: Request, _client_addr: IpAddr) -> Response { if hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(req, None) .expect("Failed to upgrade WebSocket connection"); diff --git a/crates/torii/server/src/handlers/mod.rs b/crates/torii/server/src/handlers/mod.rs index d40ece8ad0..aa73962e31 100644 --- a/crates/torii/server/src/handlers/mod.rs +++ b/crates/torii/server/src/handlers/mod.rs @@ -4,13 +4,15 @@ pub mod mcp; pub mod sql; pub mod static_files; +use std::net::IpAddr; + use hyper::{Body, Request, Response}; #[async_trait::async_trait] -pub trait Handler: Send + Sync { +pub trait Handler: Send + Sync + std::fmt::Debug { // Check if this handler should handle the given request fn should_handle(&self, req: &Request) -> bool; // Handle the request - async fn handle(&self, req: Request) -> Response; + async fn handle(&self, req: Request, client_addr: IpAddr) -> Response; } diff --git a/crates/torii/server/src/handlers/sql.rs b/crates/torii/server/src/handlers/sql.rs index 563786f495..2959b58462 100644 --- a/crates/torii/server/src/handlers/sql.rs +++ b/crates/torii/server/src/handlers/sql.rs @@ -1,3 +1,4 @@ +use std::net::IpAddr; use std::sync::Arc; use base64::engine::general_purpose::STANDARD; @@ -9,6 +10,7 @@ use sqlx::{Column, Row, SqlitePool, TypeInfo}; use super::Handler; +#[derive(Debug)] pub struct SqlHandler { pool: Arc, } @@ -111,7 +113,7 @@ impl Handler for SqlHandler { req.uri().path().starts_with("/sql") } - async fn handle(&self, req: Request) -> Response { + async fn handle(&self, req: Request, _client_addr: IpAddr) -> Response { self.handle_request(req).await } } diff --git a/crates/torii/server/src/handlers/static_files.rs b/crates/torii/server/src/handlers/static_files.rs index 631b032e11..13c12d4838 100644 --- a/crates/torii/server/src/handlers/static_files.rs +++ b/crates/torii/server/src/handlers/static_files.rs @@ -7,14 +7,14 @@ use super::Handler; pub(crate) const LOG_TARGET: &str = "torii::server::handlers::static"; +#[derive(Debug)] pub struct StaticHandler { - client_ip: IpAddr, artifacts_addr: Option, } impl StaticHandler { - pub fn new(client_ip: IpAddr, artifacts_addr: Option) -> Self { - Self { client_ip, artifacts_addr } + pub fn new(artifacts_addr: Option) -> Self { + Self { artifacts_addr } } } @@ -24,13 +24,10 @@ impl Handler for StaticHandler { req.uri().path().starts_with("/static") } - async fn handle(&self, req: Request) -> Response { + async fn handle(&self, req: Request, client_addr: IpAddr) -> Response { if let Some(artifacts_addr) = self.artifacts_addr { let artifacts_addr = format!("http://{}", artifacts_addr); - match crate::proxy::GRAPHQL_PROXY_CLIENT - .call(self.client_ip, &artifacts_addr, req) - .await - { + match crate::proxy::GRAPHQL_PROXY_CLIENT.call(client_addr, &artifacts_addr, req).await { Ok(response) => response, Err(_error) => { error!(target: LOG_TARGET, "{:?}", _error); diff --git a/crates/torii/server/src/proxy.rs b/crates/torii/server/src/proxy.rs index bf1c57ae12..cd99e3caa1 100644 --- a/crates/torii/server/src/proxy.rs +++ b/crates/torii/server/src/proxy.rs @@ -64,10 +64,7 @@ lazy_static::lazy_static! { pub struct Proxy { addr: SocketAddr, allowed_origins: Option>, - grpc_addr: Option, - artifacts_addr: Option, - graphql_addr: Arc>>, - pool: Arc, + handlers: Arc>>>, } impl Proxy { @@ -79,19 +76,20 @@ impl Proxy { artifacts_addr: Option, pool: Arc, ) -> Self { - Self { - addr, - allowed_origins, - grpc_addr, - graphql_addr: Arc::new(RwLock::new(graphql_addr)), - artifacts_addr, - pool, - } + let handlers: Arc>>> = Arc::new(RwLock::new(vec![ + Box::new(GraphQLHandler::new(graphql_addr)), + Box::new(GrpcHandler::new(grpc_addr)), + Box::new(McpHandler::new(pool.clone())), + Box::new(SqlHandler::new(pool.clone())), + Box::new(StaticHandler::new(artifacts_addr)), + ])); + + Self { addr, allowed_origins, handlers } } pub async fn set_graphql_addr(&self, addr: SocketAddr) { - let mut graphql_addr = self.graphql_addr.write().await; - *graphql_addr = Some(addr); + let mut handlers = self.handlers.write().await; + handlers[0] = Box::new(GraphQLHandler::new(Some(addr))); } pub async fn start( @@ -100,13 +98,10 @@ impl Proxy { ) -> Result<(), hyper::Error> { let addr = self.addr; let allowed_origins = self.allowed_origins.clone(); - let grpc_addr = self.grpc_addr; - let graphql_addr = self.graphql_addr.clone(); - let artifacts_addr = self.artifacts_addr; - let pool = self.pool.clone(); let make_svc = make_service_fn(move |conn: &AddrStream| { let remote_addr = conn.remote_addr().ip(); + let cors = CorsLayer::new() .max_age(DEFAULT_MAX_AGE) .allow_methods([Method::GET, Method::POST]) @@ -140,14 +135,12 @@ impl Proxy { ), }); - let pool_clone = pool.clone(); - let graphql_addr_clone = graphql_addr.clone(); + let handlers = self.handlers.clone(); let service = ServiceBuilder::new().option_layer(cors).service_fn(move |req| { - let pool = pool_clone.clone(); - let graphql_addr = graphql_addr_clone.clone(); + let handlers = handlers.clone(); async move { - let graphql_addr = graphql_addr.read().await; - handle(remote_addr, grpc_addr, artifacts_addr, *graphql_addr, pool, req).await + let handlers = handlers.read().await; + handle(remote_addr, req, &handlers).await } }); @@ -166,23 +159,12 @@ impl Proxy { async fn handle( client_ip: IpAddr, - grpc_addr: Option, - artifacts_addr: Option, - graphql_addr: Option, - pool: Arc, req: Request, + handlers: &[Box], ) -> Result, Infallible> { - let handlers: Vec> = vec![ - Box::new(SqlHandler::new(pool.clone())), - Box::new(GraphQLHandler::new(client_ip, graphql_addr)), - Box::new(GrpcHandler::new(client_ip, grpc_addr)), - Box::new(StaticHandler::new(client_ip, artifacts_addr)), - Box::new(McpHandler::new(pool.clone())), - ]; - - for handler in handlers { + for handler in handlers.iter() { if handler.should_handle(&req) { - return Ok(handler.handle(req).await); + return Ok(handler.handle(req, client_ip).await); } }