Skip to content

Commit

Permalink
Poll collect endpoint with GET.
Browse files Browse the repository at this point in the history
  • Loading branch information
jhoyla committed Nov 14, 2024
1 parent 2da37f6 commit 87148b6
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 38 deletions.
82 changes: 45 additions & 37 deletions crates/daphne-server/src/router/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use std::sync::Arc;

use axum::{
body::HttpBody,
extract::State,
http::StatusCode,
extract::{Path, State},
http::{Request, StatusCode},
middleware::{from_fn, Next},
response::{IntoResponse, Response},
routing::{post, put},
routing::{get, post, put},
};
use daphne::{
constants::DapMediaType,
Expand All @@ -23,6 +24,28 @@ use super::{
extractor::dap_sender::FROM_COLLECTOR, AxumDapResponse, DapRequestExtractor, DaphneService,
UnauthenticatedDapRequestExtractor,
};
use futures::{future::BoxFuture, FutureExt};
use serde::Deserialize;

#[derive(Deserialize)]
struct PathVersion {
#[serde(rename = "version")]
presented_version: DapVersion,
}

fn require_version<B: Send + 'static>(
expected_version: DapVersion,
) -> impl Copy + Fn(Path<PathVersion>, Request<B>, Next<B>) -> BoxFuture<'static, Response> {
move |Path(PathVersion { presented_version }), req, next| {
async move {
if presented_version != expected_version {
return StatusCode::METHOD_NOT_ALLOWED.into_response();
}
next.run(req).await
}
.boxed()
}
}

pub(super) fn add_leader_routes<A, B>(router: super::Router<A, B>) -> super::Router<A, B>
where
Expand All @@ -32,11 +55,25 @@ where
B::Error: Send + Sync,
{
router
.route("/:version/tasks/:task_id/reports", put(upload_draft09))
.route("/:version/tasks/:task_id/reports", post(upload_draft13))
.route(
"/:version/tasks/:task_id/reports",
put(upload).layer(from_fn(require_version(DapVersion::Draft09))),
)
.route(
"/:version/tasks/:task_id/reports",
post(upload).layer(from_fn(require_version(DapVersion::Latest))),
)
.route(
"/:version/tasks/:task_id/collection_jobs/:collect_job_id",
put(start_collection_job).post(collect),
put(start_collection_job),
)
.route(
"/:version/tasks/:task_id/collection_jobs/:collect_job_id",
post(poll_collect).layer(from_fn(require_version(DapVersion::Draft09))),
)
.route(
"/:version/tasks/:task_id/collection_jobs/:collect_job_id",
get(poll_collect).layer(from_fn(require_version(DapVersion::Latest))),
)
}

Expand All @@ -47,7 +84,7 @@ where
version = ?req.version,
)
)]
async fn upload_draft13<A>(
async fn upload<A>(
State(app): State<Arc<A>>,
UnauthenticatedDapRequestExtractor(req): UnauthenticatedDapRequestExtractor<
messages::Report,
Expand All @@ -57,41 +94,12 @@ async fn upload_draft13<A>(
where
A: DapLeader + DaphneService + Send + Sync,
{
if req.version == DapVersion::Draft09 {
return (
StatusCode::METHOD_NOT_ALLOWED,
format!("route not implemented for version {}", req.version),
)
.into_response();
}
match leader::handle_upload_req(&*app, req).await {
Ok(()) => StatusCode::OK.into_response(),
Err(e) => AxumDapResponse::new_error(e, app.server_metrics()).into_response(),
}
}

async fn upload_draft09<A>(
State(app): State<Arc<A>>,
UnauthenticatedDapRequestExtractor(req): UnauthenticatedDapRequestExtractor<
messages::Report,
resource::None,
>,
) -> Response
where
A: DapLeader + DaphneService + Send + Sync,
{
if req.version != DapVersion::Draft09 {
return (
StatusCode::METHOD_NOT_ALLOWED,
format!("route not implemented for version {}", req.version),
)
.into_response();
}
match leader::handle_upload_req(&*app, req).await {
Ok(()) => StatusCode::OK.into_response(),
Err(e) => AxumDapResponse::new_error(e, app.server_metrics()).into_response(),
}
}
#[tracing::instrument(
skip_all,
fields(
Expand Down Expand Up @@ -123,7 +131,7 @@ where
version = ?req.version,
)
)]
async fn collect<A>(
async fn poll_collect<A>(
State(app): State<Arc<A>>,
DapRequestExtractor(req): DapRequestExtractor<FROM_COLLECTOR, (), resource::CollectionJobId>,
) -> Response
Expand Down
80 changes: 80 additions & 0 deletions crates/daphne-server/tests/e2e/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,86 @@ async fn leader_collect_abort_unknown_request(version: DapVersion) {

async_test_versions! { leader_collect_abort_unknown_request }

async fn leader_collect_back_compat(version: DapVersion) {
let t = TestRunner::default_with_version(version).await;
let batch_interval = t.batch_interval();

let client = t.http_client();
let hpke_config_list = t.get_hpke_configs(version, client).await.unwrap();
let path = t.upload_path();
let method = match version {
DapVersion::Draft09 => &Method::PUT,
DapVersion::Latest => &Method::POST,
};
let expected_status = 405;

// The reports are uploaded in the background.
let mut rng = thread_rng();
let mut time_min = u64::MAX;
let mut time_max = 0u64;
for _ in 0..t.task_config.min_batch_size {
let now = rng.gen_range(TestRunner::report_interval(&batch_interval));
time_min = min(time_min, now);
time_max = max(time_max, now);
t.leader_request_expect_ok(
client,
&path,
method,
DapMediaType::Report,
None,
t.task_config
.vdaf
.produce_report(
&hpke_config_list,
now,
&t.task_id,
DapMeasurement::U64(1),
version,
)
.unwrap()
.get_encoded_with_param(&version)
.unwrap(),
)
.await
.unwrap();
}

// Get the collect URI.
let agg_param = DapAggregationParam::Empty;
let collect_req = CollectionReq {
query: Query::TimeInterval { batch_interval },
agg_param: agg_param.get_encoded().unwrap(),
};
let collect_uri = t
.leader_post_collect(
client,
collect_req.get_encoded_with_param(&t.version).unwrap(),
)
.await
.unwrap();
println!("collect_uri: {collect_uri}");

let builder = match version {
DapVersion::Draft09 => client.get(collect_uri.as_str()),
DapVersion::Latest => client.post(collect_uri.as_str()),
};
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN),
reqwest::header::HeaderValue::from_str(&t.collector_bearer_token)
.expect("couldn't parse bearer token"),
);
let resp = builder
.headers(headers)
.send()
.await
.expect("failed to get a response");

assert_eq!(resp.status(), expected_status);
}

async_test_versions! { leader_collect_back_compat }

async fn leader_collect_accept_global_config_max_batch_duration(version: DapVersion) {
let t = TestRunner::default_with_version(version).await;
let client = t.http_client();
Expand Down
5 changes: 4 additions & 1 deletion crates/daphne-server/tests/e2e/test_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,10 @@ impl TestRunner {
url: &Url,
token: &str,
) -> anyhow::Result<reqwest::Response> {
let builder = client.post(url.as_str());
let builder = match self.version {
DapVersion::Draft09 => client.post(url.as_str()),
DapVersion::Latest => client.get(url.as_str()),
};
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN),
Expand Down

0 comments on commit 87148b6

Please sign in to comment.