From 95eb6f1ae8fcbec7fc97b55a5a8798d875bc04a9 Mon Sep 17 00:00:00 2001 From: Jonathan Hoyland Date: Mon, 18 Nov 2024 14:25:07 +0000 Subject: [PATCH] Renumber errors --- crates/dapf/src/functions/helper.rs | 5 +- crates/daphne/src/messages/mod.rs | 262 +++++++++++++++++++++++--- crates/daphne/src/roles/helper.rs | 6 +- crates/daphne/src/roles/leader/mod.rs | 5 +- crates/daphne/src/roles/mod.rs | 8 +- 5 files changed, 252 insertions(+), 34 deletions(-) diff --git a/crates/dapf/src/functions/helper.rs b/crates/dapf/src/functions/helper.rs index 1392458c6..1d69c2a67 100644 --- a/crates/dapf/src/functions/helper.rs +++ b/crates/dapf/src/functions/helper.rs @@ -12,7 +12,7 @@ use daphne::{ DapVersion, }; use daphne_service_utils::{bearer_token::BearerToken, http_headers}; -use prio::codec::{Decode as _, ParameterizedEncode as _}; +use prio::codec::{ParameterizedDecode as _, ParameterizedEncode as _}; use reqwest::header; use url::Url; @@ -58,7 +58,8 @@ impl HttpClient { } else if !resp.status().is_success() { Err(response_to_anyhow(resp).await).context("while running an AggregationJobInitReq") } else { - AggregationJobResp::get_decoded( + AggregationJobResp::get_decoded_with_param( + &version, &resp .bytes() .await diff --git a/crates/daphne/src/messages/mod.rs b/crates/daphne/src/messages/mod.rs index ab2f5aa0a..ae765b683 100644 --- a/crates/daphne/src/messages/mod.rs +++ b/crates/daphne/src/messages/mod.rs @@ -526,19 +526,25 @@ pub struct Transition { pub var: TransitionVar, } -impl Encode for Transition { - fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { +impl ParameterizedEncode for Transition { + fn encode_with_param( + &self, + version: &DapVersion, + bytes: &mut Vec, + ) -> Result<(), CodecError> { self.report_id.encode(bytes)?; - self.var.encode(bytes)?; + self.var.encode_with_param(version, bytes)?; Ok(()) } } - -impl Decode for Transition { - fn decode(bytes: &mut Cursor<&[u8]>) -> Result { +impl ParameterizedDecode for Transition { + fn decode_with_param( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, + ) -> Result { Ok(Self { report_id: ReportId::decode(bytes)?, - var: TransitionVar::decode(bytes)?, + var: TransitionVar::decode_with_param(version, bytes)?, }) } } @@ -551,8 +557,12 @@ pub enum TransitionVar { Failed(TransitionFailure), } -impl Encode for TransitionVar { - fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { +impl ParameterizedEncode for TransitionVar { + fn encode_with_param( + &self, + version: &DapVersion, + bytes: &mut Vec, + ) -> Result<(), CodecError> { match self { TransitionVar::Continued(vdaf_message) => { 0_u8.encode(bytes)?; @@ -560,18 +570,23 @@ impl Encode for TransitionVar { } TransitionVar::Failed(err) => { 2_u8.encode(bytes)?; - err.encode(bytes)?; + err.encode_with_param(version, bytes)?; } }; Ok(()) } } -impl Decode for TransitionVar { - fn decode(bytes: &mut Cursor<&[u8]>) -> Result { +impl ParameterizedDecode for TransitionVar { + fn decode_with_param( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, + ) -> Result { match u8::decode(bytes)? { 0 => Ok(Self::Continued(decode_u32_bytes(bytes)?)), - 2 => Ok(Self::Failed(TransitionFailure::decode(bytes)?)), + 2 => Ok(Self::Failed(TransitionFailure::decode_with_param( + version, bytes, + )?)), _ => Err(CodecError::UnexpectedValue), } } @@ -582,6 +597,22 @@ impl Decode for TransitionVar { #[serde(rename_all = "snake_case")] #[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] pub enum TransitionFailure { + Reserved, + BatchCollected, + ReportReplayed, + ReportDropped, + HpkeUnknownConfigId, + HpkeDecryptError, + VdafPrepError, + BatchSaturated, + TaskExpired, + InvalidMessage, + ReportTooEarly, + TaskNotStarted, +} + +#[derive(Clone, Copy)] +enum TransitionFailureDraft09 { BatchCollected = 0, ReportReplayed = 1, ReportDropped = 2, @@ -594,7 +625,22 @@ pub enum TransitionFailure { ReportTooEarly = 9, } -impl TryFrom for TransitionFailure { +#[derive(Clone, Copy)] +enum TransitionFailureLatest { + Reserved = 0, + BatchCollected = 1, + ReportReplayed = 2, + ReportDropped = 3, + HpkeUnknownConfigId = 4, + HpkeDecryptError = 5, + VdafPrepError = 6, + TaskExpired = 7, + InvalidMessage = 8, + ReportTooEarly = 9, + TaskNotStarted = 10, +} + +impl TryFrom for TransitionFailureDraft09 { type Error = CodecError; fn try_from(v: u8) -> Result { @@ -614,21 +660,175 @@ impl TryFrom for TransitionFailure { } } -impl Encode for TransitionFailure { +impl TryFrom for TransitionFailureLatest { + type Error = CodecError; + + fn try_from(v: u8) -> Result { + match v { + b if b == Self::Reserved as u8 => Ok(Self::Reserved), + b if b == Self::BatchCollected as u8 => Ok(Self::BatchCollected), + b if b == Self::ReportReplayed as u8 => Ok(Self::ReportReplayed), + b if b == Self::ReportDropped as u8 => Ok(Self::ReportDropped), + b if b == Self::HpkeUnknownConfigId as u8 => Ok(Self::HpkeUnknownConfigId), + b if b == Self::HpkeDecryptError as u8 => Ok(Self::HpkeDecryptError), + b if b == Self::VdafPrepError as u8 => Ok(Self::VdafPrepError), + b if b == Self::TaskExpired as u8 => Ok(Self::TaskExpired), + b if b == Self::InvalidMessage as u8 => Ok(Self::InvalidMessage), + b if b == Self::ReportTooEarly as u8 => Ok(Self::ReportTooEarly), + b if b == Self::TaskNotStarted as u8 => Ok(Self::TaskNotStarted), + _ => Err(CodecError::UnexpectedValue), + } + } +} + +impl TryFrom for TransitionFailure { + type Error = CodecError; + + fn try_from(v: TransitionFailureDraft09) -> Result { + match v { + TransitionFailureDraft09::BatchCollected => Ok(TransitionFailure::BatchCollected), + TransitionFailureDraft09::ReportReplayed => Ok(TransitionFailure::ReportReplayed), + TransitionFailureDraft09::ReportDropped => Ok(TransitionFailure::ReportDropped), + TransitionFailureDraft09::HpkeUnknownConfigId => { + Ok(TransitionFailure::HpkeUnknownConfigId) + } + TransitionFailureDraft09::HpkeDecryptError => Ok(TransitionFailure::HpkeDecryptError), + TransitionFailureDraft09::VdafPrepError => Ok(TransitionFailure::VdafPrepError), + TransitionFailureDraft09::BatchSaturated => Ok(TransitionFailure::BatchSaturated), + TransitionFailureDraft09::TaskExpired => Ok(TransitionFailure::TaskExpired), + TransitionFailureDraft09::InvalidMessage => Ok(TransitionFailure::InvalidMessage), + TransitionFailureDraft09::ReportTooEarly => Ok(TransitionFailure::ReportTooEarly), + } + } +} + +impl TryFrom<&TransitionFailure> for TransitionFailureDraft09 { + type Error = CodecError; + + fn try_from(v: &TransitionFailure) -> Result { + match v { + TransitionFailure::BatchCollected => Ok(TransitionFailureDraft09::BatchCollected), + TransitionFailure::ReportReplayed => Ok(TransitionFailureDraft09::ReportReplayed), + TransitionFailure::ReportDropped => Ok(TransitionFailureDraft09::ReportDropped), + TransitionFailure::HpkeUnknownConfigId => { + Ok(TransitionFailureDraft09::HpkeUnknownConfigId) + } + TransitionFailure::HpkeDecryptError => Ok(TransitionFailureDraft09::HpkeDecryptError), + TransitionFailure::VdafPrepError => Ok(TransitionFailureDraft09::VdafPrepError), + TransitionFailure::BatchSaturated => Ok(TransitionFailureDraft09::BatchSaturated), + TransitionFailure::TaskExpired => Ok(TransitionFailureDraft09::TaskExpired), + TransitionFailure::InvalidMessage => Ok(TransitionFailureDraft09::InvalidMessage), + TransitionFailure::ReportTooEarly => Ok(TransitionFailureDraft09::ReportTooEarly), + _ => Err(CodecError::UnexpectedValue), + } + } +} + +impl TryFrom for TransitionFailure { + type Error = CodecError; + + fn try_from(v: TransitionFailureLatest) -> Result { + match v { + TransitionFailureLatest::Reserved => Ok(TransitionFailure::Reserved), + TransitionFailureLatest::BatchCollected => Ok(TransitionFailure::BatchCollected), + TransitionFailureLatest::ReportReplayed => Ok(TransitionFailure::ReportReplayed), + TransitionFailureLatest::ReportDropped => Ok(TransitionFailure::ReportDropped), + TransitionFailureLatest::HpkeUnknownConfigId => { + Ok(TransitionFailure::HpkeUnknownConfigId) + } + TransitionFailureLatest::HpkeDecryptError => Ok(TransitionFailure::HpkeDecryptError), + TransitionFailureLatest::VdafPrepError => Ok(TransitionFailure::VdafPrepError), + TransitionFailureLatest::TaskExpired => Ok(TransitionFailure::TaskExpired), + TransitionFailureLatest::InvalidMessage => Ok(TransitionFailure::InvalidMessage), + TransitionFailureLatest::ReportTooEarly => Ok(TransitionFailure::ReportTooEarly), + TransitionFailureLatest::TaskNotStarted => Ok(TransitionFailure::TaskNotStarted), + } + } +} + +#[expect(clippy::match_wildcard_for_single_variants)] +impl TryFrom<&TransitionFailure> for TransitionFailureLatest { + type Error = CodecError; + + fn try_from(v: &TransitionFailure) -> Result { + match v { + TransitionFailure::Reserved => Ok(TransitionFailureLatest::Reserved), + TransitionFailure::BatchCollected => Ok(TransitionFailureLatest::BatchCollected), + TransitionFailure::ReportReplayed => Ok(TransitionFailureLatest::ReportReplayed), + TransitionFailure::ReportDropped => Ok(TransitionFailureLatest::ReportDropped), + TransitionFailure::HpkeUnknownConfigId => { + Ok(TransitionFailureLatest::HpkeUnknownConfigId) + } + TransitionFailure::HpkeDecryptError => Ok(TransitionFailureLatest::HpkeDecryptError), + TransitionFailure::VdafPrepError => Ok(TransitionFailureLatest::VdafPrepError), + TransitionFailure::TaskExpired => Ok(TransitionFailureLatest::TaskExpired), + TransitionFailure::InvalidMessage => Ok(TransitionFailureLatest::InvalidMessage), + TransitionFailure::ReportTooEarly => Ok(TransitionFailureLatest::ReportTooEarly), + TransitionFailure::TaskNotStarted => Ok(TransitionFailureLatest::TaskNotStarted), + _ => Err(CodecError::UnexpectedValue), + } + } +} + +impl Encode for TransitionFailureDraft09 { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + (*self as u8).encode(bytes) + } +} + +impl Encode for TransitionFailureLatest { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { (*self as u8).encode(bytes) } } -impl Decode for TransitionFailure { +impl ParameterizedEncode for TransitionFailure { + fn encode_with_param( + &self, + version: &DapVersion, + bytes: &mut Vec, + ) -> Result<(), CodecError> { + match version { + DapVersion::Draft09 => { + TransitionFailureDraft09::try_from(self)?.encode(bytes)?; + Ok(()) + } + DapVersion::Latest => { + TransitionFailureLatest::try_from(self)?.encode(bytes)?; + Ok(()) + } + } + } +} + +impl Decode for TransitionFailureDraft09 { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { u8::decode(bytes)?.try_into() } } +impl Decode for TransitionFailureLatest { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + u8::decode(bytes)?.try_into() + } +} + +impl ParameterizedDecode for TransitionFailure { + fn decode_with_param( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + match version { + DapVersion::Draft09 => TransitionFailureDraft09::decode(bytes)?.try_into(), + DapVersion::Latest => TransitionFailureLatest::decode(bytes)?.try_into(), + } + } +} + impl std::fmt::Display for TransitionFailure { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { + Self::Reserved => write!(f, "reserved"), Self::BatchCollected => write!(f, "batch_collected"), Self::ReportReplayed => write!(f, "report_replayed"), Self::ReportDropped => write!(f, "report_dropped"), @@ -639,6 +839,7 @@ impl std::fmt::Display for TransitionFailure { Self::TaskExpired => write!(f, "task_expired"), Self::InvalidMessage => write!(f, "invalid_message"), Self::ReportTooEarly => write!(f, "report_too_early"), + Self::TaskNotStarted => write!(f, "task_not_started"), } } } @@ -649,16 +850,23 @@ pub struct AggregationJobResp { pub transitions: Vec, } -impl Encode for AggregationJobResp { - fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { - encode_u32_items(bytes, &(), &self.transitions) +impl ParameterizedEncode for AggregationJobResp { + fn encode_with_param( + &self, + version: &DapVersion, + bytes: &mut Vec, + ) -> Result<(), CodecError> { + encode_u32_items(bytes, version, &self.transitions) } } -impl Decode for AggregationJobResp { - fn decode(bytes: &mut Cursor<&[u8]>) -> Result { +impl ParameterizedDecode for AggregationJobResp { + fn decode_with_param( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, + ) -> Result { Ok(Self { - transitions: decode_u32_items(&(), bytes)?, + transitions: decode_u32_items(version, bytes)?, }) } } @@ -1212,7 +1420,7 @@ mod test { use crate::test_versions; use hpke_rs::HpkePublicKey; - use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode}; + use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode}; use rand::prelude::*; fn read_report(version: DapVersion) { @@ -1389,9 +1597,13 @@ mod test { }, ], }; - println!("want {:?}", want.get_encoded().unwrap()); + println!( + "want {:?}", + want.get_encoded_with_param(&DapVersion::Latest).unwrap() + ); - let got = AggregationJobResp::get_decoded(TEST_DATA).unwrap(); + let got = + AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, TEST_DATA).unwrap(); assert_eq!(got, want); } diff --git a/crates/daphne/src/roles/helper.rs b/crates/daphne/src/roles/helper.rs index 98f7d8a6f..3ba9e0a3d 100644 --- a/crates/daphne/src/roles/helper.rs +++ b/crates/daphne/src/roles/helper.rs @@ -4,7 +4,7 @@ use std::{collections::HashMap, sync::Once}; use async_trait::async_trait; -use prio::codec::{Encode, ParameterizedDecode}; +use prio::codec::{Encode, ParameterizedDecode, ParameterizedEncode}; use super::{check_batch, resolve_task_config, DapAggregator}; use crate::{ @@ -85,7 +85,9 @@ pub async fn handle_agg_job_init_req( Ok(DapResponse { version, media_type: DapMediaType::AggregationJobResp, - payload: agg_job_resp.get_encoded().map_err(DapError::encoding)?, + payload: agg_job_resp + .get_encoded_with_param(&version) + .map_err(DapError::encoding)?, }) } diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index 8144f2c01..13624386c 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -336,8 +336,9 @@ async fn run_agg_job( }, ) .await?; - let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload) - .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; + let agg_job_resp = + AggregationJobResp::get_decoded_with_param(&task_config.version, &resp.payload) + .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; // Handle AggregationJobResp. let agg_span = diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index a3b130d11..b31ec9d15 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -149,7 +149,7 @@ mod test { DapTaskConfig, DapTaskParameters, DapVersion, }; use assert_matches::assert_matches; - use prio::codec::{Decode, Encode}; + use prio::codec::{Encode, ParameterizedDecode}; #[cfg(feature = "experimental")] use prio::{idpf::IdpfInput, vdaf::poplar1::Poplar1AggregationParam}; use rand::{thread_rng, Rng}; @@ -726,7 +726,8 @@ mod test { .await; // Get AggregationJobResp and then extract the transition data from inside. - let agg_job_resp = AggregationJobResp::get_decoded( + let agg_job_resp = AggregationJobResp::get_decoded_with_param( + &version, &helper::handle_agg_job_init_req(&*t.helper, req, Default::default()) .await .unwrap() @@ -754,7 +755,8 @@ mod test { .await; // Get AggregationJobResp and then extract the transition data from inside. - let agg_job_resp = AggregationJobResp::get_decoded( + let agg_job_resp = AggregationJobResp::get_decoded_with_param( + &version, &helper::handle_agg_job_init_req(&*t.helper, req, Default::default()) .await .unwrap()