Skip to content

Commit

Permalink
Renumber errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jhoyla committed Nov 19, 2024
1 parent 87148b6 commit 95eb6f1
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 34 deletions.
5 changes: 3 additions & 2 deletions crates/dapf/src/functions/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use daphne::{
DapVersion,
};
use daphne_service_utils::{bearer_token::BearerToken, http_headers};
use prio::codec::{Decode as _, ParameterizedEncode as _};
use prio::codec::{ParameterizedDecode as _, ParameterizedEncode as _};
use reqwest::header;
use url::Url;

Expand Down Expand Up @@ -58,7 +58,8 @@ impl HttpClient {
} else if !resp.status().is_success() {
Err(response_to_anyhow(resp).await).context("while running an AggregationJobInitReq")
} else {
AggregationJobResp::get_decoded(
AggregationJobResp::get_decoded_with_param(
&version,
&resp
.bytes()
.await
Expand Down
262 changes: 237 additions & 25 deletions crates/daphne/src/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,19 +526,25 @@ pub struct Transition {
pub var: TransitionVar,
}

impl Encode for Transition {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
impl ParameterizedEncode<DapVersion> for Transition {
fn encode_with_param(
&self,
version: &DapVersion,
bytes: &mut Vec<u8>,
) -> Result<(), CodecError> {
self.report_id.encode(bytes)?;
self.var.encode(bytes)?;
self.var.encode_with_param(version, bytes)?;
Ok(())
}
}

impl Decode for Transition {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
impl ParameterizedDecode<DapVersion> for Transition {
fn decode_with_param(
version: &DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
Ok(Self {
report_id: ReportId::decode(bytes)?,
var: TransitionVar::decode(bytes)?,
var: TransitionVar::decode_with_param(version, bytes)?,
})
}
}
Expand All @@ -551,27 +557,36 @@ pub enum TransitionVar {
Failed(TransitionFailure),
}

impl Encode for TransitionVar {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
impl ParameterizedEncode<DapVersion> for TransitionVar {
fn encode_with_param(
&self,
version: &DapVersion,
bytes: &mut Vec<u8>,
) -> Result<(), CodecError> {
match self {
TransitionVar::Continued(vdaf_message) => {
0_u8.encode(bytes)?;
encode_u32_bytes(bytes, vdaf_message)?;
}
TransitionVar::Failed(err) => {
2_u8.encode(bytes)?;
err.encode(bytes)?;
err.encode_with_param(version, bytes)?;
}
};
Ok(())
}
}

impl Decode for TransitionVar {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
impl ParameterizedDecode<DapVersion> for TransitionVar {
fn decode_with_param(
version: &DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
match u8::decode(bytes)? {
0 => Ok(Self::Continued(decode_u32_bytes(bytes)?)),
2 => Ok(Self::Failed(TransitionFailure::decode(bytes)?)),
2 => Ok(Self::Failed(TransitionFailure::decode_with_param(
version, bytes,
)?)),
_ => Err(CodecError::UnexpectedValue),
}
}
Expand All @@ -582,6 +597,22 @@ impl Decode for TransitionVar {
#[serde(rename_all = "snake_case")]
#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))]
pub enum TransitionFailure {
Reserved,
BatchCollected,
ReportReplayed,
ReportDropped,
HpkeUnknownConfigId,
HpkeDecryptError,
VdafPrepError,
BatchSaturated,
TaskExpired,
InvalidMessage,
ReportTooEarly,
TaskNotStarted,
}

#[derive(Clone, Copy)]
enum TransitionFailureDraft09 {
BatchCollected = 0,
ReportReplayed = 1,
ReportDropped = 2,
Expand All @@ -594,7 +625,22 @@ pub enum TransitionFailure {
ReportTooEarly = 9,
}

impl TryFrom<u8> for TransitionFailure {
#[derive(Clone, Copy)]
enum TransitionFailureLatest {
Reserved = 0,
BatchCollected = 1,
ReportReplayed = 2,
ReportDropped = 3,
HpkeUnknownConfigId = 4,
HpkeDecryptError = 5,
VdafPrepError = 6,
TaskExpired = 7,
InvalidMessage = 8,
ReportTooEarly = 9,
TaskNotStarted = 10,
}

impl TryFrom<u8> for TransitionFailureDraft09 {
type Error = CodecError;

fn try_from(v: u8) -> Result<Self, Self::Error> {
Expand All @@ -614,21 +660,175 @@ impl TryFrom<u8> for TransitionFailure {
}
}

impl Encode for TransitionFailure {
impl TryFrom<u8> for TransitionFailureLatest {
type Error = CodecError;

fn try_from(v: u8) -> Result<Self, Self::Error> {
match v {
b if b == Self::Reserved as u8 => Ok(Self::Reserved),
b if b == Self::BatchCollected as u8 => Ok(Self::BatchCollected),
b if b == Self::ReportReplayed as u8 => Ok(Self::ReportReplayed),
b if b == Self::ReportDropped as u8 => Ok(Self::ReportDropped),
b if b == Self::HpkeUnknownConfigId as u8 => Ok(Self::HpkeUnknownConfigId),
b if b == Self::HpkeDecryptError as u8 => Ok(Self::HpkeDecryptError),
b if b == Self::VdafPrepError as u8 => Ok(Self::VdafPrepError),
b if b == Self::TaskExpired as u8 => Ok(Self::TaskExpired),
b if b == Self::InvalidMessage as u8 => Ok(Self::InvalidMessage),
b if b == Self::ReportTooEarly as u8 => Ok(Self::ReportTooEarly),
b if b == Self::TaskNotStarted as u8 => Ok(Self::TaskNotStarted),
_ => Err(CodecError::UnexpectedValue),
}
}
}

impl TryFrom<TransitionFailureDraft09> for TransitionFailure {
type Error = CodecError;

fn try_from(v: TransitionFailureDraft09) -> Result<Self, Self::Error> {
match v {
TransitionFailureDraft09::BatchCollected => Ok(TransitionFailure::BatchCollected),
TransitionFailureDraft09::ReportReplayed => Ok(TransitionFailure::ReportReplayed),
TransitionFailureDraft09::ReportDropped => Ok(TransitionFailure::ReportDropped),
TransitionFailureDraft09::HpkeUnknownConfigId => {
Ok(TransitionFailure::HpkeUnknownConfigId)
}
TransitionFailureDraft09::HpkeDecryptError => Ok(TransitionFailure::HpkeDecryptError),
TransitionFailureDraft09::VdafPrepError => Ok(TransitionFailure::VdafPrepError),
TransitionFailureDraft09::BatchSaturated => Ok(TransitionFailure::BatchSaturated),
TransitionFailureDraft09::TaskExpired => Ok(TransitionFailure::TaskExpired),
TransitionFailureDraft09::InvalidMessage => Ok(TransitionFailure::InvalidMessage),
TransitionFailureDraft09::ReportTooEarly => Ok(TransitionFailure::ReportTooEarly),
}
}
}

impl TryFrom<&TransitionFailure> for TransitionFailureDraft09 {
type Error = CodecError;

fn try_from(v: &TransitionFailure) -> Result<Self, Self::Error> {
match v {
TransitionFailure::BatchCollected => Ok(TransitionFailureDraft09::BatchCollected),
TransitionFailure::ReportReplayed => Ok(TransitionFailureDraft09::ReportReplayed),
TransitionFailure::ReportDropped => Ok(TransitionFailureDraft09::ReportDropped),
TransitionFailure::HpkeUnknownConfigId => {
Ok(TransitionFailureDraft09::HpkeUnknownConfigId)
}
TransitionFailure::HpkeDecryptError => Ok(TransitionFailureDraft09::HpkeDecryptError),
TransitionFailure::VdafPrepError => Ok(TransitionFailureDraft09::VdafPrepError),
TransitionFailure::BatchSaturated => Ok(TransitionFailureDraft09::BatchSaturated),
TransitionFailure::TaskExpired => Ok(TransitionFailureDraft09::TaskExpired),
TransitionFailure::InvalidMessage => Ok(TransitionFailureDraft09::InvalidMessage),
TransitionFailure::ReportTooEarly => Ok(TransitionFailureDraft09::ReportTooEarly),
_ => Err(CodecError::UnexpectedValue),
}
}
}

impl TryFrom<TransitionFailureLatest> for TransitionFailure {
type Error = CodecError;

fn try_from(v: TransitionFailureLatest) -> Result<Self, Self::Error> {
match v {
TransitionFailureLatest::Reserved => Ok(TransitionFailure::Reserved),
TransitionFailureLatest::BatchCollected => Ok(TransitionFailure::BatchCollected),
TransitionFailureLatest::ReportReplayed => Ok(TransitionFailure::ReportReplayed),
TransitionFailureLatest::ReportDropped => Ok(TransitionFailure::ReportDropped),
TransitionFailureLatest::HpkeUnknownConfigId => {
Ok(TransitionFailure::HpkeUnknownConfigId)
}
TransitionFailureLatest::HpkeDecryptError => Ok(TransitionFailure::HpkeDecryptError),
TransitionFailureLatest::VdafPrepError => Ok(TransitionFailure::VdafPrepError),
TransitionFailureLatest::TaskExpired => Ok(TransitionFailure::TaskExpired),
TransitionFailureLatest::InvalidMessage => Ok(TransitionFailure::InvalidMessage),
TransitionFailureLatest::ReportTooEarly => Ok(TransitionFailure::ReportTooEarly),
TransitionFailureLatest::TaskNotStarted => Ok(TransitionFailure::TaskNotStarted),
}
}
}

#[expect(clippy::match_wildcard_for_single_variants)]
impl TryFrom<&TransitionFailure> for TransitionFailureLatest {
type Error = CodecError;

fn try_from(v: &TransitionFailure) -> Result<Self, Self::Error> {
match v {
TransitionFailure::Reserved => Ok(TransitionFailureLatest::Reserved),
TransitionFailure::BatchCollected => Ok(TransitionFailureLatest::BatchCollected),
TransitionFailure::ReportReplayed => Ok(TransitionFailureLatest::ReportReplayed),
TransitionFailure::ReportDropped => Ok(TransitionFailureLatest::ReportDropped),
TransitionFailure::HpkeUnknownConfigId => {
Ok(TransitionFailureLatest::HpkeUnknownConfigId)
}
TransitionFailure::HpkeDecryptError => Ok(TransitionFailureLatest::HpkeDecryptError),
TransitionFailure::VdafPrepError => Ok(TransitionFailureLatest::VdafPrepError),
TransitionFailure::TaskExpired => Ok(TransitionFailureLatest::TaskExpired),
TransitionFailure::InvalidMessage => Ok(TransitionFailureLatest::InvalidMessage),
TransitionFailure::ReportTooEarly => Ok(TransitionFailureLatest::ReportTooEarly),
TransitionFailure::TaskNotStarted => Ok(TransitionFailureLatest::TaskNotStarted),
_ => Err(CodecError::UnexpectedValue),
}
}
}

impl Encode for TransitionFailureDraft09 {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
(*self as u8).encode(bytes)
}
}

impl Encode for TransitionFailureLatest {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
(*self as u8).encode(bytes)
}
}

impl Decode for TransitionFailure {
impl ParameterizedEncode<DapVersion> for TransitionFailure {
fn encode_with_param(
&self,
version: &DapVersion,
bytes: &mut Vec<u8>,
) -> Result<(), CodecError> {
match version {
DapVersion::Draft09 => {
TransitionFailureDraft09::try_from(self)?.encode(bytes)?;
Ok(())
}
DapVersion::Latest => {
TransitionFailureLatest::try_from(self)?.encode(bytes)?;
Ok(())
}
}
}
}

impl Decode for TransitionFailureDraft09 {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
u8::decode(bytes)?.try_into()
}
}

impl Decode for TransitionFailureLatest {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
u8::decode(bytes)?.try_into()
}
}

impl ParameterizedDecode<DapVersion> for TransitionFailure {
fn decode_with_param(
version: &DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
match version {
DapVersion::Draft09 => TransitionFailureDraft09::decode(bytes)?.try_into(),
DapVersion::Latest => TransitionFailureLatest::decode(bytes)?.try_into(),
}
}
}

impl std::fmt::Display for TransitionFailure {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Reserved => write!(f, "reserved"),
Self::BatchCollected => write!(f, "batch_collected"),
Self::ReportReplayed => write!(f, "report_replayed"),
Self::ReportDropped => write!(f, "report_dropped"),
Expand All @@ -639,6 +839,7 @@ impl std::fmt::Display for TransitionFailure {
Self::TaskExpired => write!(f, "task_expired"),
Self::InvalidMessage => write!(f, "invalid_message"),
Self::ReportTooEarly => write!(f, "report_too_early"),
Self::TaskNotStarted => write!(f, "task_not_started"),
}
}
}
Expand All @@ -649,16 +850,23 @@ pub struct AggregationJobResp {
pub transitions: Vec<Transition>,
}

impl Encode for AggregationJobResp {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
encode_u32_items(bytes, &(), &self.transitions)
impl ParameterizedEncode<DapVersion> for AggregationJobResp {
fn encode_with_param(
&self,
version: &DapVersion,
bytes: &mut Vec<u8>,
) -> Result<(), CodecError> {
encode_u32_items(bytes, version, &self.transitions)
}
}

impl Decode for AggregationJobResp {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
impl ParameterizedDecode<DapVersion> for AggregationJobResp {
fn decode_with_param(
version: &DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
Ok(Self {
transitions: decode_u32_items(&(), bytes)?,
transitions: decode_u32_items(version, bytes)?,
})
}
}
Expand Down Expand Up @@ -1212,7 +1420,7 @@ mod test {

use crate::test_versions;
use hpke_rs::HpkePublicKey;
use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode};
use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode};
use rand::prelude::*;

fn read_report(version: DapVersion) {
Expand Down Expand Up @@ -1389,9 +1597,13 @@ mod test {
},
],
};
println!("want {:?}", want.get_encoded().unwrap());
println!(
"want {:?}",
want.get_encoded_with_param(&DapVersion::Latest).unwrap()
);

let got = AggregationJobResp::get_decoded(TEST_DATA).unwrap();
let got =
AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, TEST_DATA).unwrap();
assert_eq!(got, want);
}

Expand Down
Loading

0 comments on commit 95eb6f1

Please sign in to comment.