Skip to content

Commit

Permalink
Upgrade daphne-server to http 1
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Nov 22, 2024
1 parent 95eb6f1 commit 9572f3e
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 438 deletions.
434 changes: 80 additions & 354 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 0 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ hpke-rs = "0.2.0"
hpke-rs-crypto = "0.2.0"
hpke-rs-rust-crypto = "0.2.0"
http = "1"
http-body-util = "0.1"
hyper = "0.14.29"
itertools = "0.12.1"
mappable-rc = "0.1.1"
matchit = "0.7.3"
p256 = { version = "0.13.2", features = ["ecdsa-core", "ecdsa", "pem"] }
Expand All @@ -71,7 +68,6 @@ rayon = "1.10.0"
rcgen = "0.12.1"
regex = "1.10.5"
reqwest = { version = "0.12.5", default-features = false, features = ["rustls-tls-native-roots"] }
reqwest-wasm = "0.11.16"
ring = "0.17.8"
rustls = "0.23.10"
rustls-native-certs = "0.7"
Expand All @@ -89,7 +85,6 @@ tracing = "0.1.40"
tracing-core = "0.1.32"
tracing-subscriber = "0.3.18"
url = { version = "2.5.2", features = ["serde"] }
wasm-streams = "0.4"
webpki = "0.22.4"
worker = { version = "0.3.3", features = ["http"] }
x509-parser = "0.15.1"
Expand Down
10 changes: 6 additions & 4 deletions crates/daphne-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ description = "Workers backend for Daphne"
all-features = true

[dependencies]
axum = "0.6.0" # held back to use http 0.2
daphne = { path = "../daphne" }
daphne-service-utils = { path = "../daphne-service-utils", features = ["durable_requests"] }
either.workspace = true
futures.workspace = true
hex.workspace = true
http = "0.2" # held back to use http 0.2
hyper.workspace = true
http.workspace = true
mappable-rc.workspace = true
p256.workspace = true
prio.workspace = true
Expand All @@ -37,8 +35,12 @@ tower.workspace = true
tracing.workspace = true
url.workspace = true

[dependencies.axum]
workspace = true
features = ["query", "json", "tokio", "http1", "http2"]

[dependencies.reqwest]
version = "0.11" # held back to use http 0.2
workspace = true
default-features = false
features = ["rustls-tls-native-roots", "json"]

Expand Down
15 changes: 10 additions & 5 deletions crates/daphne-server/examples/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use daphne_server::{
};
use daphne_service_utils::DapRole;
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tracing_subscriber::EnvFilter;
use url::Url;

Expand Down Expand Up @@ -120,11 +121,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
.init();

// hand the router to axum for it to run
let serve = axum::Server::bind(&std::net::SocketAddr::new(
"0.0.0.0".parse().unwrap(),
config.port,
))
.serve(router.into_make_service());
let serve = axum::serve(
TcpListener::bind(std::net::SocketAddr::new(
"0.0.0.0".parse().unwrap(),
config.port,
))
.await
.unwrap(),
router,
);

let ctrl_c = tokio::signal::ctrl_c();

Expand Down
6 changes: 1 addition & 5 deletions crates/daphne-server/src/router/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
use std::sync::Arc;

use axum::{
body::HttpBody,
extract::{Path, Query, State},
response::{AppendHeaders, IntoResponse},
routing::get,
Expand All @@ -22,12 +21,9 @@ use serde::Deserialize;

use super::{AxumDapResponse, DaphneService};

pub fn add_aggregator_routes<A, B>(router: super::Router<A, B>) -> super::Router<A, B>
pub fn add_aggregator_routes<A>(router: super::Router<A>) -> super::Router<A>
where
A: DapAggregator + DaphneService + Send + Sync + 'static,
B: Send + HttpBody + 'static,
B::Data: Send,
B::Error: Send + Sync,
{
router.route("/:version/hpke_config", get(hpke_config))
}
Expand Down
53 changes: 24 additions & 29 deletions crates/daphne-server/src/router/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use std::io::Cursor;

use axum::{
async_trait,
body::{Bytes, HttpBody},
extract::{FromRequest, FromRequestParts, Path},
body::Bytes,
extract::{FromRequest, FromRequestParts, Path, Request},
};
use daphne::{
constants::DapMediaType,
Expand All @@ -19,7 +19,7 @@ use daphne::{
DapError, DapRequest, DapRequestMeta, DapVersion,
};
use daphne_service_utils::{bearer_token::BearerToken, http_headers};
use http::{header::CONTENT_TYPE, HeaderMap, Request};
use http::{header::CONTENT_TYPE, HeaderMap};
use prio::codec::ParameterizedDecode;
use serde::Deserialize;

Expand Down Expand Up @@ -124,17 +124,15 @@ mod resource_parsers {
pub(super) struct UnauthenticatedDapRequestExtractor<P, R>(pub DapRequest<P, R>);

#[async_trait]
impl<S, B, P, R> FromRequest<S, B, P> for UnauthenticatedDapRequestExtractor<P, R>
impl<S, P, R> FromRequest<S, P> for UnauthenticatedDapRequestExtractor<P, R>
where
P: DecodeFromDapHttpBody,
S: DaphneService + Send + Sync,
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
R: resource_parsers::Resource,
{
type Rejection = AxumDapResponse;

async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct PathParams<R: resource_parsers::Resource> {
Expand Down Expand Up @@ -177,7 +175,7 @@ where

// TODO(mendess): this allocates needlessly, if prio supported some kind of
// AsyncParameterizedDecode we could avoid this allocation
let payload = hyper::body::to_bytes(body).await;
let payload = axum::body::to_bytes(body, usize::MAX).await;

let Ok(payload) = payload else {
return Err(AxumDapResponse::new_error(
Expand Down Expand Up @@ -244,18 +242,16 @@ pub(super) struct DapRequestExtractor<const SENDER: dap_sender::DapSender, P, R
);

#[async_trait]
impl<const SENDER: dap_sender::DapSender, S, B, P, R> FromRequest<S, B, P>
impl<const SENDER: dap_sender::DapSender, S, P, R> FromRequest<S, P>
for DapRequestExtractor<SENDER, P, R>
where
P: DecodeFromDapHttpBody + Send + Sync,
S: DaphneService + Send + Sync,
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
R: resource_parsers::Resource,
{
type Rejection = AxumDapResponse;

async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let bearer_token = extract_header_as_str(req.headers(), http_headers::DAP_AUTH_TOKEN)
.map(BearerToken::from);
let cf_tls_client_auth =
Expand Down Expand Up @@ -336,9 +332,8 @@ mod test {
};

use axum::{
body::{Body, HttpBody},
extract::State,
http::{header::CONTENT_TYPE, Request, StatusCode},
body::Body,
extract::{Request, State},
response::IntoResponse,
routing::get,
Router,
Expand All @@ -365,6 +360,7 @@ mod test {
use crate::metrics::{DaphnePromServiceMetrics, DaphneServiceMetrics};

use super::{dap_sender::FROM_LEADER, resource_parsers};
use http::{header, StatusCode};

const BEARER_TOKEN: &str = "test-token";

Expand All @@ -382,11 +378,8 @@ mod test {
/// - `/:version/:task_id/parse-mandatory-fields` uses the [`UnauthenticatedDapRequestExtractor`]
/// - `/:version/:agg_job_id/parse-agg-job-id` uses the [`UnauthenticatedDapRequestExtractor`]
/// - `/:version/:collect_job_id/parse-collect-job-id` uses the [`UnauthenticatedDapRequestExtractor`]
async fn test<B, R>(req: Request<B>) -> Result<DapRequest<R>, StatusCode>
async fn test<R>(req: Request) -> Result<DapRequest<R>, StatusCode>
where
B: Send + Sync + 'static + HttpBody,
B::Data: Send,
B::Error: Send + Sync + std::error::Error,
R: resource_parsers::Resource + 'static,
{
type Channel<R> = Sender<DapRequest<R>>;
Expand Down Expand Up @@ -472,7 +465,9 @@ mod test {
// get the request sent through the channel in the handler
StatusCode::OK => Ok(rx.recv().now_or_never().unwrap().unwrap()),
code => {
let payload = hyper::body::to_bytes(resp.into_body()).await.unwrap();
let payload = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
eprintln!(
"body was: {}",
String::from_utf8_lossy(&payload).into_owned()
Expand All @@ -488,14 +483,14 @@ mod test {

async fn parse_mandatory_fields(version: DapVersion) {
let task_id = mk_task_id();
let req = test::<_, resource::None>(
let req = test::<resource::None>(
Request::builder()
.uri(format!(
"/{version}/{}/parse-mandatory-fields",
task_id.to_base64url()
))
.header(
CONTENT_TYPE,
header::CONTENT_TYPE,
DapMediaType::AggregateShareReq
.as_str_for_version(version)
.unwrap(),
Expand All @@ -517,7 +512,7 @@ mod test {
let task_id = mk_task_id();
let agg_job_id = AggregationJobId(thread_rng().gen());

let req = test::<_, resource::AggregationJobId>(
let req = test::<resource::AggregationJobId>(
Request::builder()
.uri(format!(
"/{version}/{}/{}/parse-agg-job-id",
Expand All @@ -541,7 +536,7 @@ mod test {
let task_id = mk_task_id();
let collect_job_id = CollectionJobId(thread_rng().gen());

let req = test::<_, resource::CollectionJobId>(
let req = test::<resource::CollectionJobId>(
Request::builder()
.uri(format!(
"/{version}/{}/{}/parse-collect-job-id",
Expand All @@ -562,7 +557,7 @@ mod test {
async_test_versions! { parse_collect_job_id }

async fn incorrect_bearer_tokens_are_rejected(version: DapVersion) {
let status_code = test::<_, resource::None>(
let status_code = test::<resource::None>(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(http_headers::DAP_AUTH_TOKEN, "something incorrect")
Expand All @@ -578,7 +573,7 @@ mod test {
async_test_versions! { incorrect_bearer_tokens_are_rejected }

async fn missing_auth_is_rejected(version: DapVersion) {
let status_code = test::<_, resource::None>(
let status_code = test::<resource::None>(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.body(Body::empty())
Expand Down Expand Up @@ -614,7 +609,7 @@ mod test {
},
};

let req = test::<_, resource::None>(
let req = test::<resource::None>(
Request::builder()
.uri(format!(
"/{version}/{}/auth",
Expand All @@ -640,7 +635,7 @@ mod test {
async_test_versions! { mtls_auth_is_enough }

async fn incorrect_bearer_tokens_are_rejected_even_with_mtls_auth(version: DapVersion) {
let code = test::<_, resource::None>(
let code = test::<resource::None>(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(http_headers::DAP_AUTH_TOKEN, "something incorrect")
Expand All @@ -657,7 +652,7 @@ mod test {
async_test_versions! { incorrect_bearer_tokens_are_rejected_even_with_mtls_auth }

async fn invalid_mtls_auth_is_rejected_despite_correct_bearer_token(version: DapVersion) {
let code = test::<_, resource::None>(
let code = test::<resource::None>(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
Expand Down
8 changes: 1 addition & 7 deletions crates/daphne-server/src/router/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
use std::sync::Arc;

use axum::{
body::HttpBody,
extract::State,
routing::{post, put},
};
Expand All @@ -20,12 +19,7 @@ use super::{
extractor::dap_sender::FROM_LEADER, AxumDapResponse, DapRequestExtractor, DaphneService,
};

pub(super) fn add_helper_routes<B>(router: super::Router<App, B>) -> super::Router<App, B>
where
B: Send + HttpBody + 'static,
B::Data: Send,
B::Error: Send + Sync,
{
pub(super) fn add_helper_routes(router: super::Router<App>) -> super::Router<App> {
router
.route(
"/:version/tasks/:task_id/aggregation_jobs/:agg_job_id",
Expand Down
14 changes: 5 additions & 9 deletions crates/daphne-server/src/router/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
use std::sync::Arc;

use axum::{
body::HttpBody,
extract::{Path, State},
http::{Request, StatusCode},
extract::{Path, Request, State},
http::StatusCode,
middleware::{from_fn, Next},
response::{IntoResponse, Response},
routing::{get, post, put},
Expand All @@ -33,9 +32,9 @@ struct PathVersion {
presented_version: DapVersion,
}

fn require_version<B: Send + 'static>(
fn require_version(
expected_version: DapVersion,
) -> impl Copy + Fn(Path<PathVersion>, Request<B>, Next<B>) -> BoxFuture<'static, Response> {
) -> impl Copy + Fn(Path<PathVersion>, Request, Next) -> BoxFuture<'static, Response> {
move |Path(PathVersion { presented_version }), req, next| {
async move {
if presented_version != expected_version {
Expand All @@ -47,12 +46,9 @@ fn require_version<B: Send + 'static>(
}
}

pub(super) fn add_leader_routes<A, B>(router: super::Router<A, B>) -> super::Router<A, B>
pub(super) fn add_leader_routes<A>(router: super::Router<A>) -> super::Router<A>
where
A: DapLeader + DaphneService + Send + Sync + 'static,
B: Send + HttpBody + 'static,
B::Data: Send,
B::Error: Send + Sync,
{
router
.route(
Expand Down
Loading

0 comments on commit 9572f3e

Please sign in to comment.