From 49f044f7477cdb29edc41df5c8e76eb97c1bd2b2 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 Feb 2024 15:48:22 -0800 Subject: [PATCH 01/63] wip - passing userauthcontext instead of http headers --- libsql-server/src/auth/parsers.rs | 29 +++++++++++++--- .../src/auth/user_auth_strategies/disabled.rs | 3 +- .../auth/user_auth_strategies/http_basic.rs | 28 +++++++++------ .../src/auth/user_auth_strategies/jwt.rs | 34 ++++++++++++++----- .../src/auth/user_auth_strategies/mod.rs | 4 +-- libsql-server/src/hrana/ws/session.rs | 15 ++------ libsql-server/src/http/user/extract.rs | 24 +++++++------ libsql-server/src/http/user/mod.rs | 10 +++--- libsql-server/src/rpc/proxy.rs | 7 ++-- libsql-server/src/rpc/replica_proxy.rs | 11 +++--- libsql-server/src/rpc/replication_log.rs | 9 +++-- libsql-server/tests/cluster/mod.rs | 15 ++++---- ...sts__hrana__batch__sample_request.snap.new | 8 +++++ 13 files changed, 120 insertions(+), 77 deletions(-) create mode 100644 libsql-server/tests/hrana/snapshots/tests__hrana__batch__sample_request.snap.new diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index e09d930eae..31802824b6 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -4,6 +4,8 @@ use anyhow::{bail, Context as _, Result}; use axum::http::HeaderValue; use tonic::metadata::MetadataMap; +use super::UserAuthContext; + pub fn parse_http_basic_auth_arg(arg: &str) -> Result> { if arg == "always" { return Ok(None); @@ -34,13 +36,29 @@ pub fn parse_jwt_key(data: &str) -> Result { } } -pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Option { - metadata - .get(GRPC_AUTH_HEADER) - .map(|v| v.to_bytes().expect("Auth should always be ASCII")) - .map(|v| HeaderValue::from_maybe_shared(v).expect("Should already be valid header")) +pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result { + + metadata.get(GRPC_AUTH_HEADER) + .map(|v| v.to_str().expect("Auth should be ASCII")) // fixme we should not use expect + .and_then(|auth_str| auth_string_to_auth_context(auth_str).ok()) + .context("Failed to parse grpc auth header") + .map_err(|err| tonic::Status::new(tonic::Code::InvalidArgument, format!("{}", err))) +} + +// todo this should be a constructor or a factory associates iwth userauthcontext +pub fn auth_string_to_auth_context( + auth_string: &str, +) -> Result { + + let(scheme, token) = auth_string.split_once(' ').context("malformed auth header string")?; + + Ok(UserAuthContext{ + scheme: Some(scheme.into()), + token: Some(token.into()), + }) } + pub fn parse_http_auth_header<'a>( expected_scheme: &str, auth_header: &'a Option, @@ -112,3 +130,4 @@ mod tests { ) } } + diff --git a/libsql-server/src/auth/user_auth_strategies/disabled.rs b/libsql-server/src/auth/user_auth_strategies/disabled.rs index 8ffa5e7028..4f2c12894a 100644 --- a/libsql-server/src/auth/user_auth_strategies/disabled.rs +++ b/libsql-server/src/auth/user_auth_strategies/disabled.rs @@ -24,7 +24,8 @@ mod tests { fn authenticates() { let strategy = Disabled::new(); let context = UserAuthContext { - user_credential: None, + scheme: None, + token: None, }; assert!(matches!( diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index 01ab33422a..782740bdb2 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -1,4 +1,4 @@ -use crate::auth::{parse_http_auth_header, AuthError, Authenticated}; +use crate::auth::{AuthError, Authenticated}; use super::{UserAuthContext, UserAuthStrategy}; @@ -7,17 +7,22 @@ pub struct HttpBasic { } impl UserAuthStrategy for HttpBasic { + fn authenticate(&self, context: UserAuthContext) -> Result { + tracing::trace!("executing http basic auth"); - - let param = parse_http_auth_header("basic", &context.user_credential)?; - + // NOTE: this naive comparison may leak information about the `expected_value` // using a timing attack - let actual_value = param.trim_end_matches('='); let expected_value = self.credential.trim_end_matches('='); - if actual_value == expected_value { + let creds_match = match context.token { + Some(s) => s.contains(expected_value), + None => expected_value.is_empty(), + }; + + + if creds_match { return Ok(Authenticated::FullAccess); } @@ -33,8 +38,6 @@ impl HttpBasic { #[cfg(test)] mod tests { - use axum::http::HeaderValue; - use super::*; const CREDENTIAL: &str = "d29qdGVrOnRoZWJlYXI="; @@ -46,7 +49,8 @@ mod tests { #[test] fn authenticates_with_valid_credential() { let context = UserAuthContext { - user_credential: HeaderValue::from_str(&format!("Basic {CREDENTIAL}")).ok(), + scheme: Some("basic".into()), + token: Some(CREDENTIAL.into()), }; assert!(matches!( @@ -60,7 +64,8 @@ mod tests { let credential = CREDENTIAL.trim_end_matches('='); let context = UserAuthContext { - user_credential: HeaderValue::from_str(&format!("Basic {credential}")).ok(), + scheme: Some("basic".into()), + token: Some(credential.into()), }; assert!(matches!( @@ -72,7 +77,8 @@ mod tests { #[test] fn errors_when_credentials_do_not_match() { let context = UserAuthContext { - user_credential: HeaderValue::from_str("Basic abc").ok(), + token: Some("abc".into()), + scheme: Some("basic".into()), }; assert_eq!( diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index e85bde43c4..e898a17786 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -1,7 +1,7 @@ use chrono::{DateTime, Utc}; use crate::{ - auth::{parse_http_auth_header, AuthError, Authenticated, Authorized, Permission}, + auth::{AuthError, Authenticated, Authorized, Permission}, namespace::NamespaceName, }; @@ -14,8 +14,20 @@ pub struct Jwt { impl UserAuthStrategy for Jwt { fn authenticate(&self, context: UserAuthContext) -> Result { tracing::trace!("executing jwt auth"); - let param = parse_http_auth_header("bearer", &context.user_credential)?; - validate_jwt(&self.key, param) + + let Some(scheme) = context.scheme else { + return Err(AuthError::HttpAuthHeaderInvalid); + }; + + if !scheme.eq_ignore_ascii_case("bearer") { + return Err(AuthError::HttpAuthHeaderUnsupportedScheme); + } + + let Some(token) = context.token else { + return Err(AuthError::HttpAuthHeaderInvalid); + }; + + return validate_jwt(&self.key, &token); } } @@ -106,7 +118,6 @@ fn validate_jwt( mod tests { use std::time::Duration; - use axum::http::HeaderValue; use jsonwebtoken::{DecodingKey, EncodingKey}; use ring::signature::{Ed25519KeyPair, KeyPair}; use serde::Serialize; @@ -145,7 +156,8 @@ mod tests { let token = encode(&token, &enc); let context = UserAuthContext { - user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(), + scheme: Some("bearer".into()), + token: token.into(), }; assert!(matches!( @@ -166,7 +178,8 @@ mod tests { let token = encode(&token, &enc); let context = UserAuthContext { - user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(), + scheme: Some("bearer".into()), + token: token.into(), }; let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { @@ -185,7 +198,8 @@ mod tests { fn errors_when_jwt_token_invalid() { let (_enc, dec) = key_pair(); let context = UserAuthContext { - user_credential: HeaderValue::from_str("Bearer abc").ok(), + scheme: Some("bearer".into()), + token: Some("abc".into()), }; assert_eq!( @@ -207,7 +221,8 @@ mod tests { let token = encode(&token, &enc); let context = UserAuthContext { - user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(), + scheme: Some("bearer".into()), + token: Some(token.into()), }; assert_eq!( @@ -231,7 +246,8 @@ mod tests { let token = encode(&token, &enc); let context = UserAuthContext { - user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(), + scheme: Some("bearer".into()), + token: token.into(), }; let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index f3b8a2e5b2..9fb228c82f 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -2,7 +2,6 @@ pub mod disabled; pub mod http_basic; pub mod jwt; -use axum::http::HeaderValue; pub use disabled::*; pub use http_basic::*; pub use jwt::*; @@ -10,7 +9,8 @@ pub use jwt::*; use super::{AuthError, Authenticated}; pub struct UserAuthContext { - pub user_credential: Option, + pub scheme: Option, + pub token: Option, // token might not be required in some cases } pub trait UserAuthStrategy: Sync + Send { diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index f88541e246..bbb851e165 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use std::sync::Arc; use anyhow::{anyhow, bail, Result}; -use axum::http::HeaderValue; use futures::future::BoxFuture; use tokio::sync::{mpsc, oneshot}; @@ -77,16 +76,11 @@ pub(super) async fn handle_initial_hello( .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - // Convert jwt token into a HeaderValue to be compatible with UserAuthStrategy - let user_credential = jwt - .clone() - .and_then(|t| HeaderValue::from_str(&format!("Bearer {t}")).ok()); - let auth = namespace_jwt_key .map(Jwt::new) .map(Auth::new) .unwrap_or(server.user_auth_strategy.clone()) - .authenticate(UserAuthContext { user_credential }) + .authenticate(UserAuthContext { scheme: Some("Bearer".into()), token: jwt }) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(Session { @@ -115,16 +109,11 @@ pub(super) async fn handle_repeated_hello( .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - // Convert jwt token into a HeaderValue to be compatible with UserAuthStrategy - let user_credential = jwt - .clone() - .and_then(|t| HeaderValue::from_str(&format!("Bearer {t}")).ok()); - session.auth = namespace_jwt_key .map(Jwt::new) .map(Auth::new) .unwrap_or_else(|| server.user_auth_strategy.clone()) - .authenticate(UserAuthContext { user_credential }) + .authenticate(UserAuthContext { scheme: Some("Bearer".into()), token: jwt }) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(()) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 14f7138ee3..5bda56c82b 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -1,7 +1,8 @@ +use anyhow::Context; use axum::extract::FromRequestParts; use crate::{ - auth::{Jwt, UserAuthContext, UserAuthStrategy}, + auth::{parsers::auth_string_to_auth_context, Jwt, Auth}, connection::RequestContext, namespace::MakeNamespace, }; @@ -19,6 +20,8 @@ where parts: &mut axum::http::request::Parts, state: &AppState, ) -> std::result::Result { + + // start todo this block is same as the one in mod.rs let namespace = db_factory::namespace_from_headers( &parts.headers, state.disable_default_namespace, @@ -30,16 +33,17 @@ where .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - let auth_header = parts.headers.get(hyper::header::AUTHORIZATION); + let header = parts.headers.get(hyper::header::AUTHORIZATION).context("auth header not found")?; + let header_str = header.to_str().context("non ASCII auth token")?; + let context = auth_string_to_auth_context(header_str).context("auth header parsing failed")?; + + let auth = namespace_jwt_key + .map(Jwt::new) + .map(Auth::new) + .unwrap_or_else(|| state.user_auth_strategy.clone()) + .authenticate(context)?; - let auth = match namespace_jwt_key { - Some(key) => Jwt::new(key).authenticate(UserAuthContext { - user_credential: auth_header.cloned(), - })?, - None => state.user_auth_strategy.authenticate(UserAuthContext { - user_credential: auth_header.cloned(), - })?, - }; + // end todo Ok(Self::new(auth, namespace, state.namespaces.meta_store())) } diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 29b52ad6f5..f1927b7e71 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -29,7 +29,7 @@ use tonic::transport::Server; use tower_http::{compression::CompressionLayer, cors}; -use crate::auth::user_auth_strategies::UserAuthContext; +use crate::auth::parsers::auth_string_to_auth_context; use crate::auth::{Auth, Authenticated, Jwt}; use crate::connection::{Connection, RequestContext}; use crate::database::Database; @@ -492,15 +492,15 @@ where .with(ns.clone(), |ns| ns.jwt_key()) .await??; - let auth_header = parts.headers.get(hyper::header::AUTHORIZATION); + let header = parts.headers.get(hyper::header::AUTHORIZATION).context("auth header not found")?; + let header_str = header.to_str().context("non ASCII auth token")?; + let context = auth_string_to_auth_context(header_str).context("auth header parsing failed")?; let auth = namespace_jwt_key .map(Jwt::new) .map(Auth::new) .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(UserAuthContext { - user_credential: auth_header.cloned(), - })?; + .authenticate(context)?; Ok(auth) } diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index dea8dfb7d6..8f8ff26fc1 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -16,7 +16,6 @@ use rusqlite::types::ValueRef; use uuid::Uuid; use crate::auth::parsers::parse_grpc_auth_header; -use crate::auth::user_auth_strategies::UserAuthContext; use crate::auth::{Auth, Authenticated, Jwt}; use crate::connection::{Connection, RequestContext}; use crate::database::{Database, PrimaryConnection}; @@ -332,10 +331,10 @@ impl ProxyService { )))?, }; + let context = parse_grpc_auth_header(req.metadata())?; + let auth = if let Some(auth) = auth { - auth.authenticate(UserAuthContext { - user_credential: parse_grpc_auth_header(req.metadata()), - })? + auth.authenticate(context)? } else { Authenticated::from_proxy_grpc_request(req)? }; diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index fc7b99f95b..ae70a03f77 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -8,7 +8,7 @@ use tonic::{transport::Channel, Request, Status}; use crate::{ auth::{ - parsers::parse_grpc_auth_header, user_auth_strategies::UserAuthContext, Auth, Jwt, + parsers::parse_grpc_auth_header, Auth, Jwt, UserAuthStrategy, }, namespace::{NamespaceStore, ReplicaNamespaceMaker}, @@ -46,12 +46,13 @@ impl ReplicaProxyService { .with(namespace.clone(), |ns| ns.jwt_key()) .await; - let user_credential = parse_grpc_auth_header(req.metadata()); + //todo julian figure this out + let auth_context = parse_grpc_auth_header(req.metadata())?; match namespace_jwt_key { Ok(Ok(Some(key))) => { let authenticated = - Jwt::new(key).authenticate(UserAuthContext { user_credential })?; + Jwt::new(key).authenticate(auth_context)?; authenticated.upgrade_grpc_request(req); Ok(()) @@ -59,7 +60,7 @@ impl ReplicaProxyService { Ok(Ok(None)) => { let authenticated = self .user_auth_strategy - .authenticate(UserAuthContext { user_credential })?; + .authenticate(auth_context)?; authenticated.upgrade_grpc_request(req); Ok(()) @@ -68,7 +69,7 @@ impl ReplicaProxyService { crate::error::Error::NamespaceDoesntExist(_) => { let authenticated = self .user_auth_strategy - .authenticate(UserAuthContext { user_credential })?; + .authenticate(auth_context)?; authenticated.upgrade_grpc_request(req); Ok(()) diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 3bf447dd89..a2389837a8 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -19,7 +19,6 @@ use tonic::transport::server::TcpConnectInfo; use tonic::Status; use uuid::Uuid; -use crate::auth::user_auth_strategies::UserAuthContext; use crate::auth::Jwt; use crate::auth::{parsers::parse_grpc_auth_header, Auth}; use crate::connection::config::DatabaseConfig; @@ -77,8 +76,7 @@ impl ReplicationLogService { .with(namespace.clone(), |ns| ns.jwt_key()) .await; - let user_credential = parse_grpc_auth_header(req.metadata()); - + let auth = match namespace_jwt_key { Ok(Ok(Some(key))) => Some(Auth::new(Jwt::new(key))), Ok(Ok(None)) => self.user_auth_strategy.clone(), @@ -94,9 +92,10 @@ impl ReplicationLogService { e )))?, }; - + + let user_credential = parse_grpc_auth_header(req.metadata())?; if let Some(auth) = auth { - auth.authenticate(UserAuthContext { user_credential })?; + auth.authenticate(user_credential)?; } Ok(()) diff --git a/libsql-server/tests/cluster/mod.rs b/libsql-server/tests/cluster/mod.rs index 8fc8fb5f4b..3c0e512da2 100644 --- a/libsql-server/tests/cluster/mod.rs +++ b/libsql-server/tests/cluster/mod.rs @@ -89,14 +89,14 @@ fn proxy_write() { sim.client("client", async { let db = - Database::open_remote_with_connector("http://replica0:8080", "", TurmoilConnector)?; + Database::open_remote_with_connector("http://replica0:8080", "dummy-auth", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; conn.execute("insert into test values (12)", ()).await?; // assert that the primary got the write - let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; + let db = Database::open_remote_with_connector("http://primary:8080", "dummy-auth", TurmoilConnector)?; let conn = db.connect()?; let mut rows = conn.query("select count(*) from test", ()).await?; @@ -121,7 +121,7 @@ fn replica_read_write() { sim.client("client", async { let db = - Database::open_remote_with_connector("http://replica0:8080", "", TurmoilConnector)?; + Database::open_remote_with_connector("http://replica0:8080", "dummy-auth", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -145,7 +145,7 @@ fn sync_many_replica() { let mut sim = Builder::new().build(); make_cluster(&mut sim, NUM_REPLICA, true); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; + let db = Database::open_remote_with_connector("http://primary:8080", "dummy-auth", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -195,7 +195,7 @@ fn sync_many_replica() { for i in 0..NUM_REPLICA { let db = Database::open_remote_with_connector( format!("http://replica{i}:8080"), - "", + "dummy-auth", TurmoilConnector, )?; let conn = db.connect()?; @@ -212,6 +212,7 @@ fn sync_many_replica() { sim.run().unwrap(); } + #[test] fn create_namespace() { let mut sim = Builder::new().build(); @@ -219,7 +220,7 @@ fn create_namespace() { sim.client("client", async { let db = - Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; + Database::open_remote_with_connector("http://foo.primary:8080", "dummy-auth", TurmoilConnector)?; let conn = db.connect()?; let Err(e) = conn.execute("create table test (x)", ()).await else { @@ -259,7 +260,7 @@ fn large_proxy_query() { make_cluster(&mut sim, 1, true); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector) + let db = Database::open_remote_with_connector("http://primary:8080", "dummy-auth", TurmoilConnector) .unwrap(); let conn = db.connect().unwrap(); diff --git a/libsql-server/tests/hrana/snapshots/tests__hrana__batch__sample_request.snap.new b/libsql-server/tests/hrana/snapshots/tests__hrana__batch__sample_request.snap.new new file mode 100644 index 0000000000..6ebe0bb0ec --- /dev/null +++ b/libsql-server/tests/hrana/snapshots/tests__hrana__batch__sample_request.snap.new @@ -0,0 +1,8 @@ +--- +source: libsql-server/tests/hrana/batch.rs +assertion_line: 32 +expression: resp.json_value().await.unwrap() +--- +{ + "error": "Internal Error: `auth header not found`" +} From 946c9eaebf8a96f8ba24ab4cbca72cf9609a2113 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 1 Mar 2024 11:07:00 -0800 Subject: [PATCH 02/63] revert comment out --- libsql-ffi/build.rs | 1140 +++++++++++++++++++++---------------------- 1 file changed, 570 insertions(+), 570 deletions(-) diff --git a/libsql-ffi/build.rs b/libsql-ffi/build.rs index fe7af4c03c..fd631aabfe 100644 --- a/libsql-ffi/build.rs +++ b/libsql-ffi/build.rs @@ -1,570 +1,570 @@ -// use std::env; -// use std::ffi::OsString; -// use std::fs::{self, OpenOptions}; -// use std::io::Write; -// use std::path::{Path, PathBuf}; -// use std::process::Command; - -// const LIB_NAME: &str = "libsql"; -// const BUNDLED_DIR: &str = "bundled"; -// const SQLITE_DIR: &str = "../libsql-sqlite3"; - -// fn main() { -// let target = env::var("TARGET").unwrap(); -// let host = env::var("HOST").unwrap(); - -// let is_apple = host.contains("apple") && target.contains("apple"); -// if is_apple { -// println!("cargo:rustc-link-lib=framework=Security"); -// } -// let out_dir = env::var("OUT_DIR").unwrap(); -// let out_path = Path::new(&out_dir).join("bindgen.rs"); - -// println!("cargo:rerun-if-changed={BUNDLED_DIR}/src/sqlite3.c"); - -// if cfg!(feature = "multiple-ciphers") { -// println!( -// "cargo:rerun-if-changed={BUNDLED_DIR}/SQLite3MultipleCiphers/build/libsqlite3mc_static.a" -// ); -// } - -// if std::env::var("LIBSQL_DEV").is_ok() { -// make_amalgation(); -// build_multiple_ciphers(&out_path); -// } - -// let bindgen_rs_path = if cfg!(feature = "session") { -// "bundled/bindings/session_bindgen.rs" -// } else { -// "bundled/bindings/bindgen.rs" -// }; - -// let dir = env!("CARGO_MANIFEST_DIR"); -// std::fs::copy(format!("{dir}/{bindgen_rs_path}"), &out_path).unwrap(); - -// println!("cargo:lib_dir={out_dir}"); - -// if cfg!(feature = "wasmtime-bindings") && !cfg!(feature = "multiple-ciphers") { -// build_bundled(&out_dir, &out_path); -// } - -// if cfg!(feature = "multiple-ciphers") { -// copy_multiple_ciphers(&out_dir, &out_path); -// return; -// } - -// build_bundled(&out_dir, &out_path); -// } - -// fn make_amalgation() { -// let flags = ["-DSQLITE_ENABLE_COLUMN_METADATA=1"]; - -// Command::new("make") -// .current_dir(SQLITE_DIR) -// .arg("clean") -// .output() -// .unwrap(); - -// Command::new("./configure") -// .current_dir(SQLITE_DIR) -// .env("CFLAGS", flags.join(" ")) -// .output() -// .unwrap(); -// Command::new("make") -// .current_dir(SQLITE_DIR) -// .output() -// .unwrap(); - -// std::fs::copy( -// (SQLITE_DIR.as_ref() as &Path).join("sqlite3.c"), -// (BUNDLED_DIR.as_ref() as &Path).join("src/sqlite3.c"), -// ) -// .unwrap(); -// std::fs::copy( -// (SQLITE_DIR.as_ref() as &Path).join("sqlite3.h"), -// (BUNDLED_DIR.as_ref() as &Path).join("src/sqlite3.h"), -// ) -// .unwrap(); -// } - -// pub fn build_bundled(out_dir: &str, out_path: &Path) { -// let bindgen_rs_path = if cfg!(feature = "session") { -// "bundled/bindings/session_bindgen.rs" -// } else { -// "bundled/bindings/bindgen.rs" -// }; - -// if std::env::var("LIBSQL_DEV").is_ok() { -// let header = HeaderLocation::FromPath(format!("{BUNDLED_DIR}/src/sqlite3.h")); -// bindings::write_to_out_dir(header, bindgen_rs_path.as_ref()); -// } - -// let dir = env!("CARGO_MANIFEST_DIR"); -// std::fs::copy(format!("{dir}/{bindgen_rs_path}"), out_path).unwrap(); - -// let mut cfg = cc::Build::new(); -// cfg.file(format!("{BUNDLED_DIR}/src/sqlite3.c")) -// .flag("-std=c11") -// .flag("-DSQLITE_CORE") -// .flag("-DSQLITE_DEFAULT_FOREIGN_KEYS=1") -// .flag("-DSQLITE_ENABLE_API_ARMOR") -// .flag("-DSQLITE_ENABLE_COLUMN_METADATA") -// .flag("-DSQLITE_ENABLE_DBSTAT_VTAB") -// .flag("-DSQLITE_ENABLE_FTS3") -// .flag("-DSQLITE_ENABLE_FTS3_PARENTHESIS") -// .flag("-DSQLITE_ENABLE_FTS5") -// .flag("-DSQLITE_ENABLE_JSON1") -// .flag("-DSQLITE_ENABLE_LOAD_EXTENSION=1") -// .flag("-DSQLITE_ENABLE_MEMORY_MANAGEMENT") -// .flag("-DSQLITE_ENABLE_RTREE") -// .flag("-DSQLITE_ENABLE_STAT2") -// .flag("-DSQLITE_ENABLE_STAT4") -// .flag("-DSQLITE_SOUNDEX") -// .flag("-DSQLITE_THREADSAFE=1") -// .flag("-DSQLITE_USE_URI") -// .flag("-DHAVE_USLEEP=1") -// .flag("-D_POSIX_THREAD_SAFE_FUNCTIONS") // cross compile with MinGW -// .warnings(false); - -// if cfg!(feature = "wasmtime-bindings") { -// cfg.flag("-DLIBSQL_ENABLE_WASM_RUNTIME=1"); -// } - -// if cfg!(feature = "bundled-sqlcipher") { -// cfg.flag("-DSQLITE_HAS_CODEC").flag("-DSQLITE_TEMP_STORE=2"); - -// let target = env::var("TARGET").unwrap(); -// let host = env::var("HOST").unwrap(); - -// let is_windows = host.contains("windows") && target.contains("windows"); -// let is_apple = host.contains("apple") && target.contains("apple"); - -// let lib_dir = env("OPENSSL_LIB_DIR").map(PathBuf::from); -// let inc_dir = env("OPENSSL_INCLUDE_DIR").map(PathBuf::from); -// let mut use_openssl = false; - -// let (lib_dir, inc_dir) = match (lib_dir, inc_dir) { -// (Some(lib_dir), Some(inc_dir)) => { -// use_openssl = true; -// (lib_dir, inc_dir) -// } -// (lib_dir, inc_dir) => match find_openssl_dir(&host, &target) { -// None => { -// if is_windows && !cfg!(feature = "bundled-sqlcipher-vendored-openssl") { -// panic!("Missing environment variable OPENSSL_DIR or OPENSSL_DIR is not set") -// } else { -// (PathBuf::new(), PathBuf::new()) -// } -// } -// Some(openssl_dir) => { -// let lib_dir = lib_dir.unwrap_or_else(|| openssl_dir.join("lib")); -// let inc_dir = inc_dir.unwrap_or_else(|| openssl_dir.join("include")); - -// assert!( -// Path::new(&lib_dir).exists(), -// "OpenSSL library directory does not exist: {}", -// lib_dir.to_string_lossy() -// ); - -// if !Path::new(&inc_dir).exists() { -// panic!( -// "OpenSSL include directory does not exist: {}", -// inc_dir.to_string_lossy() -// ); -// } - -// use_openssl = true; -// (lib_dir, inc_dir) -// } -// }, -// }; - -// if cfg!(feature = "bundled-sqlcipher-vendored-openssl") { -// cfg.include(env::var("DEP_OPENSSL_INCLUDE").unwrap()); -// // cargo will resolve downstream to the static lib in -// // openssl-sys -// } else if use_openssl { -// cfg.include(inc_dir.to_string_lossy().as_ref()); -// let lib_name = if is_windows { "libcrypto" } else { "crypto" }; -// println!("cargo:rustc-link-lib=dylib={}", lib_name); -// println!("cargo:rustc-link-search={}", lib_dir.to_string_lossy()); -// } else if is_apple { -// cfg.flag("-DSQLCIPHER_CRYPTO_CC"); -// println!("cargo:rustc-link-lib=framework=Security"); -// println!("cargo:rustc-link-lib=framework=CoreFoundation"); -// } else { -// // branch not taken on Windows, just `crypto` is fine. -// println!("cargo:rustc-link-lib=dylib=crypto"); -// } -// } - -// if cfg!(feature = "with-asan") { -// cfg.flag("-fsanitize=address"); -// } - -// // Target wasm32-wasi can't compile the default VFS -// if env::var("TARGET").map_or(false, |v| v == "wasm32-wasi") { -// cfg.flag("-DSQLITE_OS_OTHER") -// // https://github.com/rust-lang/rust/issues/74393 -// .flag("-DLONGDOUBLE_TYPE=double"); -// if cfg!(feature = "wasm32-wasi-vfs") { -// cfg.file("sqlite3/wasm32-wasi-vfs.c"); -// } -// } -// if cfg!(feature = "unlock_notify") { -// cfg.flag("-DSQLITE_ENABLE_UNLOCK_NOTIFY"); -// } -// if cfg!(feature = "preupdate_hook") { -// cfg.flag("-DSQLITE_ENABLE_PREUPDATE_HOOK"); -// } -// if cfg!(feature = "session") { -// cfg.flag("-DSQLITE_ENABLE_SESSION"); -// } - -// if let Ok(limit) = env::var("SQLITE_MAX_VARIABLE_NUMBER") { -// cfg.flag(&format!("-DSQLITE_MAX_VARIABLE_NUMBER={limit}")); -// } -// println!("cargo:rerun-if-env-changed=SQLITE_MAX_VARIABLE_NUMBER"); - -// if let Ok(limit) = env::var("SQLITE_MAX_EXPR_DEPTH") { -// cfg.flag(&format!("-DSQLITE_MAX_EXPR_DEPTH={limit}")); -// } -// println!("cargo:rerun-if-env-changed=SQLITE_MAX_EXPR_DEPTH"); - -// if let Ok(limit) = env::var("SQLITE_MAX_COLUMN") { -// cfg.flag(&format!("-DSQLITE_MAX_COLUMN={limit}")); -// } -// println!("cargo:rerun-if-env-changed=SQLITE_MAX_COLUMN"); - -// if let Ok(extras) = env::var("LIBSQLITE3_FLAGS") { -// for extra in extras.split_whitespace() { -// if extra.starts_with("-D") || extra.starts_with("-U") { -// cfg.flag(extra); -// } else if extra.starts_with("SQLITE_") { -// cfg.flag(&format!("-D{extra}")); -// } else { -// panic!("Don't understand {} in LIBSQLITE3_FLAGS", extra); -// } -// } -// } -// println!("cargo:rerun-if-env-changed=LIBSQLITE3_FLAGS"); - -// cfg.compile(LIB_NAME); - -// println!("cargo:lib_dir={out_dir}"); -// } - -// fn copy_multiple_ciphers(out_dir: &str, out_path: &Path) { -// let dylib = format!("{BUNDLED_DIR}/SQLite3MultipleCiphers/build/libsqlite3mc_static.a"); -// if !Path::new(&dylib).exists() { -// build_multiple_ciphers(out_path); -// } - -// std::fs::copy(dylib, format!("{out_dir}/libsqlite3mc.a")).unwrap(); -// println!("cargo:rustc-link-lib=static=sqlite3mc"); -// println!("cargo:rustc-link-search={out_dir}"); -// } - -// fn build_multiple_ciphers(out_path: &Path) { -// let bindgen_rs_path = if cfg!(feature = "session") { -// "bundled/bindings/session_bindgen.rs" -// } else { -// "bundled/bindings/bindgen.rs" -// }; -// if std::env::var("LIBSQL_DEV").is_ok() { -// let header = HeaderLocation::FromPath(format!("{BUNDLED_DIR}/src/sqlite3.h")); -// bindings::write_to_out_dir(header, bindgen_rs_path.as_ref()); -// } -// let dir = env!("CARGO_MANIFEST_DIR"); -// std::fs::copy(format!("{dir}/{bindgen_rs_path}"), out_path).unwrap(); - -// std::fs::copy( -// (BUNDLED_DIR.as_ref() as &Path) -// .join("src") -// .join("sqlite3.c"), -// (BUNDLED_DIR.as_ref() as &Path) -// .join("SQLite3MultipleCiphers") -// .join("src") -// .join("sqlite3.c"), -// ) -// .unwrap(); - -// let bundled_dir = fs::canonicalize(BUNDLED_DIR).unwrap(); - -// let build_dir = bundled_dir.join("SQLite3MultipleCiphers").join("build"); -// let _ = fs::remove_dir_all(build_dir.clone()); -// fs::create_dir_all(build_dir.clone()).unwrap(); - -// let mut cmake_opts: Vec<&str> = vec![]; - -// let cargo_build_target = env::var("CARGO_BUILD_TARGET").unwrap_or_default(); -// let cross_cc_var_name = format!("CC_{}", cargo_build_target.replace("-", "_")); -// let cross_cc = env::var(&cross_cc_var_name).ok(); - -// let cross_cxx_var_name = format!("CXX_{}", cargo_build_target.replace("-", "_")); -// let cross_cxx = env::var(&cross_cxx_var_name).ok(); - -// let toolchain_path = build_dir.join("toolchain.cmake"); -// let cmake_toolchain_opt = format!("-DCMAKE_TOOLCHAIN_FILE=toolchain.cmake"); - -// let mut toolchain_file = OpenOptions::new() -// .create(true) -// .write(true) -// .append(true) -// .open(toolchain_path.clone()) -// .unwrap(); - -// if let Some(ref cc) = cross_cc { -// if cc.contains("aarch64") && cc.contains("linux") { -// cmake_opts.push(&cmake_toolchain_opt); -// writeln!(toolchain_file, "set(CMAKE_SYSTEM_NAME \"Linux\")").unwrap(); -// writeln!(toolchain_file, "set(CMAKE_SYSTEM_PROCESSOR \"arm64\")").unwrap(); -// } -// } -// if let Some(cc) = cross_cc { -// writeln!(toolchain_file, "set(CMAKE_C_COMPILER {})", cc).unwrap(); -// } -// if let Some(cxx) = cross_cxx { -// writeln!(toolchain_file, "set(CMAKE_CXX_COMPILER {})", cxx).unwrap(); -// } - -// cmake_opts.push("-DCMAKE_BUILD_TYPE=Release"); -// cmake_opts.push("-DSQLITE3MC_STATIC=ON"); -// cmake_opts.push("-DCODEC_TYPE=AES256"); -// cmake_opts.push("-DSQLITE3MC_BUILD_SHELL=OFF"); -// cmake_opts.push("-DSQLITE_SHELL_IS_UTF8=OFF"); -// cmake_opts.push("-DSQLITE_USER_AUTHENTICATION=OFF"); -// cmake_opts.push("-DSQLITE_SECURE_DELETE=OFF"); -// cmake_opts.push("-DSQLITE_ENABLE_COLUMN_METADATA=ON"); -// cmake_opts.push("-DSQLITE_USE_URI=ON"); -// cmake_opts.push("-DCMAKE_POSITION_INDEPENDENT_CODE=ON"); - -// let mut cmake = Command::new("cmake"); -// cmake.current_dir("bundled/SQLite3MultipleCiphers/build"); -// cmake.args(cmake_opts.clone()); -// cmake.arg(".."); -// if cfg!(feature = "wasmtime-bindings") { -// cmake.arg("-DLIBSQL_ENABLE_WASM_RUNTIME=1"); -// } -// if cfg!(feature = "session") { -// cmake.arg("-DSQLITE_ENABLE_PREUPDATE_HOOK=ON"); -// cmake.arg("-DSQLITE_ENABLE_SESSION=ON"); -// } -// println!("Running `cmake` with options: {}", cmake_opts.join(" ")); -// let status = cmake.status().unwrap(); -// if !status.success() { -// panic!("Failed to run cmake with options: {}", cmake_opts.join(" ")); -// } - -// let mut make = Command::new("cmake"); -// make.current_dir("bundled/SQLite3MultipleCiphers/build"); -// make.args(&["--build", "."]); -// make.args(&["--config", "Release"]); -// if !make.status().unwrap().success() { -// panic!("Failed to run make"); -// } -// // The `msbuild` tool puts the output in a different place so let's move it. -// if Path::exists(&build_dir.join("Release/sqlite3mc_static.lib")) { -// fs::rename( -// build_dir.join("Release/sqlite3mc_static.lib"), -// build_dir.join("libsqlite3mc_static.a"), -// ) -// .unwrap(); -// } -// } - -// fn env(name: &str) -> Option { -// let prefix = env::var("TARGET").unwrap().to_uppercase().replace('-', "_"); -// let prefixed = format!("{prefix}_{name}"); -// let var = env::var_os(prefixed); - -// match var { -// None => env::var_os(name), -// _ => var, -// } -// } - -// fn find_openssl_dir(_host: &str, _target: &str) -> Option { -// let openssl_dir = env("OPENSSL_DIR"); -// openssl_dir.map(PathBuf::from) -// } - -// fn env_prefix() -> &'static str { -// if cfg!(any(feature = "sqlcipher", feature = "bundled-sqlcipher")) { -// "SQLCIPHER" -// } else { -// "SQLITE3" -// } -// } - -// pub enum HeaderLocation { -// FromEnvironment, -// Wrapper, -// FromPath(String), -// } - -// impl From for String { -// fn from(header: HeaderLocation) -> String { -// match header { -// HeaderLocation::FromEnvironment => { -// let prefix = env_prefix(); -// let mut header = env::var(format!("{prefix}_INCLUDE_DIR")).unwrap_or_else(|_| { -// panic!( -// "{}_INCLUDE_DIR must be set if {}_LIB_DIR is set", -// prefix, prefix -// ) -// }); -// header.push_str("/sqlite3.h"); -// header -// } -// HeaderLocation::Wrapper => "wrapper.h".into(), -// HeaderLocation::FromPath(path) => path, -// } -// } -// } - -// mod bindings { -// use super::HeaderLocation; -// use bindgen::callbacks::{IntKind, ParseCallbacks}; - -// use std::fs::OpenOptions; -// use std::io::Write; -// use std::path::Path; - -// #[derive(Debug)] -// struct SqliteTypeChooser; - -// impl ParseCallbacks for SqliteTypeChooser { -// fn int_macro(&self, _name: &str, value: i64) -> Option { -// if value >= i32::MIN as i64 && value <= i32::MAX as i64 { -// Some(IntKind::I32) -// } else { -// None -// } -// } -// fn item_name(&self, original_item_name: &str) -> Option { -// original_item_name -// .strip_prefix("sqlite3_index_info_") -// .map(|s| s.to_owned()) -// } -// } - -// // Are we generating the bundled bindings? Used to avoid emitting things -// // that would be problematic in bundled builds. This env var is set by -// // `upgrade.sh`. -// fn generating_bundled_bindings() -> bool { -// // Hacky way to know if we're generating the bundled bindings -// println!("cargo:rerun-if-env-changed=LIBSQLITE3_SYS_BUNDLING"); -// match std::env::var("LIBSQLITE3_SYS_BUNDLING") { -// Ok(v) => v != "0", -// Err(_) => false, -// } -// } - -// pub fn write_to_out_dir(header: HeaderLocation, out_path: &Path) { -// let header: String = header.into(); -// let mut output = Vec::new(); -// let mut bindings = bindgen::builder() -// .trust_clang_mangling(false) -// .header(header.clone()) -// .parse_callbacks(Box::new(SqliteTypeChooser)) -// .blocklist_function("sqlite3_auto_extension") -// .raw_line( -// r#"extern "C" { -// pub fn sqlite3_auto_extension( -// xEntryPoint: ::std::option::Option< -// unsafe extern "C" fn( -// db: *mut sqlite3, -// pzErrMsg: *mut *const ::std::os::raw::c_char, -// pThunk: *const sqlite3_api_routines, -// ) -> ::std::os::raw::c_int, -// >, -// ) -> ::std::os::raw::c_int; -// }"#, -// ) -// .blocklist_function("sqlite3_cancel_auto_extension") -// .raw_line( -// r#"extern "C" { -// pub fn sqlite3_cancel_auto_extension( -// xEntryPoint: ::std::option::Option< -// unsafe extern "C" fn( -// db: *mut sqlite3, -// pzErrMsg: *mut *const ::std::os::raw::c_char, -// pThunk: *const sqlite3_api_routines, -// ) -> ::std::os::raw::c_int, -// >, -// ) -> ::std::os::raw::c_int; -// }"#, -// ); - -// if cfg!(any(feature = "sqlcipher", feature = "bundled-sqlcipher")) { -// bindings = bindings.clang_arg("-DSQLITE_HAS_CODEC"); -// } -// if cfg!(feature = "unlock_notify") { -// bindings = bindings.clang_arg("-DSQLITE_ENABLE_UNLOCK_NOTIFY"); -// } -// if cfg!(feature = "preupdate_hook") { -// bindings = bindings.clang_arg("-DSQLITE_ENABLE_PREUPDATE_HOOK"); -// } -// if cfg!(feature = "session") { -// bindings = bindings.clang_arg("-DSQLITE_ENABLE_SESSION"); -// } - -// // When cross compiling unless effort is taken to fix the issue, bindgen -// // will find the wrong headers. There's only one header included by the -// // amalgamated `sqlite.h`: `stdarg.h`. -// // -// // Thankfully, there's almost no case where rust code needs to use -// // functions taking `va_list` (It's nearly impossible to get a `va_list` -// // in Rust unless you get passed it by C code for some reason). -// // -// // Arguably, we should never be including these, but we include them for -// // the cases where they aren't totally broken... -// let target_arch = std::env::var("TARGET").unwrap(); -// let host_arch = std::env::var("HOST").unwrap(); -// let is_cross_compiling = target_arch != host_arch; - -// // Note that when generating the bundled file, we're essentially always -// // cross compiling. -// if generating_bundled_bindings() || is_cross_compiling { -// // Get rid of va_list, as it's not -// bindings = bindings -// .blocklist_function("sqlite3_vmprintf") -// .blocklist_function("sqlite3_vsnprintf") -// .blocklist_function("sqlite3_str_vappendf") -// .blocklist_type("va_list") -// .blocklist_type("__builtin_va_list") -// .blocklist_type("__gnuc_va_list") -// .blocklist_type("__va_list_tag") -// .blocklist_item("__GNUC_VA_LIST"); -// } - -// bindings -// .layout_tests(false) -// .generate() -// .unwrap_or_else(|_| panic!("could not run bindgen on header {}", header)) -// .write(Box::new(&mut output)) -// .expect("could not write output of bindgen"); -// let mut output = String::from_utf8(output).expect("bindgen output was not UTF-8?!"); - -// // rusqlite's functions feature ors in the SQLITE_DETERMINISTIC flag when it -// // can. This flag was added in SQLite 3.8.3, but oring it in in prior -// // versions of SQLite is harmless. We don't want to not build just -// // because this flag is missing (e.g., if we're linking against -// // SQLite 3.7.x), so append the flag manually if it isn't present in bindgen's -// // output. -// if !output.contains("pub const SQLITE_DETERMINISTIC") { -// output.push_str("\npub const SQLITE_DETERMINISTIC: i32 = 2048;\n"); -// } - -// let mut file = OpenOptions::new() -// .write(true) -// .truncate(true) -// .create(true) -// .open(out_path) -// .unwrap_or_else(|_| panic!("Could not write to {:?}", out_path)); - -// file.write_all(output.as_bytes()) -// .unwrap_or_else(|_| panic!("Could not write to {:?}", out_path)); -// } -// } +use std::env; +use std::ffi::OsString; +use std::fs::{self, OpenOptions}; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::process::Command; + +const LIB_NAME: &str = "libsql"; +const BUNDLED_DIR: &str = "bundled"; +const SQLITE_DIR: &str = "../libsql-sqlite3"; + +fn main() { + let target = env::var("TARGET").unwrap(); + let host = env::var("HOST").unwrap(); + + let is_apple = host.contains("apple") && target.contains("apple"); + if is_apple { + println!("cargo:rustc-link-lib=framework=Security"); + } + let out_dir = env::var("OUT_DIR").unwrap(); + let out_path = Path::new(&out_dir).join("bindgen.rs"); + + println!("cargo:rerun-if-changed={BUNDLED_DIR}/src/sqlite3.c"); + + if cfg!(feature = "multiple-ciphers") { + println!( + "cargo:rerun-if-changed={BUNDLED_DIR}/SQLite3MultipleCiphers/build/libsqlite3mc_static.a" + ); + } + + if std::env::var("LIBSQL_DEV").is_ok() { + make_amalgation(); + build_multiple_ciphers(&out_path); + } + + let bindgen_rs_path = if cfg!(feature = "session") { + "bundled/bindings/session_bindgen.rs" + } else { + "bundled/bindings/bindgen.rs" + }; + + let dir = env!("CARGO_MANIFEST_DIR"); + std::fs::copy(format!("{dir}/{bindgen_rs_path}"), &out_path).unwrap(); + + println!("cargo:lib_dir={out_dir}"); + + if cfg!(feature = "wasmtime-bindings") && !cfg!(feature = "multiple-ciphers") { + build_bundled(&out_dir, &out_path); + } + + if cfg!(feature = "multiple-ciphers") { + copy_multiple_ciphers(&out_dir, &out_path); + return; + } + + build_bundled(&out_dir, &out_path); +} + +fn make_amalgation() { + let flags = ["-DSQLITE_ENABLE_COLUMN_METADATA=1"]; + + Command::new("make") + .current_dir(SQLITE_DIR) + .arg("clean") + .output() + .unwrap(); + + Command::new("./configure") + .current_dir(SQLITE_DIR) + .env("CFLAGS", flags.join(" ")) + .output() + .unwrap(); + Command::new("make") + .current_dir(SQLITE_DIR) + .output() + .unwrap(); + + std::fs::copy( + (SQLITE_DIR.as_ref() as &Path).join("sqlite3.c"), + (BUNDLED_DIR.as_ref() as &Path).join("src/sqlite3.c"), + ) + .unwrap(); + std::fs::copy( + (SQLITE_DIR.as_ref() as &Path).join("sqlite3.h"), + (BUNDLED_DIR.as_ref() as &Path).join("src/sqlite3.h"), + ) + .unwrap(); +} + +pub fn build_bundled(out_dir: &str, out_path: &Path) { + let bindgen_rs_path = if cfg!(feature = "session") { + "bundled/bindings/session_bindgen.rs" + } else { + "bundled/bindings/bindgen.rs" + }; + + if std::env::var("LIBSQL_DEV").is_ok() { + let header = HeaderLocation::FromPath(format!("{BUNDLED_DIR}/src/sqlite3.h")); + bindings::write_to_out_dir(header, bindgen_rs_path.as_ref()); + } + + let dir = env!("CARGO_MANIFEST_DIR"); + std::fs::copy(format!("{dir}/{bindgen_rs_path}"), out_path).unwrap(); + + let mut cfg = cc::Build::new(); + cfg.file(format!("{BUNDLED_DIR}/src/sqlite3.c")) + .flag("-std=c11") + .flag("-DSQLITE_CORE") + .flag("-DSQLITE_DEFAULT_FOREIGN_KEYS=1") + .flag("-DSQLITE_ENABLE_API_ARMOR") + .flag("-DSQLITE_ENABLE_COLUMN_METADATA") + .flag("-DSQLITE_ENABLE_DBSTAT_VTAB") + .flag("-DSQLITE_ENABLE_FTS3") + .flag("-DSQLITE_ENABLE_FTS3_PARENTHESIS") + .flag("-DSQLITE_ENABLE_FTS5") + .flag("-DSQLITE_ENABLE_JSON1") + .flag("-DSQLITE_ENABLE_LOAD_EXTENSION=1") + .flag("-DSQLITE_ENABLE_MEMORY_MANAGEMENT") + .flag("-DSQLITE_ENABLE_RTREE") + .flag("-DSQLITE_ENABLE_STAT2") + .flag("-DSQLITE_ENABLE_STAT4") + .flag("-DSQLITE_SOUNDEX") + .flag("-DSQLITE_THREADSAFE=1") + .flag("-DSQLITE_USE_URI") + .flag("-DHAVE_USLEEP=1") + .flag("-D_POSIX_THREAD_SAFE_FUNCTIONS") // cross compile with MinGW + .warnings(false); + + if cfg!(feature = "wasmtime-bindings") { + cfg.flag("-DLIBSQL_ENABLE_WASM_RUNTIME=1"); + } + + if cfg!(feature = "bundled-sqlcipher") { + cfg.flag("-DSQLITE_HAS_CODEC").flag("-DSQLITE_TEMP_STORE=2"); + + let target = env::var("TARGET").unwrap(); + let host = env::var("HOST").unwrap(); + + let is_windows = host.contains("windows") && target.contains("windows"); + let is_apple = host.contains("apple") && target.contains("apple"); + + let lib_dir = env("OPENSSL_LIB_DIR").map(PathBuf::from); + let inc_dir = env("OPENSSL_INCLUDE_DIR").map(PathBuf::from); + let mut use_openssl = false; + + let (lib_dir, inc_dir) = match (lib_dir, inc_dir) { + (Some(lib_dir), Some(inc_dir)) => { + use_openssl = true; + (lib_dir, inc_dir) + } + (lib_dir, inc_dir) => match find_openssl_dir(&host, &target) { + None => { + if is_windows && !cfg!(feature = "bundled-sqlcipher-vendored-openssl") { + panic!("Missing environment variable OPENSSL_DIR or OPENSSL_DIR is not set") + } else { + (PathBuf::new(), PathBuf::new()) + } + } + Some(openssl_dir) => { + let lib_dir = lib_dir.unwrap_or_else(|| openssl_dir.join("lib")); + let inc_dir = inc_dir.unwrap_or_else(|| openssl_dir.join("include")); + + assert!( + Path::new(&lib_dir).exists(), + "OpenSSL library directory does not exist: {}", + lib_dir.to_string_lossy() + ); + + if !Path::new(&inc_dir).exists() { + panic!( + "OpenSSL include directory does not exist: {}", + inc_dir.to_string_lossy() + ); + } + + use_openssl = true; + (lib_dir, inc_dir) + } + }, + }; + + if cfg!(feature = "bundled-sqlcipher-vendored-openssl") { + cfg.include(env::var("DEP_OPENSSL_INCLUDE").unwrap()); + // cargo will resolve downstream to the static lib in + // openssl-sys + } else if use_openssl { + cfg.include(inc_dir.to_string_lossy().as_ref()); + let lib_name = if is_windows { "libcrypto" } else { "crypto" }; + println!("cargo:rustc-link-lib=dylib={}", lib_name); + println!("cargo:rustc-link-search={}", lib_dir.to_string_lossy()); + } else if is_apple { + cfg.flag("-DSQLCIPHER_CRYPTO_CC"); + println!("cargo:rustc-link-lib=framework=Security"); + println!("cargo:rustc-link-lib=framework=CoreFoundation"); + } else { + // branch not taken on Windows, just `crypto` is fine. + println!("cargo:rustc-link-lib=dylib=crypto"); + } + } + + if cfg!(feature = "with-asan") { + cfg.flag("-fsanitize=address"); + } + + // Target wasm32-wasi can't compile the default VFS + if env::var("TARGET").map_or(false, |v| v == "wasm32-wasi") { + cfg.flag("-DSQLITE_OS_OTHER") + // https://github.com/rust-lang/rust/issues/74393 + .flag("-DLONGDOUBLE_TYPE=double"); + if cfg!(feature = "wasm32-wasi-vfs") { + cfg.file("sqlite3/wasm32-wasi-vfs.c"); + } + } + if cfg!(feature = "unlock_notify") { + cfg.flag("-DSQLITE_ENABLE_UNLOCK_NOTIFY"); + } + if cfg!(feature = "preupdate_hook") { + cfg.flag("-DSQLITE_ENABLE_PREUPDATE_HOOK"); + } + if cfg!(feature = "session") { + cfg.flag("-DSQLITE_ENABLE_SESSION"); + } + + if let Ok(limit) = env::var("SQLITE_MAX_VARIABLE_NUMBER") { + cfg.flag(&format!("-DSQLITE_MAX_VARIABLE_NUMBER={limit}")); + } + println!("cargo:rerun-if-env-changed=SQLITE_MAX_VARIABLE_NUMBER"); + + if let Ok(limit) = env::var("SQLITE_MAX_EXPR_DEPTH") { + cfg.flag(&format!("-DSQLITE_MAX_EXPR_DEPTH={limit}")); + } + println!("cargo:rerun-if-env-changed=SQLITE_MAX_EXPR_DEPTH"); + + if let Ok(limit) = env::var("SQLITE_MAX_COLUMN") { + cfg.flag(&format!("-DSQLITE_MAX_COLUMN={limit}")); + } + println!("cargo:rerun-if-env-changed=SQLITE_MAX_COLUMN"); + + if let Ok(extras) = env::var("LIBSQLITE3_FLAGS") { + for extra in extras.split_whitespace() { + if extra.starts_with("-D") || extra.starts_with("-U") { + cfg.flag(extra); + } else if extra.starts_with("SQLITE_") { + cfg.flag(&format!("-D{extra}")); + } else { + panic!("Don't understand {} in LIBSQLITE3_FLAGS", extra); + } + } + } + println!("cargo:rerun-if-env-changed=LIBSQLITE3_FLAGS"); + + cfg.compile(LIB_NAME); + + println!("cargo:lib_dir={out_dir}"); +} + +fn copy_multiple_ciphers(out_dir: &str, out_path: &Path) { + let dylib = format!("{BUNDLED_DIR}/SQLite3MultipleCiphers/build/libsqlite3mc_static.a"); + if !Path::new(&dylib).exists() { + build_multiple_ciphers(out_path); + } + + std::fs::copy(dylib, format!("{out_dir}/libsqlite3mc.a")).unwrap(); + println!("cargo:rustc-link-lib=static=sqlite3mc"); + println!("cargo:rustc-link-search={out_dir}"); +} + +fn build_multiple_ciphers(out_path: &Path) { + let bindgen_rs_path = if cfg!(feature = "session") { + "bundled/bindings/session_bindgen.rs" + } else { + "bundled/bindings/bindgen.rs" + }; + if std::env::var("LIBSQL_DEV").is_ok() { + let header = HeaderLocation::FromPath(format!("{BUNDLED_DIR}/src/sqlite3.h")); + bindings::write_to_out_dir(header, bindgen_rs_path.as_ref()); + } + let dir = env!("CARGO_MANIFEST_DIR"); + std::fs::copy(format!("{dir}/{bindgen_rs_path}"), out_path).unwrap(); + + std::fs::copy( + (BUNDLED_DIR.as_ref() as &Path) + .join("src") + .join("sqlite3.c"), + (BUNDLED_DIR.as_ref() as &Path) + .join("SQLite3MultipleCiphers") + .join("src") + .join("sqlite3.c"), + ) + .unwrap(); + + let bundled_dir = fs::canonicalize(BUNDLED_DIR).unwrap(); + + let build_dir = bundled_dir.join("SQLite3MultipleCiphers").join("build"); + let _ = fs::remove_dir_all(build_dir.clone()); + fs::create_dir_all(build_dir.clone()).unwrap(); + + let mut cmake_opts: Vec<&str> = vec![]; + + let cargo_build_target = env::var("CARGO_BUILD_TARGET").unwrap_or_default(); + let cross_cc_var_name = format!("CC_{}", cargo_build_target.replace("-", "_")); + let cross_cc = env::var(&cross_cc_var_name).ok(); + + let cross_cxx_var_name = format!("CXX_{}", cargo_build_target.replace("-", "_")); + let cross_cxx = env::var(&cross_cxx_var_name).ok(); + + let toolchain_path = build_dir.join("toolchain.cmake"); + let cmake_toolchain_opt = format!("-DCMAKE_TOOLCHAIN_FILE=toolchain.cmake"); + + let mut toolchain_file = OpenOptions::new() + .create(true) + .write(true) + .append(true) + .open(toolchain_path.clone()) + .unwrap(); + + if let Some(ref cc) = cross_cc { + if cc.contains("aarch64") && cc.contains("linux") { + cmake_opts.push(&cmake_toolchain_opt); + writeln!(toolchain_file, "set(CMAKE_SYSTEM_NAME \"Linux\")").unwrap(); + writeln!(toolchain_file, "set(CMAKE_SYSTEM_PROCESSOR \"arm64\")").unwrap(); + } + } + if let Some(cc) = cross_cc { + writeln!(toolchain_file, "set(CMAKE_C_COMPILER {})", cc).unwrap(); + } + if let Some(cxx) = cross_cxx { + writeln!(toolchain_file, "set(CMAKE_CXX_COMPILER {})", cxx).unwrap(); + } + + cmake_opts.push("-DCMAKE_BUILD_TYPE=Release"); + cmake_opts.push("-DSQLITE3MC_STATIC=ON"); + cmake_opts.push("-DCODEC_TYPE=AES256"); + cmake_opts.push("-DSQLITE3MC_BUILD_SHELL=OFF"); + cmake_opts.push("-DSQLITE_SHELL_IS_UTF8=OFF"); + cmake_opts.push("-DSQLITE_USER_AUTHENTICATION=OFF"); + cmake_opts.push("-DSQLITE_SECURE_DELETE=OFF"); + cmake_opts.push("-DSQLITE_ENABLE_COLUMN_METADATA=ON"); + cmake_opts.push("-DSQLITE_USE_URI=ON"); + cmake_opts.push("-DCMAKE_POSITION_INDEPENDENT_CODE=ON"); + + let mut cmake = Command::new("cmake"); + cmake.current_dir("bundled/SQLite3MultipleCiphers/build"); + cmake.args(cmake_opts.clone()); + cmake.arg(".."); + if cfg!(feature = "wasmtime-bindings") { + cmake.arg("-DLIBSQL_ENABLE_WASM_RUNTIME=1"); + } + if cfg!(feature = "session") { + cmake.arg("-DSQLITE_ENABLE_PREUPDATE_HOOK=ON"); + cmake.arg("-DSQLITE_ENABLE_SESSION=ON"); + } + println!("Running `cmake` with options: {}", cmake_opts.join(" ")); + let status = cmake.status().unwrap(); + if !status.success() { + panic!("Failed to run cmake with options: {}", cmake_opts.join(" ")); + } + + let mut make = Command::new("cmake"); + make.current_dir("bundled/SQLite3MultipleCiphers/build"); + make.args(&["--build", "."]); + make.args(&["--config", "Release"]); + if !make.status().unwrap().success() { + panic!("Failed to run make"); + } + // The `msbuild` tool puts the output in a different place so let's move it. + if Path::exists(&build_dir.join("Release/sqlite3mc_static.lib")) { + fs::rename( + build_dir.join("Release/sqlite3mc_static.lib"), + build_dir.join("libsqlite3mc_static.a"), + ) + .unwrap(); + } +} + +fn env(name: &str) -> Option { + let prefix = env::var("TARGET").unwrap().to_uppercase().replace('-', "_"); + let prefixed = format!("{prefix}_{name}"); + let var = env::var_os(prefixed); + + match var { + None => env::var_os(name), + _ => var, + } +} + +fn find_openssl_dir(_host: &str, _target: &str) -> Option { + let openssl_dir = env("OPENSSL_DIR"); + openssl_dir.map(PathBuf::from) +} + +fn env_prefix() -> &'static str { + if cfg!(any(feature = "sqlcipher", feature = "bundled-sqlcipher")) { + "SQLCIPHER" + } else { + "SQLITE3" + } +} + +pub enum HeaderLocation { + FromEnvironment, + Wrapper, + FromPath(String), +} + +impl From for String { + fn from(header: HeaderLocation) -> String { + match header { + HeaderLocation::FromEnvironment => { + let prefix = env_prefix(); + let mut header = env::var(format!("{prefix}_INCLUDE_DIR")).unwrap_or_else(|_| { + panic!( + "{}_INCLUDE_DIR must be set if {}_LIB_DIR is set", + prefix, prefix + ) + }); + header.push_str("/sqlite3.h"); + header + } + HeaderLocation::Wrapper => "wrapper.h".into(), + HeaderLocation::FromPath(path) => path, + } + } +} + +mod bindings { + use super::HeaderLocation; + use bindgen::callbacks::{IntKind, ParseCallbacks}; + + use std::fs::OpenOptions; + use std::io::Write; + use std::path::Path; + + #[derive(Debug)] + struct SqliteTypeChooser; + + impl ParseCallbacks for SqliteTypeChooser { + fn int_macro(&self, _name: &str, value: i64) -> Option { + if value >= i32::MIN as i64 && value <= i32::MAX as i64 { + Some(IntKind::I32) + } else { + None + } + } + fn item_name(&self, original_item_name: &str) -> Option { + original_item_name + .strip_prefix("sqlite3_index_info_") + .map(|s| s.to_owned()) + } + } + + // Are we generating the bundled bindings? Used to avoid emitting things + // that would be problematic in bundled builds. This env var is set by + // `upgrade.sh`. + fn generating_bundled_bindings() -> bool { + // Hacky way to know if we're generating the bundled bindings + println!("cargo:rerun-if-env-changed=LIBSQLITE3_SYS_BUNDLING"); + match std::env::var("LIBSQLITE3_SYS_BUNDLING") { + Ok(v) => v != "0", + Err(_) => false, + } + } + + pub fn write_to_out_dir(header: HeaderLocation, out_path: &Path) { + let header: String = header.into(); + let mut output = Vec::new(); + let mut bindings = bindgen::builder() + .trust_clang_mangling(false) + .header(header.clone()) + .parse_callbacks(Box::new(SqliteTypeChooser)) + .blocklist_function("sqlite3_auto_extension") + .raw_line( + r#"extern "C" { + pub fn sqlite3_auto_extension( + xEntryPoint: ::std::option::Option< + unsafe extern "C" fn( + db: *mut sqlite3, + pzErrMsg: *mut *const ::std::os::raw::c_char, + pThunk: *const sqlite3_api_routines, + ) -> ::std::os::raw::c_int, + >, + ) -> ::std::os::raw::c_int; +}"#, + ) + .blocklist_function("sqlite3_cancel_auto_extension") + .raw_line( + r#"extern "C" { + pub fn sqlite3_cancel_auto_extension( + xEntryPoint: ::std::option::Option< + unsafe extern "C" fn( + db: *mut sqlite3, + pzErrMsg: *mut *const ::std::os::raw::c_char, + pThunk: *const sqlite3_api_routines, + ) -> ::std::os::raw::c_int, + >, + ) -> ::std::os::raw::c_int; +}"#, + ); + + if cfg!(any(feature = "sqlcipher", feature = "bundled-sqlcipher")) { + bindings = bindings.clang_arg("-DSQLITE_HAS_CODEC"); + } + if cfg!(feature = "unlock_notify") { + bindings = bindings.clang_arg("-DSQLITE_ENABLE_UNLOCK_NOTIFY"); + } + if cfg!(feature = "preupdate_hook") { + bindings = bindings.clang_arg("-DSQLITE_ENABLE_PREUPDATE_HOOK"); + } + if cfg!(feature = "session") { + bindings = bindings.clang_arg("-DSQLITE_ENABLE_SESSION"); + } + + // When cross compiling unless effort is taken to fix the issue, bindgen + // will find the wrong headers. There's only one header included by the + // amalgamated `sqlite.h`: `stdarg.h`. + // + // Thankfully, there's almost no case where rust code needs to use + // functions taking `va_list` (It's nearly impossible to get a `va_list` + // in Rust unless you get passed it by C code for some reason). + // + // Arguably, we should never be including these, but we include them for + // the cases where they aren't totally broken... + let target_arch = std::env::var("TARGET").unwrap(); + let host_arch = std::env::var("HOST").unwrap(); + let is_cross_compiling = target_arch != host_arch; + + // Note that when generating the bundled file, we're essentially always + // cross compiling. + if generating_bundled_bindings() || is_cross_compiling { + // Get rid of va_list, as it's not + bindings = bindings + .blocklist_function("sqlite3_vmprintf") + .blocklist_function("sqlite3_vsnprintf") + .blocklist_function("sqlite3_str_vappendf") + .blocklist_type("va_list") + .blocklist_type("__builtin_va_list") + .blocklist_type("__gnuc_va_list") + .blocklist_type("__va_list_tag") + .blocklist_item("__GNUC_VA_LIST"); + } + + bindings + .layout_tests(false) + .generate() + .unwrap_or_else(|_| panic!("could not run bindgen on header {}", header)) + .write(Box::new(&mut output)) + .expect("could not write output of bindgen"); + let mut output = String::from_utf8(output).expect("bindgen output was not UTF-8?!"); + + // rusqlite's functions feature ors in the SQLITE_DETERMINISTIC flag when it + // can. This flag was added in SQLite 3.8.3, but oring it in in prior + // versions of SQLite is harmless. We don't want to not build just + // because this flag is missing (e.g., if we're linking against + // SQLite 3.7.x), so append the flag manually if it isn't present in bindgen's + // output. + if !output.contains("pub const SQLITE_DETERMINISTIC") { + output.push_str("\npub const SQLITE_DETERMINISTIC: i32 = 2048;\n"); + } + + let mut file = OpenOptions::new() + .write(true) + .truncate(true) + .create(true) + .open(out_path) + .unwrap_or_else(|_| panic!("Could not write to {:?}", out_path)); + + file.write_all(output.as_bytes()) + .unwrap_or_else(|_| panic!("Could not write to {:?}", out_path)); + } +} From ecf2389ae3a84a38a1b7fcf59c18b0cf613b61c3 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 1 Mar 2024 11:10:09 -0800 Subject: [PATCH 03/63] fixed temp change --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index 836a7f0d34..bc55503ce3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "bottomless", "bottomless-cli", "libsql-replication", + "libsql-ffi", "vendored/rusqlite", "vendored/sqlite3-parser", From 9521d7297edc0fc3714a0f3179487323dd890e55 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 1 Mar 2024 12:56:27 -0800 Subject: [PATCH 04/63] Merge branch 'main' into jw/changing-how-jwt-is-passed-around --- libsql-ffi/build.rs | 2 +- libsql-server/src/auth/parsers.rs | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/libsql-ffi/build.rs b/libsql-ffi/build.rs index fd631aabfe..a1fc7caedc 100644 --- a/libsql-ffi/build.rs +++ b/libsql-ffi/build.rs @@ -6,7 +6,7 @@ use std::path::{Path, PathBuf}; use std::process::Command; const LIB_NAME: &str = "libsql"; -const BUNDLED_DIR: &str = "bundled"; +const BUNDLED_DIR: &str = "/Users/julian/src/github.com/Shopify/libsql/libsql-ffi/bundled"; const SQLITE_DIR: &str = "../libsql-sqlite3"; fn main() { diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index b89f130118..fbb72f7fba 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -39,14 +39,17 @@ pub fn parse_jwt_key(data: &str) -> Result { pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result { // todo print metadata // tracing::trace!() + + return ""; let header = metadata.get(GRPC_AUTH_HEADER) .ok_or(tonic::Status::new(tonic::Code::Unauthenticated,""))?; + let header_str = header.to_str().context("Auth should be ASCII") .map_err(|err| tonic::Status::new(tonic::Code::InvalidArgument, ""))?; - auth_string_to_auth_context(header_str).context(format!("Failed parse grpc auth: {header_str}")) - .map_err(|err| tonic::Status::new(tonic::Code::InvalidArgument, ""))?; + return auth_string_to_auth_context(header_str).context(format!("Failed parse grpc auth: {header_str}")) + .map_err(|err| tonic::Status::new(tonic::Code::InvalidArgument, "")); } // todo this should be a constructor or a factory associates iwth userauthcontext From 83e91ffb4666926fb9e71abcb5f76c4bd54f61a1 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 1 Mar 2024 14:09:56 -0800 Subject: [PATCH 05/63] fixed failing tests --- libsql-server/src/auth/parsers.rs | 5 +---- libsql-server/src/rpc/proxy.rs | 4 ++-- libsql-server/src/rpc/replication_log.rs | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index fbb72f7fba..4c0c648d39 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -37,10 +37,7 @@ pub fn parse_jwt_key(data: &str) -> Result { } pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result { - // todo print metadata - // tracing::trace!() - - return ""; +// todo fix error messages and chaining let header = metadata.get(GRPC_AUTH_HEADER) .ok_or(tonic::Status::new(tonic::Code::Unauthenticated,""))?; diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index 01539e802a..a1b0079894 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -334,9 +334,9 @@ impl ProxyService { )))?, }; - let context = parse_grpc_auth_header(req.metadata())?; - + let auth = if let Some(auth) = auth { + let context = parse_grpc_auth_header(req.metadata())?; auth.authenticate(context)? } else { Authenticated::from_proxy_grpc_request(req)? diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 492b8b46ad..737ca77832 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -93,8 +93,8 @@ impl ReplicationLogService { )))?, }; - let user_credential = parse_grpc_auth_header(req.metadata())?; if let Some(auth) = auth { + let user_credential = parse_grpc_auth_header(req.metadata())?; auth.authenticate(user_credential)?; } From b8bb476652aa38ca0af7606e79e6db5eefc94051 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Mon, 4 Mar 2024 08:47:18 -0800 Subject: [PATCH 06/63] next iteration of unit test fixing --- libsql-server/src/http/user/mod.rs | 29 ++++++++++++++---------- libsql-server/src/rpc/replica_proxy.rs | 13 ++++------- libsql-server/src/rpc/replication_log.rs | 3 ++- libsql-server/tests/hrana/batch.rs | 2 +- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 661335cd7a..0d24f980fd 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -30,7 +30,7 @@ use tonic::transport::Server; use tower_http::{compression::CompressionLayer, cors}; use crate::auth::parsers::auth_string_to_auth_context; -use crate::auth::{Auth, Authenticated, Jwt}; +use crate::auth::{Auth, Authenticated, Jwt, UserAuthContext}; use crate::connection::{Connection, RequestContext}; use crate::error::Error; use crate::hrana; @@ -460,23 +460,28 @@ impl FromRequestParts for Authenticated { state.disable_default_namespace, state.disable_namespaces, )?; - +// todo get rid of duplication - this and replication_log.rs let namespace_jwt_key = state .namespaces .with(ns.clone(), |ns| ns.jwt_key()) .await??; - let header = parts.headers.get(hyper::header::AUTHORIZATION).context("auth header not found")?; - let header_str = header.to_str().context("non ASCII auth token")?; - let context = auth_string_to_auth_context(header_str).context("auth header parsing failed")?; - + let auth = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; - - Ok(auth) + .map(Jwt::new) + .map(Auth::new) + .unwrap_or_else(|| state.user_auth_strategy.clone()); + + // todo julian chain these three instead of short circuit + // not all auth strategies need those + let context = parts.headers + .get(hyper::header::AUTHORIZATION).context("auth header not found") + .and_then(|h| h.to_str().context("non ascii auth token")) + .and_then(|t| auth_string_to_auth_context(t)) + .unwrap_or(UserAuthContext{scheme: None, token: None}); + + let authenticated = auth.authenticate(context)?; + Ok(authenticated) } } diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 2cb84e6e1a..e8a879f490 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -8,7 +8,7 @@ use tonic::{transport::Channel, Request, Status}; use crate::auth::parsers::parse_grpc_auth_header; -use crate::auth::{user_auth_strategies::UserAuthContext, Auth, Jwt, UserAuthStrategy}; +use crate::auth::{Auth, Jwt, UserAuthStrategy}; use crate::namespace::NamespaceStore; pub struct ReplicaProxyService { @@ -48,17 +48,12 @@ impl ReplicaProxyService { match namespace_jwt_key { Ok(Ok(Some(key))) => { - let authenticated = - Jwt::new(key).authenticate(auth_context)?; + let authenticated = Jwt::new(key).authenticate(auth_context)?; authenticated.upgrade_grpc_request(req); - Ok(()) } - Ok(Ok(None)) => { - let authenticated = self - .user_auth_strategy - .authenticate(auth_context)?; - + Ok(Ok(None)) => { // non jwt auth, we don't know if context matches it + let authenticated = self.user_auth_strategy.authenticate(auth_context)?; authenticated.upgrade_grpc_request(req); Ok(()) } diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 737ca77832..a7d7fc51b2 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -71,12 +71,13 @@ impl ReplicationLogService { req: &tonic::Request, namespace: NamespaceName, ) -> Result<(), Status> { + + // todo duplicate code let namespace_jwt_key = self .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) .await; - let auth = match namespace_jwt_key { Ok(Ok(Some(key))) => Some(Auth::new(Jwt::new(key))), Ok(Ok(None)) => self.user_auth_strategy.clone(), diff --git a/libsql-server/tests/hrana/batch.rs b/libsql-server/tests/hrana/batch.rs index 8fe23704c3..ec449463fe 100644 --- a/libsql-server/tests/hrana/batch.rs +++ b/libsql-server/tests/hrana/batch.rs @@ -5,7 +5,7 @@ use libsql_server::hrana_proto::{Batch, BatchStep, Stmt}; use crate::common::http::Client; use crate::common::net::TurmoilConnector; -#[test] +#[test] fn sample_request() { let mut sim = turmoil::Builder::new().build(); sim.host("primary", super::make_standalone_server); From e322b1a1a015a96973124aa4a523fbaa992f55b6 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Mon, 4 Mar 2024 10:09:16 -0800 Subject: [PATCH 07/63] fixed remaining unit tests --- libsql-server/src/http/user/extract.rs | 21 ++++++++++++--------- libsql-server/src/http/user/mod.rs | 19 +++++++++---------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 60c480cb96..4eacbb93d3 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -2,7 +2,7 @@ use anyhow::Context; use axum::extract::FromRequestParts; use crate::{ - auth::{parsers::auth_string_to_auth_context, Jwt, Auth}, + auth::{parsers::auth_string_to_auth_context, Auth, Jwt, UserAuthContext}, connection::RequestContext, }; @@ -29,20 +29,23 @@ impl FromRequestParts for RequestContext { .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - let header = parts.headers.get(hyper::header::AUTHORIZATION).context("auth header not found")?; - let header_str = header.to_str().context("non ASCII auth token")?; - let context = auth_string_to_auth_context(header_str).context("auth header parsing failed")?; + // todo think how to decide if the particular auth stragegy even needs a context? + // we need it to understand whether to treat absence of auth header as error or not + let context = parts.headers + .get(hyper::header::AUTHORIZATION).context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking + .and_then(|h| h.to_str().context("non ascii auth token")) + .and_then(|t| auth_string_to_auth_context(t)) + .unwrap_or(UserAuthContext{scheme: None, token: None}); - let auth = namespace_jwt_key + + let authenticated = namespace_jwt_key .map(Jwt::new) .map(Auth::new) .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; - - // end todo + .authenticate(context)?; Ok(Self::new( - auth, + authenticated, namespace, state.namespaces.meta_store().clone(), )) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 0d24f980fd..855e2225af 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -466,21 +466,20 @@ impl FromRequestParts for Authenticated { .with(ns.clone(), |ns| ns.jwt_key()) .await??; - - let auth = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()); - - // todo julian chain these three instead of short circuit - // not all auth strategies need those + // todo think how to decide if the particular auth stragegy even needs a context? + // we need it to understand whether to treat absence of auth header as error or not let context = parts.headers - .get(hyper::header::AUTHORIZATION).context("auth header not found") + .get(hyper::header::AUTHORIZATION).context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking .and_then(|h| h.to_str().context("non ascii auth token")) .and_then(|t| auth_string_to_auth_context(t)) .unwrap_or(UserAuthContext{scheme: None, token: None}); - let authenticated = auth.authenticate(context)?; + + let authenticated = namespace_jwt_key + .map(Jwt::new) + .map(Auth::new) + .unwrap_or_else(|| state.user_auth_strategy.clone()) + .authenticate(context)?; Ok(authenticated) } } From 6b517ea42ae30a13d82b60e29ffa8c671990f9b1 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Mon, 4 Mar 2024 10:20:30 -0800 Subject: [PATCH 08/63] reverted debug --- libsql-ffi/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql-ffi/build.rs b/libsql-ffi/build.rs index a1fc7caedc..fd631aabfe 100644 --- a/libsql-ffi/build.rs +++ b/libsql-ffi/build.rs @@ -6,7 +6,7 @@ use std::path::{Path, PathBuf}; use std::process::Command; const LIB_NAME: &str = "libsql"; -const BUNDLED_DIR: &str = "/Users/julian/src/github.com/Shopify/libsql/libsql-ffi/bundled"; +const BUNDLED_DIR: &str = "bundled"; const SQLITE_DIR: &str = "../libsql-sqlite3"; fn main() { From e3bd67c27300f08aafd4ab1a35c2743642ce2834 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Mon, 4 Mar 2024 11:49:25 -0800 Subject: [PATCH 09/63] removed accidentally added file --- .../tests__hrana__batch__sample_request.snap.new | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 libsql-server/tests/hrana/snapshots/tests__hrana__batch__sample_request.snap.new diff --git a/libsql-server/tests/hrana/snapshots/tests__hrana__batch__sample_request.snap.new b/libsql-server/tests/hrana/snapshots/tests__hrana__batch__sample_request.snap.new deleted file mode 100644 index 6ebe0bb0ec..0000000000 --- a/libsql-server/tests/hrana/snapshots/tests__hrana__batch__sample_request.snap.new +++ /dev/null @@ -1,8 +0,0 @@ ---- -source: libsql-server/tests/hrana/batch.rs -assertion_line: 32 -expression: resp.json_value().await.unwrap() ---- -{ - "error": "Internal Error: `auth header not found`" -} From f038f391850f1c28aeb5398516559b66d7eb2120 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Mon, 4 Mar 2024 16:51:20 -0800 Subject: [PATCH 10/63] moved str to userauthcontext conversion to from trait --- libsql-server/src/auth/parsers.rs | 29 ++++--------------- .../src/auth/user_auth_strategies/mod.rs | 9 ++++++ libsql-server/src/http/user/extract.rs | 4 +-- libsql-server/src/http/user/mod.rs | 3 +- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 4c0c648d39..9c25480d2e 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -36,33 +36,16 @@ pub fn parse_jwt_key(data: &str) -> Result { } } +// this could be try_from trait pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result { -// todo fix error messages and chaining - let header = metadata.get(GRPC_AUTH_HEADER) - .ok_or(tonic::Status::new(tonic::Code::Unauthenticated,""))?; - - let header_str = header.to_str().context("Auth should be ASCII") - .map_err(|err| tonic::Status::new(tonic::Code::InvalidArgument, ""))?; - - return auth_string_to_auth_context(header_str).context(format!("Failed parse grpc auth: {header_str}")) - .map_err(|err| tonic::Status::new(tonic::Code::InvalidArgument, "")); + return metadata + .get(GRPC_AUTH_HEADER).context("auth header not found") + .and_then(|h| h.to_str().context("non ascii auth token")) + .and_then(|t| t.try_into()) + .map_err(|e| tonic::Status::new(tonic::Code::InvalidArgument,format!("Failed parse grpc auth: {e}"))); } -// todo this should be a constructor or a factory associates iwth userauthcontext -pub fn auth_string_to_auth_context( - auth_string: &str, -) -> Result { - - let(scheme, token) = auth_string.split_once(' ').context("malformed auth header string")?; - - Ok(UserAuthContext{ - scheme: Some(scheme.into()), - token: Some(token.into()), - }) -} - - pub fn parse_http_auth_header<'a>( expected_scheme: &str, auth_header: &'a Option, diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 9fb228c82f..9bf190788d 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -2,6 +2,7 @@ pub mod disabled; pub mod http_basic; pub mod jwt; +use anyhow::Context; pub use disabled::*; pub use http_basic::*; pub use jwt::*; @@ -13,6 +14,14 @@ pub struct UserAuthContext { pub token: Option, // token might not be required in some cases } +impl TryFrom<&str> for UserAuthContext { + type Error = anyhow::Error; + + fn try_from(auth_string: &str) -> Result { + let (scheme, token) = auth_string.split_once(' ').context("malformed auth string`")?; + Ok(UserAuthContext{scheme: Some(scheme.into()), token: Some(token.into())}) + } +} pub trait UserAuthStrategy: Sync + Send { fn authenticate(&self, context: UserAuthContext) -> Result; } diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 4eacbb93d3..5f0f7d7a36 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -2,7 +2,7 @@ use anyhow::Context; use axum::extract::FromRequestParts; use crate::{ - auth::{parsers::auth_string_to_auth_context, Auth, Jwt, UserAuthContext}, + auth::{Auth, Jwt, UserAuthContext}, connection::RequestContext, }; @@ -34,7 +34,7 @@ impl FromRequestParts for RequestContext { let context = parts.headers .get(hyper::header::AUTHORIZATION).context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking .and_then(|h| h.to_str().context("non ascii auth token")) - .and_then(|t| auth_string_to_auth_context(t)) + .and_then(|t| t.try_into()) .unwrap_or(UserAuthContext{scheme: None, token: None}); diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 855e2225af..b5b155704f 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -29,7 +29,6 @@ use tonic::transport::Server; use tower_http::{compression::CompressionLayer, cors}; -use crate::auth::parsers::auth_string_to_auth_context; use crate::auth::{Auth, Authenticated, Jwt, UserAuthContext}; use crate::connection::{Connection, RequestContext}; use crate::error::Error; @@ -471,7 +470,7 @@ impl FromRequestParts for Authenticated { let context = parts.headers .get(hyper::header::AUTHORIZATION).context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking .and_then(|h| h.to_str().context("non ascii auth token")) - .and_then(|t| auth_string_to_auth_context(t)) + .and_then(|t| t.try_into()) .unwrap_or(UserAuthContext{scheme: None, token: None}); From 9bc376b8d1c9070d83d75504b891503a8b56afbc Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Mon, 4 Mar 2024 16:55:00 -0800 Subject: [PATCH 11/63] remove vague comment --- libsql-server/src/http/user/extract.rs | 2 -- libsql-server/src/http/user/mod.rs | 2 -- 2 files changed, 4 deletions(-) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 5f0f7d7a36..0ee4e45a04 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -29,8 +29,6 @@ impl FromRequestParts for RequestContext { .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - // todo think how to decide if the particular auth stragegy even needs a context? - // we need it to understand whether to treat absence of auth header as error or not let context = parts.headers .get(hyper::header::AUTHORIZATION).context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking .and_then(|h| h.to_str().context("non ascii auth token")) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index b5b155704f..276fe047b2 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -465,8 +465,6 @@ impl FromRequestParts for Authenticated { .with(ns.clone(), |ns| ns.jwt_key()) .await??; - // todo think how to decide if the particular auth stragegy even needs a context? - // we need it to understand whether to treat absence of auth header as error or not let context = parts.headers .get(hyper::header::AUTHORIZATION).context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking .and_then(|h| h.to_str().context("non ascii auth token")) From 37c130f81db97a83774845ec616027e16500fdc9 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Tue, 5 Mar 2024 09:07:37 -0800 Subject: [PATCH 12/63] cargo fmt --- libsql-server/src/auth/parsers.rs | 16 +++-- .../auth/user_auth_strategies/http_basic.rs | 5 +- .../src/auth/user_auth_strategies/mod.rs | 9 ++- libsql-server/src/hrana/ws/session.rs | 10 ++- libsql-server/src/http/user/extract.rs | 25 ++++--- libsql-server/src/http/user/mod.rs | 22 +++--- libsql-server/src/rpc/proxy.rs | 1 - libsql-server/src/rpc/replica_proxy.rs | 10 ++- libsql-server/src/rpc/replication_log.rs | 3 +- libsql-server/tests/cluster/mod.rs | 50 +++++++++---- .../tests/cluster/replica_restart.rs | 18 ++++- libsql-server/tests/cluster/replication.rs | 9 ++- libsql-server/tests/embedded_replica/local.rs | 7 +- libsql-server/tests/embedded_replica/mod.rs | 24 +++++-- libsql-server/tests/hrana/batch.rs | 26 +++++-- libsql-server/tests/hrana/transaction.rs | 18 ++++- libsql-server/tests/namespaces/dumps.rs | 35 +++++++--- libsql-server/tests/namespaces/mod.rs | 21 ++++-- libsql-server/tests/standalone/attach.rs | 14 ++-- libsql-server/tests/standalone/mod.rs | 70 +++++++++++++++---- 20 files changed, 280 insertions(+), 113 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 9c25480d2e..08d2af85fc 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -37,13 +37,20 @@ pub fn parse_jwt_key(data: &str) -> Result { } // this could be try_from trait -pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result { - +pub(crate) fn parse_grpc_auth_header( + metadata: &MetadataMap, +) -> Result { return metadata - .get(GRPC_AUTH_HEADER).context("auth header not found") + .get(GRPC_AUTH_HEADER) + .context("auth header not found") .and_then(|h| h.to_str().context("non ascii auth token")) .and_then(|t| t.try_into()) - .map_err(|e| tonic::Status::new(tonic::Code::InvalidArgument,format!("Failed parse grpc auth: {e}"))); + .map_err(|e| { + tonic::Status::new( + tonic::Code::InvalidArgument, + format!("Failed parse grpc auth: {e}"), + ) + }); } pub fn parse_http_auth_header<'a>( @@ -125,4 +132,3 @@ mod tests { assert_eq!(out, Some("always".to_string())); } } - diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index 782740bdb2..3dba4fca3c 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -7,11 +7,9 @@ pub struct HttpBasic { } impl UserAuthStrategy for HttpBasic { - fn authenticate(&self, context: UserAuthContext) -> Result { - tracing::trace!("executing http basic auth"); - + // NOTE: this naive comparison may leak information about the `expected_value` // using a timing attack let expected_value = self.credential.trim_end_matches('='); @@ -20,7 +18,6 @@ impl UserAuthStrategy for HttpBasic { Some(s) => s.contains(expected_value), None => expected_value.is_empty(), }; - if creds_match { return Ok(Authenticated::FullAccess); diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 9bf190788d..1026fd68a6 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -18,8 +18,13 @@ impl TryFrom<&str> for UserAuthContext { type Error = anyhow::Error; fn try_from(auth_string: &str) -> Result { - let (scheme, token) = auth_string.split_once(' ').context("malformed auth string`")?; - Ok(UserAuthContext{scheme: Some(scheme.into()), token: Some(token.into())}) + let (scheme, token) = auth_string + .split_once(' ') + .context("malformed auth string`")?; + Ok(UserAuthContext { + scheme: Some(scheme.into()), + token: Some(token.into()), + }) } } pub trait UserAuthStrategy: Sync + Send { diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index a160221bf7..3192f4fb14 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -80,7 +80,10 @@ pub(super) async fn handle_initial_hello( .map(Jwt::new) .map(Auth::new) .unwrap_or(server.user_auth_strategy.clone()) - .authenticate(UserAuthContext { scheme: Some("Bearer".into()), token: jwt }) + .authenticate(UserAuthContext { + scheme: Some("Bearer".into()), + token: jwt, + }) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(Session { @@ -113,7 +116,10 @@ pub(super) async fn handle_repeated_hello( .map(Jwt::new) .map(Auth::new) .unwrap_or_else(|| server.user_auth_strategy.clone()) - .authenticate(UserAuthContext { scheme: Some("Bearer".into()), token: jwt }) + .authenticate(UserAuthContext { + scheme: Some("Bearer".into()), + token: jwt, + }) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(()) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 0ee4e45a04..350239c839 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -16,7 +16,6 @@ impl FromRequestParts for RequestContext { parts: &mut axum::http::request::Parts, state: &AppState, ) -> std::result::Result { - // start todo this block is same as the one in mod.rs let namespace = db_factory::namespace_from_headers( &parts.headers, @@ -29,18 +28,22 @@ impl FromRequestParts for RequestContext { .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - let context = parts.headers - .get(hyper::header::AUTHORIZATION).context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking - .and_then(|h| h.to_str().context("non ascii auth token")) - .and_then(|t| t.try_into()) - .unwrap_or(UserAuthContext{scheme: None, token: None}); - + let context = parts + .headers + .get(hyper::header::AUTHORIZATION) + .context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking + .and_then(|h| h.to_str().context("non ascii auth token")) + .and_then(|t| t.try_into()) + .unwrap_or(UserAuthContext { + scheme: None, + token: None, + }); let authenticated = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; + .map(Jwt::new) + .map(Auth::new) + .unwrap_or_else(|| state.user_auth_strategy.clone()) + .authenticate(context)?; Ok(Self::new( authenticated, diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 276fe047b2..fa39385ea7 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -459,24 +459,28 @@ impl FromRequestParts for Authenticated { state.disable_default_namespace, state.disable_namespaces, )?; -// todo get rid of duplication - this and replication_log.rs + // todo get rid of duplication - this and replication_log.rs let namespace_jwt_key = state .namespaces .with(ns.clone(), |ns| ns.jwt_key()) .await??; - let context = parts.headers - .get(hyper::header::AUTHORIZATION).context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking + let context = parts + .headers + .get(hyper::header::AUTHORIZATION) + .context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking .and_then(|h| h.to_str().context("non ascii auth token")) .and_then(|t| t.try_into()) - .unwrap_or(UserAuthContext{scheme: None, token: None}); - + .unwrap_or(UserAuthContext { + scheme: None, + token: None, + }); let authenticated = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; + .map(Jwt::new) + .map(Auth::new) + .unwrap_or_else(|| state.user_auth_strategy.clone()) + .authenticate(context)?; Ok(authenticated) } } diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index a1b0079894..7a1a30f603 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -334,7 +334,6 @@ impl ProxyService { )))?, }; - let auth = if let Some(auth) = auth { let context = parse_grpc_auth_header(req.metadata())?; auth.authenticate(context)? diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index e8a879f490..61ac11dc0f 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -6,7 +6,6 @@ use libsql_replication::rpc::proxy::{ use tokio_stream::StreamExt; use tonic::{transport::Channel, Request, Status}; - use crate::auth::parsers::parse_grpc_auth_header; use crate::auth::{Auth, Jwt, UserAuthStrategy}; use crate::namespace::NamespaceStore; @@ -43,7 +42,7 @@ impl ReplicaProxyService { .with(namespace.clone(), |ns| ns.jwt_key()) .await; - //todo julian figure this out + //todo julian figure this out let auth_context = parse_grpc_auth_header(req.metadata())?; match namespace_jwt_key { @@ -52,16 +51,15 @@ impl ReplicaProxyService { authenticated.upgrade_grpc_request(req); Ok(()) } - Ok(Ok(None)) => { // non jwt auth, we don't know if context matches it + Ok(Ok(None)) => { + // non jwt auth, we don't know if context matches it let authenticated = self.user_auth_strategy.authenticate(auth_context)?; authenticated.upgrade_grpc_request(req); Ok(()) } Err(e) => match e.as_ref() { crate::error::Error::NamespaceDoesntExist(_) => { - let authenticated = self - .user_auth_strategy - .authenticate(auth_context)?; + let authenticated = self.user_auth_strategy.authenticate(auth_context)?; authenticated.upgrade_grpc_request(req); Ok(()) diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index a7d7fc51b2..a4ed521d50 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -71,7 +71,6 @@ impl ReplicationLogService { req: &tonic::Request, namespace: NamespaceName, ) -> Result<(), Status> { - // todo duplicate code let namespace_jwt_key = self .namespaces @@ -93,7 +92,7 @@ impl ReplicationLogService { e )))?, }; - + if let Some(auth) = auth { let user_credential = parse_grpc_auth_header(req.metadata())?; auth.authenticate(user_credential)?; diff --git a/libsql-server/tests/cluster/mod.rs b/libsql-server/tests/cluster/mod.rs index d5569d18f0..e79bac7f67 100644 --- a/libsql-server/tests/cluster/mod.rs +++ b/libsql-server/tests/cluster/mod.rs @@ -88,15 +88,22 @@ fn proxy_write() { make_cluster(&mut sim, 1, true); sim.client("client", async { - let db = - Database::open_remote_with_connector("http://replica0:8080", "dummy-auth", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://replica0:8080", + "dummy-auth", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; conn.execute("insert into test values (12)", ()).await?; // assert that the primary got the write - let db = Database::open_remote_with_connector("http://primary:8080", "dummy-auth", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-auth", + TurmoilConnector, + )?; let conn = db.connect()?; let mut rows = conn.query("select count(*) from test", ()).await?; @@ -120,8 +127,11 @@ fn replica_read_write() { make_cluster(&mut sim, 1, true); sim.client("client", async { - let db = - Database::open_remote_with_connector("http://replica0:8080", "dummy-auth", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://replica0:8080", + "dummy-auth", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -145,7 +155,11 @@ fn sync_many_replica() { let mut sim = Builder::new().build(); make_cluster(&mut sim, NUM_REPLICA, true); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -212,15 +226,17 @@ fn sync_many_replica() { sim.run().unwrap(); } - #[test] fn create_namespace() { let mut sim = Builder::new().build(); make_cluster(&mut sim, 0, false); sim.client("client", async { - let db = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy-auth", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy-auth", + TurmoilConnector, + )?; let conn = db.connect()?; let Err(e) = conn.execute("create table test (x)", ()).await else { @@ -260,8 +276,12 @@ fn large_proxy_query() { make_cluster(&mut sim, 1, true); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy-auth", TurmoilConnector) - .unwrap(); + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-auth", + TurmoilConnector, + ) + .unwrap(); let conn = db.connect().unwrap(); conn.execute("create table test (x)", ()).await.unwrap(); @@ -271,8 +291,12 @@ fn large_proxy_query() { .unwrap(); } - let db = Database::open_remote_with_connector("http://replica0:8080", "dummy_token", TurmoilConnector) - .unwrap(); + let db = Database::open_remote_with_connector( + "http://replica0:8080", + "dummy_token", + TurmoilConnector, + ) + .unwrap(); let conn = db.connect().unwrap(); conn.execute_batch("begin immediate; select * from test limit (4000)") diff --git a/libsql-server/tests/cluster/replica_restart.rs b/libsql-server/tests/cluster/replica_restart.rs index 9665233385..2915309056 100644 --- a/libsql-server/tests/cluster/replica_restart.rs +++ b/libsql-server/tests/cluster/replica_restart.rs @@ -94,7 +94,11 @@ fn replica_restart() { sim.client("client", async move { let http = Client::new(); - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; // insert a few valued into the primary @@ -257,7 +261,11 @@ fn primary_regenerate_log_no_replica_restart() { sim.client("client", async move { let http = Client::new(); - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; // insert a few valued into the primary @@ -460,7 +468,11 @@ fn primary_regenerate_log_with_replica_restart() { sim.client("client", async move { let http = Client::new(); - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; // insert a few valued into the primary diff --git a/libsql-server/tests/cluster/replication.rs b/libsql-server/tests/cluster/replication.rs index be05fa3468..c3c330b185 100644 --- a/libsql-server/tests/cluster/replication.rs +++ b/libsql-server/tests/cluster/replication.rs @@ -213,9 +213,12 @@ fn replica_lazy_creation() { }); sim.client("client", async move { - let db = - Database::open_remote_with_connector("http://test.replica:8080", "dummy_token", TurmoilConnector) - .unwrap(); + let db = Database::open_remote_with_connector( + "http://test.replica:8080", + "dummy_token", + TurmoilConnector, + ) + .unwrap(); let conn = db.connect().unwrap(); assert_debug_snapshot!(conn.execute("create table test (x)", ()).await.unwrap_err()); let primary_http = Client::new(); diff --git a/libsql-server/tests/embedded_replica/local.rs b/libsql-server/tests/embedded_replica/local.rs index d98a237079..28e51a2b38 100644 --- a/libsql-server/tests/embedded_replica/local.rs +++ b/libsql-server/tests/embedded_replica/local.rs @@ -33,8 +33,11 @@ fn local_sync_with_writes() { let _path = tmp_embedded_path.join("embedded"); - let primary = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy_token", TurmoilConnector)?; + let primary = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = primary.connect()?; // Do enough writes to ensure that we can force the server to write some snapshots diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index 026ba03e71..d0d35ec594 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -360,8 +360,11 @@ fn replica_primary_reset() { }); sim.client("client", async move { - let primary = - Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let primary = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = primary.connect()?; // insert a few valued into the primary @@ -521,9 +524,12 @@ fn replica_no_resync_on_restart() { sim.client("client", async { // seed database { - let db = - Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector) - .unwrap(); + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + ) + .unwrap(); let conn = db.connect().unwrap(); conn.execute("create table test (x)", ()).await.unwrap(); for _ in 0..500 { @@ -624,8 +630,12 @@ fn replicate_with_snapshots() { .post("http://primary:9090/v1/namespaces/foo/create", json!({})) .await?; - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector) - .unwrap(); + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + ) + .unwrap(); let conn = db.connect().unwrap(); conn.execute("create table test (x)", ()).await.unwrap(); // insert enough to trigger snapshot creation. diff --git a/libsql-server/tests/hrana/batch.rs b/libsql-server/tests/hrana/batch.rs index ec449463fe..47c4162592 100644 --- a/libsql-server/tests/hrana/batch.rs +++ b/libsql-server/tests/hrana/batch.rs @@ -5,7 +5,7 @@ use libsql_server::hrana_proto::{Batch, BatchStep, Stmt}; use crate::common::http::Client; use crate::common::net::TurmoilConnector; -#[test] +#[test] fn sample_request() { let mut sim = turmoil::Builder::new().build(); sim.host("primary", super::make_standalone_server); @@ -42,7 +42,11 @@ fn execute_individual_statements() { let mut sim = turmoil::Builder::new().build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table t(x text)", ()).await?; @@ -68,7 +72,11 @@ fn execute_batch() { let mut sim = turmoil::Builder::new().build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute_batch( @@ -102,7 +110,11 @@ fn multistatement_query() { let mut sim = turmoil::Builder::new().build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; let mut rows = conn .query("select 1 + ?; select 'abc';", params![1]) @@ -123,7 +135,11 @@ fn affected_rows_and_last_rowid() { let mut sim = turmoil::Builder::new().build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute( diff --git a/libsql-server/tests/hrana/transaction.rs b/libsql-server/tests/hrana/transaction.rs index ab3018fe2c..1c7a0fbe90 100644 --- a/libsql-server/tests/hrana/transaction.rs +++ b/libsql-server/tests/hrana/transaction.rs @@ -8,7 +8,11 @@ fn transaction_commit_and_rollback() { let mut sim = turmoil::Builder::new().build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; // initialize tables @@ -50,7 +54,11 @@ fn multiple_concurrent_transactions() { let mut sim = turmoil::Builder::new().build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute_batch(r#"create table t(x text);"#).await?; @@ -108,7 +116,11 @@ fn transaction_timeout() { .build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy_token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy_token", + TurmoilConnector, + )?; let conn = db.connect()?; // initialize tables diff --git a/libsql-server/tests/namespaces/dumps.rs b/libsql-server/tests/namespaces/dumps.rs index 899b00edbd..8a1d552288 100644 --- a/libsql-server/tests/namespaces/dumps.rs +++ b/libsql-server/tests/namespaces/dumps.rs @@ -51,8 +51,11 @@ fn load_namespace_from_dump_from_url() { assert_eq!(resp.status(), 200); assert_snapshot!(resp.body_string().await.unwrap()); - let foo = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy_token", TurmoilConnector)?; + let foo = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let foo_conn = foo.connect()?; let mut rows = foo_conn.query("select count(*) from test", ()).await?; assert!(matches!( @@ -120,8 +123,11 @@ fn load_namespace_from_dump_from_file() { resp.json::().await.unwrap_or_default() ); - let foo = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy_token", TurmoilConnector)?; + let foo = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let foo_conn = foo.connect()?; let mut rows = foo_conn.query("select count(*) from test", ()).await?; assert!(matches!( @@ -170,8 +176,11 @@ fn load_namespace_from_no_commit() { ); // namespace doesn't exist - let foo = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy_token", TurmoilConnector)?; + let foo = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let foo_conn = foo.connect()?; assert!(foo_conn .query("select count(*) from test", ()) @@ -219,8 +228,11 @@ fn load_namespace_from_no_txn() { assert_json_snapshot!(resp.json_value().await.unwrap()); // namespace doesn't exist - let foo = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy_token", TurmoilConnector)?; + let foo = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let foo_conn = foo.connect()?; assert!(foo_conn .query("select count(*) from test", ()) @@ -247,8 +259,11 @@ fn export_dump() { .await?; assert_eq!(resp.status(), StatusCode::OK); - let foo = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy_token", TurmoilConnector)?; + let foo = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let foo_conn = foo.connect()?; foo_conn.execute("create table test (x)", ()).await?; foo_conn.execute("insert into test values (42)", ()).await?; diff --git a/libsql-server/tests/namespaces/mod.rs b/libsql-server/tests/namespaces/mod.rs index b6cc4557dc..ebb9324b13 100644 --- a/libsql-server/tests/namespaces/mod.rs +++ b/libsql-server/tests/namespaces/mod.rs @@ -57,8 +57,11 @@ fn fork_namespace() { .post("http://primary:9090/v1/namespaces/foo/create", json!({})) .await?; - let foo = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy_token", TurmoilConnector)?; + let foo = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let foo_conn = foo.connect()?; foo_conn.execute("create table test (c)", ()).await?; @@ -68,8 +71,11 @@ fn fork_namespace() { .post("http://primary:9090/v1/namespaces/foo/fork/bar", ()) .await?; - let bar = - Database::open_remote_with_connector("http://bar.primary:8080", "dummy_token", TurmoilConnector)?; + let bar = Database::open_remote_with_connector( + "http://bar.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let bar_conn = bar.connect()?; // what's in foo is in bar as well @@ -113,8 +119,11 @@ fn delete_namespace() { .post("http://primary:9090/v1/namespaces/foo/create", json!({})) .await?; - let foo = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy_token", TurmoilConnector)?; + let foo = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let foo_conn = foo.connect()?; foo_conn.execute("create table test (c)", ()).await?; diff --git a/libsql-server/tests/standalone/attach.rs b/libsql-server/tests/standalone/attach.rs index 1afa79f937..2be84124e1 100644 --- a/libsql-server/tests/standalone/attach.rs +++ b/libsql-server/tests/standalone/attach.rs @@ -33,8 +33,11 @@ fn attach_no_auth() { .await .unwrap(); - let foo_db = - Database::open_remote_with_connector("http://foo.primary:8080", "dummy_token", TurmoilConnector)?; + let foo_db = Database::open_remote_with_connector( + "http://foo.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let foo_conn = foo_db.connect().unwrap(); foo_conn .execute("CREATE TABLE foo_table (x)", ()) @@ -45,8 +48,11 @@ fn attach_no_auth() { .await .unwrap(); - let bar_db = - Database::open_remote_with_connector("http://bar.primary:8080", "dummy_token", TurmoilConnector)?; + let bar_db = Database::open_remote_with_connector( + "http://bar.primary:8080", + "dummy_token", + TurmoilConnector, + )?; let bar_conn = bar_db.connect().unwrap(); bar_conn .execute("CREATE TABLE bar_table (x)", ()) diff --git a/libsql-server/tests/standalone/mod.rs b/libsql-server/tests/standalone/mod.rs index 0dae6f62c1..a7bc9e0ef7 100644 --- a/libsql-server/tests/standalone/mod.rs +++ b/libsql-server/tests/standalone/mod.rs @@ -51,7 +51,11 @@ fn basic_query() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -77,7 +81,11 @@ fn basic_metrics() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -122,8 +130,11 @@ fn primary_serializability() { sim.client("writer", { let notify = notify.clone(); async move { - let db = - Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; conn.execute("insert into test values (12)", ()).await?; @@ -136,8 +147,11 @@ fn primary_serializability() { sim.client("reader", { async move { - let db = - Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; notify.notified().await; @@ -167,8 +181,11 @@ fn execute_transaction() { sim.client("writer", { let notify = notify.clone(); async move { - let db = - Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -193,8 +210,11 @@ fn execute_transaction() { sim.client("reader", { async move { - let db = - Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; notify.notified().await; @@ -234,7 +254,11 @@ fn basic_query_fail() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -260,7 +284,11 @@ fn begin_commit() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -296,7 +324,11 @@ fn begin_rollback() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -333,7 +365,11 @@ fn is_autocommit() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; assert!(conn.is_autocommit().await); @@ -376,7 +412,11 @@ fn random_rowid() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector("http://primary:8080", "dummy-token", TurmoilConnector)?; + let db = Database::open_remote_with_connector( + "http://primary:8080", + "dummy-token", + TurmoilConnector, + )?; let conn = db.connect()?; conn.execute( From b350ce52102ed6e53ba01a991d26ee9ddd914ca1 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Tue, 5 Mar 2024 10:22:51 -0800 Subject: [PATCH 13/63] cleaned up mod reimportss --- libsql-server/src/auth/parsers.rs | 1 - libsql-server/src/auth/user_auth_strategies/mod.rs | 9 +++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 08d2af85fc..c7ff54b9c6 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -36,7 +36,6 @@ pub fn parse_jwt_key(data: &str) -> Result { } } -// this could be try_from trait pub(crate) fn parse_grpc_auth_header( metadata: &MetadataMap, ) -> Result { diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 1026fd68a6..25ec0ed692 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -2,16 +2,17 @@ pub mod disabled; pub mod http_basic; pub mod jwt; +pub use disabled::Disabled; +pub use http_basic::HttpBasic; +pub use jwt::Jwt; + use anyhow::Context; -pub use disabled::*; -pub use http_basic::*; -pub use jwt::*; use super::{AuthError, Authenticated}; pub struct UserAuthContext { pub scheme: Option, - pub token: Option, // token might not be required in some cases + pub token: Option, } impl TryFrom<&str> for UserAuthContext { From d3ccabc8be9bc71a2d0ce8c8974b36b82fc1063e Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 6 Mar 2024 10:18:45 -0800 Subject: [PATCH 14/63] refactored cryptic matching in replica_proxy --- libsql-server/src/http/user/extract.rs | 1 - libsql-server/src/rpc/replica_proxy.rs | 46 ++++++++------------------ 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 350239c839..29421b5af4 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -16,7 +16,6 @@ impl FromRequestParts for RequestContext { parts: &mut axum::http::request::Parts, state: &AppState, ) -> std::result::Result { - // start todo this block is same as the one in mod.rs let namespace = db_factory::namespace_from_headers( &parts.headers, state.disable_default_namespace, diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 61ac11dc0f..350cf699c2 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -7,7 +7,7 @@ use tokio_stream::StreamExt; use tonic::{transport::Channel, Request, Status}; use crate::auth::parsers::parse_grpc_auth_header; -use crate::auth::{Auth, Jwt, UserAuthStrategy}; +use crate::auth::{Auth, Jwt}; use crate::namespace::NamespaceStore; pub struct ReplicaProxyService { @@ -37,43 +37,25 @@ impl ReplicaProxyService { async fn do_auth(&self, req: &mut Request) -> Result<(), Status> { let namespace = super::extract_namespace(self.disable_namespaces, req)?; - let namespace_jwt_key = self + let jwt_result = self .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) .await; - //todo julian figure this out - let auth_context = parse_grpc_auth_header(req.metadata())?; + let namespace_jwt_key = jwt_result + .and_then(|s|s); - match namespace_jwt_key { - Ok(Ok(Some(key))) => { - let authenticated = Jwt::new(key).authenticate(auth_context)?; - authenticated.upgrade_grpc_request(req); - Ok(()) - } - Ok(Ok(None)) => { - // non jwt auth, we don't know if context matches it - let authenticated = self.user_auth_strategy.authenticate(auth_context)?; - authenticated.upgrade_grpc_request(req); - Ok(()) - } - Err(e) => match e.as_ref() { - crate::error::Error::NamespaceDoesntExist(_) => { - let authenticated = self.user_auth_strategy.authenticate(auth_context)?; - authenticated.upgrade_grpc_request(req); - Ok(()) - } - _ => Err(Status::internal(format!( - "Error fetching jwt key for a namespace: {}", - e - ))), - }, - Ok(Err(e)) => Err(Status::internal(format!( - "Error fetching jwt key for a namespace: {}", - e - ))), - } + let auth_strategy = match namespace_jwt_key { + Ok(Some(key)) => Ok(Auth::new(Jwt::new(key))), + Ok(None) | Err(crate::error::Error::NamespaceDoesntExist(_)) => Ok(self.user_auth_strategy.clone()), + Err(e) => Err(Status::internal(format!("Can't fetch jwt key for a namespace: {}", e))), + }?; + + let auth_context = parse_grpc_auth_header(req.metadata())?; + auth_strategy.authenticate(auth_context)? + .upgrade_grpc_request(req); + return Ok(()); } } From 9b353cdc111b252e718e8c13c490b3ec90ae314b Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 6 Mar 2024 10:23:52 -0800 Subject: [PATCH 15/63] marked potentially duplicate code with // todo dupe #auth --- libsql-server/src/hrana/ws/session.rs | 2 ++ libsql-server/src/http/user/extract.rs | 2 +- libsql-server/src/rpc/proxy.rs | 2 +- libsql-server/src/rpc/replica_proxy.rs | 1 + libsql-server/src/rpc/replication_log.rs | 2 +- 5 files changed, 6 insertions(+), 3 deletions(-) diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index 3192f4fb14..c1bfbdafbd 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -71,6 +71,7 @@ pub(super) async fn handle_initial_hello( jwt: Option, namespace: NamespaceName, ) -> Result { + // todo dupe #auth let namespace_jwt_key = server .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) @@ -107,6 +108,7 @@ pub(super) async fn handle_repeated_hello( min_version: Version::Hrana2, }) } + // todo dupe #auth let namespace_jwt_key = server .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 29421b5af4..b450193558 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -21,7 +21,7 @@ impl FromRequestParts for RequestContext { state.disable_default_namespace, state.disable_namespaces, )?; - + // todo dupe #auth let namespace_jwt_key = state .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index 7a1a30f603..be198d6282 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -312,7 +312,7 @@ impl ProxyService { req: &mut tonic::Request, ) -> Result { let namespace = super::extract_namespace(self.disable_namespaces, req)?; - + // todo dupe #auth let namespace_jwt_key = self .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 350cf699c2..646a71daa6 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -37,6 +37,7 @@ impl ReplicaProxyService { async fn do_auth(&self, req: &mut Request) -> Result<(), Status> { let namespace = super::extract_namespace(self.disable_namespaces, req)?; + // todo dupe #auth let jwt_result = self .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index a4ed521d50..0cee882ffb 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -71,7 +71,7 @@ impl ReplicationLogService { req: &tonic::Request, namespace: NamespaceName, ) -> Result<(), Status> { - // todo duplicate code + // todo dupe #auth let namespace_jwt_key = self .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) From e30d5a3afec08183265910be3db8ebf800b38b1f Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 6 Mar 2024 11:57:33 -0800 Subject: [PATCH 16/63] refactored context to custome errors --- libsql-server/src/auth/errors.rs | 9 +++++++++ libsql-server/src/auth/parsers.rs | 7 +++---- libsql-server/src/auth/user_auth_strategies/mod.rs | 8 +++----- libsql-server/src/http/user/extract.rs | 9 ++++----- libsql-server/src/http/user/mod.rs | 9 ++++----- libsql-server/tests/tests.rs | 2 +- 6 files changed, 24 insertions(+), 20 deletions(-) diff --git a/libsql-server/src/auth/errors.rs b/libsql-server/src/auth/errors.rs index 3e8a414960..5275153315 100644 --- a/libsql-server/src/auth/errors.rs +++ b/libsql-server/src/auth/errors.rs @@ -22,6 +22,12 @@ pub enum AuthError { JwtExpired, #[error("The JWT is immature (not valid yet)")] JwtImmature, + #[error("Auth string does not conform to ' ' form")] + AuthStringMalformed, + #[error("Expected authorization header but none given")] + AuthHeaderNotFound, + #[error("Non-ASCII auth header")] + AuthHeaderNonAscii, #[error("Authentication failed")] Other, } @@ -39,6 +45,9 @@ impl AuthError { Self::JwtInvalid => "AUTH_JWT_INVALID", Self::JwtExpired => "AUTH_JWT_EXPIRED", Self::JwtImmature => "AUTH_JWT_IMMATURE", + Self::AuthStringMalformed => "AUTH_HEADER_MALFORMED", + Self::AuthHeaderNotFound => "AUTH_HEADER_NOT_FOUND", + Self::AuthHeaderNonAscii => "AUTH_HEADER_MALFORMED", Self::Other => "AUTH_FAILED", } } diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index c7ff54b9c6..bd126b0041 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -40,10 +40,9 @@ pub(crate) fn parse_grpc_auth_header( metadata: &MetadataMap, ) -> Result { return metadata - .get(GRPC_AUTH_HEADER) - .context("auth header not found") - .and_then(|h| h.to_str().context("non ascii auth token")) - .and_then(|t| t.try_into()) + .get(GRPC_AUTH_HEADER).ok_or(AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) + .and_then(|t| t.try_into()) .map_err(|e| { tonic::Status::new( tonic::Code::InvalidArgument, diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 25ec0ed692..5c980373f2 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -6,8 +6,6 @@ pub use disabled::Disabled; pub use http_basic::HttpBasic; pub use jwt::Jwt; -use anyhow::Context; - use super::{AuthError, Authenticated}; pub struct UserAuthContext { @@ -16,12 +14,12 @@ pub struct UserAuthContext { } impl TryFrom<&str> for UserAuthContext { - type Error = anyhow::Error; + type Error = AuthError; - fn try_from(auth_string: &str) -> Result { + fn try_from(auth_string: &str) -> Result { let (scheme, token) = auth_string .split_once(' ') - .context("malformed auth string`")?; + .ok_or(AuthError::AuthStringMalformed)?; Ok(UserAuthContext { scheme: Some(scheme.into()), token: Some(token.into()), diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index b450193558..1de9dd20a1 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -2,7 +2,7 @@ use anyhow::Context; use axum::extract::FromRequestParts; use crate::{ - auth::{Auth, Jwt, UserAuthContext}, + auth::{Auth, AuthError, Jwt, UserAuthContext}, connection::RequestContext, }; @@ -29,10 +29,9 @@ impl FromRequestParts for RequestContext { let context = parts .headers - .get(hyper::header::AUTHORIZATION) - .context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking - .and_then(|h| h.to_str().context("non ascii auth token")) - .and_then(|t| t.try_into()) + .get(hyper::header::AUTHORIZATION).ok_or(AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) + .and_then(|t| t.try_into()) .unwrap_or(UserAuthContext { scheme: None, token: None, diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index fa39385ea7..cab525e968 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -29,7 +29,7 @@ use tonic::transport::Server; use tower_http::{compression::CompressionLayer, cors}; -use crate::auth::{Auth, Authenticated, Jwt, UserAuthContext}; +use crate::auth::{Auth, AuthError, Authenticated, Jwt, UserAuthContext}; use crate::connection::{Connection, RequestContext}; use crate::error::Error; use crate::hrana; @@ -467,10 +467,9 @@ impl FromRequestParts for Authenticated { let context = parts .headers - .get(hyper::header::AUTHORIZATION) - .context("auth header not found") // todo this context is swallowed for now, gotta fix that but not with panicking - .and_then(|h| h.to_str().context("non ascii auth token")) - .and_then(|t| t.try_into()) + .get(hyper::header::AUTHORIZATION).ok_or(AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) + .and_then(|t| t.try_into()) .unwrap_or(UserAuthContext { scheme: None, token: None, diff --git a/libsql-server/tests/tests.rs b/libsql-server/tests/tests.rs index 497814660c..ac4c514b53 100644 --- a/libsql-server/tests/tests.rs +++ b/libsql-server/tests/tests.rs @@ -7,4 +7,4 @@ mod cluster; mod embedded_replica; mod hrana; mod namespaces; -mod standalone; +mod standalone; \ No newline at end of file From cadf4273cfe4fec40a8d452d43631ce737278b90 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 6 Mar 2024 12:18:23 -0800 Subject: [PATCH 17/63] added a factory to produce empty UserAuthContext --- libsql-server/src/auth/user_auth_strategies/disabled.rs | 5 +---- libsql-server/src/auth/user_auth_strategies/mod.rs | 6 ++++++ libsql-server/src/http/user/extract.rs | 6 +----- libsql-server/src/http/user/mod.rs | 7 ++----- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/libsql-server/src/auth/user_auth_strategies/disabled.rs b/libsql-server/src/auth/user_auth_strategies/disabled.rs index 4f2c12894a..ef9aae9062 100644 --- a/libsql-server/src/auth/user_auth_strategies/disabled.rs +++ b/libsql-server/src/auth/user_auth_strategies/disabled.rs @@ -23,10 +23,7 @@ mod tests { #[test] fn authenticates() { let strategy = Disabled::new(); - let context = UserAuthContext { - scheme: None, - token: None, - }; + let context = UserAuthContext::empty(); assert!(matches!( strategy.authenticate(context).unwrap(), diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 5c980373f2..8661068dcb 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -13,6 +13,12 @@ pub struct UserAuthContext { pub token: Option, } +impl UserAuthContext { + pub fn empty() -> UserAuthContext{ + UserAuthContext{scheme: None, token: None} + } +} + impl TryFrom<&str> for UserAuthContext { type Error = AuthError; diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 1de9dd20a1..3d78834735 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -1,4 +1,3 @@ -use anyhow::Context; use axum::extract::FromRequestParts; use crate::{ @@ -32,10 +31,7 @@ impl FromRequestParts for RequestContext { .get(hyper::header::AUTHORIZATION).ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) .and_then(|t| t.try_into()) - .unwrap_or(UserAuthContext { - scheme: None, - token: None, - }); + .unwrap_or(UserAuthContext::empty()); let authenticated = namespace_jwt_key .map(Jwt::new) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index cab525e968..0c9d63a425 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -459,7 +459,7 @@ impl FromRequestParts for Authenticated { state.disable_default_namespace, state.disable_namespaces, )?; - // todo get rid of duplication - this and replication_log.rs + // todo dupe #auth let namespace_jwt_key = state .namespaces .with(ns.clone(), |ns| ns.jwt_key()) @@ -470,10 +470,7 @@ impl FromRequestParts for Authenticated { .get(hyper::header::AUTHORIZATION).ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) .and_then(|t| t.try_into()) - .unwrap_or(UserAuthContext { - scheme: None, - token: None, - }); + .unwrap_or(UserAuthContext::empty()); let authenticated = namespace_jwt_key .map(Jwt::new) From 5077466a1003ae29f82018ab4322dc6590300178 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 6 Mar 2024 13:23:05 -0800 Subject: [PATCH 18/63] added constructors for UserAuthContext --- .../auth/user_auth_strategies/http_basic.rs | 16 ++-------- .../src/auth/user_auth_strategies/jwt.rs | 25 +++------------ .../src/auth/user_auth_strategies/mod.rs | 31 ++++++++++++++----- libsql-server/src/hrana/ws/session.rs | 10 ++---- 4 files changed, 34 insertions(+), 48 deletions(-) diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index 3dba4fca3c..3e62d912bb 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -45,10 +45,7 @@ mod tests { #[test] fn authenticates_with_valid_credential() { - let context = UserAuthContext { - scheme: Some("basic".into()), - token: Some(CREDENTIAL.into()), - }; + let context = UserAuthContext::basic(CREDENTIAL); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -59,11 +56,7 @@ mod tests { #[test] fn authenticates_with_valid_trimmed_credential() { let credential = CREDENTIAL.trim_end_matches('='); - - let context = UserAuthContext { - scheme: Some("basic".into()), - token: Some(credential.into()), - }; + let context = UserAuthContext::basic(credential); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -73,10 +66,7 @@ mod tests { #[test] fn errors_when_credentials_do_not_match() { - let context = UserAuthContext { - token: Some("abc".into()), - scheme: Some("basic".into()), - }; + let context = UserAuthContext::basic("abc"); assert_eq!( strategy().authenticate(context).unwrap_err(), diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index 5a03f68d44..cec519ba69 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -145,10 +145,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = UserAuthContext { - scheme: Some("bearer".into()), - token: token.into(), - }; + let context = UserAuthContext::bearer(token.as_str()); assert!(matches!( strategy(dec).authenticate(context).unwrap(), @@ -167,10 +164,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = UserAuthContext { - scheme: Some("bearer".into()), - token: token.into(), - }; + let context = UserAuthContext::bearer(token.as_str()); let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { panic!() @@ -187,10 +181,7 @@ mod tests { #[test] fn errors_when_jwt_token_invalid() { let (_enc, dec) = key_pair(); - let context = UserAuthContext { - scheme: Some("bearer".into()), - token: Some("abc".into()), - }; + let context = UserAuthContext::bearer("abc"); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -210,10 +201,7 @@ mod tests { let token = encode(&token, &enc); - let context = UserAuthContext { - scheme: Some("bearer".into()), - token: Some(token.into()), - }; + let context = UserAuthContext::bearer(token.as_str()); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -235,10 +223,7 @@ mod tests { let token = encode(&token, &enc); - let context = UserAuthContext { - scheme: Some("bearer".into()), - token: token.into(), - }; + let context = UserAuthContext::bearer(token.as_str()); let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { panic!() diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 8661068dcb..3a5d06fb60 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -17,19 +17,36 @@ impl UserAuthContext { pub fn empty() -> UserAuthContext{ UserAuthContext{scheme: None, token: None} } + + pub fn basic(creds: &str) -> UserAuthContext { + UserAuthContext{scheme: Some("Basic".into()), token: Some(creds.into())} + } + + pub fn bearer(token: &str) -> UserAuthContext { + UserAuthContext{scheme: Some("Bearer".into()), token: Some(token.into())} + } + + pub fn bearer_opt(token: Option) -> UserAuthContext { + UserAuthContext{scheme: Some("Bearer".into()), token: token} + } + + fn new(scheme: &str, token: &str) -> UserAuthContext { + UserAuthContext{scheme: Some(scheme.into()), token: Some(token.into())} + } + + fn from_auth_str(auth_string: &str) -> Result { + let (scheme, token) = auth_string + .split_once(' ') + .ok_or(AuthError::AuthStringMalformed)?; + Ok(UserAuthContext::new(scheme, token)) + } } impl TryFrom<&str> for UserAuthContext { type Error = AuthError; fn try_from(auth_string: &str) -> Result { - let (scheme, token) = auth_string - .split_once(' ') - .ok_or(AuthError::AuthStringMalformed)?; - Ok(UserAuthContext { - scheme: Some(scheme.into()), - token: Some(token.into()), - }) + UserAuthContext::from_auth_str(auth_string) } } pub trait UserAuthStrategy: Sync + Send { diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index c1bfbdafbd..eaa163286f 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -81,10 +81,7 @@ pub(super) async fn handle_initial_hello( .map(Jwt::new) .map(Auth::new) .unwrap_or(server.user_auth_strategy.clone()) - .authenticate(UserAuthContext { - scheme: Some("Bearer".into()), - token: jwt, - }) + .authenticate(UserAuthContext::bearer_opt(jwt)) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(Session { @@ -118,10 +115,7 @@ pub(super) async fn handle_repeated_hello( .map(Jwt::new) .map(Auth::new) .unwrap_or_else(|| server.user_auth_strategy.clone()) - .authenticate(UserAuthContext { - scheme: Some("Bearer".into()), - token: jwt, - }) + .authenticate(UserAuthContext::bearer_opt(jwt)) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(()) From 071dcf37693756b669177f9e7b04a530b79da809 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 6 Mar 2024 13:38:30 -0800 Subject: [PATCH 19/63] switched from try_into to using constructors --- libsql-server/src/auth/parsers.rs | 2 +- .../src/auth/user_auth_strategies/mod.rs | 23 ++++++++++--------- libsql-server/src/http/user/extract.rs | 2 +- libsql-server/src/http/user/mod.rs | 2 +- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index bd126b0041..a56a342443 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -42,7 +42,7 @@ pub(crate) fn parse_grpc_auth_header( return metadata .get(GRPC_AUTH_HEADER).ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) - .and_then(|t| t.try_into()) + .and_then(|t| UserAuthContext::from_auth_str(t)) .map_err(|e| { tonic::Status::new( tonic::Code::InvalidArgument, diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 3a5d06fb60..02eddde56f 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -9,11 +9,19 @@ pub use jwt::Jwt; use super::{AuthError, Authenticated}; pub struct UserAuthContext { - pub scheme: Option, - pub token: Option, + scheme: Option, + token: Option, } impl UserAuthContext { + pub fn scheme(&self) -> &Option { + &self.scheme + } + + pub fn token(&self) -> &Option { + &self.token + } + pub fn empty() -> UserAuthContext{ UserAuthContext{scheme: None, token: None} } @@ -30,11 +38,11 @@ impl UserAuthContext { UserAuthContext{scheme: Some("Bearer".into()), token: token} } - fn new(scheme: &str, token: &str) -> UserAuthContext { + pub fn new(scheme: &str, token: &str) -> UserAuthContext { UserAuthContext{scheme: Some(scheme.into()), token: Some(token.into())} } - fn from_auth_str(auth_string: &str) -> Result { + pub fn from_auth_str(auth_string: &str) -> Result { let (scheme, token) = auth_string .split_once(' ') .ok_or(AuthError::AuthStringMalformed)?; @@ -42,13 +50,6 @@ impl UserAuthContext { } } -impl TryFrom<&str> for UserAuthContext { - type Error = AuthError; - - fn try_from(auth_string: &str) -> Result { - UserAuthContext::from_auth_str(auth_string) - } -} pub trait UserAuthStrategy: Sync + Send { fn authenticate(&self, context: UserAuthContext) -> Result; } diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 3d78834735..719aff0359 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -30,7 +30,7 @@ impl FromRequestParts for RequestContext { .headers .get(hyper::header::AUTHORIZATION).ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) - .and_then(|t| t.try_into()) + .and_then(|t| UserAuthContext::from_auth_str(t)) .unwrap_or(UserAuthContext::empty()); let authenticated = namespace_jwt_key diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 0c9d63a425..77822aa678 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -469,7 +469,7 @@ impl FromRequestParts for Authenticated { .headers .get(hyper::header::AUTHORIZATION).ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) - .and_then(|t| t.try_into()) + .and_then(|t| UserAuthContext::from_auth_str(t)) .unwrap_or(UserAuthContext::empty()); let authenticated = namespace_jwt_key From 9f94e4ef1e0ef4f4df84ec9bc8b4ab447fb1b173 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 6 Mar 2024 13:39:21 -0800 Subject: [PATCH 20/63] cargo fmt --- libsql-server/src/auth/parsers.rs | 7 +++-- .../src/auth/user_auth_strategies/mod.rs | 29 ++++++++++++++----- libsql-server/src/http/user/extract.rs | 7 +++-- libsql-server/src/http/user/mod.rs | 7 +++-- libsql-server/src/rpc/replica_proxy.rs | 16 ++++++---- libsql-server/tests/tests.rs | 2 +- 6 files changed, 45 insertions(+), 23 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index a56a342443..d35d81be2c 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -40,9 +40,10 @@ pub(crate) fn parse_grpc_auth_header( metadata: &MetadataMap, ) -> Result { return metadata - .get(GRPC_AUTH_HEADER).ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)) + .get(GRPC_AUTH_HEADER) + .ok_or(AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) + .and_then(|t| UserAuthContext::from_auth_str(t)) .map_err(|e| { tonic::Status::new( tonic::Code::InvalidArgument, diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 02eddde56f..13c828436f 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -22,30 +22,45 @@ impl UserAuthContext { &self.token } - pub fn empty() -> UserAuthContext{ - UserAuthContext{scheme: None, token: None} + pub fn empty() -> UserAuthContext { + UserAuthContext { + scheme: None, + token: None, + } } pub fn basic(creds: &str) -> UserAuthContext { - UserAuthContext{scheme: Some("Basic".into()), token: Some(creds.into())} + UserAuthContext { + scheme: Some("Basic".into()), + token: Some(creds.into()), + } } pub fn bearer(token: &str) -> UserAuthContext { - UserAuthContext{scheme: Some("Bearer".into()), token: Some(token.into())} + UserAuthContext { + scheme: Some("Bearer".into()), + token: Some(token.into()), + } } pub fn bearer_opt(token: Option) -> UserAuthContext { - UserAuthContext{scheme: Some("Bearer".into()), token: token} + UserAuthContext { + scheme: Some("Bearer".into()), + token: token, + } } pub fn new(scheme: &str, token: &str) -> UserAuthContext { - UserAuthContext{scheme: Some(scheme.into()), token: Some(token.into())} + UserAuthContext { + scheme: Some(scheme.into()), + token: Some(token.into()), + } } pub fn from_auth_str(auth_string: &str) -> Result { let (scheme, token) = auth_string .split_once(' ') - .ok_or(AuthError::AuthStringMalformed)?; + .ok_or(AuthError::AuthStringMalformed)?; Ok(UserAuthContext::new(scheme, token)) } } diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 719aff0359..9293a8deae 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -28,9 +28,10 @@ impl FromRequestParts for RequestContext { let context = parts .headers - .get(hyper::header::AUTHORIZATION).ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)) + .get(hyper::header::AUTHORIZATION) + .ok_or(AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) + .and_then(|t| UserAuthContext::from_auth_str(t)) .unwrap_or(UserAuthContext::empty()); let authenticated = namespace_jwt_key diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 77822aa678..c6679058b7 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -467,9 +467,10 @@ impl FromRequestParts for Authenticated { let context = parts .headers - .get(hyper::header::AUTHORIZATION).ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_|AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)) + .get(hyper::header::AUTHORIZATION) + .ok_or(AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) + .and_then(|t| UserAuthContext::from_auth_str(t)) .unwrap_or(UserAuthContext::empty()); let authenticated = namespace_jwt_key diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 646a71daa6..a44a0d87f5 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -43,18 +43,22 @@ impl ReplicaProxyService { .with(namespace.clone(), |ns| ns.jwt_key()) .await; - let namespace_jwt_key = jwt_result - .and_then(|s|s); - + let namespace_jwt_key = jwt_result.and_then(|s| s); let auth_strategy = match namespace_jwt_key { Ok(Some(key)) => Ok(Auth::new(Jwt::new(key))), - Ok(None) | Err(crate::error::Error::NamespaceDoesntExist(_)) => Ok(self.user_auth_strategy.clone()), - Err(e) => Err(Status::internal(format!("Can't fetch jwt key for a namespace: {}", e))), + Ok(None) | Err(crate::error::Error::NamespaceDoesntExist(_)) => { + Ok(self.user_auth_strategy.clone()) + } + Err(e) => Err(Status::internal(format!( + "Can't fetch jwt key for a namespace: {}", + e + ))), }?; let auth_context = parse_grpc_auth_header(req.metadata())?; - auth_strategy.authenticate(auth_context)? + auth_strategy + .authenticate(auth_context)? .upgrade_grpc_request(req); return Ok(()); } diff --git a/libsql-server/tests/tests.rs b/libsql-server/tests/tests.rs index ac4c514b53..497814660c 100644 --- a/libsql-server/tests/tests.rs +++ b/libsql-server/tests/tests.rs @@ -7,4 +7,4 @@ mod cluster; mod embedded_replica; mod hrana; mod namespaces; -mod standalone; \ No newline at end of file +mod standalone; From 284b36fd7a64244b3c9a7e0f63741701707cb767 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Thu, 7 Mar 2024 12:28:15 -0800 Subject: [PATCH 21/63] added tests for failing cases in parsers --- libsql-server/src/auth/parsers.rs | 29 +++++++++++++++++-- .../src/auth/user_auth_strategies/mod.rs | 1 + 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index d35d81be2c..5d4aa998ba 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -42,7 +42,7 @@ pub(crate) fn parse_grpc_auth_header( return metadata .get(GRPC_AUTH_HEADER) .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) + .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) // this never happens, guaranteed at type level .and_then(|t| UserAuthContext::from_auth_str(t)) .map_err(|e| { tonic::Status::new( @@ -82,7 +82,32 @@ mod tests { use crate::auth::{parse_http_auth_header, AuthError}; - use super::parse_http_basic_auth_arg; + use super::{parse_grpc_auth_header, parse_http_basic_auth_arg}; + + #[test] + fn parse_grpc_auth_header_returns_valid_context(){ + let mut map = tonic::metadata::MetadataMap::new(); + map.insert("x-authorization", "bearer 123".parse().unwrap()); + let context = parse_grpc_auth_header(&map).unwrap(); + assert_eq!(context.scheme().as_ref().unwrap(), "bearer"); + assert_eq!(context.token().as_ref().unwrap(), "123"); + } + + + #[test] + fn parse_grpc_auth_header_error_no_header(){ + let map = tonic::metadata::MetadataMap::new(); + let result = parse_grpc_auth_header(&map); + assert_eq!(result.unwrap_err().message(), "Failed parse grpc auth: Expected authorization header but none given"); + } + + #[test] + fn parse_grpc_auth_header_error_malformed_auth_str(){ + let mut map = tonic::metadata::MetadataMap::new(); + map.insert("x-authorization", "bearer123".parse().unwrap()); + let result = parse_grpc_auth_header(&map); + assert_eq!(result.unwrap_err().message(), "Failed parse grpc auth: Auth string does not conform to ' ' form") + } #[test] fn parse_http_auth_header_returns_auth_header_param_when_valid() { diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 13c828436f..41e20f8bad 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -8,6 +8,7 @@ pub use jwt::Jwt; use super::{AuthError, Authenticated}; +#[derive(Debug)] pub struct UserAuthContext { scheme: Option, token: Option, From d0ec05aa0f5d3a41795e6785d27b6e93486e3b84 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 8 Mar 2024 11:01:37 -0800 Subject: [PATCH 22/63] adding mamespace as param wip --- libsql-server/src/http/user/db_factory.rs | 30 ++++++++++++++--------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/libsql-server/src/http/user/db_factory.rs b/libsql-server/src/http/user/db_factory.rs index 257d8811c1..794a1f25ac 100644 --- a/libsql-server/src/http/user/db_factory.rs +++ b/libsql-server/src/http/user/db_factory.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use axum::extract::{FromRequestParts, Path}; use hyper::http::request::Parts; use hyper::HeaderMap; +use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY; use crate::auth::Authenticated; use crate::connection::MakeConnection; @@ -46,18 +47,25 @@ pub fn namespace_from_headers( return Ok(NamespaceName::default()); } - let host = headers - .get("host") - .ok_or_else(|| Error::InvalidHost("missing host header".into()))? - .as_bytes(); - let host_str = std::str::from_utf8(host) - .map_err(|_| Error::InvalidHost("host header is not valid UTF-8".into()))?; + headers + .get(NAMESPACE_METADATA_KEY) + .ok_or(Error::InvalidNamespace) + .and_then(|h| h.to_str().map_err(|_| Error::InvalidNamespace)) + .and_then(|n| NamespaceName::from_string(n.into())) + .or_else(|_| { + let host = headers + .get("host") + .ok_or_else(|| Error::InvalidHost("missing host header".into()))? + .as_bytes(); + let host_str = std::str::from_utf8(host) + .map_err(|_| Error::InvalidHost("host header is not valid UTF-8".into()))?; - match split_namespace(host_str) { - Ok(ns) => Ok(ns), - Err(_) if !disable_default_namespace => Ok(NamespaceName::default()), - Err(e) => Err(e), - } + match split_namespace(host_str) { + Ok(ns) => Ok(ns), + Err(_) if !disable_default_namespace => Ok(NamespaceName::default()), + Err(e) => Err(e), + } + }) } pub struct MakeConnectionExtractorPath(pub Arc>); From f716a28b5bfd8fa3ef9a20b28b20cc2a18e4c70f Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 8 Mar 2024 13:14:09 -0800 Subject: [PATCH 23/63] cargo fmt --- libsql-server/src/auth/parsers.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 5d4aa998ba..0118bb0d69 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -85,7 +85,7 @@ mod tests { use super::{parse_grpc_auth_header, parse_http_basic_auth_arg}; #[test] - fn parse_grpc_auth_header_returns_valid_context(){ + fn parse_grpc_auth_header_returns_valid_context() { let mut map = tonic::metadata::MetadataMap::new(); map.insert("x-authorization", "bearer 123".parse().unwrap()); let context = parse_grpc_auth_header(&map).unwrap(); @@ -93,20 +93,25 @@ mod tests { assert_eq!(context.token().as_ref().unwrap(), "123"); } - #[test] - fn parse_grpc_auth_header_error_no_header(){ + fn parse_grpc_auth_header_error_no_header() { let map = tonic::metadata::MetadataMap::new(); let result = parse_grpc_auth_header(&map); - assert_eq!(result.unwrap_err().message(), "Failed parse grpc auth: Expected authorization header but none given"); + assert_eq!( + result.unwrap_err().message(), + "Failed parse grpc auth: Expected authorization header but none given" + ); } #[test] - fn parse_grpc_auth_header_error_malformed_auth_str(){ + fn parse_grpc_auth_header_error_malformed_auth_str() { let mut map = tonic::metadata::MetadataMap::new(); map.insert("x-authorization", "bearer123".parse().unwrap()); let result = parse_grpc_auth_header(&map); - assert_eq!(result.unwrap_err().message(), "Failed parse grpc auth: Auth string does not conform to ' ' form") + assert_eq!( + result.unwrap_err().message(), + "Failed parse grpc auth: Auth string does not conform to ' ' form" + ) } #[test] From 843c9f7a553ae1773d2ad798a90eae8c1d8e1523 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 8 Mar 2024 13:17:46 -0800 Subject: [PATCH 24/63] added test for non-asci error --- libsql-server/src/auth/parsers.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 0118bb0d69..ca4e7b9734 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -42,7 +42,7 @@ pub(crate) fn parse_grpc_auth_header( return metadata .get(GRPC_AUTH_HEADER) .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) // this never happens, guaranteed at type level + .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) .and_then(|t| UserAuthContext::from_auth_str(t)) .map_err(|e| { tonic::Status::new( @@ -103,6 +103,17 @@ mod tests { ); } + #[test] + fn parse_grpc_auth_header_error_non_ascii() { + let mut map = tonic::metadata::MetadataMap::new(); + map.insert("x-authorization", "bearer I❤NY".parse().unwrap()); + let result = parse_grpc_auth_header(&map); + assert_eq!( + result.unwrap_err().message(), + "Failed parse grpc auth: Non-ASCII auth header" + ) + } + #[test] fn parse_grpc_auth_header_error_malformed_auth_str() { let mut map = tonic::metadata::MetadataMap::new(); From 047823b4627ee26e60bc9839f6d377fc354b0797 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 8 Mar 2024 15:48:50 -0800 Subject: [PATCH 25/63] incremental changes to make namespace as param work --- libsql/src/database.rs | 3 ++- libsql/src/database/builder.rs | 17 +++++++++++++---- libsql/src/local/database.rs | 4 ++++ libsql/src/replication/client.rs | 8 +++++++- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/libsql/src/database.rs b/libsql/src/database.rs index b9972aaa59..f0818d8ea7 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -184,7 +184,7 @@ cfg_replication! { None, OpenFlags::default(), encryption_config.clone(), - None + None, ).await?; Ok(Database { @@ -309,6 +309,7 @@ cfg_replication! { read_your_writes, encryption_config.clone(), periodic_sync, + None, None ).await?; diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index 7b3df3b231..91afb8be22 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -54,7 +54,8 @@ impl Builder<()> { encryption_config: None, read_your_writes: true, periodic_sync: None, - http_request_callback: None + http_request_callback: None, + namespace: None }, } } @@ -159,6 +160,7 @@ cfg_replication! { read_your_writes: bool, periodic_sync: Option, http_request_callback: Option, + namespace: Option, } /// Local replica configuration type in [`Builder`]. @@ -220,6 +222,11 @@ cfg_replication! { } + pub fn namespace(mut self, namespace: &str) -> Builder { + self.inner.namespace = Some(namespace.into()); + self + } + #[doc(hidden)] pub fn version(mut self, version: String) -> Builder { self.inner.remote = self.inner.remote.version(version); @@ -240,7 +247,8 @@ cfg_replication! { encryption_config, read_your_writes, periodic_sync, - http_request_callback + http_request_callback, + namespace } = self.inner; let connector = if let Some(connector) = connector { @@ -267,7 +275,8 @@ cfg_replication! { read_your_writes, encryption_config.clone(), periodic_sync, - http_request_callback + http_request_callback, + namespace, ) .await?; @@ -333,7 +342,7 @@ cfg_replication! { version, flags, encryption_config.clone(), - http_request_callback + http_request_callback, ) .await? } else { diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index f27638ee8f..265f54e4c1 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -65,6 +65,7 @@ impl Database { encryption_config, periodic_sync, None, + None, ) .await } @@ -81,6 +82,7 @@ impl Database { encryption_config: Option, periodic_sync: Option, http_request_callback: Option, + namespace: Option ) -> Result { use std::path::PathBuf; @@ -95,6 +97,7 @@ impl Database { auth_token, version.as_deref(), http_request_callback, + namespace, ) .unwrap(); let path = PathBuf::from(db_path); @@ -166,6 +169,7 @@ impl Database { auth_token, version.as_deref(), http_request_callback, + None, ) .unwrap(); diff --git a/libsql/src/replication/client.rs b/libsql/src/replication/client.rs index 8ef5edaf04..16227d4f5a 100644 --- a/libsql/src/replication/client.rs +++ b/libsql/src/replication/client.rs @@ -47,6 +47,7 @@ impl Client { auth_token: impl AsRef, version: Option<&str>, http_request_callback: Option, + maybe_namespace: Option, ) -> anyhow::Result { let ver = version.unwrap_or(env!("CARGO_PKG_VERSION")); @@ -58,7 +59,12 @@ impl Client { .try_into() .context("Invalid auth token must be ascii")?; - let ns = split_namespace(origin.host().unwrap()).unwrap_or_else(|_| "default".to_string()); + + let ns = maybe_namespace.unwrap_or_else(|| + split_namespace(origin.host().unwrap()) + .unwrap_or_else(|_| "default".to_string()) + ); + let namespace = BinaryMetadataValue::from_bytes(ns.as_bytes()); let channel = GrpcChannel::new(connector, http_request_callback); From d32bd5f27594c1db04439e8545eeb86d2d174729 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Mon, 11 Mar 2024 08:00:04 -0700 Subject: [PATCH 26/63] fixed failing test --- libsql-server/src/http/user/db_factory.rs | 30 +++++++++++++-------- libsql-server/tests/embedded_replica/mod.rs | 2 +- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/libsql-server/src/http/user/db_factory.rs b/libsql-server/src/http/user/db_factory.rs index 257d8811c1..794a1f25ac 100644 --- a/libsql-server/src/http/user/db_factory.rs +++ b/libsql-server/src/http/user/db_factory.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use axum::extract::{FromRequestParts, Path}; use hyper::http::request::Parts; use hyper::HeaderMap; +use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY; use crate::auth::Authenticated; use crate::connection::MakeConnection; @@ -46,18 +47,25 @@ pub fn namespace_from_headers( return Ok(NamespaceName::default()); } - let host = headers - .get("host") - .ok_or_else(|| Error::InvalidHost("missing host header".into()))? - .as_bytes(); - let host_str = std::str::from_utf8(host) - .map_err(|_| Error::InvalidHost("host header is not valid UTF-8".into()))?; + headers + .get(NAMESPACE_METADATA_KEY) + .ok_or(Error::InvalidNamespace) + .and_then(|h| h.to_str().map_err(|_| Error::InvalidNamespace)) + .and_then(|n| NamespaceName::from_string(n.into())) + .or_else(|_| { + let host = headers + .get("host") + .ok_or_else(|| Error::InvalidHost("missing host header".into()))? + .as_bytes(); + let host_str = std::str::from_utf8(host) + .map_err(|_| Error::InvalidHost("host header is not valid UTF-8".into()))?; - match split_namespace(host_str) { - Ok(ns) => Ok(ns), - Err(_) if !disable_default_namespace => Ok(NamespaceName::default()), - Err(e) => Err(e), - } + match split_namespace(host_str) { + Ok(ns) => Ok(ns), + Err(_) if !disable_default_namespace => Ok(NamespaceName::default()), + Err(e) => Err(e), + } + }) } pub struct MakeConnectionExtractorPath(pub Arc>); diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index 6f414ea8ce..5417db359c 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -929,7 +929,7 @@ fn errors_on_bad_replica() { let db = libsql::Builder::new_remote_replica( path.to_str().unwrap(), "http://foo.primary:8080".to_string(), - "".to_string(), + "dummy_token".to_string(), ) .connector(TurmoilConnector) .build() From 412538333fec5c951727835bb0bcccd0175f4e15 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 13 Mar 2024 10:25:29 -0700 Subject: [PATCH 27/63] fixed log message --- libsql/examples/replica.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libsql/examples/replica.rs b/libsql/examples/replica.rs index 499b040161..b61e886407 100644 --- a/libsql/examples/replica.rs +++ b/libsql/examples/replica.rs @@ -10,12 +10,12 @@ async fn main() { let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or_else(|_| { println!("Using empty token since LIBSQL_TOKEN was not set"); - "".to_string() + "x".to_string() }); let url = std::env::var("LIBSQL_URL") .unwrap_or_else(|_| { - println!("Using empty token since LIBSQL_URL was not set"); + println!("Using http://localhost:8080 LIBSQL_URL was not set"); "http://localhost:8080".to_string() }) .replace("libsql", "https"); From 31050256854a2456b3784d51133c47860ac7c928 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Thu, 14 Mar 2024 10:40:14 -0700 Subject: [PATCH 28/63] removed unnecessary error mapping --- libsql-server/src/auth/parsers.rs | 25 ++++++------------- .../src/auth/user_auth_strategies/jwt.rs | 5 +--- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index ca4e7b9734..643a3fee1b 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -36,20 +36,12 @@ pub fn parse_jwt_key(data: &str) -> Result { } } -pub(crate) fn parse_grpc_auth_header( - metadata: &MetadataMap, -) -> Result { - return metadata +pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result { + metadata .get(GRPC_AUTH_HEADER) .ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) .and_then(|t| UserAuthContext::from_auth_str(t)) - .map_err(|e| { - tonic::Status::new( - tonic::Code::InvalidArgument, - format!("Failed parse grpc auth: {e}"), - ) - }); } pub fn parse_http_auth_header<'a>( @@ -98,8 +90,8 @@ mod tests { let map = tonic::metadata::MetadataMap::new(); let result = parse_grpc_auth_header(&map); assert_eq!( - result.unwrap_err().message(), - "Failed parse grpc auth: Expected authorization header but none given" + result.unwrap_err().to_string(), + "Expected authorization header but none given" ); } @@ -108,10 +100,7 @@ mod tests { let mut map = tonic::metadata::MetadataMap::new(); map.insert("x-authorization", "bearer I❤NY".parse().unwrap()); let result = parse_grpc_auth_header(&map); - assert_eq!( - result.unwrap_err().message(), - "Failed parse grpc auth: Non-ASCII auth header" - ) + assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header") } #[test] @@ -120,8 +109,8 @@ mod tests { map.insert("x-authorization", "bearer123".parse().unwrap()); let result = parse_grpc_auth_header(&map); assert_eq!( - result.unwrap_err().message(), - "Failed parse grpc auth: Auth string does not conform to ' ' form" + result.unwrap_err().to_string(), + "Auth string does not conform to ' ' form" ) } diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index b0ca865a20..0a0d163286 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -1,10 +1,7 @@ use chrono::{DateTime, Utc}; use crate::{ - auth::{ - authenticated::LegacyAuth, AuthError, Authenticated, Authorized, - Permission, - }, + auth::{authenticated::LegacyAuth, AuthError, Authenticated, Authorized, Permission}, namespace::NamespaceName, }; From bccfe77aab22533ec6f55f2646ac60067cb8dec0 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Thu, 14 Mar 2024 13:15:56 -0700 Subject: [PATCH 29/63] turned context to result --- libsql-server/src/auth/mod.rs | 5 +++- .../src/auth/user_auth_strategies/disabled.rs | 7 +++-- .../auth/user_auth_strategies/http_basic.rs | 13 ++++---- .../src/auth/user_auth_strategies/jwt.rs | 27 ++++++++++------- .../src/auth/user_auth_strategies/mod.rs | 5 +++- libsql-server/src/hrana/ws/session.rs | 4 +-- libsql-server/src/http/user/extract.rs | 5 ++-- libsql-server/src/http/user/mod.rs | 3 +- libsql-server/src/rpc/proxy.rs | 2 +- libsql-server/src/rpc/replica_proxy.rs | 2 +- libsql-server/src/rpc/replication_log.rs | 2 +- libsql-server/tests/cluster/mod.rs | 4 +-- .../tests/cluster/replica_restart.rs | 6 ++-- libsql-server/tests/cluster/replication.rs | 4 +-- libsql-server/tests/embedded_replica/local.rs | 4 +-- libsql-server/tests/embedded_replica/mod.rs | 30 +++++++++---------- libsql-server/tests/hrana/batch.rs | 8 ++--- libsql-server/tests/hrana/transaction.rs | 6 ++-- libsql-server/tests/namespaces/dumps.rs | 10 +++---- libsql-server/tests/namespaces/meta.rs | 16 +++++----- libsql-server/tests/namespaces/mod.rs | 6 ++-- libsql-server/tests/standalone/attach.rs | 4 +-- 22 files changed, 94 insertions(+), 79 deletions(-) diff --git a/libsql-server/src/auth/mod.rs b/libsql-server/src/auth/mod.rs index 26c6dfecc1..09468f4b3a 100644 --- a/libsql-server/src/auth/mod.rs +++ b/libsql-server/src/auth/mod.rs @@ -27,7 +27,10 @@ impl Auth { } } - pub fn authenticate(&self, context: UserAuthContext) -> Result { + pub fn authenticate( + &self, + context: Result, + ) -> Result { self.user_strategy.authenticate(context) } } diff --git a/libsql-server/src/auth/user_auth_strategies/disabled.rs b/libsql-server/src/auth/user_auth_strategies/disabled.rs index ef9aae9062..b95d52c061 100644 --- a/libsql-server/src/auth/user_auth_strategies/disabled.rs +++ b/libsql-server/src/auth/user_auth_strategies/disabled.rs @@ -4,7 +4,10 @@ use crate::auth::{AuthError, Authenticated}; pub struct Disabled {} impl UserAuthStrategy for Disabled { - fn authenticate(&self, _context: UserAuthContext) -> Result { + fn authenticate( + &self, + _context: Result, + ) -> Result { tracing::trace!("executing disabled auth"); Ok(Authenticated::FullAccess) } @@ -23,7 +26,7 @@ mod tests { #[test] fn authenticates() { let strategy = Disabled::new(); - let context = UserAuthContext::empty(); + let context = Ok(UserAuthContext::empty()); assert!(matches!( strategy.authenticate(context).unwrap(), diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index 3e62d912bb..fbb45d0912 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -7,14 +7,17 @@ pub struct HttpBasic { } impl UserAuthStrategy for HttpBasic { - fn authenticate(&self, context: UserAuthContext) -> Result { + fn authenticate( + &self, + context: Result, + ) -> Result { tracing::trace!("executing http basic auth"); // NOTE: this naive comparison may leak information about the `expected_value` // using a timing attack let expected_value = self.credential.trim_end_matches('='); - let creds_match = match context.token { + let creds_match = match context?.token { Some(s) => s.contains(expected_value), None => expected_value.is_empty(), }; @@ -45,7 +48,7 @@ mod tests { #[test] fn authenticates_with_valid_credential() { - let context = UserAuthContext::basic(CREDENTIAL); + let context = Ok(UserAuthContext::basic(CREDENTIAL)); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -56,7 +59,7 @@ mod tests { #[test] fn authenticates_with_valid_trimmed_credential() { let credential = CREDENTIAL.trim_end_matches('='); - let context = UserAuthContext::basic(credential); + let context = Ok(UserAuthContext::basic(credential)); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -66,7 +69,7 @@ mod tests { #[test] fn errors_when_credentials_do_not_match() { - let context = UserAuthContext::basic("abc"); + let context = Ok(UserAuthContext::basic("abc")); assert_eq!( strategy().authenticate(context).unwrap_err(), diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index 0a0d163286..da68e91df0 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -12,10 +12,19 @@ pub struct Jwt { } impl UserAuthStrategy for Jwt { - fn authenticate(&self, context: UserAuthContext) -> Result { + fn authenticate( + &self, + context: Result, + ) -> Result { tracing::trace!("executing jwt auth"); - let Some(scheme) = context.scheme else { + let ctx = context?; + + let UserAuthContext { + scheme: Some(scheme), + token: Some(token), + } = ctx + else { return Err(AuthError::HttpAuthHeaderInvalid); }; @@ -23,10 +32,6 @@ impl UserAuthStrategy for Jwt { return Err(AuthError::HttpAuthHeaderUnsupportedScheme); } - let Some(token) = context.token else { - return Err(AuthError::HttpAuthHeaderInvalid); - }; - return validate_jwt(&self.key, &token); } } @@ -150,7 +155,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = UserAuthContext::bearer(token.as_str()); + let context = Ok(UserAuthContext::bearer(token.as_str())); assert!(matches!( strategy(dec).authenticate(context).unwrap(), @@ -172,7 +177,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = UserAuthContext::bearer(token.as_str()); + let context = Ok(UserAuthContext::bearer(token.as_str())); let Authenticated::Legacy(a) = strategy(dec).authenticate(context).unwrap() else { panic!() @@ -185,7 +190,7 @@ mod tests { #[test] fn errors_when_jwt_token_invalid() { let (_enc, dec) = key_pair(); - let context = UserAuthContext::bearer("abc"); + let context = Ok(UserAuthContext::bearer("abc")); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -205,7 +210,7 @@ mod tests { let token = encode(&token, &enc); - let context = UserAuthContext::bearer(token.as_str()); + let context = Ok(UserAuthContext::bearer(token.as_str())); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -227,7 +232,7 @@ mod tests { let token = encode(&token, &enc); - let context = UserAuthContext::bearer(token.as_str()); + let context = Ok(UserAuthContext::bearer(token.as_str())); let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { panic!() diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 41e20f8bad..4f0f2ef786 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -67,5 +67,8 @@ impl UserAuthContext { } pub trait UserAuthStrategy: Sync + Send { - fn authenticate(&self, context: UserAuthContext) -> Result; + fn authenticate( + &self, + context: Result, + ) -> Result; } diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index eaa163286f..9cc654c59d 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -81,7 +81,7 @@ pub(super) async fn handle_initial_hello( .map(Jwt::new) .map(Auth::new) .unwrap_or(server.user_auth_strategy.clone()) - .authenticate(UserAuthContext::bearer_opt(jwt)) + .authenticate(Ok(UserAuthContext::bearer_opt(jwt))) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(Session { @@ -115,7 +115,7 @@ pub(super) async fn handle_repeated_hello( .map(Jwt::new) .map(Auth::new) .unwrap_or_else(|| server.user_auth_strategy.clone()) - .authenticate(UserAuthContext::bearer_opt(jwt)) + .authenticate(Ok(UserAuthContext::bearer_opt(jwt))) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(()) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 9293a8deae..522c1bb343 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -16,7 +16,7 @@ impl FromRequestParts for RequestContext { state: &AppState, ) -> std::result::Result { let namespace = db_factory::namespace_from_headers( - &parts.headers, + &parts.headers, state.disable_default_namespace, state.disable_namespaces, )?; @@ -31,8 +31,7 @@ impl FromRequestParts for RequestContext { .get(hyper::header::AUTHORIZATION) .ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)) - .unwrap_or(UserAuthContext::empty()); + .and_then(|t| UserAuthContext::from_auth_str(t)); let authenticated = namespace_jwt_key .map(Jwt::new) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index d6bb33e795..3e394fc579 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -473,8 +473,7 @@ impl FromRequestParts for Authenticated { .get(hyper::header::AUTHORIZATION) .ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)) - .unwrap_or(UserAuthContext::empty()); + .and_then(|t| UserAuthContext::from_auth_str(t)); let authenticated = namespace_jwt_key .map(Jwt::new) diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index be198d6282..b8e0d945dd 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -335,7 +335,7 @@ impl ProxyService { }; let auth = if let Some(auth) = auth { - let context = parse_grpc_auth_header(req.metadata())?; + let context = parse_grpc_auth_header(req.metadata()); auth.authenticate(context)? } else { Authenticated::from_proxy_grpc_request(req)? diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index a44a0d87f5..6945cfffed 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -56,7 +56,7 @@ impl ReplicaProxyService { ))), }?; - let auth_context = parse_grpc_auth_header(req.metadata())?; + let auth_context = parse_grpc_auth_header(req.metadata()); auth_strategy .authenticate(auth_context)? .upgrade_grpc_request(req); diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 0cee882ffb..24d7b9ca03 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -94,7 +94,7 @@ impl ReplicationLogService { }; if let Some(auth) = auth { - let user_credential = parse_grpc_auth_header(req.metadata())?; + let user_credential = parse_grpc_auth_header(req.metadata()); auth.authenticate(user_credential)?; } diff --git a/libsql-server/tests/cluster/mod.rs b/libsql-server/tests/cluster/mod.rs index bfe7c5e274..2517b386bb 100644 --- a/libsql-server/tests/cluster/mod.rs +++ b/libsql-server/tests/cluster/mod.rs @@ -163,7 +163,7 @@ fn sync_many_replica() { sim.client("client", async { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; @@ -301,7 +301,7 @@ fn large_proxy_query() { let db = Database::open_remote_with_connector( "http://replica0:8080", - "dummy_token", + "", TurmoilConnector, ) .unwrap(); diff --git a/libsql-server/tests/cluster/replica_restart.rs b/libsql-server/tests/cluster/replica_restart.rs index 574296d947..826e027d8b 100644 --- a/libsql-server/tests/cluster/replica_restart.rs +++ b/libsql-server/tests/cluster/replica_restart.rs @@ -98,7 +98,7 @@ fn replica_restart() { let http = Client::new(); let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; @@ -267,7 +267,7 @@ fn primary_regenerate_log_no_replica_restart() { let http = Client::new(); let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; @@ -476,7 +476,7 @@ fn primary_regenerate_log_with_replica_restart() { let http = Client::new(); let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; diff --git a/libsql-server/tests/cluster/replication.rs b/libsql-server/tests/cluster/replication.rs index 8bbc2a23f9..4444d1bde6 100644 --- a/libsql-server/tests/cluster/replication.rs +++ b/libsql-server/tests/cluster/replication.rs @@ -91,7 +91,7 @@ fn apply_partial_snapshot() { sim.client("client", async move { let primary = libsql::Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, ) .unwrap(); @@ -217,7 +217,7 @@ fn replica_lazy_creation() { sim.client("client", async move { let db = Database::open_remote_with_connector( "http://test.replica:8080", - "dummy_token", + "", TurmoilConnector, ) .unwrap(); diff --git a/libsql-server/tests/embedded_replica/local.rs b/libsql-server/tests/embedded_replica/local.rs index 28e51a2b38..fbba29c7dc 100644 --- a/libsql-server/tests/embedded_replica/local.rs +++ b/libsql-server/tests/embedded_replica/local.rs @@ -35,7 +35,7 @@ fn local_sync_with_writes() { let primary = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = primary.connect()?; @@ -70,7 +70,7 @@ fn local_sync_with_writes() { let db = Database::open_with_local_sync_remote_writes_connector( tmp_host_path.join("embedded").to_str().unwrap(), "http://foo.primary:8080".to_string(), - "dummy_token".to_string(), + "".to_string(), TurmoilConnector, None, ) diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index a9dee343e2..1d27f935bd 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -90,7 +90,7 @@ fn embedded_replica() { let db = Database::open_with_remote_sync_connector( path.to_str().unwrap(), "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, false, None, @@ -163,7 +163,7 @@ fn execute_batch() { let db = Database::open_with_remote_sync_connector( path.to_str().unwrap(), "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, false, None, @@ -370,7 +370,7 @@ fn replica_primary_reset() { sim.client("client", async move { let primary = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = primary.connect()?; @@ -387,7 +387,7 @@ fn replica_primary_reset() { let replica = Database::open_with_remote_sync_connector( tmp.path().join("data").display().to_string(), "http://primary:8080", - "dummy_token", + "", TurmoilConnector, false, None, @@ -446,7 +446,7 @@ fn replica_primary_reset() { let replica = Database::open_with_remote_sync_connector( tmp.path().join("data").display().to_string(), "http://primary:8080", - "dummy_token", + "", TurmoilConnector, false, None, @@ -534,7 +534,7 @@ fn replica_no_resync_on_restart() { { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, ) .unwrap(); @@ -554,7 +554,7 @@ fn replica_no_resync_on_restart() { let db = Database::open_with_remote_sync_connector( db_path.display().to_string(), "http://primary:8080", - "dummy_token", + "", TurmoilConnector, false, None, @@ -570,7 +570,7 @@ fn replica_no_resync_on_restart() { let db = Database::open_with_remote_sync_connector( db_path.display().to_string(), "http://primary:8080", - "dummy_token", + "", TurmoilConnector, false, None, @@ -640,7 +640,7 @@ fn replicate_with_snapshots() { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, ) .unwrap(); @@ -657,7 +657,7 @@ fn replicate_with_snapshots() { let db = Database::open_with_remote_sync_connector( tmp.path().join("data").display().to_string(), "http://primary:8080", - "dummy_token", + "", TurmoilConnector, false, None, @@ -726,7 +726,7 @@ fn read_your_writes() { let db = Database::open_with_remote_sync_connector( path.to_str().unwrap(), "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, true, None, @@ -771,7 +771,7 @@ fn proxy_write_returning_row() { let db = Database::open_with_remote_sync_connector( path.to_str().unwrap(), "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, true, None, @@ -818,7 +818,7 @@ fn freeze() { let db = Database::open_with_remote_sync_connector( path.to_str().unwrap(), "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, true, None, @@ -841,7 +841,7 @@ fn freeze() { let db = Database::open_with_remote_sync_connector( path.to_str().unwrap(), "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, true, None, @@ -891,7 +891,7 @@ fn sync_interval() { let db = libsql::Builder::new_remote_replica( path.to_str().unwrap(), "http://foo.primary:8080".to_string(), - "dummy_token".to_string(), + "".to_string(), ) .connector(TurmoilConnector) .sync_interval(Duration::from_millis(100)) diff --git a/libsql-server/tests/hrana/batch.rs b/libsql-server/tests/hrana/batch.rs index f52fb1f97b..3dcab9db8c 100644 --- a/libsql-server/tests/hrana/batch.rs +++ b/libsql-server/tests/hrana/batch.rs @@ -50,7 +50,7 @@ fn execute_individual_statements() { sim.client("client", async { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; @@ -82,7 +82,7 @@ fn execute_batch() { sim.client("client", async { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; @@ -122,7 +122,7 @@ fn multistatement_query() { sim.client("client", async { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; @@ -149,7 +149,7 @@ fn affected_rows_and_last_rowid() { sim.client("client", async { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; diff --git a/libsql-server/tests/hrana/transaction.rs b/libsql-server/tests/hrana/transaction.rs index 388ab174f0..00dc66feb7 100644 --- a/libsql-server/tests/hrana/transaction.rs +++ b/libsql-server/tests/hrana/transaction.rs @@ -12,7 +12,7 @@ fn transaction_commit_and_rollback() { sim.client("client", async { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; @@ -60,7 +60,7 @@ fn multiple_concurrent_transactions() { sim.client("client", async { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; @@ -122,7 +122,7 @@ fn transaction_timeout() { sim.client("client", async { let db = Database::open_remote_with_connector( "http://primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let conn = db.connect()?; diff --git a/libsql-server/tests/namespaces/dumps.rs b/libsql-server/tests/namespaces/dumps.rs index 97b7ba9b1a..9f9bfded0d 100644 --- a/libsql-server/tests/namespaces/dumps.rs +++ b/libsql-server/tests/namespaces/dumps.rs @@ -56,7 +56,7 @@ fn load_namespace_from_dump_from_url() { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -130,7 +130,7 @@ fn load_namespace_from_dump_from_file() { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -185,7 +185,7 @@ fn load_namespace_from_no_commit() { // namespace doesn't exist let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -239,7 +239,7 @@ fn load_namespace_from_no_txn() { // namespace doesn't exist let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -272,7 +272,7 @@ fn export_dump() { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; diff --git a/libsql-server/tests/namespaces/meta.rs b/libsql-server/tests/namespaces/meta.rs index b484947348..498d88d2e9 100644 --- a/libsql-server/tests/namespaces/meta.rs +++ b/libsql-server/tests/namespaces/meta.rs @@ -41,7 +41,7 @@ fn replicated_config() { { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -53,7 +53,7 @@ fn replicated_config() { { let foo = Database::open_remote_with_connector( "http://foo.replica1:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -91,7 +91,7 @@ fn meta_store() { { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -113,7 +113,7 @@ fn meta_store() { { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -135,7 +135,7 @@ fn meta_store() { { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -173,7 +173,7 @@ fn meta_attach() { { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -185,7 +185,7 @@ fn meta_attach() { { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -208,7 +208,7 @@ fn meta_attach() { { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; diff --git a/libsql-server/tests/namespaces/mod.rs b/libsql-server/tests/namespaces/mod.rs index c42898cdc1..4f80857f18 100644 --- a/libsql-server/tests/namespaces/mod.rs +++ b/libsql-server/tests/namespaces/mod.rs @@ -62,7 +62,7 @@ fn fork_namespace() { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; @@ -76,7 +76,7 @@ fn fork_namespace() { let bar = Database::open_remote_with_connector( "http://bar.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let bar_conn = bar.connect()?; @@ -126,7 +126,7 @@ fn delete_namespace() { let foo = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo.connect()?; diff --git a/libsql-server/tests/standalone/attach.rs b/libsql-server/tests/standalone/attach.rs index 3af41a7781..517905e19c 100644 --- a/libsql-server/tests/standalone/attach.rs +++ b/libsql-server/tests/standalone/attach.rs @@ -38,7 +38,7 @@ fn attach_no_auth() { let foo_db = Database::open_remote_with_connector( "http://foo.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let foo_conn = foo_db.connect().unwrap(); @@ -53,7 +53,7 @@ fn attach_no_auth() { let bar_db = Database::open_remote_with_connector( "http://bar.primary:8080", - "dummy_token", + "", TurmoilConnector, )?; let bar_conn = bar_db.connect().unwrap(); From 37dc3af8b673b3a339eedd0b69cc391bd2ab5980 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Thu, 14 Mar 2024 14:14:38 -0700 Subject: [PATCH 30/63] removing dummy tokens from tests --- libsql-server/tests/cluster/mod.rs | 51 +++++------------ .../tests/cluster/replica_restart.rs | 18 +----- libsql-server/tests/cluster/replication.rs | 9 +-- libsql-server/tests/embedded_replica/mod.rs | 24 +++----- libsql-server/tests/hrana/batch.rs | 24 ++------ libsql-server/tests/hrana/transaction.rs | 18 +----- libsql-server/tests/namespaces/dumps.rs | 35 ++++-------- libsql-server/tests/namespaces/mod.rs | 21 ++----- libsql-server/tests/standalone/attach.rs | 14 ++--- libsql-server/tests/standalone/mod.rs | 56 ++++--------------- 10 files changed, 64 insertions(+), 206 deletions(-) diff --git a/libsql-server/tests/cluster/mod.rs b/libsql-server/tests/cluster/mod.rs index 2517b386bb..688391a7a8 100644 --- a/libsql-server/tests/cluster/mod.rs +++ b/libsql-server/tests/cluster/mod.rs @@ -90,22 +90,15 @@ fn proxy_write() { make_cluster(&mut sim, 1, true); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://replica0:8080", - "dummy-auth", - TurmoilConnector, - )?; + let db = + Database::open_remote_with_connector("http://replica0:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; conn.execute("insert into test values (12)", ()).await?; // assert that the primary got the write - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-auth", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; let mut rows = conn.query("select count(*) from test", ()).await?; @@ -131,11 +124,8 @@ fn replica_read_write() { make_cluster(&mut sim, 1, true); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://replica0:8080", - "dummy-auth", - TurmoilConnector, - )?; + let db = + Database::open_remote_with_connector("http://replica0:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -161,11 +151,7 @@ fn sync_many_replica() { .build(); make_cluster(&mut sim, NUM_REPLICA, true); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -215,7 +201,7 @@ fn sync_many_replica() { for i in 0..NUM_REPLICA { let db = Database::open_remote_with_connector( format!("http://replica{i}:8080"), - "dummy-auth", + "", TurmoilConnector, )?; let conn = db.connect()?; @@ -240,11 +226,8 @@ fn create_namespace() { make_cluster(&mut sim, 0, false); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://foo.primary:8080", - "dummy-auth", - TurmoilConnector, - )?; + let db = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; let Err(e) = conn.execute("create table test (x)", ()).await else { @@ -284,12 +267,8 @@ fn large_proxy_query() { make_cluster(&mut sim, 1, true); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-auth", - TurmoilConnector, - ) - .unwrap(); + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector) + .unwrap(); let conn = db.connect().unwrap(); conn.execute("create table test (x)", ()).await.unwrap(); @@ -299,12 +278,8 @@ fn large_proxy_query() { .unwrap(); } - let db = Database::open_remote_with_connector( - "http://replica0:8080", - "", - TurmoilConnector, - ) - .unwrap(); + let db = Database::open_remote_with_connector("http://replica0:8080", "", TurmoilConnector) + .unwrap(); let conn = db.connect().unwrap(); conn.execute_batch("begin immediate; select * from test limit (4000)") diff --git a/libsql-server/tests/cluster/replica_restart.rs b/libsql-server/tests/cluster/replica_restart.rs index 826e027d8b..11c78d8ced 100644 --- a/libsql-server/tests/cluster/replica_restart.rs +++ b/libsql-server/tests/cluster/replica_restart.rs @@ -96,11 +96,7 @@ fn replica_restart() { sim.client("client", async move { let http = Client::new(); - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; // insert a few valued into the primary @@ -265,11 +261,7 @@ fn primary_regenerate_log_no_replica_restart() { sim.client("client", async move { let http = Client::new(); - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; // insert a few valued into the primary @@ -474,11 +466,7 @@ fn primary_regenerate_log_with_replica_restart() { sim.client("client", async move { let http = Client::new(); - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; // insert a few valued into the primary diff --git a/libsql-server/tests/cluster/replication.rs b/libsql-server/tests/cluster/replication.rs index 4444d1bde6..27d053bc4b 100644 --- a/libsql-server/tests/cluster/replication.rs +++ b/libsql-server/tests/cluster/replication.rs @@ -215,12 +215,9 @@ fn replica_lazy_creation() { }); sim.client("client", async move { - let db = Database::open_remote_with_connector( - "http://test.replica:8080", - "", - TurmoilConnector, - ) - .unwrap(); + let db = + Database::open_remote_with_connector("http://test.replica:8080", "", TurmoilConnector) + .unwrap(); let conn = db.connect().unwrap(); assert_debug_snapshot!(conn.execute("create table test (x)", ()).await.unwrap_err()); let primary_http = Client::new(); diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index 1d27f935bd..868bf3d1ea 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -368,11 +368,8 @@ fn replica_primary_reset() { }); sim.client("client", async move { - let primary = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let primary = + Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = primary.connect()?; // insert a few valued into the primary @@ -532,12 +529,9 @@ fn replica_no_resync_on_restart() { sim.client("client", async { // seed database { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - ) - .unwrap(); + let db = + Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector) + .unwrap(); let conn = db.connect().unwrap(); conn.execute("create table test (x)", ()).await.unwrap(); for _ in 0..500 { @@ -638,12 +632,8 @@ fn replicate_with_snapshots() { .post("http://primary:9090/v1/namespaces/foo/create", json!({})) .await?; - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - ) - .unwrap(); + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector) + .unwrap(); let conn = db.connect().unwrap(); conn.execute("create table test (x)", ()).await.unwrap(); // insert enough to trigger snapshot creation. diff --git a/libsql-server/tests/hrana/batch.rs b/libsql-server/tests/hrana/batch.rs index 3dcab9db8c..aa03b1f8a6 100644 --- a/libsql-server/tests/hrana/batch.rs +++ b/libsql-server/tests/hrana/batch.rs @@ -48,11 +48,7 @@ fn execute_individual_statements() { .build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table t(x text)", ()).await?; @@ -80,11 +76,7 @@ fn execute_batch() { .build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute_batch( @@ -120,11 +112,7 @@ fn multistatement_query() { .build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; let mut rows = conn .query("select 1 + ?; select 'abc';", params![1]) @@ -147,11 +135,7 @@ fn affected_rows_and_last_rowid() { .build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute( diff --git a/libsql-server/tests/hrana/transaction.rs b/libsql-server/tests/hrana/transaction.rs index 00dc66feb7..2ebb02fef6 100644 --- a/libsql-server/tests/hrana/transaction.rs +++ b/libsql-server/tests/hrana/transaction.rs @@ -10,11 +10,7 @@ fn transaction_commit_and_rollback() { .build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; // initialize tables @@ -58,11 +54,7 @@ fn multiple_concurrent_transactions() { .build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute_batch(r#"create table t(x text);"#).await?; @@ -120,11 +112,7 @@ fn transaction_timeout() { .build(); sim.host("primary", super::make_standalone_server); sim.client("client", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; // initialize tables diff --git a/libsql-server/tests/namespaces/dumps.rs b/libsql-server/tests/namespaces/dumps.rs index 9f9bfded0d..93b7e1676c 100644 --- a/libsql-server/tests/namespaces/dumps.rs +++ b/libsql-server/tests/namespaces/dumps.rs @@ -54,11 +54,8 @@ fn load_namespace_from_dump_from_url() { assert_eq!(resp.status(), 200); assert_snapshot!(resp.body_string().await.unwrap()); - let foo = Database::open_remote_with_connector( - "http://foo.primary:8080", - "", - TurmoilConnector, - )?; + let foo = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let foo_conn = foo.connect()?; let mut rows = foo_conn.query("select count(*) from test", ()).await?; assert!(matches!( @@ -128,11 +125,8 @@ fn load_namespace_from_dump_from_file() { resp.json::().await.unwrap_or_default() ); - let foo = Database::open_remote_with_connector( - "http://foo.primary:8080", - "", - TurmoilConnector, - )?; + let foo = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let foo_conn = foo.connect()?; let mut rows = foo_conn.query("select count(*) from test", ()).await?; assert!(matches!( @@ -183,11 +177,8 @@ fn load_namespace_from_no_commit() { ); // namespace doesn't exist - let foo = Database::open_remote_with_connector( - "http://foo.primary:8080", - "", - TurmoilConnector, - )?; + let foo = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let foo_conn = foo.connect()?; assert!(foo_conn .query("select count(*) from test", ()) @@ -237,11 +228,8 @@ fn load_namespace_from_no_txn() { assert_json_snapshot!(resp.json_value().await.unwrap()); // namespace doesn't exist - let foo = Database::open_remote_with_connector( - "http://foo.primary:8080", - "", - TurmoilConnector, - )?; + let foo = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let foo_conn = foo.connect()?; assert!(foo_conn .query("select count(*) from test", ()) @@ -270,11 +258,8 @@ fn export_dump() { .await?; assert_eq!(resp.status(), StatusCode::OK); - let foo = Database::open_remote_with_connector( - "http://foo.primary:8080", - "", - TurmoilConnector, - )?; + let foo = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let foo_conn = foo.connect()?; foo_conn.execute("create table test (x)", ()).await?; foo_conn.execute("insert into test values (42)", ()).await?; diff --git a/libsql-server/tests/namespaces/mod.rs b/libsql-server/tests/namespaces/mod.rs index 4f80857f18..2979991b1a 100644 --- a/libsql-server/tests/namespaces/mod.rs +++ b/libsql-server/tests/namespaces/mod.rs @@ -60,11 +60,8 @@ fn fork_namespace() { .post("http://primary:9090/v1/namespaces/foo/create", json!({})) .await?; - let foo = Database::open_remote_with_connector( - "http://foo.primary:8080", - "", - TurmoilConnector, - )?; + let foo = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let foo_conn = foo.connect()?; foo_conn.execute("create table test (c)", ()).await?; @@ -74,11 +71,8 @@ fn fork_namespace() { .post("http://primary:9090/v1/namespaces/foo/fork/bar", ()) .await?; - let bar = Database::open_remote_with_connector( - "http://bar.primary:8080", - "", - TurmoilConnector, - )?; + let bar = + Database::open_remote_with_connector("http://bar.primary:8080", "", TurmoilConnector)?; let bar_conn = bar.connect()?; // what's in foo is in bar as well @@ -124,11 +118,8 @@ fn delete_namespace() { .post("http://primary:9090/v1/namespaces/foo/create", json!({})) .await?; - let foo = Database::open_remote_with_connector( - "http://foo.primary:8080", - "", - TurmoilConnector, - )?; + let foo = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let foo_conn = foo.connect()?; foo_conn.execute("create table test (c)", ()).await?; diff --git a/libsql-server/tests/standalone/attach.rs b/libsql-server/tests/standalone/attach.rs index 517905e19c..36e1b16fce 100644 --- a/libsql-server/tests/standalone/attach.rs +++ b/libsql-server/tests/standalone/attach.rs @@ -36,11 +36,8 @@ fn attach_no_auth() { .await .unwrap(); - let foo_db = Database::open_remote_with_connector( - "http://foo.primary:8080", - "", - TurmoilConnector, - )?; + let foo_db = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let foo_conn = foo_db.connect().unwrap(); foo_conn .execute("CREATE TABLE foo_table (x)", ()) @@ -51,11 +48,8 @@ fn attach_no_auth() { .await .unwrap(); - let bar_db = Database::open_remote_with_connector( - "http://bar.primary:8080", - "", - TurmoilConnector, - )?; + let bar_db = + Database::open_remote_with_connector("http://bar.primary:8080", "", TurmoilConnector)?; let bar_conn = bar_db.connect().unwrap(); bar_conn .execute("CREATE TABLE bar_table (x)", ()) diff --git a/libsql-server/tests/standalone/mod.rs b/libsql-server/tests/standalone/mod.rs index a64ff60e15..fc81fdbcc8 100644 --- a/libsql-server/tests/standalone/mod.rs +++ b/libsql-server/tests/standalone/mod.rs @@ -52,11 +52,7 @@ fn basic_query() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-token", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -84,11 +80,7 @@ fn basic_metrics() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-token", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -135,11 +127,8 @@ fn primary_serializability() { sim.client("writer", { let notify = notify.clone(); async move { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-token", - TurmoilConnector, - )?; + let db = + Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; conn.execute("insert into test values (12)", ()).await?; @@ -152,11 +141,8 @@ fn primary_serializability() { sim.client("reader", { async move { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-token", - TurmoilConnector, - )?; + let db = + Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; notify.notified().await; @@ -184,11 +170,7 @@ fn basic_query_fail() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-token", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -216,11 +198,7 @@ fn begin_commit() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-token", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -258,11 +236,7 @@ fn begin_rollback() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-token", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute("create table test (x)", ()).await?; @@ -301,11 +275,7 @@ fn is_autocommit() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-token", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; assert!(conn.is_autocommit()); @@ -350,11 +320,7 @@ fn random_rowid() { sim.host("primary", make_standalone_server); sim.client("test", async { - let db = Database::open_remote_with_connector( - "http://primary:8080", - "dummy-token", - TurmoilConnector, - )?; + let db = Database::open_remote_with_connector("http://primary:8080", "", TurmoilConnector)?; let conn = db.connect()?; conn.execute( From 57861f6203a5880e24b758444ac3965b7e0cfd7c Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Thu, 14 Mar 2024 14:18:07 -0700 Subject: [PATCH 31/63] cargo fmt + cleanup --- libsql-server/src/http/user/extract.rs | 2 +- libsql-server/tests/embedded_replica/local.rs | 7 ++----- libsql/examples/replica.rs | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 522c1bb343..b850e76ff3 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -16,7 +16,7 @@ impl FromRequestParts for RequestContext { state: &AppState, ) -> std::result::Result { let namespace = db_factory::namespace_from_headers( - &parts.headers, + &parts.headers, state.disable_default_namespace, state.disable_namespaces, )?; diff --git a/libsql-server/tests/embedded_replica/local.rs b/libsql-server/tests/embedded_replica/local.rs index fbba29c7dc..911f679c4d 100644 --- a/libsql-server/tests/embedded_replica/local.rs +++ b/libsql-server/tests/embedded_replica/local.rs @@ -33,11 +33,8 @@ fn local_sync_with_writes() { let _path = tmp_embedded_path.join("embedded"); - let primary = Database::open_remote_with_connector( - "http://foo.primary:8080", - "", - TurmoilConnector, - )?; + let primary = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; let conn = primary.connect()?; // Do enough writes to ensure that we can force the server to write some snapshots diff --git a/libsql/examples/replica.rs b/libsql/examples/replica.rs index b61e886407..82fca89702 100644 --- a/libsql/examples/replica.rs +++ b/libsql/examples/replica.rs @@ -10,7 +10,7 @@ async fn main() { let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or_else(|_| { println!("Using empty token since LIBSQL_TOKEN was not set"); - "x".to_string() + "".to_string() }); let url = std::env::var("LIBSQL_URL") From 9df60da58b9db4aa97c83154176cb28b029b8eda Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Mon, 18 Mar 2024 15:01:39 -0700 Subject: [PATCH 32/63] namespace passing exammple --- libsql/examples/remote_sync.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/libsql/examples/remote_sync.rs b/libsql/examples/remote_sync.rs index ccc31eb033..c742ac3e57 100644 --- a/libsql/examples/remote_sync.rs +++ b/libsql/examples/remote_sync.rs @@ -32,6 +32,7 @@ async fn main() { let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or("".to_string()); let db = match Builder::new_remote_replica(db_path, sync_url, auth_token) + .namespace("alphabravo") .build() .await { From 9f97d3ecb1fc15857bd6874a445a8f86c399c9a2 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 22 Mar 2024 08:27:57 -0700 Subject: [PATCH 33/63] added namespace config for the example --- libsql/examples/remote_sync.rs | 39 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/libsql/examples/remote_sync.rs b/libsql/examples/remote_sync.rs index c742ac3e57..d42ccd1f75 100644 --- a/libsql/examples/remote_sync.rs +++ b/libsql/examples/remote_sync.rs @@ -7,41 +7,42 @@ async fn main() { tracing_subscriber::fmt::init(); // The local database path where the data will be stored. - let db_path = match std::env::var("LIBSQL_DB_PATH") { - Ok(path) => path, - Err(_) => { + let db_path = std::env::var("LIBSQL_DB_PATH") + .map_err(|_| { eprintln!( "Please set the LIBSQL_DB_PATH environment variable to set to local database path." - ); - return; - } - }; + ) + }) + .unwrap(); // The remote sync URL to use. - let sync_url = match std::env::var("LIBSQL_SYNC_URL") { - Ok(url) => url, - Err(_) => { + let sync_url = std::env::var("LIBSQL_SYNC_URL") + .map_err(|_| { eprintln!( "Please set the LIBSQL_SYNC_URL environment variable to set to remote sync URL." - ); - return; - } - }; + ) + }) + .unwrap(); + + let namespace = std::env::var("LIBSQL_NAMESPACE").ok(); // The authentication token to use. let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or("".to_string()); - let db = match Builder::new_remote_replica(db_path, sync_url, auth_token) - .namespace("alphabravo") - .build() - .await - { + let db_builder = if let Some(ns) = namespace { + Builder::new_remote_replica(db_path, sync_url, auth_token).namespace(&ns) + } else { + Builder::new_remote_replica(db_path, sync_url, auth_token) + }; + + let db = match db_builder.build().await { Ok(db) => db, Err(error) => { eprintln!("Error connecting to remote sync server: {}", error); return; } }; + let conn = db.connect().unwrap(); print!("Syncing with remote database..."); From d0a5646fd6137b912f7a42b303cbf5032b5ebc02 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 22 Mar 2024 08:55:56 -0700 Subject: [PATCH 34/63] remove unnecessary dummy token --- libsql-server/tests/embedded_replica/mod.rs | 2 +- testdb | Bin 0 -> 8192 bytes testdb-client_wal_index | Bin 0 -> 32 bytes 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 testdb create mode 100644 testdb-client_wal_index diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index 0316a3cb87..868bf3d1ea 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -933,7 +933,7 @@ fn errors_on_bad_replica() { let db = libsql::Builder::new_remote_replica( path.to_str().unwrap(), "http://foo.primary:8080".to_string(), - "dummy_token".to_string(), + "".to_string(), ) .connector(TurmoilConnector) .build() diff --git a/testdb b/testdb new file mode 100644 index 0000000000000000000000000000000000000000..8c3e17767b1607687bdf51bed3aa2fb40a363d60 GIT binary patch literal 8192 zcmeI#y$ZrG6b0a$AXo}gH*ve8E<*JMtkSKE+QCg)h*GpIv=Ln0eFz`W+1XU82pyco zb4YS;ex`5BY7rz_@is5gNb`icBqdFmiAdJdlxKulTXkQRYX2>gM#Z&bUJ2(yW*`uN z00bZa0SG_<0uX=z1Rwx`zZWP(yXU&%?C2;ysNGSOdK2gQexov7B&uwl$obOuLZ3r# z7Wmw}=Yh>1(dwi*^w;70bXfh(rE??aZWaUr5P$##AOHafKmY;|fB*y_0D+$uu%zr~ LQL4taC0^YFokcSm literal 0 HcmV?d00001 diff --git a/testdb-client_wal_index b/testdb-client_wal_index new file mode 100644 index 0000000000000000000000000000000000000000..38938bb33134e54dac59dc8b95b3b3757cc5d92f GIT binary patch literal 32 bcmZqdT3Ne3t#EdS Date: Fri, 22 Mar 2024 08:57:15 -0700 Subject: [PATCH 35/63] reverting accidental commit --- testdb | Bin 8192 -> 0 bytes testdb-client_wal_index | Bin 32 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 testdb delete mode 100644 testdb-client_wal_index diff --git a/testdb b/testdb deleted file mode 100644 index 8c3e17767b1607687bdf51bed3aa2fb40a363d60..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8192 zcmeI#y$ZrG6b0a$AXo}gH*ve8E<*JMtkSKE+QCg)h*GpIv=Ln0eFz`W+1XU82pyco zb4YS;ex`5BY7rz_@is5gNb`icBqdFmiAdJdlxKulTXkQRYX2>gM#Z&bUJ2(yW*`uN z00bZa0SG_<0uX=z1Rwx`zZWP(yXU&%?C2;ysNGSOdK2gQexov7B&uwl$obOuLZ3r# z7Wmw}=Yh>1(dwi*^w;70bXfh(rE??aZWaUr5P$##AOHafKmY;|fB*y_0D+$uu%zr~ LQL4taC0^YFokcSm diff --git a/testdb-client_wal_index b/testdb-client_wal_index deleted file mode 100644 index 38938bb33134e54dac59dc8b95b3b3757cc5d92f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 32 bcmZqdT3Ne3t#EdS Date: Tue, 26 Mar 2024 17:17:55 -0700 Subject: [PATCH 36/63] wip --- .../src/auth/user_auth_strategies/http_basic.rs | 4 ++++ libsql-server/src/auth/user_auth_strategies/jwt.rs | 5 +++++ libsql-server/src/auth/user_auth_strategies/mod.rs | 4 ++++ libsql-server/src/hrana/ws/session.rs | 13 ++++++++++--- 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index fbb45d0912..3b0d5fe89f 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -28,6 +28,10 @@ impl UserAuthStrategy for HttpBasic { Err(AuthError::BasicRejected) } + + fn required_fields(&self) -> Vec {vec!["authentication".to_string()]} + + } impl HttpBasic { diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index da68e91df0..5c0d7bcbd0 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -34,6 +34,11 @@ impl UserAuthStrategy for Jwt { return validate_jwt(&self.key, &token); } + + fn required_fields(&self) -> Vec {vec!["authentication".to_string()]} + + + } impl Jwt { diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 4f0f2ef786..c28e028f9b 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -67,8 +67,12 @@ impl UserAuthContext { } pub trait UserAuthStrategy: Sync + Send { + + fn required_fields(&self) -> Vec {vec![]} + fn authenticate( &self, context: Result, ) -> Result; + } diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index aef1f63574..8dae403e07 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use anyhow::{anyhow, bail, Result}; use futures::future::BoxFuture; +use s3s::auth; use tokio::sync::{mpsc, oneshot}; use super::super::{batch, cursor, stmt, ProtocolError, Version}; @@ -77,11 +78,17 @@ pub(super) async fn handle_initial_hello( .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - let auth = namespace_jwt_key + let auth_strategy = namespace_jwt_key .map(Jwt::new) .map(Auth::new) - .unwrap_or(server.user_auth_strategy.clone()) - .authenticate(Ok(UserAuthContext::bearer_opt(jwt))) + .unwrap_or(server.user_auth_strategy.clone()); + + let context:UserAuthContext = build_context(jwt, auth_strategy.user_strategy.required_fields()); + + // Ok(UserAuthContext::bearer_opt(jwt)) + + let auth = auth_strategy + .authenticate(Ok(context)) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(Session { From 3e7999a9770c9469f0111cd47e3be43068ce57d0 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 27 Mar 2024 11:04:56 -0700 Subject: [PATCH 37/63] wip --- libsql-server/src/auth/user_auth_strategies/mod.rs | 2 ++ libsql-server/src/hrana/ws/session.rs | 10 ++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index c28e028f9b..82a49869e4 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -3,6 +3,7 @@ pub mod http_basic; pub mod jwt; pub use disabled::Disabled; +use hashbrown::HashMap; pub use http_basic::HttpBasic; pub use jwt::Jwt; @@ -12,6 +13,7 @@ use super::{AuthError, Authenticated}; pub struct UserAuthContext { scheme: Option, token: Option, + custom_fields: HashMap, String> } impl UserAuthContext { diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index 8dae403e07..697ed0cef8 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use anyhow::{anyhow, bail, Result}; use futures::future::BoxFuture; -use s3s::auth; use tokio::sync::{mpsc, oneshot}; use super::super::{batch, cursor, stmt, ProtocolError, Version}; @@ -72,6 +71,7 @@ pub(super) async fn handle_initial_hello( jwt: Option, namespace: NamespaceName, ) -> Result { + // todo dupe #auth let namespace_jwt_key = server .namespaces @@ -82,11 +82,9 @@ pub(super) async fn handle_initial_hello( .map(Jwt::new) .map(Auth::new) .unwrap_or(server.user_auth_strategy.clone()); - + let context:UserAuthContext = build_context(jwt, auth_strategy.user_strategy.required_fields()); - // Ok(UserAuthContext::bearer_opt(jwt)) - let auth = auth_strategy .authenticate(Ok(context)) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; @@ -100,6 +98,10 @@ pub(super) async fn handle_initial_hello( }) } +fn build_context(jwt: Option, required_fields: Vec) -> UserAuthContext { + UserAuthContext::bearer_opt(jwt) +} + pub(super) async fn handle_repeated_hello( server: &Server, session: &mut Session, From 844dd540070084082c13a0e44113527befb67bfe Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 27 Mar 2024 21:27:11 -0700 Subject: [PATCH 38/63] custom fields in auth context --- libsql-server/src/auth/parsers.rs | 2 +- .../src/auth/user_auth_strategies/jwt.rs | 1 + .../src/auth/user_auth_strategies/mod.rs | 6 ++++ libsql-server/src/hrana/ws/session.rs | 10 +++++-- libsql-server/src/http/user/extract.rs | 16 ++++------ libsql-server/src/http/user/mod.rs | 30 +++++++++++-------- libsql-server/src/rpc/proxy.rs | 2 +- libsql-server/src/rpc/replica_proxy.rs | 6 ++-- libsql-server/src/rpc/replication_log.rs | 4 +-- 9 files changed, 43 insertions(+), 34 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 643a3fee1b..19a7c9c12b 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -36,7 +36,7 @@ pub fn parse_jwt_key(data: &str) -> Result { } } -pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result { +pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap, required_fields: Vec) -> Result { metadata .get(GRPC_AUTH_HEADER) .ok_or(AuthError::AuthHeaderNotFound) diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index 5c0d7bcbd0..6f1a12b61c 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -23,6 +23,7 @@ impl UserAuthStrategy for Jwt { let UserAuthContext { scheme: Some(scheme), token: Some(token), + custom_fields: _ } = ctx else { return Err(AuthError::HttpAuthHeaderInvalid); diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 82a49869e4..ddb2d7e84a 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -29,6 +29,8 @@ impl UserAuthContext { UserAuthContext { scheme: None, token: None, + custom_fields: HashMap::new(), + } } @@ -36,6 +38,7 @@ impl UserAuthContext { UserAuthContext { scheme: Some("Basic".into()), token: Some(creds.into()), + custom_fields: HashMap::new(), } } @@ -43,6 +46,7 @@ impl UserAuthContext { UserAuthContext { scheme: Some("Bearer".into()), token: Some(token.into()), + custom_fields: HashMap::new(), } } @@ -50,6 +54,7 @@ impl UserAuthContext { UserAuthContext { scheme: Some("Bearer".into()), token: token, + custom_fields: HashMap::new(), } } @@ -57,6 +62,7 @@ impl UserAuthContext { UserAuthContext { scheme: Some(scheme.into()), token: Some(token.into()), + custom_fields: HashMap::new(), } } diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index 697ed0cef8..76026ca148 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -120,11 +120,15 @@ pub(super) async fn handle_repeated_hello( .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - session.auth = namespace_jwt_key + let auth_strategy = namespace_jwt_key .map(Jwt::new) .map(Auth::new) - .unwrap_or_else(|| server.user_auth_strategy.clone()) - .authenticate(Ok(UserAuthContext::bearer_opt(jwt))) + .unwrap_or(server.user_auth_strategy.clone()); + + let context:UserAuthContext = build_context(jwt, auth_strategy.user_strategy.required_fields()); + + session.auth = auth_strategy + .authenticate(Ok(context)) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(()) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index b850e76ff3..60f5be0d57 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -26,21 +26,15 @@ impl FromRequestParts for RequestContext { .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - let context = parts - .headers - .get(hyper::header::AUTHORIZATION) - .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)); - - let authenticated = namespace_jwt_key + let auth = namespace_jwt_key .map(Jwt::new) .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; + .unwrap_or_else(|| state.user_auth_strategy.clone()); + + let context = super::build_context(&parts.headers, auth.user_strategy.required_fields()); Ok(Self::new( - authenticated, + auth.authenticate(Ok(context))?, namespace, state.namespaces.meta_store().clone(), )) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 3e394fc579..da69bd8099 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -468,22 +468,26 @@ impl FromRequestParts for Authenticated { .with(ns.clone(), |ns| ns.jwt_key()) .await??; - let context = parts - .headers - .get(hyper::header::AUTHORIZATION) - .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)); - - let authenticated = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; - Ok(authenticated) + let auth = namespace_jwt_key + .map(Jwt::new) + .map(Auth::new) + .unwrap_or_else(|| state.user_auth_strategy.clone()); + + let context = build_context(&parts.headers, auth.user_strategy.required_fields()); + + Ok(auth.authenticate(Ok(context))?) } } +fn build_context(headers: &hyper::HeaderMap, required_fields: Vec) -> UserAuthContext { + headers + .get(hyper::header::AUTHORIZATION) + .ok_or(AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) + .and_then(|t| UserAuthContext::from_auth_str(t)) + .unwrap_or(UserAuthContext::empty()) +} + impl FromRef for Auth { fn from_ref(input: &AppState) -> Self { input.user_auth_strategy.clone() diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index d7f96d68a7..92e9c10fca 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -335,7 +335,7 @@ impl ProxyService { }; let auth = if let Some(auth) = auth { - let context = parse_grpc_auth_header(req.metadata()); + let context = parse_grpc_auth_header(req.metadata(), auth.user_strategy.required_fields()); auth.authenticate(context)? } else { Authenticated::from_proxy_grpc_request(req)? diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 6945cfffed..1bff196185 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -45,7 +45,7 @@ impl ReplicaProxyService { let namespace_jwt_key = jwt_result.and_then(|s| s); - let auth_strategy = match namespace_jwt_key { + let auth = match namespace_jwt_key { Ok(Some(key)) => Ok(Auth::new(Jwt::new(key))), Ok(None) | Err(crate::error::Error::NamespaceDoesntExist(_)) => { Ok(self.user_auth_strategy.clone()) @@ -56,8 +56,8 @@ impl ReplicaProxyService { ))), }?; - let auth_context = parse_grpc_auth_header(req.metadata()); - auth_strategy + let auth_context = parse_grpc_auth_header(req.metadata(), auth.user_strategy.required_fields()); + auth .authenticate(auth_context)? .upgrade_grpc_request(req); return Ok(()); diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 5a93b3331e..2d969e68ef 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -94,8 +94,8 @@ impl ReplicationLogService { }; if let Some(auth) = auth { - let user_credential = parse_grpc_auth_header(req.metadata()); - auth.authenticate(user_credential)?; + let context = parse_grpc_auth_header(req.metadata()); + auth.authenticate(context)?; } Ok(()) From 6afb11223b54aee172dbe568432499c9c143d8c1 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Thu, 28 Mar 2024 10:03:08 -0700 Subject: [PATCH 39/63] fixed compilation errors, wip --- libsql-server/src/auth/parsers.rs | 18 +++++++++++++----- .../auth/user_auth_strategies/http_basic.rs | 8 ++++---- .../src/auth/user_auth_strategies/jwt.rs | 9 ++++----- .../src/auth/user_auth_strategies/mod.rs | 9 ++++----- libsql-server/src/hrana/ws/session.rs | 15 ++++++++------- libsql-server/src/http/user/extract.rs | 2 +- libsql-server/src/http/user/mod.rs | 15 +++++++++------ libsql-server/src/rpc/proxy.rs | 3 ++- libsql-server/src/rpc/replica_proxy.rs | 7 +++---- libsql-server/src/rpc/replication_log.rs | 3 ++- 10 files changed, 50 insertions(+), 39 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 19a7c9c12b..0ebb43d386 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -36,7 +36,10 @@ pub fn parse_jwt_key(data: &str) -> Result { } } -pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap, required_fields: Vec) -> Result { +pub(crate) fn parse_grpc_auth_header( + metadata: &MetadataMap, + required_fields: &Vec, +) -> Result { metadata .get(GRPC_AUTH_HEADER) .ok_or(AuthError::AuthHeaderNotFound) @@ -70,6 +73,7 @@ pub fn parse_http_auth_header<'a>( #[cfg(test)] mod tests { use axum::http::HeaderValue; + use hashbrown::HashMap; use hyper::header::AUTHORIZATION; use crate::auth::{parse_http_auth_header, AuthError}; @@ -80,7 +84,8 @@ mod tests { fn parse_grpc_auth_header_returns_valid_context() { let mut map = tonic::metadata::MetadataMap::new(); map.insert("x-authorization", "bearer 123".parse().unwrap()); - let context = parse_grpc_auth_header(&map).unwrap(); + let required_fields = Vec::new(); + let context = parse_grpc_auth_header(&map, &required_fields).unwrap(); assert_eq!(context.scheme().as_ref().unwrap(), "bearer"); assert_eq!(context.token().as_ref().unwrap(), "123"); } @@ -88,7 +93,8 @@ mod tests { #[test] fn parse_grpc_auth_header_error_no_header() { let map = tonic::metadata::MetadataMap::new(); - let result = parse_grpc_auth_header(&map); + let required_fields = Vec::new(); + let result = parse_grpc_auth_header(&map, &required_fields); assert_eq!( result.unwrap_err().to_string(), "Expected authorization header but none given" @@ -99,7 +105,8 @@ mod tests { fn parse_grpc_auth_header_error_non_ascii() { let mut map = tonic::metadata::MetadataMap::new(); map.insert("x-authorization", "bearer I❤NY".parse().unwrap()); - let result = parse_grpc_auth_header(&map); + let required_fields = Vec::new(); + let result = parse_grpc_auth_header(&map, &required_fields); assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header") } @@ -107,7 +114,8 @@ mod tests { fn parse_grpc_auth_header_error_malformed_auth_str() { let mut map = tonic::metadata::MetadataMap::new(); map.insert("x-authorization", "bearer123".parse().unwrap()); - let result = parse_grpc_auth_header(&map); + let required_fields = Vec::new(); + let result = parse_grpc_auth_header(&map, &required_fields); assert_eq!( result.unwrap_err().to_string(), "Auth string does not conform to ' ' form" diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index 3b0d5fe89f..d423eb678b 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -28,10 +28,10 @@ impl UserAuthStrategy for HttpBasic { Err(AuthError::BasicRejected) } - - fn required_fields(&self) -> Vec {vec!["authentication".to_string()]} - - + + fn required_fields(&self) -> Vec { + vec!["authentication".to_string()] + } } impl HttpBasic { diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index 6f1a12b61c..e5633a05eb 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -23,7 +23,7 @@ impl UserAuthStrategy for Jwt { let UserAuthContext { scheme: Some(scheme), token: Some(token), - custom_fields: _ + custom_fields: _, } = ctx else { return Err(AuthError::HttpAuthHeaderInvalid); @@ -36,10 +36,9 @@ impl UserAuthStrategy for Jwt { return validate_jwt(&self.key, &token); } - fn required_fields(&self) -> Vec {vec!["authentication".to_string()]} - - - + fn required_fields(&self) -> Vec { + vec!["authentication".to_string()] + } } impl Jwt { diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index ddb2d7e84a..20254bf5a8 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -13,7 +13,7 @@ use super::{AuthError, Authenticated}; pub struct UserAuthContext { scheme: Option, token: Option, - custom_fields: HashMap, String> + custom_fields: HashMap, String>, } impl UserAuthContext { @@ -30,7 +30,6 @@ impl UserAuthContext { scheme: None, token: None, custom_fields: HashMap::new(), - } } @@ -75,12 +74,12 @@ impl UserAuthContext { } pub trait UserAuthStrategy: Sync + Send { - - fn required_fields(&self) -> Vec {vec![]} + fn required_fields(&self) -> Vec { + vec![] + } fn authenticate( &self, context: Result, ) -> Result; - } diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index 76026ca148..f17cda3f3a 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -71,7 +71,6 @@ pub(super) async fn handle_initial_hello( jwt: Option, namespace: NamespaceName, ) -> Result { - // todo dupe #auth let namespace_jwt_key = server .namespaces @@ -82,8 +81,9 @@ pub(super) async fn handle_initial_hello( .map(Jwt::new) .map(Auth::new) .unwrap_or(server.user_auth_strategy.clone()); - - let context:UserAuthContext = build_context(jwt, auth_strategy.user_strategy.required_fields()); + + let context: UserAuthContext = + build_context(jwt, &auth_strategy.user_strategy.required_fields()); let auth = auth_strategy .authenticate(Ok(context)) @@ -98,7 +98,7 @@ pub(super) async fn handle_initial_hello( }) } -fn build_context(jwt: Option, required_fields: Vec) -> UserAuthContext { +fn build_context(jwt: Option, required_fields: &Vec) -> UserAuthContext { UserAuthContext::bearer_opt(jwt) } @@ -120,12 +120,13 @@ pub(super) async fn handle_repeated_hello( .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - let auth_strategy = namespace_jwt_key + let auth_strategy = namespace_jwt_key .map(Jwt::new) .map(Auth::new) .unwrap_or(server.user_auth_strategy.clone()); - - let context:UserAuthContext = build_context(jwt, auth_strategy.user_strategy.required_fields()); + + let context: UserAuthContext = + build_context(jwt, &auth_strategy.user_strategy.required_fields()); session.auth = auth_strategy .authenticate(Ok(context)) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 60f5be0d57..821a77005b 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -31,7 +31,7 @@ impl FromRequestParts for RequestContext { .map(Auth::new) .unwrap_or_else(|| state.user_auth_strategy.clone()); - let context = super::build_context(&parts.headers, auth.user_strategy.required_fields()); + let context = super::build_context(&parts.headers, &auth.user_strategy.required_fields()); Ok(Self::new( auth.authenticate(Ok(context))?, diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index da69bd8099..2b0a755a7b 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -469,17 +469,20 @@ impl FromRequestParts for Authenticated { .await??; let auth = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()); - - let context = build_context(&parts.headers, auth.user_strategy.required_fields()); + .map(Jwt::new) + .map(Auth::new) + .unwrap_or_else(|| state.user_auth_strategy.clone()); + + let context = build_context(&parts.headers, &auth.user_strategy.required_fields()); Ok(auth.authenticate(Ok(context))?) } } -fn build_context(headers: &hyper::HeaderMap, required_fields: Vec) -> UserAuthContext { +fn build_context( + headers: &hyper::HeaderMap, + required_fields: &Vec, +) -> UserAuthContext { headers .get(hyper::header::AUTHORIZATION) .ok_or(AuthError::AuthHeaderNotFound) diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index 92e9c10fca..aba85c64eb 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -335,7 +335,8 @@ impl ProxyService { }; let auth = if let Some(auth) = auth { - let context = parse_grpc_auth_header(req.metadata(), auth.user_strategy.required_fields()); + let context = + parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()); auth.authenticate(context)? } else { Authenticated::from_proxy_grpc_request(req)? diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 1bff196185..47732daba7 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -56,10 +56,9 @@ impl ReplicaProxyService { ))), }?; - let auth_context = parse_grpc_auth_header(req.metadata(), auth.user_strategy.required_fields()); - auth - .authenticate(auth_context)? - .upgrade_grpc_request(req); + let auth_context = + parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()); + auth.authenticate(auth_context)?.upgrade_grpc_request(req); return Ok(()); } } diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 2d969e68ef..4dc72190bd 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -94,7 +94,8 @@ impl ReplicationLogService { }; if let Some(auth) = auth { - let context = parse_grpc_auth_header(req.metadata()); + let context = + parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()); auth.authenticate(context)?; } From 7f8b0b42eb3460da1584519a3e70fe3a965ea21c Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 29 Mar 2024 10:07:09 -0700 Subject: [PATCH 40/63] refactored auth api with on demand headers --- libsql-server/src/auth/parsers.rs | 25 +++++++++--- .../auth/user_auth_strategies/http_basic.rs | 19 ++++++---- .../src/auth/user_auth_strategies/jwt.rs | 17 ++++----- .../src/auth/user_auth_strategies/mod.rs | 38 ++++--------------- libsql-server/src/hrana/ws/session.rs | 6 ++- libsql-server/src/http/user/extract.rs | 2 +- libsql-server/src/http/user/mod.rs | 13 ++++++- 7 files changed, 64 insertions(+), 56 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 0ebb43d386..d8fa335b07 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -40,11 +40,23 @@ pub(crate) fn parse_grpc_auth_header( metadata: &MetadataMap, required_fields: &Vec, ) -> Result { - metadata + let mut context = metadata .get(GRPC_AUTH_HEADER) .ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)) + .and_then(|t| UserAuthContext::from_auth_str(t)); + + if let Ok(ref mut ctx) = context { + for field in required_fields.iter() { + metadata + .get(field) + .map(|header| header.to_str().ok()) + .and_then(|r| r) + .map(|v| ctx.add_field(field.into(), v.into())); + } + } + + context } pub fn parse_http_auth_header<'a>( @@ -73,7 +85,6 @@ pub fn parse_http_auth_header<'a>( #[cfg(test)] mod tests { use axum::http::HeaderValue; - use hashbrown::HashMap; use hyper::header::AUTHORIZATION; use crate::auth::{parse_http_auth_header, AuthError}; @@ -84,10 +95,12 @@ mod tests { fn parse_grpc_auth_header_returns_valid_context() { let mut map = tonic::metadata::MetadataMap::new(); map.insert("x-authorization", "bearer 123".parse().unwrap()); - let required_fields = Vec::new(); + let required_fields = vec!["x-authorization".into()]; let context = parse_grpc_auth_header(&map, &required_fields).unwrap(); - assert_eq!(context.scheme().as_ref().unwrap(), "bearer"); - assert_eq!(context.token().as_ref().unwrap(), "123"); + assert_eq!( + context.custom_fields.get("x-authorization"), + Some(&"bearer 123".to_string()) + ); } #[test] diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index d423eb678b..1e69283016 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -13,15 +13,20 @@ impl UserAuthStrategy for HttpBasic { ) -> Result { tracing::trace!("executing http basic auth"); + let ctx = context?; + let auth_str = None + .or_else(|| ctx.custom_fields.get("authorization")) + .or_else(|| ctx.custom_fields.get("x-authorization")); + + let (_, token) = auth_str + .ok_or(AuthError::AuthHeaderNotFound) + .map(|s| s.split_once(' ').ok_or(AuthError::AuthStringMalformed)) + .and_then(|o| o)?; + // NOTE: this naive comparison may leak information about the `expected_value` // using a timing attack let expected_value = self.credential.trim_end_matches('='); - - let creds_match = match context?.token { - Some(s) => s.contains(expected_value), - None => expected_value.is_empty(), - }; - + let creds_match = token.contains(expected_value); if creds_match { return Ok(Authenticated::FullAccess); } @@ -30,7 +35,7 @@ impl UserAuthStrategy for HttpBasic { } fn required_fields(&self) -> Vec { - vec!["authentication".to_string()] + vec!["authorization".to_string(), "x-authorization".to_string()] } } diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index e5633a05eb..6b73ea74f6 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -19,15 +19,14 @@ impl UserAuthStrategy for Jwt { tracing::trace!("executing jwt auth"); let ctx = context?; - - let UserAuthContext { - scheme: Some(scheme), - token: Some(token), - custom_fields: _, - } = ctx - else { - return Err(AuthError::HttpAuthHeaderInvalid); - }; + let auth_str = None + .or_else(|| ctx.custom_fields.get("authorization")) + .or_else(|| ctx.custom_fields.get("x-authorization")) + .ok_or_else(|| AuthError::AuthHeaderNotFound)?; + + let (scheme, token) = auth_str + .split_once(' ') + .ok_or(AuthError::AuthStringMalformed)?; if !scheme.eq_ignore_ascii_case("bearer") { return Err(AuthError::HttpAuthHeaderUnsupportedScheme); diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 20254bf5a8..e56b6e1e78 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -11,57 +11,31 @@ use super::{AuthError, Authenticated}; #[derive(Debug)] pub struct UserAuthContext { - scheme: Option, - token: Option, - custom_fields: HashMap, String>, + pub custom_fields: HashMap, String>, // todo, add aliases } impl UserAuthContext { - pub fn scheme(&self) -> &Option { - &self.scheme - } - - pub fn token(&self) -> &Option { - &self.token - } - pub fn empty() -> UserAuthContext { UserAuthContext { - scheme: None, - token: None, custom_fields: HashMap::new(), } } pub fn basic(creds: &str) -> UserAuthContext { UserAuthContext { - scheme: Some("Basic".into()), - token: Some(creds.into()), - custom_fields: HashMap::new(), + custom_fields: HashMap::from([("authorization".into(), format!("Basic {creds}"))]), } } pub fn bearer(token: &str) -> UserAuthContext { UserAuthContext { - scheme: Some("Bearer".into()), - token: Some(token.into()), - custom_fields: HashMap::new(), - } - } - - pub fn bearer_opt(token: Option) -> UserAuthContext { - UserAuthContext { - scheme: Some("Bearer".into()), - token: token, - custom_fields: HashMap::new(), + custom_fields: HashMap::from([("authorization".into(), format!("Bearer {token}"))]), } } pub fn new(scheme: &str, token: &str) -> UserAuthContext { UserAuthContext { - scheme: Some(scheme.into()), - token: Some(token.into()), - custom_fields: HashMap::new(), + custom_fields: HashMap::from([("authorization".into(), format!("{scheme} {token}"))]), } } @@ -71,6 +45,10 @@ impl UserAuthContext { .ok_or(AuthError::AuthStringMalformed)?; Ok(UserAuthContext::new(scheme, token)) } + + pub fn add_field(&mut self, key: String, value: String) { + self.custom_fields.insert(key.into(), value.into()); + } } pub trait UserAuthStrategy: Sync + Send { diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index f17cda3f3a..bb729a6d94 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -99,7 +99,11 @@ pub(super) async fn handle_initial_hello( } fn build_context(jwt: Option, required_fields: &Vec) -> UserAuthContext { - UserAuthContext::bearer_opt(jwt) + let mut ctx = UserAuthContext::empty(); + if required_fields.contains(&"authorization".into()) && jwt.is_some() { + ctx.add_field("authorization".into(), jwt.unwrap()); + } + ctx } pub(super) async fn handle_repeated_hello( diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 821a77005b..f5f81b822d 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -1,7 +1,7 @@ use axum::extract::FromRequestParts; use crate::{ - auth::{Auth, AuthError, Jwt, UserAuthContext}, + auth::{Auth, Jwt}, connection::RequestContext, }; diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 2b0a755a7b..aa21314301 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -483,12 +483,21 @@ fn build_context( headers: &hyper::HeaderMap, required_fields: &Vec, ) -> UserAuthContext { - headers + let mut ctx = headers .get(hyper::header::AUTHORIZATION) .ok_or(AuthError::AuthHeaderNotFound) .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) .and_then(|t| UserAuthContext::from_auth_str(t)) - .unwrap_or(UserAuthContext::empty()) + .unwrap_or(UserAuthContext::empty()); + + for field in required_fields.iter() { + headers + .get(field) + .map(|h| h.to_str().ok()) + .and_then(|t| t.map(|s| ctx.add_field(field.into(), s.into()))); + } + + ctx } impl FromRef for Auth { From 6aba6e5ab135dde5c95526c94630c1f89687955b Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 29 Mar 2024 10:47:14 -0700 Subject: [PATCH 41/63] removing optionality of UserAuthContext --- libsql-server/src/auth/mod.rs | 5 +- libsql-server/src/auth/parsers.rs | 74 +++++++------------ .../src/auth/user_auth_strategies/disabled.rs | 7 +- .../auth/user_auth_strategies/http_basic.rs | 13 +--- .../src/auth/user_auth_strategies/jwt.rs | 18 ++--- .../src/auth/user_auth_strategies/mod.rs | 5 +- libsql-server/src/hrana/ws/session.rs | 4 +- libsql-server/src/http/user/extract.rs | 2 +- libsql-server/src/http/user/mod.rs | 2 +- 9 files changed, 44 insertions(+), 86 deletions(-) diff --git a/libsql-server/src/auth/mod.rs b/libsql-server/src/auth/mod.rs index 09468f4b3a..26c6dfecc1 100644 --- a/libsql-server/src/auth/mod.rs +++ b/libsql-server/src/auth/mod.rs @@ -27,10 +27,7 @@ impl Auth { } } - pub fn authenticate( - &self, - context: Result, - ) -> Result { + pub fn authenticate(&self, context: UserAuthContext) -> Result { self.user_strategy.authenticate(context) } } diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index d8fa335b07..208a5e3864 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -39,21 +39,19 @@ pub fn parse_jwt_key(data: &str) -> Result { pub(crate) fn parse_grpc_auth_header( metadata: &MetadataMap, required_fields: &Vec, -) -> Result { - let mut context = metadata - .get(GRPC_AUTH_HEADER) - .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)); - - if let Ok(ref mut ctx) = context { - for field in required_fields.iter() { - metadata - .get(field) - .map(|header| header.to_str().ok()) - .and_then(|r| r) - .map(|v| ctx.add_field(field.into(), v.into())); - } +) -> UserAuthContext { + // let mut context = metadata + // .get(GRPC_AUTH_HEADER) + // .ok_or(AuthError::AuthHeaderNotFound) + // .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) + // .and_then(|t| UserAuthContext::from_auth_str(t)); + let mut context = UserAuthContext::empty(); + for field in required_fields.iter() { + metadata + .get(field) + .map(|header| header.to_str().ok()) + .and_then(|r| r) + .map(|v| context.add_field(field.into(), v.into())); } context @@ -94,46 +92,26 @@ mod tests { #[test] fn parse_grpc_auth_header_returns_valid_context() { let mut map = tonic::metadata::MetadataMap::new(); - map.insert("x-authorization", "bearer 123".parse().unwrap()); + map.insert( + crate::auth::constants::GRPC_AUTH_HEADER, + "bearer 123".parse().unwrap(), + ); let required_fields = vec!["x-authorization".into()]; - let context = parse_grpc_auth_header(&map, &required_fields).unwrap(); + let context = parse_grpc_auth_header(&map, &required_fields); assert_eq!( context.custom_fields.get("x-authorization"), Some(&"bearer 123".to_string()) ); } - #[test] - fn parse_grpc_auth_header_error_no_header() { - let map = tonic::metadata::MetadataMap::new(); - let required_fields = Vec::new(); - let result = parse_grpc_auth_header(&map, &required_fields); - assert_eq!( - result.unwrap_err().to_string(), - "Expected authorization header but none given" - ); - } - - #[test] - fn parse_grpc_auth_header_error_non_ascii() { - let mut map = tonic::metadata::MetadataMap::new(); - map.insert("x-authorization", "bearer I❤NY".parse().unwrap()); - let required_fields = Vec::new(); - let result = parse_grpc_auth_header(&map, &required_fields); - assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header") - } - - #[test] - fn parse_grpc_auth_header_error_malformed_auth_str() { - let mut map = tonic::metadata::MetadataMap::new(); - map.insert("x-authorization", "bearer123".parse().unwrap()); - let required_fields = Vec::new(); - let result = parse_grpc_auth_header(&map, &required_fields); - assert_eq!( - result.unwrap_err().to_string(), - "Auth string does not conform to ' ' form" - ) - } + // #[test] TODO rewrite + // fn parse_grpc_auth_header_error_non_ascii() { + // let mut map = tonic::metadata::MetadataMap::new(); + // map.insert("x-authorization", "bearer I❤NY".parse().unwrap()); + // let required_fields = Vec::new(); + // let result = parse_grpc_auth_header(&map, &required_fields); + // assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header") + // } #[test] fn parse_http_auth_header_returns_auth_header_param_when_valid() { diff --git a/libsql-server/src/auth/user_auth_strategies/disabled.rs b/libsql-server/src/auth/user_auth_strategies/disabled.rs index b95d52c061..ef9aae9062 100644 --- a/libsql-server/src/auth/user_auth_strategies/disabled.rs +++ b/libsql-server/src/auth/user_auth_strategies/disabled.rs @@ -4,10 +4,7 @@ use crate::auth::{AuthError, Authenticated}; pub struct Disabled {} impl UserAuthStrategy for Disabled { - fn authenticate( - &self, - _context: Result, - ) -> Result { + fn authenticate(&self, _context: UserAuthContext) -> Result { tracing::trace!("executing disabled auth"); Ok(Authenticated::FullAccess) } @@ -26,7 +23,7 @@ mod tests { #[test] fn authenticates() { let strategy = Disabled::new(); - let context = Ok(UserAuthContext::empty()); + let context = UserAuthContext::empty(); assert!(matches!( strategy.authenticate(context).unwrap(), diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index 1e69283016..2310c7821b 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -7,13 +7,8 @@ pub struct HttpBasic { } impl UserAuthStrategy for HttpBasic { - fn authenticate( - &self, - context: Result, - ) -> Result { + fn authenticate(&self, ctx: UserAuthContext) -> Result { tracing::trace!("executing http basic auth"); - - let ctx = context?; let auth_str = None .or_else(|| ctx.custom_fields.get("authorization")) .or_else(|| ctx.custom_fields.get("x-authorization")); @@ -57,7 +52,7 @@ mod tests { #[test] fn authenticates_with_valid_credential() { - let context = Ok(UserAuthContext::basic(CREDENTIAL)); + let context = UserAuthContext::basic(CREDENTIAL); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -68,7 +63,7 @@ mod tests { #[test] fn authenticates_with_valid_trimmed_credential() { let credential = CREDENTIAL.trim_end_matches('='); - let context = Ok(UserAuthContext::basic(credential)); + let context = UserAuthContext::basic(credential); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -78,7 +73,7 @@ mod tests { #[test] fn errors_when_credentials_do_not_match() { - let context = Ok(UserAuthContext::basic("abc")); + let context = UserAuthContext::basic("abc"); assert_eq!( strategy().authenticate(context).unwrap_err(), diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index 6b73ea74f6..6fd504ba88 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -12,13 +12,8 @@ pub struct Jwt { } impl UserAuthStrategy for Jwt { - fn authenticate( - &self, - context: Result, - ) -> Result { + fn authenticate(&self, ctx: UserAuthContext) -> Result { tracing::trace!("executing jwt auth"); - - let ctx = context?; let auth_str = None .or_else(|| ctx.custom_fields.get("authorization")) .or_else(|| ctx.custom_fields.get("x-authorization")) @@ -159,7 +154,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); assert!(matches!( strategy(dec).authenticate(context).unwrap(), @@ -181,8 +176,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); - + let context = UserAuthContext::bearer(token.as_str()); let Authenticated::Legacy(a) = strategy(dec).authenticate(context).unwrap() else { panic!() }; @@ -194,7 +188,7 @@ mod tests { #[test] fn errors_when_jwt_token_invalid() { let (_enc, dec) = key_pair(); - let context = Ok(UserAuthContext::bearer("abc")); + let context = UserAuthContext::bearer("abc"); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -214,7 +208,7 @@ mod tests { let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -236,7 +230,7 @@ mod tests { let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { panic!() diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index e56b6e1e78..22dcfce224 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -56,8 +56,5 @@ pub trait UserAuthStrategy: Sync + Send { vec![] } - fn authenticate( - &self, - context: Result, - ) -> Result; + fn authenticate(&self, context: UserAuthContext) -> Result; } diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index bb729a6d94..7d6b687eb2 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -86,7 +86,7 @@ pub(super) async fn handle_initial_hello( build_context(jwt, &auth_strategy.user_strategy.required_fields()); let auth = auth_strategy - .authenticate(Ok(context)) + .authenticate(context) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(Session { @@ -133,7 +133,7 @@ pub(super) async fn handle_repeated_hello( build_context(jwt, &auth_strategy.user_strategy.required_fields()); session.auth = auth_strategy - .authenticate(Ok(context)) + .authenticate(context) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(()) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index f5f81b822d..8ab1a489d5 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -34,7 +34,7 @@ impl FromRequestParts for RequestContext { let context = super::build_context(&parts.headers, &auth.user_strategy.required_fields()); Ok(Self::new( - auth.authenticate(Ok(context))?, + auth.authenticate(context)?, namespace, state.namespaces.meta_store().clone(), )) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index aa21314301..26b93e2de7 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -475,7 +475,7 @@ impl FromRequestParts for Authenticated { let context = build_context(&parts.headers, &auth.user_strategy.required_fields()); - Ok(auth.authenticate(Ok(context))?) + Ok(auth.authenticate(context)?) } } From 1f2de7ef4c43602eff082bb931228bae352d2e5e Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 29 Mar 2024 10:58:29 -0700 Subject: [PATCH 42/63] clean up dead code --- libsql-server/src/auth/parsers.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 208a5e3864..dbae78cf17 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -1,4 +1,4 @@ -use crate::auth::{constants::GRPC_AUTH_HEADER, AuthError}; +use crate::auth::AuthError; use anyhow::{bail, Context as _, Result}; use axum::http::HeaderValue; @@ -40,11 +40,6 @@ pub(crate) fn parse_grpc_auth_header( metadata: &MetadataMap, required_fields: &Vec, ) -> UserAuthContext { - // let mut context = metadata - // .get(GRPC_AUTH_HEADER) - // .ok_or(AuthError::AuthHeaderNotFound) - // .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - // .and_then(|t| UserAuthContext::from_auth_str(t)); let mut context = UserAuthContext::empty(); for field in required_fields.iter() { metadata From fc582ecf6ab41d1cf490c91b952a7b9f31f2de22 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 29 Mar 2024 11:32:26 -0700 Subject: [PATCH 43/63] wip --- libsql-server/src/auth/mod.rs | 10 ++++--- .../src/auth/user_auth_strategies/mod.rs | 4 ++- .../auth/user_auth_strategies/proxy_grpc.rs | 29 +++++++++++++++++++ libsql-server/src/hrana/ws/session.rs | 6 ++-- libsql-server/src/http/user/extract.rs | 2 +- libsql-server/src/http/user/mod.rs | 2 +- libsql-server/src/rpc/proxy.rs | 15 ++++------ libsql-server/src/rpc/replica_proxy.rs | 3 +- libsql-server/src/rpc/replication_log.rs | 3 +- 9 files changed, 49 insertions(+), 25 deletions(-) create mode 100644 libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs diff --git a/libsql-server/src/auth/mod.rs b/libsql-server/src/auth/mod.rs index 26c6dfecc1..052aa0f7e3 100644 --- a/libsql-server/src/auth/mod.rs +++ b/libsql-server/src/auth/mod.rs @@ -13,21 +13,23 @@ pub use authorized::Authorized; pub use errors::AuthError; pub use parsers::{parse_http_auth_header, parse_http_basic_auth_arg, parse_jwt_key}; pub use permission::Permission; -pub use user_auth_strategies::{Disabled, HttpBasic, Jwt, UserAuthContext, UserAuthStrategy}; +pub use user_auth_strategies::{ + Disabled, HttpBasic, Jwt, ProxyGrpc, UserAuthContext, UserAuthStrategy, +}; #[derive(Clone)] pub struct Auth { - pub user_strategy: Arc, + pub strategy: Arc, } impl Auth { pub fn new(user_strategy: impl UserAuthStrategy + Send + Sync + 'static) -> Self { Self { - user_strategy: Arc::new(user_strategy), + strategy: Arc::new(user_strategy), } } pub fn authenticate(&self, context: UserAuthContext) -> Result { - self.user_strategy.authenticate(context) + self.strategy.authenticate(context) } } diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 22dcfce224..119e57ef7e 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -1,17 +1,19 @@ pub mod disabled; pub mod http_basic; pub mod jwt; +pub mod proxy_grpc; pub use disabled::Disabled; use hashbrown::HashMap; pub use http_basic::HttpBasic; pub use jwt::Jwt; +pub use proxy_grpc::ProxyGrpc; use super::{AuthError, Authenticated}; #[derive(Debug)] pub struct UserAuthContext { - pub custom_fields: HashMap, String>, // todo, add aliases + pub custom_fields: HashMap, String>, } impl UserAuthContext { diff --git a/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs b/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs new file mode 100644 index 0000000000..4965fd8bf6 --- /dev/null +++ b/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs @@ -0,0 +1,29 @@ +use crate::auth::{AuthError, Authenticated}; + +use super::{UserAuthContext, UserAuthStrategy}; + +pub struct ProxyGrpc {} + +impl UserAuthStrategy for ProxyGrpc { + fn authenticate(&self, ctx: UserAuthContext) -> Result { + tracing::trace!("executing proxy grpc auth"); + let auth_str = None + .or_else(|| ctx.custom_fields.get("proxy-authorization")) + .or_else(|| ctx.custom_fields.get("x-proxy-authorization")); + } + + fn required_fields(&self) -> Vec { + vec![ + "authorization".to_string(), + "x-proxy-authorization".to_string(), + ] + } +} + +impl ProxyGrpc { + pub fn new() -> Self { + Self {} + } +} + +// todo tests diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index 7d6b687eb2..fce0903035 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -82,8 +82,7 @@ pub(super) async fn handle_initial_hello( .map(Auth::new) .unwrap_or(server.user_auth_strategy.clone()); - let context: UserAuthContext = - build_context(jwt, &auth_strategy.user_strategy.required_fields()); + let context: UserAuthContext = build_context(jwt, &auth_strategy.strategy.required_fields()); let auth = auth_strategy .authenticate(context) @@ -129,8 +128,7 @@ pub(super) async fn handle_repeated_hello( .map(Auth::new) .unwrap_or(server.user_auth_strategy.clone()); - let context: UserAuthContext = - build_context(jwt, &auth_strategy.user_strategy.required_fields()); + let context: UserAuthContext = build_context(jwt, &auth_strategy.strategy.required_fields()); session.auth = auth_strategy .authenticate(context) diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index 8ab1a489d5..33014efeb0 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -31,7 +31,7 @@ impl FromRequestParts for RequestContext { .map(Auth::new) .unwrap_or_else(|| state.user_auth_strategy.clone()); - let context = super::build_context(&parts.headers, &auth.user_strategy.required_fields()); + let context = super::build_context(&parts.headers, &auth.strategy.required_fields()); Ok(Self::new( auth.authenticate(context)?, diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 26b93e2de7..d5a2c3a9d4 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -473,7 +473,7 @@ impl FromRequestParts for Authenticated { .map(Auth::new) .unwrap_or_else(|| state.user_auth_strategy.clone()); - let context = build_context(&parts.headers, &auth.user_strategy.required_fields()); + let context = build_context(&parts.headers, &auth.strategy.required_fields()); Ok(auth.authenticate(context)?) } diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index aba85c64eb..a82d7e89e8 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -19,7 +19,7 @@ use tokio::time::Duration; use uuid::Uuid; use crate::auth::parsers::parse_grpc_auth_header; -use crate::auth::{Auth, Authenticated, Jwt}; +use crate::auth::{Auth, Jwt, ProxyGrpc}; use crate::connection::{Connection as _, RequestContext}; use crate::database::Connection; use crate::namespace::NamespaceStore; @@ -332,18 +332,13 @@ impl ProxyService { "Error fetching jwt key for a namespace: {}", e )))?, - }; + } + .or_else(ProxyGrpc::new()); - let auth = if let Some(auth) = auth { - let context = - parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()); - auth.authenticate(context)? - } else { - Authenticated::from_proxy_grpc_request(req)? - }; + let context = parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()); Ok(RequestContext::new( - auth, + auth.authenticate(context)?, namespace, self.namespaces.meta_store().clone(), )) diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 47732daba7..c55961289e 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -56,8 +56,7 @@ impl ReplicaProxyService { ))), }?; - let auth_context = - parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()); + let auth_context = parse_grpc_auth_header(req.metadata(), &auth.strategy.required_fields()); auth.authenticate(auth_context)?.upgrade_grpc_request(req); return Ok(()); } diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 4dc72190bd..6126fa1e69 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -94,8 +94,7 @@ impl ReplicationLogService { }; if let Some(auth) = auth { - let context = - parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()); + let context = parse_grpc_auth_header(req.metadata(), &auth.strategy.required_fields()); auth.authenticate(context)?; } From 4f49949708ba8452614f4d0eebac856a30f5ef20 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 29 Mar 2024 12:36:01 -0700 Subject: [PATCH 44/63] turned handling proxy grpc into auth strategy --- libsql-server/src/auth/errors.rs | 6 ++++++ libsql-server/src/auth/mod.rs | 4 ++-- libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs | 6 +++++- libsql-server/src/rpc/proxy.rs | 4 ++-- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/libsql-server/src/auth/errors.rs b/libsql-server/src/auth/errors.rs index 5275153315..f1267c0c44 100644 --- a/libsql-server/src/auth/errors.rs +++ b/libsql-server/src/auth/errors.rs @@ -26,6 +26,10 @@ pub enum AuthError { AuthStringMalformed, #[error("Expected authorization header but none given")] AuthHeaderNotFound, + #[error("Expected authorization proxy header but none given")] + AuthProxyHeaderNotFound, + #[error("Failed to parse auth proxy header")] + AuthProxyHeaderInvalid, #[error("Non-ASCII auth header")] AuthHeaderNonAscii, #[error("Authentication failed")] @@ -47,6 +51,8 @@ impl AuthError { Self::JwtImmature => "AUTH_JWT_IMMATURE", Self::AuthStringMalformed => "AUTH_HEADER_MALFORMED", Self::AuthHeaderNotFound => "AUTH_HEADER_NOT_FOUND", + Self::AuthProxyHeaderNotFound => "AUTH_PROXY_HEADER_NOT_FOUND", + Self::AuthProxyHeaderInvalid => "AUTH_PROXY_HEADER_INVALID", Self::AuthHeaderNonAscii => "AUTH_HEADER_MALFORMED", Self::Other => "AUTH_FAILED", } diff --git a/libsql-server/src/auth/mod.rs b/libsql-server/src/auth/mod.rs index 052aa0f7e3..365044e3ce 100644 --- a/libsql-server/src/auth/mod.rs +++ b/libsql-server/src/auth/mod.rs @@ -23,9 +23,9 @@ pub struct Auth { } impl Auth { - pub fn new(user_strategy: impl UserAuthStrategy + Send + Sync + 'static) -> Self { + pub fn new(strategy: impl UserAuthStrategy + Send + Sync + 'static) -> Self { Self { - strategy: Arc::new(user_strategy), + strategy: Arc::new(strategy), } } diff --git a/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs b/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs index 4965fd8bf6..c6c8e39151 100644 --- a/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs +++ b/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs @@ -9,7 +9,11 @@ impl UserAuthStrategy for ProxyGrpc { tracing::trace!("executing proxy grpc auth"); let auth_str = None .or_else(|| ctx.custom_fields.get("proxy-authorization")) - .or_else(|| ctx.custom_fields.get("x-proxy-authorization")); + .or_else(|| ctx.custom_fields.get("x-proxy-authorization")) + .ok_or_else(|| AuthError::AuthProxyHeaderNotFound)?; + + serde_json::from_str::(&auth_str) + .map_err(|_| AuthError::AuthProxyHeaderInvalid) } fn required_fields(&self) -> Vec { diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index a82d7e89e8..c054c475f9 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -333,9 +333,9 @@ impl ProxyService { e )))?, } - .or_else(ProxyGrpc::new()); + .unwrap_or_else(|| Auth::new(ProxyGrpc::new())); - let context = parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()); + let context = parse_grpc_auth_header(req.metadata(), &auth.strategy.required_fields()); Ok(RequestContext::new( auth.authenticate(context)?, From ba204fde1b8d1864c4460c49bd87a4e659240572 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 22 Mar 2024 13:57:10 -0700 Subject: [PATCH 45/63] bottomless: emit restored snapshot for waiters (#1252) This fixes an issue where a db gets restored from bottomless and doesn't get any writes until shutdown. At this point, the current generation is the same as the restored one but the graceful shutdown process expects to wait for that generation to be uploaded which never happens because there are no writes. This change adds a snapshot generation emit call at restore time to allow graceful shutdown to happen when there are no writes without having to checkpoint and upload a new snapshot. --- bottomless/src/replicator.rs | 5 +++++ libsql-server/local-test-envs | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 libsql-server/local-test-envs diff --git a/bottomless/src/replicator.rs b/bottomless/src/replicator.rs index 1a02a4690a..bf4cfa1764 100644 --- a/bottomless/src/replicator.rs +++ b/bottomless/src/replicator.rs @@ -487,6 +487,7 @@ impl Replicator { Err(_) => true, }) .await?; + tracing::debug!("done waiting"); match res.deref() { Ok(_) => Ok(true), Err(e) => Err(anyhow!("Failed snapshot generation {}: {}", generation, e)), @@ -923,6 +924,7 @@ impl Replicator { let _ = self.snapshot_notifier.send(Ok(self.generation().ok())); return Ok(None); } + tracing::debug!("snapshotting db file"); if !self.main_db_exists_and_not_empty().await { let generation = self.generation()?; tracing::debug!( @@ -1561,6 +1563,9 @@ impl Replicator { }; let (action, recovered) = self.restore_from(generation, timestamp).await?; + + let _ = self.snapshot_notifier.send(Ok(Some(generation))); + tracing::info!( "Restoring from generation {generation}: action={action:?}, recovered={recovered}" ); diff --git a/libsql-server/local-test-envs b/libsql-server/local-test-envs new file mode 100644 index 0000000000..2cb07ade19 --- /dev/null +++ b/libsql-server/local-test-envs @@ -0,0 +1,18 @@ +#!/bin/bash + +export LIBSQL_BOTTOMLESS_AWS_ACCESS_KEY_ID=minioadmin +export LIBSQL_BOTTOMLESS_AWS_DEFAULT_REGION=us-east-1 +export LIBSQL_BOTTOMLESS_AWS_SECRET_ACCESS_KEY=minioadmin +export LIBSQL_BOTTOMLESS_BUCKET=turso-dev +export LIBSQL_BOTTOMLESS_DATABASE_ID=5d64e223-21a3-4835-9815-9613216d9859 +export LIBSQL_BOTTOMLESS_ENDPOINT=http://localhost:9000 +export LIBSQL_BOTTOMLESS_VERIFY_CRC=false +export SQLD_BACKUP_META_STORE=true +export SQLD_ENABLE_BOTTOMLESS_REPLICATION=true +export SQLD_META_STORE_ACCESS_KEY_ID=minioadmin +export SQLD_META_STORE_BACKUP_ID=metastore-dev +export SQLD_META_STORE_BACKUP_INTERVAL_S=1 +export SQLD_META_STORE_BUCKET_ENDPOINT=http://localhost:9000 +export SQLD_META_STORE_BUCKET_NAME=turso-dev +export SQLD_META_STORE_REGION=us-east-1 +export SQLD_META_STORE_SECRET_ACCESS=minioadmin From 7226bf9a24ce58f548d996dd1b80c9ea4aa9c39c Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 23 Mar 2024 00:05:07 +0100 Subject: [PATCH 46/63] fix conn upgrade lock (#1244) fix deadlock on read transaction upgrade --- .../src/connection/connection_manager.rs | 287 +++++++++++++----- 1 file changed, 204 insertions(+), 83 deletions(-) diff --git a/libsql-server/src/connection/connection_manager.rs b/libsql-server/src/connection/connection_manager.rs index d5e5537f58..f1669ba26e 100644 --- a/libsql-server/src/connection/connection_manager.rs +++ b/libsql-server/src/connection/connection_manager.rs @@ -9,7 +9,7 @@ use hashbrown::HashMap; use libsql_sys::wal::wrapper::{WrapWal, WrappedWal}; use libsql_sys::wal::{CheckpointMode, Sqlite3Wal, Wal}; use metrics::atomics::AtomicU64; -use parking_lot::Mutex; +use parking_lot::{Mutex, MutexGuard}; use rusqlite::ErrorCode; use super::libsql::Connection; @@ -19,6 +19,13 @@ pub type ConnId = u64; pub type ManagedConnectionWal = WrappedWal; +#[derive(Copy, Clone, Debug)] +struct Slot { + id: ConnId, + started_at: Instant, + state: SlotState, +} + #[derive(Clone)] struct Abort(Arc); @@ -74,13 +81,15 @@ pub struct ConnectionManagerInner { /// When a slot becomes available, the connection allowed to make progress is put here /// the connection currently holding the lock /// bool: acquired - current: Mutex>, + current: Mutex>, /// map of registered connections abort_handle: Mutex>, /// threads waiting to acquire the lock /// todo: limit how many can be push write_queue: crossbeam::deque::Injector<(ConnId, Unparker)>, txn_timeout_duration: Duration, + /// the time we are given to acquire a transaction after we were given a slot + acquire_timeout_duration: Duration, next_conn_id: AtomicU64, sync_token: AtomicU64, } @@ -92,6 +101,7 @@ impl Default for ConnectionManagerInner { abort_handle: Default::default(), write_queue: Default::default(), txn_timeout_duration: TXN_TIMEOUT, + acquire_timeout_duration: Duration::from_millis(15), next_conn_id: Default::default(), sync_token: AtomicU64::new(0), } @@ -133,7 +143,7 @@ impl ManagedConnectionWalWrapper { extended_code: 517, // stale read }); } - if current.map_or(true, |(id, _, _)| id != self.id) && !enqueued { + if current.as_mut().map_or(true, |slot| slot.id != self.id) && !enqueued { self.manager .write_queue .push((self.id, parker.unparker().clone())); @@ -141,11 +151,16 @@ impl ManagedConnectionWalWrapper { tracing::debug!("enqueued"); } match *current { - Some((id, started_at, acquired)) => { + Some(ref mut slot) => { + tracing::debug!("current slot: {slot:?}"); // this is us, the previous connection put us here when it closed the // transaction - if id == self.id { - assert!(!acquired); + if slot.id == self.id { + assert!( + slot.state.is_notified() || slot.state.is_failure(), + "{slot:?}" + ); + slot.state = SlotState::Acquiring; tracing::debug!( line = line!(), "got lock after: {:?}", @@ -154,81 +169,124 @@ impl ManagedConnectionWalWrapper { break; } else { // not us, maybe we need to steal the lock? - drop(current); - if started_at.elapsed() >= self.manager.inner.txn_timeout_duration { - let handle = { - self.manager - .inner - .abort_handle - .lock() - .get(&id) - .unwrap() - .clone() - }; - // the guard must be dropped before rolling back, or end write txn will - // deadlock - tracing::debug!("forcing rollback of {id}"); - handle.abort(); - parker.park(); - tracing::debug!(line = line!(), "unparked"); - } else { - // otherwise we wait for the txn to timeout, or to be unparked by it - let before = Instant::now(); - let deadline = started_at + self.manager.inner.txn_timeout_duration; - parker.park_deadline( - started_at + self.manager.inner.txn_timeout_duration, - ); - tracing::debug!( - line = line!(), - "unparked after: {:?}, before_deadline: {:?}", - before.elapsed(), - Instant::now() < deadline - ); + let since_started = slot.started_at.elapsed(); + let deadline = slot.started_at + self.manager.txn_timeout_duration; + match slot.state { + SlotState::Acquired => { + if since_started >= self.manager.txn_timeout_duration { + let id = slot.id; + drop(current); + let handle = { + self.manager + .inner + .abort_handle + .lock() + .get(&id) + .unwrap() + .clone() + }; + // the guard must be dropped before rolling back, or end write txn will + // deadlock + tracing::debug!("forcing rollback of {id}"); + handle.abort(); + tracing::debug!(line = line!(), "parking"); + parker.park(); + tracing::debug!(line = line!(), "unparked"); + } else { + // otherwise we wait for the txn to timeout, or to be unparked by it + let deadline = + slot.started_at + self.manager.inner.txn_timeout_duration; + drop(current); + tracing::debug!(line = line!(), "parking"); + parker.park_deadline(deadline); + tracing::debug!( + line = line!(), + "before_deadline?: {:?}", + Instant::now() < deadline + ); + } + } + // we may want to limit how long a lock takes to go from notified + // to acquiring + SlotState::Acquiring | SlotState::Notified => { + drop(current); + tracing::debug!(line = line!(), "parking"); + parker.park_deadline(deadline); + tracing::debug!( + line = line!(), + "unparked after before_deadline?: {:?}", + Instant::now() < deadline + ); + } + SlotState::Failure => { + if since_started >= self.manager.inner.acquire_timeout_duration { + // the connection failed to acquire a transaction during the grace + // period. schedule the next transaction + match self.schedule_next(&mut current) { + Some(id) if id == self.id => { + current.as_mut().unwrap().state = SlotState::Acquiring; + break; + } + Some(_) => { + drop(current); + tracing::debug!(line = line!(), "parking"); + parker.park(); + tracing::debug!(line = line!(), "unparked"); + } + None => { + *current = Some(Slot { + id: self.id, + started_at: Instant::now(), + state: SlotState::Acquiring, + }); + break; + } + } + } else { + tracing::trace!("noticed failure from id={}, parking until end of grace period", slot.id); + let deadline = slot.started_at + + self.manager.inner.acquire_timeout_duration; + drop(current); + tracing::debug!(line = line!(), "parking"); + parker.park_deadline(deadline); + tracing::debug!( + line = line!(), + "unparked after before_deadline?: {:?}", + Instant::now() < deadline + ); + } + } } } } - None => { - let next = loop { - match self.manager.write_queue.steal() { - Steal::Empty => break None, - Steal::Success(item) => break Some(item), - Steal::Retry => (), - } - }; - - match next { - Some((id, _)) if id == self.id => { - // this is us! - *current = Some((self.id, Instant::now(), false)); - tracing::debug!("got lock after: {:?}", enqueued_at.elapsed()); - break; - } - Some((id, unpaker)) => { - tracing::debug!(line = line!(), "unparking id={id}"); - *current = Some((id, Instant::now(), false)); - drop(current); - unpaker.unpark(); - parker.park(); - } - None => unreachable!(), + None => match self.schedule_next(&mut current) { + Some(id) if id == self.id => { + current.as_mut().unwrap().state = SlotState::Acquiring; + break; } - } + Some(_) => { + drop(current); + tracing::debug!(line = line!(), "parking"); + parker.park(); + tracing::debug!(line = line!(), "unparked"); + } + None => { + *current = Some(Slot { + id: self.id, + started_at: Instant::now(), + state: SlotState::Acquiring, + }) + } + }, } } Ok(()) } - #[tracing::instrument(skip(self))] - fn release(&self) { - let mut current = self.manager.current.lock(); - let Some((id, started_at, _)) = current.take() else { - unreachable!("no lock to release") - }; - - assert_eq!(id, self.id); - - tracing::debug!("transaction finished after {:?}", started_at.elapsed()); + #[tracing::instrument(skip(self, current))] + #[track_caller] + fn schedule_next(&self, current: &mut MutexGuard>) -> Option { let next = loop { match self.manager.write_queue.steal() { Steal::Empty => break None, @@ -237,16 +295,67 @@ impl ManagedConnectionWalWrapper { } }; - if let Some((id, unparker)) = next { - tracing::debug!(line = line!(), "unparking id={id}"); - *current = Some((id, Instant::now(), false)); - unparker.unpark() - } else { - *current = None; + match next { + Some((id, unpaker)) => { + tracing::debug!(line = line!(), "unparking id={id}"); + **current = Some(Slot { + id, + started_at: Instant::now(), + state: SlotState::Notified, + }); + unpaker.unpark(); + Some(id) + } + None => None, + } + } + + #[tracing::instrument(skip(self))] + #[track_caller] + fn release(&self) { + let mut current = self.manager.current.lock(); + let Some(slot) = current.take() else { + unreachable!("no lock to release") + }; + + assert_eq!(slot.id, self.id); + + tracing::debug!("transaction finished after {:?}", slot.started_at.elapsed()); + match self.schedule_next(&mut current) { + Some(_) => (), + None => { + *current = None; + } } } } +#[derive(Copy, Clone, Debug)] +enum SlotState { + Notified, + Acquiring, + Acquired, + Failure, +} + +impl SlotState { + /// Returns `true` if the slot state is [`Notified`]. + /// + /// [`Notified`]: SlotState::Notified + #[must_use] + fn is_notified(&self) -> bool { + matches!(self, Self::Notified) + } + + /// Returns `true` if the slot state is [`Failure`]. + /// + /// [`Failure`]: SlotState::Failure + #[must_use] + fn is_failure(&self) -> bool { + matches!(self, Self::Failure) + } +} + impl WrapWal for ManagedConnectionWalWrapper { #[tracing::instrument(skip_all, fields(id = self.id))] fn begin_write_txn(&mut self, wrapped: &mut Sqlite3Wal) -> libsql_sys::wal::Result<()> { @@ -256,7 +365,7 @@ impl WrapWal for ManagedConnectionWalWrapper { Ok(_) => { tracing::debug!("transaction acquired"); let mut lock = self.manager.current.lock(); - lock.as_mut().unwrap().2 = true; + lock.as_mut().unwrap().state = SlotState::Acquired; Ok(()) } @@ -266,6 +375,8 @@ impl WrapWal for ManagedConnectionWalWrapper { tracing::debug!("error acquiring lock, releasing: {e}"); self.release(); } else { + let mut lock = self.manager.current.lock(); + lock.as_mut().unwrap().state = SlotState::Failure; tracing::debug!("error acquiring lock: {e}"); } Err(e) @@ -289,6 +400,7 @@ impl WrapWal for ManagedConnectionWalWrapper { ) -> libsql_sys::wal::Result<()> { let before = Instant::now(); self.acquire()?; + self.manager.current.lock().as_mut().unwrap().state = SlotState::Acquired; let mode = if rand::random::() < 0.1 { CheckpointMode::Truncate @@ -336,7 +448,14 @@ impl WrapWal for ManagedConnectionWalWrapper { wrapped.end_read_txn(); { let current = self.manager.current.lock(); - if let Some((id, _, true)) = *current { + // end read will only close the write txn if we actually acquired one, so only release + // if the slot acquire the transaction lock + if let Some(Slot { + id, + state: SlotState::Acquired, + .. + }) = *current + { // releasing read transaction releases the write lock (see wal.c) if id == self.id { drop(current); @@ -368,11 +487,13 @@ impl WrapWal for ManagedConnectionWalWrapper { let before = Instant::now(); let ret = manager.close(wrapped, db, sync_flags, None); { - tracing::debug!(line = line!(), "unparked"); let current = self.manager.current.lock(); - if let Some((id, _, _)) = *current { + if let Some(slot @ Slot { id, .. }) = *current { if id == self.id { - tracing::debug!("connection closed without releasing lock"); + tracing::debug!( + id = self.id, + "connection closed without releasing lock: {slot:?}" + ); drop(current); self.release() } @@ -380,7 +501,7 @@ impl WrapWal for ManagedConnectionWalWrapper { } self.manager.inner.abort_handle.lock().remove(&self.id); - tracing::debug!("closed in {:?}", before.elapsed()); + tracing::debug!(id = self.id, "closed in {:?}", before.elapsed()); ret } } From 14b05cd9ef34b62125c0ef42f55c91e62c6b461e Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 23 Mar 2024 19:26:46 +0100 Subject: [PATCH 47/63] prevent primary to remove itself in case of load error (#1253) --- libsql-server/src/namespace/mod.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 082bfaca1f..263db1b156 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -248,6 +248,8 @@ impl Namespace { restore_option: RestoreOption, resolve_attach_path: ResolveNamespacePathFn, ) -> crate::Result { + let db_path: Arc = config.base_path.join("dbs").join(name.as_str()).into(); + let fresh_namespace = !db_path.try_exists()?; // FIXME: make that truly atomic. explore the idea of using temp directories, and it's implications match Self::try_new_primary( config, @@ -255,17 +257,19 @@ impl Namespace { meta_store_handle, restore_option, resolve_attach_path, + db_path.clone(), ) .await { - Ok(ns) => Ok(ns), - Err(e) => { - let path = config.base_path.join("dbs").join(name.as_str()); - if let Err(e) = tokio::fs::remove_dir_all(path).await { - tracing::error!("failed to clean dirty namespace: {e}"); + Ok(this) => Ok(this), + Err(e) if fresh_namespace => { + tracing::error!("an error occured while deleting creating namespace, cleaning..."); + if let Err(e) = tokio::fs::remove_dir_all(&db_path).await { + tracing::error!("failed to remove dirty namespace directory: {e}") } Err(e) } + Err(e) => Err(e), } } @@ -396,9 +400,9 @@ impl Namespace { meta_store_handle: MetaStoreHandle, restore_option: RestoreOption, resolve_attach_path: ResolveNamespacePathFn, + db_path: Arc, ) -> crate::Result { let mut join_set = JoinSet::new(); - let db_path = ns_config.base_path.join("dbs").join(name.as_str()); tokio::fs::create_dir_all(&db_path).await?; From c2ca36cf2bf083e5208dd91749aead45d8d42636 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Mon, 25 Mar 2024 09:40:11 -0700 Subject: [PATCH 48/63] server: fix interactive txn schema panic (#1250) * server: fix interactive txn schema panic * fix txn check * fix check errors * Apply suggestions from code review --------- Co-authored-by: ad hoc --- libsql-server/src/database/schema.rs | 26 +++++++++++++++++++++----- libsql-server/src/schema/error.rs | 4 ++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/libsql-server/src/database/schema.rs b/libsql-server/src/database/schema.rs index fca46e8bb3..5493167e40 100644 --- a/libsql-server/src/database/schema.rs +++ b/libsql-server/src/database/schema.rs @@ -37,9 +37,19 @@ impl crate::connection::Connection for SchemaConnection { replication_index: Option, ) -> crate::Result { if migration.is_read_only() { - self.connection + let res = self + .connection .execute_program(migration, ctx, builder, replication_index) - .await + .await; + + // If the query was okay, verify if the connection is not in a txn state + if res.is_ok() && self.connection.is_autocommit().await? { + return Err(crate::Error::Migration( + crate::schema::Error::ConnectionInTxnState, + )); + } + + res } else { check_program_auth(&ctx, &migration, &self.config.get())?; let connection = self.connection.clone(); @@ -48,10 +58,14 @@ impl crate::connection::Connection for SchemaConnection { let builder = tokio::task::spawn_blocking({ let migration = migration.clone(); move || { - connection.with_raw(|conn| -> crate::Result<_> { + let res = connection.with_raw(|conn| -> crate::Result<_> { let mut txn = conn .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate) - .unwrap(); + .map_err(|_| { + crate::Error::Migration( + crate::schema::Error::InteractiveTxnNotAllowed, + ) + })?; // TODO: pass proper config let (ret, _) = perform_migration( &mut txn, @@ -62,7 +76,9 @@ impl crate::connection::Connection for SchemaConnection { ); txn.rollback().unwrap(); Ok(ret?) - }) + }); + + res } }) .await diff --git a/libsql-server/src/schema/error.rs b/libsql-server/src/schema/error.rs index dff6dc148a..13f21f3c15 100644 --- a/libsql-server/src/schema/error.rs +++ b/libsql-server/src/schema/error.rs @@ -41,6 +41,10 @@ pub enum Error { MigrationFailure(String), #[error("Error executing migration: {0}")] MigrationExecuteError(Box), + #[error("Interactive transactions are not allowed against a schema")] + InteractiveTxnNotAllowed, + #[error("Connection left in transaction state")] + ConnectionInTxnState, } impl ResponseError for Error {} From ade4fb7883fe0c291dd3388fc8e524482b91a841 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Mon, 25 Mar 2024 18:45:26 +0200 Subject: [PATCH 49/63] libsql-ffi: Fix sqlite3mc build output directory (#1234) Don't build inside the source tree because `cargo clean` won't clean up after it... --- libsql-ffi/build.rs | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/libsql-ffi/build.rs b/libsql-ffi/build.rs index 04a45e6e54..7e2e263faa 100644 --- a/libsql-ffi/build.rs +++ b/libsql-ffi/build.rs @@ -23,9 +23,7 @@ fn main() { println!("cargo:rerun-if-changed={BUNDLED_DIR}/src/sqlite3.c"); if cfg!(feature = "multiple-ciphers") { - println!( - "cargo:rerun-if-changed={BUNDLED_DIR}/SQLite3MultipleCiphers/build/libsqlite3mc_static.a" - ); + println!("cargo:rerun-if-changed={out_dir}/sqlite3mc/libsqlite3mc_static.a"); } if std::env::var("LIBSQL_DEV").is_ok() { @@ -255,7 +253,7 @@ pub fn build_bundled(out_dir: &str, out_path: &Path) { } fn copy_multiple_ciphers(out_dir: &str, out_path: &Path) { - let dylib = format!("{BUNDLED_DIR}/SQLite3MultipleCiphers/build/libsqlite3mc_static.a"); + let dylib = format!("{out_dir}/sqlite3mc/libsqlite3mc_static.a"); if !Path::new(&dylib).exists() { build_multiple_ciphers(out_path); } @@ -289,11 +287,14 @@ fn build_multiple_ciphers(out_path: &Path) { ) .unwrap(); - let bundled_dir = fs::canonicalize(BUNDLED_DIR).unwrap(); + let bundled_dir = fs::canonicalize(BUNDLED_DIR) + .unwrap() + .join("SQLite3MultipleCiphers"); - let build_dir = bundled_dir.join("SQLite3MultipleCiphers").join("build"); - let _ = fs::remove_dir_all(build_dir.clone()); - fs::create_dir_all(build_dir.clone()).unwrap(); + let out_dir = env::var("OUT_DIR").unwrap(); + let sqlite3mc_build_dir = fs::canonicalize(out_dir.clone()).unwrap().join("sqlite3mc"); + let _ = fs::remove_dir_all(sqlite3mc_build_dir.clone()); + fs::create_dir_all(sqlite3mc_build_dir.clone()).unwrap(); let mut cmake_opts: Vec<&str> = vec![]; @@ -304,7 +305,7 @@ fn build_multiple_ciphers(out_path: &Path) { let cross_cxx_var_name = format!("CXX_{}", cargo_build_target.replace("-", "_")); let cross_cxx = env::var(&cross_cxx_var_name).ok(); - let toolchain_path = build_dir.join("toolchain.cmake"); + let toolchain_path = sqlite3mc_build_dir.join("toolchain.cmake"); let cmake_toolchain_opt = format!("-DCMAKE_TOOLCHAIN_FILE=toolchain.cmake"); let mut toolchain_file = OpenOptions::new() @@ -345,9 +346,9 @@ fn build_multiple_ciphers(out_path: &Path) { } let mut cmake = Command::new("cmake"); - cmake.current_dir("bundled/SQLite3MultipleCiphers/build"); + cmake.current_dir(sqlite3mc_build_dir.clone()); cmake.args(cmake_opts.clone()); - cmake.arg(".."); + cmake.arg(bundled_dir.clone()); if cfg!(feature = "wasmtime-bindings") { cmake.arg("-DLIBSQL_ENABLE_WASM_RUNTIME=1"); } @@ -362,17 +363,17 @@ fn build_multiple_ciphers(out_path: &Path) { } let mut make = Command::new("cmake"); - make.current_dir("bundled/SQLite3MultipleCiphers/build"); + make.current_dir(sqlite3mc_build_dir.clone()); make.args(&["--build", "."]); make.args(&["--config", "Release"]); if !make.status().unwrap().success() { panic!("Failed to run make"); } // The `msbuild` tool puts the output in a different place so let's move it. - if Path::exists(&build_dir.join("Release/sqlite3mc_static.lib")) { + if Path::exists(&sqlite3mc_build_dir.join("Release/sqlite3mc_static.lib")) { fs::rename( - build_dir.join("Release/sqlite3mc_static.lib"), - build_dir.join("libsqlite3mc_static.a"), + sqlite3mc_build_dir.join("Release/sqlite3mc_static.lib"), + sqlite3mc_build_dir.join("libsqlite3mc_static.a"), ) .unwrap(); } From 7e029413b5cdcd32ac1d05bc295afeeda571b00b Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Mon, 25 Mar 2024 11:00:46 -0700 Subject: [PATCH 50/63] server: add shutdown timeout (#1258) * server: add shutdown timeout * add config cli/env var for timeout * wire timeout --- libsql-server/src/lib.rs | 29 ++++++++++++++++++++++++---- libsql-server/src/main.rs | 8 ++++++++ libsql-server/src/test/bottomless.rs | 4 +--- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index f3569f7842..3eabaaf2b7 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -104,6 +104,7 @@ pub struct Server Default for Server { @@ -124,6 +125,7 @@ impl Default for Server { max_active_namespaces: 100, meta_store_config: Default::default(), max_concurrent_connections: 128, + shutdown_timeout: Duration::from_secs(30), } } } @@ -503,6 +505,7 @@ where )); } + let shutdown_timeout = self.shutdown_timeout.clone(); let shutdown = self.shutdown.clone(); // setup user-facing rpc services match db_kind { @@ -567,10 +570,28 @@ where tokio::select! { _ = shutdown.notified() => { - join_set.shutdown().await; - service_shutdown.notify_waiters(); - namespace_store.shutdown().await?; - tracing::info!("sqld was shutdown gracefully. Bye!"); + let shutdown = async { + join_set.shutdown().await; + service_shutdown.notify_waiters(); + namespace_store.shutdown().await?; + + Ok::<_, crate::Error>(()) + }; + + match tokio::time::timeout(shutdown_timeout, shutdown).await { + Ok(Ok(())) => { + tracing::info!("sqld was shutdown gracefully. Bye!"); + } + Ok(Err(e)) => { + tracing::error!("failed to shutdown gracefully: {}", e); + std::process::exit(1); + }, + Err(_) => { + tracing::error!("shutdown timeout hit, forcefully shutting down"); + std::process::exit(1); + }, + + } } Some(res) = join_set.join_next() => { res??; diff --git a/libsql-server/src/main.rs b/libsql-server/src/main.rs index 39e5b5b5b5..e8260e7d17 100644 --- a/libsql-server/src/main.rs +++ b/libsql-server/src/main.rs @@ -242,6 +242,10 @@ struct Cli { /// empty on startup #[clap(long, env = "SQLD_ALLOW_METASTORE_RECOVERY")] allow_metastore_recovery: bool, + + /// Shutdown timeout duration in seconds, defaults to 30 seconds. + #[clap(long, env = "SQLD_SHUTDOWN_TIMEOUT")] + shutdown_timeout: Option, } #[derive(clap::Subcommand, Debug)] @@ -647,6 +651,10 @@ async fn build_server(config: &Cli) -> anyhow::Result { max_active_namespaces: config.max_active_namespaces, meta_store_config, max_concurrent_connections: config.max_concurrent_connections, + shutdown_timeout: config + .shutdown_timeout + .map(Duration::from_secs) + .unwrap_or(Duration::from_secs(30)), }) } diff --git a/libsql-server/src/test/bottomless.rs b/libsql-server/src/test/bottomless.rs index 5d28b37559..f8562a991c 100644 --- a/libsql-server/src/test/bottomless.rs +++ b/libsql-server/src/test/bottomless.rs @@ -114,9 +114,7 @@ async fn configure_server( initial_idle_shutdown_timeout: None, rpc_server_config: None, rpc_client_config: None, - shutdown: Default::default(), - meta_store_config: Default::default(), - max_concurrent_connections: 128, + ..Default::default() } } From 56b1221f9c87873b90c0ac70a38f8ac918c55cee Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Mon, 25 Mar 2024 11:41:30 -0700 Subject: [PATCH 51/63] server: release v0.24.4 (#1259) --- Cargo.lock | 2 +- libsql-server/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1b6a3dcda2..8d536f0170 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2779,7 +2779,7 @@ dependencies = [ [[package]] name = "libsql-server" -version = "0.24.3" +version = "0.24.4" dependencies = [ "aes", "anyhow", diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index c3e5336f9b..1167a6078e 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql-server" -version = "0.24.3" +version = "0.24.4" edition = "2021" default-run = "sqld" From b23231cedce108d4e5aad79367ba04339d00a48d Mon Sep 17 00:00:00 2001 From: Glauber Costa Date: Mon, 25 Mar 2024 15:41:56 -0400 Subject: [PATCH 52/63] server: allow explain queries without bind parameters (#1256) allow explain queries without bind parameters In SQLite, it is invalid to pass a query with a ? or named parameter without a bind variable, unless you are just trying to explain the command. We always fail those queries, but we shouldn't. --- libsql-server/src/query.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/libsql-server/src/query.rs b/libsql-server/src/query.rs index c6d1c76186..72096226cc 100644 --- a/libsql-server/src/query.rs +++ b/libsql-server/src/query.rs @@ -142,9 +142,17 @@ impl Params { if let Some(value) = maybe_value { stmt.raw_bind_parameter(index, value)?; } else if let Some(name) = param_name { - return Err(anyhow!("value for parameter {} not found", name)); + if stmt.is_explain() > 0 { + return Ok(()); + } else { + return Err(anyhow!("value for parameter {} not found", name)); + } } else { - return Err(anyhow!("value for parameter {} not found", index)); + if stmt.is_explain() > 0 { + return Ok(()); + } else { + return Err(anyhow!("value for parameter {} not found", index)); + } } } } From c731a3a4b6f045a0df9c51d30cd802b54210dc34 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Mon, 25 Mar 2024 16:07:55 -0700 Subject: [PATCH 53/63] add libsql-hrana crate (#1260) --- Cargo.lock | 17 +++++++++++++---- Cargo.toml | 3 ++- libsql-hrana/Cargo.toml | 15 +++++++++++++++ .../src/hrana/mod.rs => libsql-hrana/src/lib.rs | 0 .../src/hrana => libsql-hrana/src}/proto.rs | 0 .../src/hrana => libsql-hrana/src}/protobuf.rs | 0 libsql-server/Cargo.toml | 3 ++- libsql-server/src/hrana/http/mod.rs | 2 +- libsql-server/src/hrana/http/request.rs | 2 +- libsql-server/src/hrana/mod.rs | 2 +- libsql-sys/Cargo.toml | 13 +------------ libsql-sys/src/lib.rs | 2 -- libsql/Cargo.toml | 5 +++-- libsql/src/hrana/mod.rs | 4 ++-- libsql/src/hrana/stream.rs | 2 +- libsql/src/hrana/transaction.rs | 2 +- libsql/src/wasm/mod.rs | 2 +- 17 files changed, 44 insertions(+), 30 deletions(-) create mode 100644 libsql-hrana/Cargo.toml rename libsql-sys/src/hrana/mod.rs => libsql-hrana/src/lib.rs (100%) rename {libsql-sys/src/hrana => libsql-hrana/src}/proto.rs (100%) rename {libsql-sys/src/hrana => libsql-hrana/src}/protobuf.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index 8d536f0170..bd55cd2b67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2705,6 +2705,7 @@ dependencies = [ "http", "hyper", "hyper-rustls 0.25.0", + "libsql-hrana", "libsql-sqlite3-parser", "libsql-sys", "libsql_replication", @@ -2753,6 +2754,17 @@ dependencies = [ "libsql-wasmtime-bindings", ] +[[package]] +name = "libsql-hrana" +version = "0.1.0" +dependencies = [ + "base64", + "bytes", + "prost", + "serde", + "serde_json", +] + [[package]] name = "libsql-rusqlite" version = "0.30.0" @@ -2821,6 +2833,7 @@ dependencies = [ "jsonwebtoken", "libsql", "libsql-client", + "libsql-hrana", "libsql-rusqlite", "libsql-sqlite3-parser", "libsql-sys", @@ -2891,14 +2904,10 @@ dependencies = [ name = "libsql-sys" version = "0.4.0" dependencies = [ - "base64", "bytes", "libsql-ffi", "libsql-rusqlite", "once_cell", - "prost", - "serde", - "serde_json", "tracing", "zerocopy", ] diff --git a/Cargo.toml b/Cargo.toml index 772d0c19f6..6ef87f13f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,11 +10,12 @@ members = [ "bottomless-cli", "libsql-replication", "libsql-ffi", + "libsql-hrana", "vendored/rusqlite", "vendored/sqlite3-parser", - "xtask", + "xtask", "libsql-hrana", ] exclude = [ diff --git a/libsql-hrana/Cargo.toml b/libsql-hrana/Cargo.toml new file mode 100644 index 0000000000..f7401d5038 --- /dev/null +++ b/libsql-hrana/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "libsql-hrana" +version = "0.1.0" +edition = "2021" +license = "MIT" + +[dependencies] +serde = { version = "1.0", features = ["derive", "rc"] } +prost = { version = "0.12" } +base64 = { version = "0.21" } +bytes = "1" + +[dev-dependencies] +serde_json = "1.0" + diff --git a/libsql-sys/src/hrana/mod.rs b/libsql-hrana/src/lib.rs similarity index 100% rename from libsql-sys/src/hrana/mod.rs rename to libsql-hrana/src/lib.rs diff --git a/libsql-sys/src/hrana/proto.rs b/libsql-hrana/src/proto.rs similarity index 100% rename from libsql-sys/src/hrana/proto.rs rename to libsql-hrana/src/proto.rs diff --git a/libsql-sys/src/hrana/protobuf.rs b/libsql-hrana/src/protobuf.rs similarity index 100% rename from libsql-sys/src/hrana/protobuf.rs rename to libsql-hrana/src/protobuf.rs diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 1167a6078e..7de7d65fe6 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -60,7 +60,8 @@ serde_json = { version = "1.0.91", features = ["preserve_order"] } md-5 = "0.10" sha2 = "0.10" sha256 = "1.1.3" -libsql-sys = { path = "../libsql-sys", features = ["wal", "hrana"], default-features = false } +libsql-sys = { path = "../libsql-sys", features = ["wal"], default-features = false } +libsql-hrana = { path = "../libsql-hrana" } sqlite3-parser = { package = "libsql-sqlite3-parser", path = "../vendored/sqlite3-parser", version = "0.11.0", default-features = false, features = [ "YYNOERRORRECOVERY" ] } tempfile = "3.7.0" thiserror = "1.0.38" diff --git a/libsql-server/src/hrana/http/mod.rs b/libsql-server/src/hrana/http/mod.rs index 938270feb2..c336a3e1f0 100644 --- a/libsql-server/src/hrana/http/mod.rs +++ b/libsql-server/src/hrana/http/mod.rs @@ -1,7 +1,7 @@ use anyhow::{Context, Result}; use bytes::Bytes; use futures::stream::Stream; -use libsql_sys::hrana::proto; +use libsql_hrana::proto; use parking_lot::Mutex; use serde::{de::DeserializeOwned, Serialize}; use std::pin::Pin; diff --git a/libsql-server/src/hrana/http/request.rs b/libsql-server/src/hrana/http/request.rs index e444c5fdf5..d0f9d1a9e4 100644 --- a/libsql-server/src/hrana/http/request.rs +++ b/libsql-server/src/hrana/http/request.rs @@ -4,7 +4,7 @@ use bytesize::ByteSize; use super::super::{batch, stmt, ProtocolError, Version}; use super::stream; use crate::connection::{Connection, RequestContext}; -use libsql_sys::hrana::proto; +use libsql_hrana::proto; const MAX_SQL_COUNT: usize = 50; const MAX_STORED_SQL_SIZE: ByteSize = ByteSize::kb(5); diff --git a/libsql-server/src/hrana/mod.rs b/libsql-server/src/hrana/mod.rs index 8b5aeefdde..9023f085d1 100644 --- a/libsql-server/src/hrana/mod.rs +++ b/libsql-server/src/hrana/mod.rs @@ -6,7 +6,7 @@ pub mod http; mod result_builder; pub mod stmt; pub mod ws; -pub use libsql_sys::hrana::proto; +pub use libsql_hrana::proto; #[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] pub enum Version { diff --git a/libsql-sys/Cargo.toml b/libsql-sys/Cargo.toml index 3e77836639..8c778aee6d 100644 --- a/libsql-sys/Cargo.toml +++ b/libsql-sys/Cargo.toml @@ -17,12 +17,6 @@ once_cell = "1.18.0" rusqlite = { workspace = true, features = ["trace"], optional = true } tracing = "0.1.37" zerocopy = { version = "0.7.28", features = ["derive"] } -serde = { version = "1.0", features = ["derive", "rc"], optional = true } -prost = { version = "0.12", optional = true } -base64 = { version = "0.21", optional = true } - -[dev-dependencies] -serde_json = "1.0" [features] default = ["api"] @@ -32,9 +26,4 @@ rusqlite = ["dep:rusqlite"] wasmtime-bindings = ["libsql-ffi/wasmtime-bindings"] unix-excl-vfs = [] encryption = ["libsql-ffi/multiple-ciphers"] -serde = ["dep:serde"] -hrana = [ - "serde", - "dep:prost", - "dep:base64" -] + diff --git a/libsql-sys/src/lib.rs b/libsql-sys/src/lib.rs index 839fa5aded..bf2e6ec27f 100644 --- a/libsql-sys/src/lib.rs +++ b/libsql-sys/src/lib.rs @@ -63,8 +63,6 @@ pub mod ffi { #[cfg(feature = "api")] pub mod connection; pub mod error; -#[cfg(feature = "hrana")] -pub mod hrana; #[cfg(feature = "api")] pub mod statement; #[cfg(feature = "api")] diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index 07e7f565ae..bd098aa0a4 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -11,7 +11,8 @@ tracing = { version = "0.1.37", default-features = false } thiserror = "1.0.40" futures = { version = "0.3.28", optional = true } -libsql-sys = { version = "0.4", path = "../libsql-sys", features = ["hrana"], optional = true } +libsql-sys = { version = "0.4", path = "../libsql-sys", optional = true } +libsql-hrana = { version = "0.1", path = "../libsql-hrana", optional = true } tokio = { version = "1.29.1", features = ["sync"], optional = true } tokio-util = { version = "0.7", features = ["io-util", "codec"], optional = true } parking_lot = { version = "0.12.1", optional = true } @@ -101,7 +102,7 @@ hrana = [ "dep:tokio", "dep:tokio-util", "dep:bytes", - "libsql-sys", + "dep:libsql-hrana", ] serde = ["dep:serde"] remote = [ diff --git a/libsql/src/hrana/mod.rs b/libsql/src/hrana/mod.rs index 3f330f1c1f..1bf6e6cc57 100644 --- a/libsql/src/hrana/mod.rs +++ b/libsql/src/hrana/mod.rs @@ -16,8 +16,8 @@ use crate::parser::StmtKind; use crate::{params::Params, ValueType}; use bytes::Bytes; use futures::{Stream, StreamExt}; -pub use libsql_sys::hrana::proto; -use libsql_sys::hrana::proto::{Batch, BatchResult, Col, Stmt}; +pub use libsql_hrana::proto; +use libsql_hrana::proto::{Batch, BatchResult, Col, Stmt}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; diff --git a/libsql/src/hrana/stream.rs b/libsql/src/hrana/stream.rs index fcc668ecf5..cbf61d53fd 100644 --- a/libsql/src/hrana/stream.rs +++ b/libsql/src/hrana/stream.rs @@ -3,7 +3,7 @@ use crate::hrana::proto::{Batch, BatchResult, DescribeResult, Stmt, StmtResult}; use crate::hrana::{CursorResponseError, HranaError, HttpSend, Result}; use bytes::{Bytes, BytesMut}; use futures::Stream; -use libsql_sys::hrana::proto::{ +use libsql_hrana::proto::{ BatchStreamReq, CloseSqlStreamReq, CloseStreamReq, CloseStreamResp, DescribeStreamReq, GetAutocommitStreamReq, PipelineReqBody, PipelineRespBody, SequenceStreamReq, StoreSqlStreamReq, StreamRequest, StreamResponse, StreamResult, diff --git a/libsql/src/hrana/transaction.rs b/libsql/src/hrana/transaction.rs index 562aa92c4c..447f29ab43 100644 --- a/libsql/src/hrana/transaction.rs +++ b/libsql/src/hrana/transaction.rs @@ -3,7 +3,7 @@ use crate::hrana::stream::HranaStream; use crate::hrana::{HttpSend, Result}; use crate::parser::StmtKind; use crate::TransactionBehavior; -use libsql_sys::hrana::proto::{ExecuteStreamReq, StreamRequest}; +use libsql_hrana::proto::{ExecuteStreamReq, StreamRequest}; #[derive(Debug, Clone)] pub(crate) struct HttpTransaction diff --git a/libsql/src/wasm/mod.rs b/libsql/src/wasm/mod.rs index ee81f32cbc..2ae5a5ec91 100644 --- a/libsql/src/wasm/mod.rs +++ b/libsql/src/wasm/mod.rs @@ -36,7 +36,7 @@ use crate::{ params::IntoParams, TransactionBehavior, }; -use libsql_sys::hrana::proto::{Batch, Stmt}; +use libsql_hrana::proto::{Batch, Stmt}; pub use crate::wasm::rows::Rows; From e81ab33000b192fe85bd90e121fcf495002fe83b Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Tue, 26 Mar 2024 07:01:18 -0700 Subject: [PATCH 54/63] add windows ci (#1263) --- .github/workflows/rust.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index fd89101d34..1423c29a02 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -134,6 +134,29 @@ jobs: run: rm -rf libsql-ffi/bundled/SQLite3MultipleCiphers/build - name: embedded replica encryption tests run: cargo test -F test-encryption --package libsql-server --test tests embedded_replica + windows: + runs-on: windows-latest + name: Windows checks + steps: + - uses: hecrj/setup-rust-action@v1 + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: recursive + - name: Set up cargo cache + uses: actions/cache@v3 + continue-on-error: false + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-cargo- + - name: check libsql remote + run: cargo check -p libsql --no-default-features -F remote # test-rust-wasm: # runs-on: ubuntu-latest From 70d43242f01929eba393190390463ae608e3e388 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Tue, 26 Mar 2024 08:43:54 -0700 Subject: [PATCH 55/63] libsql: prepare v0.3.1 release (#1264) --- Cargo.lock | 2 +- libsql/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bd55cd2b67..ad6ea1d1a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2690,7 +2690,7 @@ dependencies = [ [[package]] name = "libsql" -version = "0.3.1" +version = "0.3.2" dependencies = [ "anyhow", "async-stream", diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index bd098aa0a4..5e4ff2feb3 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libsql" -version = "0.3.1" +version = "0.3.2" edition = "2021" description = "libSQL library: the main gateway for interacting with the database" repository = "https://github.com/tursodatabase/libsql" From 6cd1e00456aebd37e89ee1fd7399e9ee7015527b Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Tue, 26 Mar 2024 09:41:47 -0700 Subject: [PATCH 56/63] libsql-hrana: add description (#1265) --- libsql-hrana/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/libsql-hrana/Cargo.toml b/libsql-hrana/Cargo.toml index f7401d5038..3181335f64 100644 --- a/libsql-hrana/Cargo.toml +++ b/libsql-hrana/Cargo.toml @@ -3,6 +3,7 @@ name = "libsql-hrana" version = "0.1.0" edition = "2021" license = "MIT" +description = "hrana protocol for libsql" [dependencies] serde = { version = "1.0", features = ["derive", "rc"] } From 0730b4fac27764640af0bf9832696f712855d966 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Tue, 26 Mar 2024 19:30:06 -0700 Subject: [PATCH 57/63] deduplicated handling of hrana hello and repreated hello. Code is fully equivalent to the previous form. No change in logic. --- libsql-server/src/hrana/ws/conn.rs | 12 +++-- libsql-server/src/hrana/ws/session.rs | 68 ++++++++------------------- 2 files changed, 29 insertions(+), 51 deletions(-) diff --git a/libsql-server/src/hrana/ws/conn.rs b/libsql-server/src/hrana/ws/conn.rs index e34e00e6bf..5197b2dbce 100644 --- a/libsql-server/src/hrana/ws/conn.rs +++ b/libsql-server/src/hrana/ws/conn.rs @@ -208,12 +208,18 @@ async fn handle_client_msg(conn: &mut Conn, client_msg: proto::ClientMsg) -> Res async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result { let hello_res = match conn.session.as_mut() { None => { - session::handle_initial_hello(&conn.server, conn.version, jwt, conn.namespace.clone()) + conn.session = session::handle_hello(&conn.server, jwt, conn.namespace.clone()) .await - .map(|session| conn.session = Some(session)) + .map(|auth| session::Session::new(auth, conn.version)) + .map(|s| Some(s))?; + Ok(()) } Some(session) => { - session::handle_repeated_hello(&conn.server, session, jwt, conn.namespace.clone()).await + if session.version < Version::Hrana2 { + bail!(ProtocolError::NotSupported {what: "Repeated hello message", min_version: Version::Hrana2,}) + } + session.auth = session::handle_hello(&conn.server, jwt, conn.namespace.clone()).await?; + Ok(()) } }; diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index fce0903035..dae21cf7e4 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -15,13 +15,26 @@ use crate::namespace::NamespaceName; /// Session-level state of an authenticated Hrana connection. pub struct Session { - auth: Authenticated, - version: Version, + pub auth: Authenticated, + pub version: Version, streams: HashMap, sqls: HashMap, cursors: HashMap, } +impl Session { + pub fn new(auth: Authenticated, version: Version) -> Self { + Self { + auth, + version, + streams: HashMap::new(), + sqls: HashMap::new(), + cursors: HashMap::new(), + } + } +} + + struct StreamHandle { job_tx: mpsc::Sender, cursor_id: Option, @@ -65,13 +78,12 @@ pub enum ResponseError { Batch(batch::BatchError), } -pub(super) async fn handle_initial_hello( + +pub(super) async fn handle_hello( server: &Server, - version: Version, jwt: Option, namespace: NamespaceName, -) -> Result { - // todo dupe #auth +) -> Result { let namespace_jwt_key = server .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) @@ -84,17 +96,9 @@ pub(super) async fn handle_initial_hello( let context: UserAuthContext = build_context(jwt, &auth_strategy.strategy.required_fields()); - let auth = auth_strategy + auth_strategy .authenticate(context) - .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; - - Ok(Session { - auth, - version, - streams: HashMap::new(), - sqls: HashMap::new(), - cursors: HashMap::new(), - }) + .map_err(|err| anyhow!(ResponseError::Auth { source: err })) } fn build_context(jwt: Option, required_fields: &Vec) -> UserAuthContext { @@ -105,38 +109,6 @@ fn build_context(jwt: Option, required_fields: &Vec) -> UserAuth ctx } -pub(super) async fn handle_repeated_hello( - server: &Server, - session: &mut Session, - jwt: Option, - namespace: NamespaceName, -) -> Result<()> { - if session.version < Version::Hrana2 { - bail!(ProtocolError::NotSupported { - what: "Repeated hello message", - min_version: Version::Hrana2, - }) - } - // todo dupe #auth - let namespace_jwt_key = server - .namespaces - .with(namespace.clone(), |ns| ns.jwt_key()) - .await??; - - let auth_strategy = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or(server.user_auth_strategy.clone()); - - let context: UserAuthContext = build_context(jwt, &auth_strategy.strategy.required_fields()); - - session.auth = auth_strategy - .authenticate(context) - .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; - - Ok(()) -} - pub(super) async fn handle_request( server: &Server, session: &mut Session, From 90aee49a3b26534389b5ce08d4afee23b02a714c Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 27 Mar 2024 13:09:10 -0700 Subject: [PATCH 58/63] fixed early return --- libsql-server/src/hrana/ws/conn.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/libsql-server/src/hrana/ws/conn.rs b/libsql-server/src/hrana/ws/conn.rs index 5197b2dbce..048c69c606 100644 --- a/libsql-server/src/hrana/ws/conn.rs +++ b/libsql-server/src/hrana/ws/conn.rs @@ -208,18 +208,18 @@ async fn handle_client_msg(conn: &mut Conn, client_msg: proto::ClientMsg) -> Res async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result { let hello_res = match conn.session.as_mut() { None => { - conn.session = session::handle_hello(&conn.server, jwt, conn.namespace.clone()) + session::handle_hello(&conn.server, jwt, conn.namespace.clone()) .await .map(|auth| session::Session::new(auth, conn.version)) - .map(|s| Some(s))?; - Ok(()) + .map(|s| {conn.session = Some(s)}) } Some(session) => { if session.version < Version::Hrana2 { bail!(ProtocolError::NotSupported {what: "Repeated hello message", min_version: Version::Hrana2,}) } - session.auth = session::handle_hello(&conn.server, jwt, conn.namespace.clone()).await?; - Ok(()) + session::handle_hello(&conn.server, jwt, conn.namespace.clone()) + .await + .map(|a| {session.auth = a}) } }; From 57052b51cc4626327b612c22c3de1fe2ef051de4 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 27 Mar 2024 13:36:07 -0700 Subject: [PATCH 59/63] lazy unwrapping --- libsql-server/src/hrana/ws/session.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index dae21cf7e4..15c89c987a 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -92,7 +92,7 @@ pub(super) async fn handle_hello( let auth_strategy = namespace_jwt_key .map(Jwt::new) .map(Auth::new) - .unwrap_or(server.user_auth_strategy.clone()); + .unwrap_or_else(|| server.user_auth_strategy.clone()); let context: UserAuthContext = build_context(jwt, &auth_strategy.strategy.required_fields()); From 8a1b675b3ba57a520d325325668cda154c4a31d5 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 27 Mar 2024 18:15:41 -0700 Subject: [PATCH 60/63] made session fields private again --- libsql-server/src/hrana/ws/conn.rs | 22 ++++++++-------------- libsql-server/src/hrana/ws/session.rs | 15 ++++++++++++--- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/libsql-server/src/hrana/ws/conn.rs b/libsql-server/src/hrana/ws/conn.rs index 048c69c606..acba5a342f 100644 --- a/libsql-server/src/hrana/ws/conn.rs +++ b/libsql-server/src/hrana/ws/conn.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use std::borrow::{BorrowMut, Cow}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; @@ -206,21 +206,15 @@ async fn handle_client_msg(conn: &mut Conn, client_msg: proto::ClientMsg) -> Res } async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result { + let auth = session::handle_hello(&conn.server, jwt, conn.namespace.clone()).await; + let hello_res = match conn.session.as_mut() { - None => { - session::handle_hello(&conn.server, jwt, conn.namespace.clone()) - .await + None => auth .map(|auth| session::Session::new(auth, conn.version)) - .map(|s| {conn.session = Some(s)}) - } - Some(session) => { - if session.version < Version::Hrana2 { - bail!(ProtocolError::NotSupported {what: "Repeated hello message", min_version: Version::Hrana2,}) - } - session::handle_hello(&conn.server, jwt, conn.namespace.clone()) - .await - .map(|a| {session.auth = a}) - } + .map(|s| {conn.session = Some(s)}), + Some(session) => auth + .map(|a| {session.update_auth(a)}) + .and_then(|op| op) }; match hello_res { diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index 15c89c987a..d18eab1002 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, bail, Error, Result}; use futures::future::BoxFuture; use tokio::sync::{mpsc, oneshot}; @@ -15,8 +15,8 @@ use crate::namespace::NamespaceName; /// Session-level state of an authenticated Hrana connection. pub struct Session { - pub auth: Authenticated, - pub version: Version, + auth: Authenticated, + version: Version, streams: HashMap, sqls: HashMap, cursors: HashMap, @@ -32,6 +32,15 @@ impl Session { cursors: HashMap::new(), } } + + pub fn update_auth(&mut self, auth: Authenticated) -> Result<(), Error>{ + if self.version < Version::Hrana2 { + bail!(ProtocolError::NotSupported {what: "Repeated hello message", min_version: Version::Hrana2,}) + } + self.auth = auth; + Ok(()) + } + } From dabcfd14b52ca635658166a9579e86f764c0125b Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Wed, 27 Mar 2024 18:18:05 -0700 Subject: [PATCH 61/63] fmt --- libsql-server/src/hrana/ws/conn.rs | 10 ++++------ libsql-server/src/hrana/ws/session.rs | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/libsql-server/src/hrana/ws/conn.rs b/libsql-server/src/hrana/ws/conn.rs index acba5a342f..40e903147e 100644 --- a/libsql-server/src/hrana/ws/conn.rs +++ b/libsql-server/src/hrana/ws/conn.rs @@ -1,4 +1,4 @@ -use std::borrow::{BorrowMut, Cow}; +use std::borrow::Cow; use std::future::Future; use std::pin::Pin; use std::sync::Arc; @@ -210,11 +210,9 @@ async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result let hello_res = match conn.session.as_mut() { None => auth - .map(|auth| session::Session::new(auth, conn.version)) - .map(|s| {conn.session = Some(s)}), - Some(session) => auth - .map(|a| {session.update_auth(a)}) - .and_then(|op| op) + .map(|auth| session::Session::new(auth, conn.version)) + .map(|s| conn.session = Some(s)), + Some(session) => auth.map(|a| session.update_auth(a)).and_then(|op| op), }; match hello_res { diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index d18eab1002..4cf4d59407 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -24,26 +24,27 @@ pub struct Session { impl Session { pub fn new(auth: Authenticated, version: Version) -> Self { - Self { - auth, - version, - streams: HashMap::new(), - sqls: HashMap::new(), + Self { + auth, + version, + streams: HashMap::new(), + sqls: HashMap::new(), cursors: HashMap::new(), } } - pub fn update_auth(&mut self, auth: Authenticated) -> Result<(), Error>{ + pub fn update_auth(&mut self, auth: Authenticated) -> Result<(), Error> { if self.version < Version::Hrana2 { - bail!(ProtocolError::NotSupported {what: "Repeated hello message", min_version: Version::Hrana2,}) + bail!(ProtocolError::NotSupported { + what: "Repeated hello message", + min_version: Version::Hrana2, + }) } self.auth = auth; Ok(()) } - } - struct StreamHandle { job_tx: mpsc::Sender, cursor_id: Option, @@ -87,7 +88,6 @@ pub enum ResponseError { Batch(batch::BatchError), } - pub(super) async fn handle_hello( server: &Server, jwt: Option, From 9938645a8aa58f8633f8d499b2899574d97567bb Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Thu, 28 Mar 2024 09:33:54 -0700 Subject: [PATCH 62/63] cleaned up nesting in conn --- libsql-server/src/hrana/ws/conn.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/libsql-server/src/hrana/ws/conn.rs b/libsql-server/src/hrana/ws/conn.rs index 40e903147e..94e47423fc 100644 --- a/libsql-server/src/hrana/ws/conn.rs +++ b/libsql-server/src/hrana/ws/conn.rs @@ -208,12 +208,15 @@ async fn handle_client_msg(conn: &mut Conn, client_msg: proto::ClientMsg) -> Res async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result { let auth = session::handle_hello(&conn.server, jwt, conn.namespace.clone()).await; - let hello_res = match conn.session.as_mut() { - None => auth - .map(|auth| session::Session::new(auth, conn.version)) - .map(|s| conn.session = Some(s)), - Some(session) => auth.map(|a| session.update_auth(a)).and_then(|op| op), - }; + let hello_res = auth.map(|a| { + if let Some(sess) = conn.session.as_mut(){ + sess.update_auth(a) + } else { + conn.session = Some(session::Session::new(a, conn.version)); + Ok(()) + } + }).and_then(|o|o); + match hello_res { Ok(_) => { From 632a6406f9bdf76e243ea5f016a71a3bd2a4bcf2 Mon Sep 17 00:00:00 2001 From: julian warszawski Date: Fri, 29 Mar 2024 10:47:36 -0700 Subject: [PATCH 63/63] fmt --- libsql-server/src/hrana/ws/conn.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/libsql-server/src/hrana/ws/conn.rs b/libsql-server/src/hrana/ws/conn.rs index 94e47423fc..a5581cbc71 100644 --- a/libsql-server/src/hrana/ws/conn.rs +++ b/libsql-server/src/hrana/ws/conn.rs @@ -208,15 +208,16 @@ async fn handle_client_msg(conn: &mut Conn, client_msg: proto::ClientMsg) -> Res async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result { let auth = session::handle_hello(&conn.server, jwt, conn.namespace.clone()).await; - let hello_res = auth.map(|a| { - if let Some(sess) = conn.session.as_mut(){ - sess.update_auth(a) - } else { - conn.session = Some(session::Session::new(a, conn.version)); - Ok(()) - } - }).and_then(|o|o); - + let hello_res = auth + .map(|a| { + if let Some(sess) = conn.session.as_mut() { + sess.update_auth(a) + } else { + conn.session = Some(session::Session::new(a, conn.version)); + Ok(()) + } + }) + .and_then(|o| o); match hello_res { Ok(_) => {