From 5777c0076576de369bba7fe9ffd73abe38447dd3 Mon Sep 17 00:00:00 2001 From: Larko <59736843+Larkooo@users.noreply.github.com> Date: Thu, 19 Sep 2024 10:38:07 -0400 Subject: [PATCH] opt(torii-grpc): parallelize queries (#2443) * opt(torii-grpc): parallelize queries * feat :parallelize keys clause * feat: parallelize composite * opt hashmap * optimize fetching * itertools * refactor: queries & keys clause * fmt * fix: multiple mmeber clauses same column * fmt * clean --------- Co-authored-by: glihm --- Cargo.lock | 1 + crates/torii/grpc/Cargo.toml | 1 + crates/torii/grpc/src/server/mod.rs | 411 +++++++++++++--------------- 3 files changed, 186 insertions(+), 227 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 380d8e681e..bb7394efc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14820,6 +14820,7 @@ dependencies = [ "futures-util", "hex", "hyper 0.14.30", + "itertools 0.12.1", "katana-runner", "num-traits 0.2.19", "parking_lot 0.12.3", diff --git a/crates/torii/grpc/Cargo.toml b/crates/torii/grpc/Cargo.toml index 64b0a90ecb..c4eb6021e7 100644 --- a/crates/torii/grpc/Cargo.toml +++ b/crates/torii/grpc/Cargo.toml @@ -13,6 +13,7 @@ futures.workspace = true num-traits.workspace = true parking_lot.workspace = true rayon.workspace = true +itertools.workspace = true starknet-crypto.workspace = true starknet.workspace = true thiserror.workspace = true diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index b528173968..3b497c6def 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -21,6 +21,7 @@ use proto::world::{ RetrieveEventsRequest, RetrieveEventsResponse, SubscribeModelsRequest, SubscribeModelsResponse, UpdateEntitiesSubscriptionRequest, }; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use sqlx::prelude::FromRow; use sqlx::sqlite::SqliteRow; use sqlx::{Pool, Row, Sqlite}; @@ -183,29 +184,16 @@ impl DojoWorld { async fn entities_all( &self, + table: &str, + model_relation_table: &str, + entity_relation_column: &str, limit: u32, offset: u32, ) -> Result<(Vec, u32), Error> { self.query_by_hashed_keys( - ENTITIES_TABLE, - ENTITIES_MODEL_RELATION_TABLE, - ENTITIES_ENTITY_RELATION_COLUMN, - None, - Some(limit), - Some(offset), - ) - .await - } - - async fn event_messages_all( - &self, - limit: u32, - offset: u32, - ) -> Result<(Vec, u32), Error> { - self.query_by_hashed_keys( - EVENT_MESSAGES_TABLE, - EVENT_MESSAGES_MODEL_RELATION_TABLE, - EVENT_MESSAGES_ENTITY_RELATION_COLUMN, + table, + model_relation_table, + entity_relation_column, None, Some(limit), Some(offset), @@ -227,6 +215,93 @@ impl DojoWorld { row_events.iter().map(map_row_to_event).collect() } + async fn fetch_entities( + &self, + table: &str, + entity_relation_column: &str, + entities: Vec<(String, String)>, + ) -> Result, Error> { + // Group entities by their model combinations + let mut model_groups: HashMap> = HashMap::new(); + for (entity_id, models_str) in entities { + model_groups.entry(models_str).or_default().push(entity_id); + } + + let mut all_entities = Vec::new(); + + let mut tx = self.pool.begin().await?; + + // Create a temporary table to store entity IDs due to them potentially exceeding + // SQLite's parameters limit which is 999 + sqlx::query( + "CREATE TEMPORARY TABLE temp_entity_ids (id TEXT PRIMARY KEY, model_group TEXT)", + ) + .execute(&mut *tx) + .await?; + + // Insert all entity IDs into the temporary table + for (model_ids, entity_ids) in &model_groups { + for chunk in entity_ids.chunks(999) { + let placeholders = chunk.iter().map(|_| "(?, ?)").collect::>().join(","); + let query = format!( + "INSERT INTO temp_entity_ids (id, model_group) VALUES {}", + placeholders + ); + let mut query = sqlx::query(&query); + for id in chunk { + query = query.bind(id).bind(model_ids); + } + query.execute(&mut *tx).await?; + } + } + + for (models_str, _) in model_groups { + let model_ids = + models_str.split(',').map(|id| Felt::from_str(id).unwrap()).collect::>(); + let schemas = + self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); + + let (entity_query, arrays_queries, _) = build_sql_query( + &schemas, + table, + entity_relation_column, + Some(&format!( + "[{table}].id IN (SELECT id FROM temp_entity_ids WHERE model_group = ?)" + )), + Some(&format!( + "[{table}].id IN (SELECT id FROM temp_entity_ids WHERE model_group = ?)" + )), + None, + None, + )?; + + let rows = sqlx::query(&entity_query).bind(&models_str).fetch_all(&mut *tx).await?; + + let mut arrays_rows = HashMap::new(); + for (name, array_query) in arrays_queries { + let array_rows = + sqlx::query(&array_query).bind(&models_str).fetch_all(&mut *tx).await?; + arrays_rows.insert(name, array_rows); + } + + let arrays_rows = Arc::new(arrays_rows); + let schemas = Arc::new(schemas); + + let group_entities: Result, Error> = rows + .par_iter() + .map(|row| map_row_to_entity(row, &arrays_rows, (*schemas).clone())) + .collect(); + + all_entities.extend(group_entities?); + } + + sqlx::query("DROP TABLE temp_entity_ids").execute(&mut *tx).await?; + + tx.commit().await?; + + Ok(all_entities) + } + pub(crate) async fn query_by_hashed_keys( &self, table: &str, @@ -265,7 +340,7 @@ impl DojoWorld { return Ok((Vec::new(), 0)); } - // query to filter with limit and offset + // Query to get entity IDs and their model IDs let mut query = format!( r#" SELECT {table}.id, group_concat({model_relation_table}.model_id) as model_ids @@ -288,38 +363,7 @@ impl DojoWorld { let db_entities: Vec<(String, String)> = sqlx::query_as(&query).bind(limit).bind(offset).fetch_all(&self.pool).await?; - let mut entities = Vec::with_capacity(db_entities.len()); - for (entity_id, models_str) in db_entities { - let model_ids: Vec = models_str - .split(',') - .map(Felt::from_str) - .collect::>() - .map_err(ParseError::FromStr)?; - let schemas = - self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); - - let (entity_query, arrays_queries, _) = build_sql_query( - &schemas, - table, - entity_relation_column, - Some(&format!("{table}.id = ?")), - Some(&format!("{table}.id = ?")), - None, - None, - )?; - - let row = - sqlx::query(&entity_query).bind(entity_id.clone()).fetch_one(&self.pool).await?; - let mut arrays_rows = HashMap::new(); - for (name, query) in arrays_queries { - let rows = - sqlx::query(&query).bind(entity_id.clone()).fetch_all(&self.pool).await?; - arrays_rows.insert(name, rows); - } - - entities.push(map_row_to_entity(&row, &arrays_rows, schemas.clone())?); - } - + let entities = self.fetch_entities(table, entity_relation_column, db_entities).await?; Ok((entities, total_count)) } @@ -432,36 +476,7 @@ impl DojoWorld { .fetch_all(&self.pool) .await?; - let mut entities = Vec::with_capacity(db_entities.len()); - for (entity_id, models_strs) in &db_entities { - let model_ids: Vec = models_strs - .split(',') - .map(Felt::from_str) - .collect::>() - .map_err(ParseError::FromStr)?; - let schemas = - self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); - - let (entity_query, arrays_queries, _) = build_sql_query( - &schemas, - table, - entity_relation_column, - Some(&format!("{table}.id = ?")), - Some(&format!("{table}.id = ?")), - None, - None, - )?; - - let row = sqlx::query(&entity_query).bind(entity_id).fetch_one(&self.pool).await?; - let mut arrays_rows = HashMap::new(); - for (name, query) in arrays_queries { - let rows = sqlx::query(&query).bind(entity_id).fetch_all(&self.pool).await?; - arrays_rows.insert(name, rows); - } - - entities.push(map_row_to_entity(&row, &arrays_rows, schemas.clone())?); - } - + let entities = self.fetch_entities(table, entity_relation_column, db_entities).await?; Ok((entities, total_count)) } @@ -582,11 +597,16 @@ impl DojoWorld { arrays_rows.insert(name, rows); } - let entities_collection = db_entities - .iter() - .map(|row| map_row_to_entity(row, &arrays_rows, schemas.clone())) - .collect::, Error>>()?; - Ok((entities_collection, total_count)) + let arrays_rows = Arc::new(arrays_rows); + let entities_collection: Result, Error> = db_entities + .par_iter() + .map(|row| { + let schemas_clone = schemas.clone(); + let arrays_rows_clone = arrays_rows.clone(); + map_row_to_entity(row, &arrays_rows_clone, schemas_clone) + }) + .collect(); + Ok((entities_collection?, total_count)) } pub(crate) async fn query_by_composite( @@ -644,36 +664,7 @@ impl DojoWorld { let db_entities: Vec<(String, String)> = db_query.fetch_all(&self.pool).await?; - let mut entities = Vec::with_capacity(db_entities.len()); - for (entity_id, models_str) in &db_entities { - let model_ids: Vec = models_str - .split(',') - .map(Felt::from_str) - .collect::>() - .map_err(ParseError::FromStr)?; - let schemas = - self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); - - let (entity_query, arrays_queries, _) = build_sql_query( - &schemas, - table, - entity_relation_column, - Some(&format!("[{table}].id = ?")), - Some(&format!("[{table}].id = ?")), - None, - None, - )?; - - let row = sqlx::query(&entity_query).bind(entity_id).fetch_one(&self.pool).await?; - let mut arrays_rows = HashMap::new(); - for (name, query) in arrays_queries { - let rows = sqlx::query(&query).bind(entity_id).fetch_all(&self.pool).await?; - arrays_rows.insert(name, rows); - } - - entities.push(map_row_to_entity(&row, &arrays_rows, schemas.clone())?); - } - + let entities = self.fetch_entities(table, entity_relation_column, db_entities).await?; Ok((entities, total_count)) } @@ -736,39 +727,47 @@ impl DojoWorld { async fn retrieve_entities( &self, + table: &str, + model_relation_table: &str, + entity_relation_column: &str, query: proto::types::Query, ) -> Result { let (entities, total_count) = match query.clause { - None => self.entities_all(query.limit, query.offset).await?, + None => { + self.entities_all( + table, + model_relation_table, + entity_relation_column, + query.limit, + query.offset, + ) + .await? + } Some(clause) => { let clause_type = clause.clause_type.ok_or(QueryError::MissingParam("clause_type".into()))?; match clause_type { ClauseType::HashedKeys(hashed_keys) => { - if hashed_keys.hashed_keys.is_empty() { - return Err(QueryError::MissingParam("ids".into()).into()); - } - self.query_by_hashed_keys( - ENTITIES_TABLE, - ENTITIES_MODEL_RELATION_TABLE, - ENTITIES_ENTITY_RELATION_COLUMN, - Some(hashed_keys), + table, + model_relation_table, + entity_relation_column, + if hashed_keys.hashed_keys.is_empty() { + None + } else { + Some(hashed_keys) + }, Some(query.limit), Some(query.offset), ) .await? } ClauseType::Keys(keys) => { - if keys.keys.is_empty() { - return Err(QueryError::MissingParam("keys".into()).into()); - } - self.query_by_keys( - ENTITIES_TABLE, - ENTITIES_MODEL_RELATION_TABLE, - ENTITIES_ENTITY_RELATION_COLUMN, + table, + model_relation_table, + entity_relation_column, &keys, Some(query.limit), Some(query.offset), @@ -777,9 +776,9 @@ impl DojoWorld { } ClauseType::Member(member) => { self.query_by_member( - ENTITIES_TABLE, - ENTITIES_MODEL_RELATION_TABLE, - ENTITIES_ENTITY_RELATION_COLUMN, + table, + model_relation_table, + entity_relation_column, member, Some(query.limit), Some(query.offset), @@ -788,9 +787,9 @@ impl DojoWorld { } ClauseType::Composite(composite) => { self.query_by_composite( - ENTITIES_TABLE, - ENTITIES_MODEL_RELATION_TABLE, - ENTITIES_ENTITY_RELATION_COLUMN, + table, + model_relation_table, + entity_relation_column, composite, Some(query.limit), Some(query.offset), @@ -813,76 +812,6 @@ impl DojoWorld { .await } - async fn retrieve_event_messages( - &self, - query: proto::types::Query, - ) -> Result { - let (entities, total_count) = match query.clause { - None => self.event_messages_all(query.limit, query.offset).await?, - Some(clause) => { - let clause_type = - clause.clause_type.ok_or(QueryError::MissingParam("clause_type".into()))?; - - match clause_type { - ClauseType::HashedKeys(hashed_keys) => { - if hashed_keys.hashed_keys.is_empty() { - return Err(QueryError::MissingParam("ids".into()).into()); - } - - self.query_by_hashed_keys( - EVENT_MESSAGES_TABLE, - EVENT_MESSAGES_MODEL_RELATION_TABLE, - EVENT_MESSAGES_ENTITY_RELATION_COLUMN, - Some(hashed_keys), - Some(query.limit), - Some(query.offset), - ) - .await? - } - ClauseType::Keys(keys) => { - if keys.keys.is_empty() { - return Err(QueryError::MissingParam("keys".into()).into()); - } - - self.query_by_keys( - EVENT_MESSAGES_TABLE, - EVENT_MESSAGES_MODEL_RELATION_TABLE, - EVENT_MESSAGES_ENTITY_RELATION_COLUMN, - &keys, - Some(query.limit), - Some(query.offset), - ) - .await? - } - ClauseType::Member(member) => { - self.query_by_member( - EVENT_MESSAGES_TABLE, - EVENT_MESSAGES_MODEL_RELATION_TABLE, - EVENT_MESSAGES_ENTITY_RELATION_COLUMN, - member, - Some(query.limit), - Some(query.offset), - ) - .await? - } - ClauseType::Composite(composite) => { - self.query_by_composite( - EVENT_MESSAGES_TABLE, - EVENT_MESSAGES_MODEL_RELATION_TABLE, - ENTITIES_ENTITY_RELATION_COLUMN, - composite, - Some(query.limit), - Some(query.offset), - ) - .await? - } - } - } - }; - - Ok(RetrieveEntitiesResponse { entities, total_count }) - } - async fn retrieve_events( &self, query: &proto::types::EventQuery, @@ -938,20 +867,26 @@ fn map_row_to_entity( // this builds a sql safe regex pattern to match against for keys fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result { - let keys = clause - .keys - .iter() - .map(|bytes| { - if bytes.is_empty() { - return Ok("0x[0-9a-fA-F]+".to_string()); - } - Ok(format!("{:#x}", Felt::from_bytes_be_slice(bytes))) - }) - .collect::, Error>>()?; + const KEY_PATTERN: &str = "0x[0-9a-fA-F]+"; + + let keys = if clause.keys.is_empty() { + vec![KEY_PATTERN.to_string()] + } else { + clause + .keys + .iter() + .map(|bytes| { + if bytes.is_empty() { + return Ok(KEY_PATTERN.to_string()); + } + Ok(format!("{:#x}", Felt::from_bytes_be_slice(bytes))) + }) + .collect::, Error>>()? + }; let mut keys_pattern = format!("^{}", keys.join("/")); if clause.pattern_matching == proto::types::PatternMatching::VariableLen as i32 { - keys_pattern += "(/0x[0-9a-fA-F]+)*"; + keys_pattern += &format!("({})*", KEY_PATTERN); } keys_pattern += "/$"; @@ -970,6 +905,9 @@ fn build_composite_clause( let mut having_clauses = Vec::new(); let mut bind_values = Vec::new(); + // HashMap to track the number of joins per model + let mut model_counters: HashMap = HashMap::new(); + for clause in &composite.clauses { match clause.clause_type.as_ref().unwrap() { ClauseType::HashedKeys(hashed_keys) => { @@ -1016,15 +954,22 @@ fn build_composite_clause( (format!("[{model}]"), format!("external_{}", member.member)) }; - let (namespace, model) = member + let (namespace, model_name) = member .model .split_once('-') .ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?; - let model_id = compute_selector_from_names(namespace, model); + let model_id = compute_selector_from_names(namespace, model_name); + + // Generate a unique alias for each model + let counter = model_counters.entry(model.clone()).or_insert(0); + *counter += 1; + let alias = + if *counter == 1 { model.clone() } else { format!("{model}_{}", *counter - 1) }; + join_clauses.push(format!( - "LEFT JOIN {table_name} ON [{table}].id = {table_name}.entity_id" + "LEFT JOIN {table_name} AS [{alias}] ON [{table}].id = [{alias}].entity_id" )); - where_clauses.push(format!("{table_name}.{column_name} {comparison_operator} ?")); + where_clauses.push(format!("[{alias}].{column_name} {comparison_operator} ?")); having_clauses.push(format!( "INSTR(group_concat({model_relation_table}.model_id), '{:#x}') > 0", model_id @@ -1137,8 +1082,15 @@ impl proto::world::world_server::World for DojoWorld { .query .ok_or_else(|| Status::invalid_argument("Missing query argument"))?; - let entities = - self.retrieve_entities(query).await.map_err(|e| Status::internal(e.to_string()))?; + let entities = self + .retrieve_entities( + ENTITIES_TABLE, + ENTITIES_MODEL_RELATION_TABLE, + ENTITIES_ENTITY_RELATION_COLUMN, + query, + ) + .await + .map_err(|e| Status::internal(e.to_string()))?; Ok(Response::new(entities)) } @@ -1181,7 +1133,12 @@ impl proto::world::world_server::World for DojoWorld { .ok_or_else(|| Status::invalid_argument("Missing query argument"))?; let entities = self - .retrieve_event_messages(query) + .retrieve_entities( + EVENT_MESSAGES_TABLE, + EVENT_MESSAGES_MODEL_RELATION_TABLE, + EVENT_MESSAGES_ENTITY_RELATION_COLUMN, + query, + ) .await .map_err(|e| Status::internal(e.to_string()))?;