Skip to content

Commit

Permalink
keep handlers in place & write graphql
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Mar 3, 2025
1 parent 2d1a6f3 commit 04b45aa
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 55 deletions.
12 changes: 5 additions & 7 deletions crates/torii/server/src/handlers/graphql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ pub(crate) const LOG_TARGET: &str = "torii::server::handlers::graphql";

#[derive(Debug)]
pub struct GraphQLHandler {
client_ip: IpAddr,
graphql_addr: Option<SocketAddr>,
pub(crate) graphql_addr: Option<SocketAddr>,
}

impl GraphQLHandler {
pub fn new(client_ip: IpAddr, graphql_addr: Option<SocketAddr>) -> Self {
Self { client_ip, graphql_addr }
pub fn new(graphql_addr: Option<SocketAddr>) -> Self {
Self { graphql_addr }
}
}

Expand All @@ -26,11 +25,10 @@ impl Handler for GraphQLHandler {
req.uri().path().starts_with("/graphql")
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
async fn handle(&self, req: Request<Body>, client_addr: IpAddr) -> Response<Body> {
if let Some(addr) = self.graphql_addr {
let graphql_addr = format!("http://{}", addr);
match crate::proxy::GRAPHQL_PROXY_CLIENT.call(self.client_ip, &graphql_addr, req).await
{
match crate::proxy::GRAPHQL_PROXY_CLIENT.call(client_addr, &graphql_addr, req).await {
Ok(response) => response,
Err(_error) => {
error!(target: LOG_TARGET, "GraphQL proxy error: {:?}", _error);
Expand Down
9 changes: 4 additions & 5 deletions crates/torii/server/src/handlers/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ pub(crate) const LOG_TARGET: &str = "torii::server::handlers::grpc";

#[derive(Debug)]
pub struct GrpcHandler {
client_ip: IpAddr,
grpc_addr: Option<SocketAddr>,
}

impl GrpcHandler {
pub fn new(client_ip: IpAddr, grpc_addr: Option<SocketAddr>) -> Self {
Self { client_ip, grpc_addr }
pub fn new(grpc_addr: Option<SocketAddr>) -> Self {
Self { grpc_addr }
}
}

Expand All @@ -30,10 +29,10 @@ impl Handler for GrpcHandler {
.unwrap_or(false)
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
async fn handle(&self, req: Request<Body>, client_addr: IpAddr) -> Response<Body> {
if let Some(grpc_addr) = self.grpc_addr {
let grpc_addr = format!("http://{}", grpc_addr);
match crate::proxy::GRPC_PROXY_CLIENT.call(self.client_ip, &grpc_addr, req).await {
match crate::proxy::GRPC_PROXY_CLIENT.call(client_addr, &grpc_addr, req).await {
Ok(response) => response,
Err(_error) => {
error!(target: LOG_TARGET, "{:?}", _error);
Expand Down
7 changes: 4 additions & 3 deletions crates/torii/server/src/handlers/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sqlx::{Row, SqlitePool};
use tokio_tungstenite::tungstenite::Message;
use std::net::IpAddr;

use super::sql::map_row_to_json;
use super::Handler;
Expand Down Expand Up @@ -219,7 +220,7 @@ impl McpHandler {
AND m.name = ?
ORDER BY m.name, p.cid"
.to_string(),
None => "SELECT
_ => "SELECT
m.name as table_name,
p.*
FROM sqlite_master m
Expand All @@ -231,7 +232,7 @@ impl McpHandler {

let rows = match table_filter {
Some(table) => sqlx::query(&schema_query).bind(table).fetch_all(&*self.pool).await,
None => sqlx::query(&schema_query).fetch_all(&*self.pool).await,
_ => sqlx::query(&schema_query).fetch_all(&*self.pool).await,
};

match rows {
Expand Down Expand Up @@ -393,7 +394,7 @@ impl Handler for McpHandler {
.unwrap_or(false)
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
async fn handle(&self, req: Request<Body>, _client_addr: IpAddr) -> Response<Body> {
if hyper_tungstenite::is_upgrade_request(&req) {
let (response, websocket) = hyper_tungstenite::upgrade(req, None)
.expect("Failed to upgrade WebSocket connection");
Expand Down
3 changes: 2 additions & 1 deletion crates/torii/server/src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ pub mod sql;
pub mod static_files;

use hyper::{Body, Request, Response};
use std::net::IpAddr;

#[async_trait::async_trait]
pub trait Handler: Send + Sync + std::fmt::Debug {
// Check if this handler should handle the given request
fn should_handle(&self, req: &Request<Body>) -> bool;

// Handle the request
async fn handle(&self, req: Request<Body>) -> Response<Body>;
async fn handle(&self, req: Request<Body>, client_addr: IpAddr) -> Response<Body>;
}
3 changes: 2 additions & 1 deletion crates/torii/server/src/handlers/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use http::header::CONTENT_TYPE;
use hyper::{Body, Method, Request, Response, StatusCode};
use include_str;
use sqlx::{Column, Row, SqlitePool, TypeInfo};
use std::net::IpAddr;

use super::Handler;

Expand Down Expand Up @@ -112,7 +113,7 @@ impl Handler for SqlHandler {
req.uri().path().starts_with("/sql")
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
async fn handle(&self, req: Request<Body>, _client_addr: IpAddr) -> Response<Body> {
self.handle_request(req).await
}
}
Expand Down
9 changes: 4 additions & 5 deletions crates/torii/server/src/handlers/static_files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ pub(crate) const LOG_TARGET: &str = "torii::server::handlers::static";

#[derive(Debug)]
pub struct StaticHandler {
client_ip: IpAddr,
artifacts_addr: Option<SocketAddr>,
}

impl StaticHandler {
pub fn new(client_ip: IpAddr, artifacts_addr: Option<SocketAddr>) -> Self {
Self { client_ip, artifacts_addr }
pub fn new(artifacts_addr: Option<SocketAddr>) -> Self {
Self { artifacts_addr }
}
}

Expand All @@ -25,11 +24,11 @@ impl Handler for StaticHandler {
req.uri().path().starts_with("/static")
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
async fn handle(&self, req: Request<Body>, client_addr: IpAddr) -> Response<Body> {
if let Some(artifacts_addr) = self.artifacts_addr {
let artifacts_addr = format!("http://{}", artifacts_addr);
match crate::proxy::GRAPHQL_PROXY_CLIENT
.call(self.client_ip, &artifacts_addr, req)
.call(client_addr, &artifacts_addr, req)
.await
{
Ok(response) => response,
Expand Down
52 changes: 19 additions & 33 deletions crates/torii/server/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@ lazy_static::lazy_static! {
pub struct Proxy {
addr: SocketAddr,
allowed_origins: Option<Vec<String>>,
grpc_addr: Option<SocketAddr>,
artifacts_addr: Option<SocketAddr>,
graphql_addr: Arc<RwLock<Option<SocketAddr>>>,
pool: Arc<SqlitePool>,
handlers: Option<Vec<Box<dyn Handler>>>,
handlers: Arc<RwLock<Vec<Box<dyn Handler>>>>,
}

impl Proxy {
Expand All @@ -80,20 +76,20 @@ impl Proxy {
artifacts_addr: Option<SocketAddr>,
pool: Arc<SqlitePool>,
) -> Self {
Self {
addr,
allowed_origins,
grpc_addr,
graphql_addr: Arc::new(RwLock::new(graphql_addr)),
artifacts_addr,
pool,
handlers: None
}
let handlers: Arc<RwLock<Vec<Box<dyn Handler>>>> = Arc::new(RwLock::new(vec![
Box::new(GraphQLHandler::new(graphql_addr)),
Box::new(GrpcHandler::new(grpc_addr)),
Box::new(McpHandler::new(pool.clone())),
Box::new(SqlHandler::new(pool.clone())),
Box::new(StaticHandler::new(artifacts_addr)),
]));

Self { addr, allowed_origins, handlers }
}

pub async fn set_graphql_addr(&self, addr: SocketAddr) {
let mut graphql_addr = self.graphql_addr.write().await;
*graphql_addr = Some(addr);
let mut handlers = self.handlers.write().await;
handlers[0] = Box::new(GraphQLHandler::new(Some(addr)));
}

pub async fn start(
Expand All @@ -102,14 +98,9 @@ impl Proxy {
) -> Result<(), hyper::Error> {
let addr = self.addr;
let allowed_origins = self.allowed_origins.clone();
let grpc_addr = self.grpc_addr;
let graphql_addr = self.graphql_addr.clone();
let artifacts_addr = self.artifacts_addr;
let pool = self.pool.clone();

let make_svc = make_service_fn(move |conn: &AddrStream| {
let remote_addr = conn.remote_addr().ip();
let handlers = self.handlers.clone().unwrap_or_default();

let cors = CorsLayer::new()
.max_age(DEFAULT_MAX_AGE)
Expand Down Expand Up @@ -144,14 +135,12 @@ impl Proxy {
),
});

let pool_clone = pool.clone();
let graphql_addr_clone = graphql_addr.clone();
let handlers = self.handlers.clone();
let service = ServiceBuilder::new().option_layer(cors).service_fn(move |req| {
let pool = pool_clone.clone();
let graphql_addr = graphql_addr_clone.clone();
let handlers = handlers.clone();
async move {
let graphql_addr = graphql_addr.read().await;
handle(remote_addr, grpc_addr, artifacts_addr, *graphql_addr, pool, req).await
let handlers = handlers.read().await;
handle(remote_addr, req, &handlers).await
}
});

Expand All @@ -170,15 +159,12 @@ impl Proxy {

async fn handle(
client_ip: IpAddr,
grpc_addr: Option<SocketAddr>,
artifacts_addr: Option<SocketAddr>,
graphql_addr: Option<SocketAddr>,
pool: Arc<SqlitePool>,
req: Request<Body>,
handlers: &Vec<Box<dyn Handler>>,
) -> Result<Response<Body>, Infallible> {
for handler in handlers {
for handler in handlers.iter() {
if handler.should_handle(&req) {
return Ok(handler.handle(req).await);
return Ok(handler.handle(req, client_ip).await);
}
}

Expand Down

0 comments on commit 04b45aa

Please sign in to comment.