diff --git a/daphne/Cargo.toml b/daphne/Cargo.toml index 668a14686..254c4e33c 100644 --- a/daphne/Cargo.toml +++ b/daphne/Cargo.toml @@ -22,8 +22,8 @@ name = "aggregation" harness = false [dependencies] -assert_matches.workspace = true async-trait.workspace = true +assert_matches = { workspace = true, optional = true } base64.workspace = true futures.workspace = true hex.workspace = true @@ -41,7 +41,11 @@ tracing.workspace = true url.workspace = true [dev-dependencies] +assert_matches.workspace = true criterion.workspace = true matchit.workspace = true paste.workspace = true tokio.workspace = true + +[features] +test-utils = ["dep:assert_matches"] diff --git a/daphne/src/lib.rs b/daphne/src/lib.rs index 967a97fcf..8e875d0b3 100644 --- a/daphne/src/lib.rs +++ b/daphne/src/lib.rs @@ -48,6 +48,7 @@ pub mod messages; pub mod metrics; pub mod roles; pub mod taskprov; +#[cfg(any(test, feature = "test-utils"))] pub mod testing; pub mod vdaf; @@ -158,11 +159,9 @@ pub struct DapGlobalConfig { /// receiver config. pub supported_hpke_kems: Vec, - /// Is the taskprov extension allowed? - pub allow_taskprov: bool, - - /// Which taskprov draft should be used? - pub taskprov_version: TaskprovVersion, + /// Is the taskprov extension allowed and which taskprov draft should be used? + #[serde(default)] + pub taskprov_version: Option, } impl DapGlobalConfig { diff --git a/daphne/src/messages/mod.rs b/daphne/src/messages/mod.rs index bff55117a..03db9dcc0 100644 --- a/daphne/src/messages/mod.rs +++ b/daphne/src/messages/mod.rs @@ -39,7 +39,7 @@ const EXTENSION_TASKPROV: u16 = 0xff00; macro_rules! id_struct { ($sname:ident, $len:expr, $doc:expr) => { #[doc=$doc] - #[derive(Clone, Debug, Default, Deserialize, Hash, PartialEq, Eq, Serialize)] + #[derive(Clone, Default, Deserialize, Hash, PartialEq, Eq, Serialize)] pub struct $sname(#[serde(with = "hex")] pub [u8; $len]); impl $sname { @@ -84,6 +84,12 @@ macro_rules! id_struct { write!(f, "{}", self.to_hex()) } } + + impl fmt::Debug for $sname { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}({})", ::std::stringify!($sname), self.to_hex()) + } + } }; } @@ -1171,9 +1177,8 @@ pub fn decode_base64url_vec>(input: T) -> Option> { mod test { use super::*; - use crate::{test_version, test_versions}; + use crate::test_versions; use hpke_rs::HpkePublicKey; - use paste::paste; use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode}; use rand::prelude::*; diff --git a/daphne/src/roles/helper.rs b/daphne/src/roles/helper.rs index 9fc6a52d4..3d8f46dbf 100644 --- a/daphne/src/roles/helper.rs +++ b/daphne/src/roles/helper.rs @@ -16,10 +16,10 @@ use crate::{ fatal_error, messages::{ constant_time_eq, AggregateShare, AggregateShareReq, AggregationJobContinueReq, - AggregationJobInitReq, PartialBatchSelector, ReportMetadata, TaskId, TransitionFailure, - TransitionVar, + AggregationJobInitReq, Draft02AggregationJobId, PartialBatchSelector, TaskId, + TransitionFailure, TransitionVar, }, - metrics::DaphneRequestType, + metrics::{ContextualizedDaphneMetrics, DaphneRequestType}, DapError, DapHelperState, DapHelperTransition, DapRequest, DapResource, DapResponse, DapTaskConfig, DapVersion, MetaAggregationJobId, }; @@ -44,254 +44,236 @@ pub trait DapHelper: DapAggregator { agg_job_id: &MetaAggregationJobId, ) -> Result, DapError>; - /// Handle a request pertaining to an aggregation job. - async fn handle_agg_job_req(&self, req: &DapRequest) -> Result { - let metrics = self.metrics().with_host(req.host()); - let task_id = req.task_id()?; + async fn handle_agg_job_init_req<'req>( + &self, + req: &'req DapRequest, + metrics: ContextualizedDaphneMetrics<'req>, + task_id: &TaskId, + ) -> Result { + let agg_job_init_req = + AggregationJobInitReq::get_decoded_with_param(&req.version, &req.payload) + .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; + + metrics.agg_job_observe_batch_size(agg_job_init_req.report_shares.len()); + + // taskprov: Resolve the task config to use for the request. We also need to ensure + // that all of the reports include the task config in the report extensions. (See + // section 6 of draft-wang-ppm-dap-taskprov-02.) + if let Some(taskprov_version) = self.get_global_config().taskprov_version { + let using_taskprov = agg_job_init_req + .report_shares + .iter() + .filter(|share| share.report_metadata.is_taskprov(taskprov_version, task_id)) + .count(); + + let first_metadata = match using_taskprov { + 0 => None, + c if c == agg_job_init_req.report_shares.len() => { + // All the extensions use taskprov and look ok, so compute first_metadata. + // Note this will always be Some(). + agg_job_init_req + .report_shares + .first() + .map(|report_share| &report_share.report_metadata) + } + _ => { + // It's not all taskprov or no taskprov, so it's an error. + return Err(DapAbort::UnrecognizedMessage { + detail: "some reports include the taskprov extensions and some do not" + .to_string(), + task_id: Some(task_id.clone()), + }); + } + }; + resolve_taskprov(self, task_id, req, first_metadata, taskprov_version).await?; + } - // Check whether the DAP version indicated by the sender is supported. - if req.version == DapVersion::Unknown { - return Err(DapAbort::version_unknown()); + let wrapped_task_config = self + .get_task_config_for(Cow::Borrowed(task_id)) + .await? + .ok_or(DapAbort::UnrecognizedTask)?; + let task_config = wrapped_task_config.as_ref(); + + if let Some(reason) = self.unauthorized_reason(task_config, req).await? { + error!("aborted unauthorized collect request: {reason}"); + return Err(DapAbort::UnauthorizedRequest { + detail: reason, + task_id: task_id.clone(), + }); } - match req.media_type { - DapMediaType::AggregationJobInitReq => { - let agg_job_init_req = - AggregationJobInitReq::get_decoded_with_param(&req.version, &req.payload) - .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; - - metrics.agg_job_observe_batch_size(agg_job_init_req.report_shares.len()); - - // taskprov: Resolve the task config to use for the request. We also need to ensure - // that all of the reports include the task config in the report extensions. (See - // section 6 of draft-wang-ppm-dap-taskprov-02.) - let mut first_metadata: Option<&ReportMetadata> = None; - let global_config = self.get_global_config(); - if global_config.allow_taskprov { - let using_taskprov = agg_job_init_req - .report_shares - .iter() - .filter(|share| { - share - .report_metadata - .is_taskprov(global_config.taskprov_version, task_id) - }) - .count(); - - if using_taskprov == agg_job_init_req.report_shares.len() { - // All the extensions use taskprov and look ok, so compute first_metadata. - // Note this will always be Some(). - first_metadata = agg_job_init_req - .report_shares - .first() - .map(|report_share| &report_share.report_metadata); - } else if using_taskprov != 0 { - // It's not all taskprov or no taskprov, so it's an error. - return Err(DapAbort::UnrecognizedMessage { - detail: "some reports include the taskprov extensions and some do not" - .to_string(), - task_id: Some(task_id.clone()), - }); - } - } - resolve_taskprov(self, task_id, req, first_metadata).await?; + let agg_job_id = resolve_agg_job_id(req, agg_job_init_req.draft02_agg_job_id.as_ref())?; - let wrapped_task_config = self - .get_task_config_for(Cow::Borrowed(task_id)) + // Check whether the DAP version in the request matches the task config. + if task_config.version != req.version { + return Err(DapAbort::version_mismatch(req.version, task_config.version)); + } + + // Ensure we know which batch the request pertains to. + check_part_batch( + task_id, + task_config, + &agg_job_init_req.part_batch_sel, + &agg_job_init_req.agg_param, + )?; + + let transition = task_config + .vdaf + .handle_agg_job_init_req( + self, + self, + task_id, + task_config, + &agg_job_init_req, + &metrics, + ) + .map_err(DapError::Abort) + .await?; + + let agg_job_resp = match transition { + DapHelperTransition::Continue(state, agg_job_resp) => { + if !self + .put_helper_state_if_not_exists(task_id, &agg_job_id, &state) .await? - .ok_or(DapAbort::UnrecognizedTask)?; - let task_config = wrapped_task_config.as_ref(); - - if let Some(reason) = self.unauthorized_reason(task_config, req).await? { - error!("aborted unauthorized collect request: {reason}"); - return Err(DapAbort::UnauthorizedRequest { - detail: reason, - task_id: task_id.clone(), - }); + { + // TODO spec: Consider an explicit abort for this case. + return Err(DapAbort::BadRequest( + "unexpected message for aggregation job (already exists)".into(), + )); } + agg_job_resp + } + DapHelperTransition::Finish(..) => { + return Err(fatal_error!(err = "unexpected transition (finished)").into()); + } + }; - // draft02 compatibility: In draft02, the aggregation job ID is parsed from the - // HTTP request payload; in the latest draft, the aggregation job ID is parsed from - // the request path. - let agg_job_id = match ( - req.version, - &req.resource, - &agg_job_init_req.draft02_agg_job_id, - ) { - (DapVersion::Draft02, DapResource::Undefined, Some(ref agg_job_id)) => { - MetaAggregationJobId::Draft02(Cow::Borrowed(agg_job_id)) - } - (DapVersion::Draft05, DapResource::AggregationJob(ref agg_job_id), None) => { - MetaAggregationJobId::Draft05(Cow::Borrowed(agg_job_id)) - } - (DapVersion::Draft05, DapResource::Undefined, None) => { - return Err(DapAbort::BadRequest("undefined resource".into())); - } - _ => unreachable!("unhandled resource {:?}", req.resource), - }; + self.audit_log().on_aggregation_job( + req.host(), + task_id, + task_config, + agg_job_init_req.report_shares.len() as u64, + AggregationJobAuditAction::Init, + ); - // Check whether the DAP version in the request matches the task config. - if task_config.version != req.version { - return Err(DapAbort::version_mismatch(req.version, task_config.version)); - } + metrics.agg_job_started_inc(); + metrics.inbound_req_inc(DaphneRequestType::Aggregate); + Ok(DapResponse { + version: req.version, + media_type: DapMediaType::AggregationJobResp, + payload: agg_job_resp.get_encoded(), + }) + } + + async fn handle_agg_job_cont_req<'req>( + &self, + req: &'req DapRequest, + metrics: ContextualizedDaphneMetrics<'req>, + task_id: &TaskId, + ) -> Result { + if let Some(taskprov_version) = self.get_global_config().taskprov_version { + resolve_taskprov(self, task_id, req, None, taskprov_version).await?; + } + let wrapped_task_config = self + .get_task_config_for(Cow::Borrowed(task_id)) + .await? + .ok_or(DapAbort::UnrecognizedTask)?; + let task_config = wrapped_task_config.as_ref(); - // Ensure we know which batch the request pertains to. - check_part_batch( - task_id, - task_config, - &agg_job_init_req.part_batch_sel, - &agg_job_init_req.agg_param, - )?; - - let transition = task_config - .vdaf - .handle_agg_job_init_req( - self, - self, - task_id, - task_config, - &agg_job_init_req, - &metrics, - ) - .map_err(DapError::Abort) + if let Some(reason) = self.unauthorized_reason(task_config, req).await? { + error!("aborted unauthorized collect request: {reason}"); + return Err(DapAbort::UnauthorizedRequest { + detail: reason, + task_id: task_id.clone(), + }); + } + + // Check whether the DAP version in the request matches the task config. + if task_config.version != req.version { + return Err(DapAbort::version_mismatch(req.version, task_config.version)); + } + + let agg_job_cont_req = + AggregationJobContinueReq::get_decoded_with_param(&req.version, &req.payload) + .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; + + let agg_job_id = resolve_agg_job_id(req, agg_job_cont_req.draft02_agg_job_id.as_ref())?; + + let state = self.get_helper_state(task_id, &agg_job_id).await?.ok_or( + DapAbort::UnrecognizedAggregationJob { + task_id: task_id.clone(), + agg_job_id_base64url: agg_job_id.to_base64url(), + }, + )?; + let part_batch_sel = state.part_batch_sel.clone(); + let transition = task_config.vdaf.handle_agg_job_cont_req( + task_id, + &agg_job_id, + state, + &agg_job_cont_req, + &metrics, + )?; + + let (agg_job_resp, out_shares_count) = match transition { + DapHelperTransition::Continue(..) => { + return Err(fatal_error!(err = "unexpected transition (continued)").into()); + } + DapHelperTransition::Finish(out_shares, mut agg_job_resp) => { + let out_shares_count = u64::try_from(out_shares.len()).unwrap(); + let replayed = self + .put_out_shares(task_id, task_config, &part_batch_sel, out_shares) .await?; - let agg_job_resp = match transition { - DapHelperTransition::Continue(state, agg_job_resp) => { - if !self - .put_helper_state_if_not_exists(task_id, &agg_job_id, &state) - .await? - { - // TODO spec: Consider an explicit abort for this case. - return Err(DapAbort::BadRequest( - "unexpected message for aggregation job (already exists)".into(), - )); + // If there are multiple aggregation jobs in flight that contain the + // same report, then we may need to reject the report at this late stage. + if !replayed.is_empty() { + for transition in agg_job_resp.transitions.iter_mut() { + if replayed.contains(&transition.report_id) { + let failure = TransitionFailure::ReportReplayed; + transition.var = TransitionVar::Failed(failure); + metrics.report_inc_by(&format!("rejected_{failure}"), 1); } - agg_job_resp - } - DapHelperTransition::Finish(..) => { - return Err(fatal_error!(err = "unexpected transition (finished)").into()); } - }; - - self.audit_log().on_aggregation_job( - req.host(), - task_id, - task_config, - agg_job_init_req.report_shares.len() as u64, - AggregationJobAuditAction::Init, - ); - - metrics.agg_job_started_inc(); - metrics.inbound_req_inc(DaphneRequestType::Aggregate); - Ok(DapResponse { - version: req.version, - media_type: DapMediaType::AggregationJobResp, - payload: agg_job_resp.get_encoded(), - }) - } - DapMediaType::AggregationJobContinueReq => { - resolve_taskprov(self, task_id, req, None).await?; - let wrapped_task_config = self - .get_task_config_for(Cow::Borrowed(task_id)) - .await? - .ok_or(DapAbort::UnrecognizedTask)?; - let task_config = wrapped_task_config.as_ref(); - - if let Some(reason) = self.unauthorized_reason(task_config, req).await? { - error!("aborted unauthorized collect request: {reason}"); - return Err(DapAbort::UnauthorizedRequest { - detail: reason, - task_id: task_id.clone(), - }); } - // Check whether the DAP version in the request matches the task config. - if task_config.version != req.version { - return Err(DapAbort::version_mismatch(req.version, task_config.version)); - } + (agg_job_resp, out_shares_count) + } + }; - let agg_job_cont_req = - AggregationJobContinueReq::get_decoded_with_param(&req.version, &req.payload) - .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; - - // draft02 compatibility: In draft02, the aggregation job ID is parsed from the - // HTTP request payload; in the latest, the aggregation job ID is parsed from the - // request path. - let agg_job_id = match ( - req.version, - &req.resource, - &agg_job_cont_req.draft02_agg_job_id, - ) { - (DapVersion::Draft02, DapResource::Undefined, Some(ref agg_job_id)) => { - MetaAggregationJobId::Draft02(Cow::Borrowed(agg_job_id)) - } - (DapVersion::Draft05, DapResource::AggregationJob(ref agg_job_id), None) => { - MetaAggregationJobId::Draft05(Cow::Borrowed(agg_job_id)) - } - (DapVersion::Draft05, DapResource::Undefined, None) => { - return Err(DapAbort::BadRequest("undefined resource".into())); - } - _ => unreachable!("unhandled resource {:?}", req.resource), - }; - - let state = self.get_helper_state(task_id, &agg_job_id).await?.ok_or( - DapAbort::UnrecognizedAggregationJob { - task_id: task_id.clone(), - agg_job_id_base64url: agg_job_id.to_base64url(), - }, - )?; - let part_batch_sel = state.part_batch_sel.clone(); - let transition = task_config.vdaf.handle_agg_job_cont_req( - task_id, - &agg_job_id, - state, - &agg_job_cont_req, - &metrics, - )?; - - let (agg_job_resp, out_shares_count) = match transition { - DapHelperTransition::Continue(..) => { - return Err(fatal_error!(err = "unexpected transition (continued)").into()); - } - DapHelperTransition::Finish(out_shares, mut agg_job_resp) => { - let out_shares_count = u64::try_from(out_shares.len()).unwrap(); - let replayed = self - .put_out_shares(task_id, task_config, &part_batch_sel, out_shares) - .await?; - - // If there are multiple aggregation jobs in flight that contain the - // same report, then we may need to reject the report at this late stage. - if !replayed.is_empty() { - for transition in agg_job_resp.transitions.iter_mut() { - if replayed.contains(&transition.report_id) { - let failure = TransitionFailure::ReportReplayed; - transition.var = TransitionVar::Failed(failure); - metrics.report_inc_by(&format!("rejected_{failure}"), 1); - } - } - } + self.audit_log().on_aggregation_job( + req.host(), + task_id, + task_config, + out_shares_count, + AggregationJobAuditAction::Continue, + ); - (agg_job_resp, out_shares_count) - } - }; - - self.audit_log().on_aggregation_job( - req.host(), - task_id, - task_config, - out_shares_count, - AggregationJobAuditAction::Continue, - ); - - metrics.report_inc_by("aggregated", out_shares_count); - metrics.agg_job_completed_inc(); - metrics.inbound_req_inc(DaphneRequestType::Aggregate); - Ok(DapResponse { - version: req.version, - media_type: DapMediaType::agg_job_cont_resp_for_version(task_config.version), - payload: agg_job_resp.get_encoded(), - }) + metrics.report_inc_by("aggregated", out_shares_count); + metrics.agg_job_completed_inc(); + metrics.inbound_req_inc(DaphneRequestType::Aggregate); + Ok(DapResponse { + version: req.version, + media_type: DapMediaType::agg_job_cont_resp_for_version(task_config.version), + payload: agg_job_resp.get_encoded(), + }) + } + + /// Handle a request pertaining to an aggregation job. + async fn handle_agg_job_req(&self, req: &DapRequest) -> Result { + let metrics = self.metrics().with_host(req.host()); + let task_id = req.task_id()?; + + // Check whether the DAP version indicated by the sender is supported. + if req.version == DapVersion::Unknown { + return Err(DapAbort::version_unknown()); + } + + match req.media_type { + DapMediaType::AggregationJobInitReq => { + self.handle_agg_job_init_req(req, metrics, task_id).await + } + DapMediaType::AggregationJobContinueReq => { + self.handle_agg_job_cont_req(req, metrics, task_id).await } //TODO spec: Specify this behavior. _ => Err(DapAbort::BadRequest("unexpected media type".into())), @@ -312,7 +294,9 @@ pub trait DapHelper: DapAggregator { check_request_content_type(req, DapMediaType::AggregateShareReq)?; - resolve_taskprov(self, task_id, req, None).await?; + if let Some(taskprov_version) = self.get_global_config().taskprov_version { + resolve_taskprov(self, task_id, req, None, taskprov_version).await?; + } let wrapped_task_config = self .get_task_config_for(Cow::Borrowed(req.task_id()?)) @@ -431,3 +415,24 @@ fn check_part_batch( Ok(()) } + +fn resolve_agg_job_id<'id, S>( + req: &'id DapRequest, + draft02_agg_job_id: Option<&'id Draft02AggregationJobId>, +) -> Result, DapAbort> { + // draft02 compatibility: In draft02, the aggregation job ID is parsed from the + // HTTP request payload; in the latest, the aggregation job ID is parsed from the + // request path. + match (req.version, &req.resource, &draft02_agg_job_id) { + (DapVersion::Draft02, DapResource::Undefined, Some(agg_job_id)) => { + Ok(MetaAggregationJobId::Draft02(Cow::Borrowed(agg_job_id))) + } + (DapVersion::Draft05, DapResource::AggregationJob(ref agg_job_id), None) => { + Ok(MetaAggregationJobId::Draft05(Cow::Borrowed(agg_job_id))) + } + (DapVersion::Draft05, DapResource::Undefined, None) => { + Err(DapAbort::BadRequest("undefined resource".into())) + } + _ => unreachable!("unhandled resource {:?}", req.resource), + } +} diff --git a/daphne/src/roles/leader.rs b/daphne/src/roles/leader.rs index 823ef982b..4232f95c6 100644 --- a/daphne/src/roles/leader.rs +++ b/daphne/src/roles/leader.rs @@ -162,7 +162,16 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; debug!("report id is {}", report.report_metadata.id); - resolve_taskprov(self, task_id, req, Some(&report.report_metadata)).await?; + if let Some(taskprov_version) = self.get_global_config().taskprov_version { + resolve_taskprov( + self, + task_id, + req, + Some(&report.report_metadata), + taskprov_version, + ) + .await?; + } let task_config = self .get_task_config_for(Cow::Borrowed(task_id)) .await? @@ -228,7 +237,9 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { check_request_content_type(req, DapMediaType::CollectReq)?; - resolve_taskprov(self, task_id, req, None).await?; + if let Some(taskprov_version) = self.get_global_config().taskprov_version { + resolve_taskprov(self, task_id, req, None, taskprov_version).await?; + } let wrapped_task_config = self .get_task_config_for(Cow::Borrowed(req.task_id()?)) diff --git a/daphne/src/roles/mod.rs b/daphne/src/roles/mod.rs index a3a2033b7..b5fffa7b2 100644 --- a/daphne/src/roles/mod.rs +++ b/daphne/src/roles/mod.rs @@ -10,7 +10,8 @@ mod leader; use crate::{ constants::DapMediaType, messages::{BatchSelector, ReportMetadata, TaskId, Time, TransitionFailure}, - taskprov, DapAbort, DapError, DapQueryConfig, DapRequest, DapTaskConfig, + taskprov::{self, TaskprovVersion}, + DapAbort, DapError, DapQueryConfig, DapRequest, DapTaskConfig, }; use std::borrow::Cow; use tracing::warn; @@ -146,6 +147,7 @@ async fn resolve_taskprov( task_id: &TaskId, req: &DapRequest, report_metadata_advertisement: Option<&ReportMetadata>, + taskprov_version: TaskprovVersion, ) -> Result<(), DapError> { if agg .get_task_config_for(Cow::Borrowed(task_id)) @@ -156,12 +158,6 @@ async fn resolve_taskprov( return Ok(()); } - let global_config = agg.get_global_config(); - if !global_config.allow_taskprov { - // Taskprov is disabled, so nothing to do. - return Ok(()); - } - let Some(vdaf_verify_key_init) = agg.taskprov_vdaf_verify_key_init() else { warn!("Taskprov disabled due to missing VDAF verification key initializer."); return Ok(()); @@ -174,7 +170,7 @@ async fn resolve_taskprov( let Some(task_config) = taskprov::resolve_advertised_task_config( req, - global_config.taskprov_version, + taskprov_version, vdaf_verify_key_init, collector_hpke_config, task_id, @@ -212,13 +208,9 @@ mod test { Extension, Interval, PartialBatchSelector, Query, Report, ReportId, ReportMetadata, ReportShare, TaskId, Time, Transition, TransitionFailure, TransitionVar, }, - metrics::DaphneMetrics, taskprov::TaskprovVersion, - test_version, test_versions, - testing::{ - AggStore, DapBatchBucketOwned, MockAggregator, MockAggregatorReportSelector, - MockAuditLog, - }, + test_versions, + testing::{AggStore, DapBatchBucketOwned, MockAggregator, MockAggregatorReportSelector}, vdaf::VdafVerifyKey, DapAbort, DapAggregateShare, DapCollectJob, DapGlobalConfig, DapMeasurement, DapQueryConfig, DapRequest, DapResource, DapTaskConfig, DapVersion, MetaAggregationJobId, @@ -226,16 +218,9 @@ mod test { }; use assert_matches::assert_matches; use matchit::Router; - use paste::paste; use prio::codec::{Decode, ParameterizedEncode}; use rand::{thread_rng, Rng}; - use std::{ - borrow::Cow, - collections::HashMap, - sync::{Arc, Mutex}, - time::SystemTime, - vec, - }; + use std::{borrow::Cow, collections::HashMap, sync::Arc, time::SystemTime, vec}; use url::Url; macro_rules! get_reports { @@ -250,10 +235,9 @@ mod test { }}; } - struct Test { + pub(super) struct TestData { now: Time, - leader: Arc, - helper: Arc, + global_config: DapGlobalConfig, collector_token: BearerToken, taskprov_collector_token: BearerToken, time_interval_task_id: TaskId, @@ -261,10 +245,15 @@ mod test { expired_task_id: TaskId, version: DapVersion, prometheus_registry: prometheus::Registry, + tasks: HashMap, + leader_token: BearerToken, + collector_hpke_receiver_config: HpkeReceiverConfig, + taskprov_vdaf_verify_key_init: [u8; 32], + taskprov_leader_token: BearerToken, } - impl Test { - fn new(version: DapVersion) -> Self { + impl TestData { + pub fn new(version: DapVersion) -> Self { let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap() @@ -280,8 +269,7 @@ mod test { min_batch_interval_start: 259200, max_batch_interval_end: 259200, supported_hpke_kems: vec![HpkeKemId::X25519HkdfSha256], - allow_taskprov: true, - taskprov_version: TaskprovVersion::Draft02, + taskprov_version: Some(TaskprovVersion::Draft02), }; // Task Parameters that the Leader and Helper must agree on. @@ -357,58 +345,9 @@ mod test { let prometheus_registry = prometheus::Registry::new(); - let helper_hpke_receiver_config_list = global_config - .gen_hpke_receiver_config_list(rng.gen()) - .collect::, _>>() - .expect("failed to generate HPKE receiver config"); - let helper = Arc::new(MockAggregator { - global_config: global_config.clone(), - tasks: Arc::new(Mutex::new(tasks.clone())), - leader_token: leader_token.clone(), - collector_token: None, - hpke_receiver_config_list: helper_hpke_receiver_config_list, - report_store: Arc::new(Mutex::new(HashMap::new())), - leader_state_store: Arc::new(Mutex::new(HashMap::new())), - helper_state_store: Arc::new(Mutex::new(HashMap::new())), - agg_store: Arc::new(Mutex::new(HashMap::new())), - collector_hpke_config: collector_hpke_receiver_config.config.clone(), - taskprov_vdaf_verify_key_init, - taskprov_leader_token: taskprov_leader_token.clone(), - taskprov_collector_token: None, - metrics: DaphneMetrics::register(&prometheus_registry, Some("test_helper")) - .unwrap(), - audit_log: MockAuditLog::default(), - peer: None, - }); - - let leader_hpke_receiver_config_list = global_config - .gen_hpke_receiver_config_list(rng.gen()) - .collect::, _>>() - .expect("failed to generate HPKE receiver config"); - let leader = Arc::new(MockAggregator { - global_config, - tasks: Arc::new(Mutex::new(tasks.clone())), - hpke_receiver_config_list: leader_hpke_receiver_config_list, - leader_token, - collector_token: Some(collector_token.clone()), - report_store: Arc::new(Mutex::new(HashMap::new())), - leader_state_store: Arc::new(Mutex::new(HashMap::new())), - helper_state_store: Arc::new(Mutex::new(HashMap::new())), - agg_store: Arc::new(Mutex::new(HashMap::new())), - collector_hpke_config: collector_hpke_receiver_config.config, - taskprov_vdaf_verify_key_init, - taskprov_leader_token, - taskprov_collector_token: Some(taskprov_collector_token.clone()), - metrics: DaphneMetrics::register(&prometheus_registry, Some("test_leader")) - .unwrap(), - audit_log: MockAuditLog::default(), - peer: Some(Arc::clone(&helper)), - }); - Self { now, - leader, - helper, + global_config, collector_token, taskprov_collector_token, time_interval_task_id, @@ -416,8 +355,82 @@ mod test { expired_task_id, version, prometheus_registry, + tasks, + leader_token, + taskprov_leader_token, + collector_hpke_receiver_config, + taskprov_vdaf_verify_key_init, + } + } + + pub fn new_helper(&self) -> Arc { + Arc::new(MockAggregator::new_helper( + self.tasks.clone(), + self.global_config + .gen_hpke_receiver_config_list(thread_rng().gen()) + .collect::, _>>() + .expect("failed to generate HPKE receiver config"), + self.global_config.clone(), + self.leader_token.clone(), + self.collector_hpke_receiver_config.config.clone(), + &self.prometheus_registry, + self.taskprov_vdaf_verify_key_init, + self.taskprov_leader_token.clone(), + )) + } + + pub fn with_leader(self, helper: Arc) -> Test { + let leader = Arc::new(MockAggregator::new_leader( + self.tasks, + self.global_config + .gen_hpke_receiver_config_list(thread_rng().gen()) + .collect::, _>>() + .expect("failed to generate HPKE receiver config"), + self.global_config, + self.leader_token, + self.collector_token.clone(), + self.collector_hpke_receiver_config.config.clone(), + &self.prometheus_registry, + self.taskprov_vdaf_verify_key_init, + self.taskprov_leader_token, + self.taskprov_collector_token.clone(), + Arc::clone(&helper), + )); + + Test { + now: self.now, + leader, + helper, + collector_token: self.collector_token, + taskprov_collector_token: self.taskprov_collector_token, + time_interval_task_id: self.time_interval_task_id, + fixed_size_task_id: self.fixed_size_task_id, + expired_task_id: self.expired_task_id, + version: self.version, + prometheus_registry: self.prometheus_registry, } } + } + + pub(super) struct Test { + now: Time, + leader: Arc, + helper: Arc, + collector_token: BearerToken, + taskprov_collector_token: BearerToken, + time_interval_task_id: TaskId, + fixed_size_task_id: TaskId, + expired_task_id: TaskId, + version: DapVersion, + prometheus_registry: prometheus::Registry, + } + + impl Test { + pub fn new(version: DapVersion) -> Self { + let data = TestData::new(version); + let helper = data.new_helper(); + data.with_leader(helper) + } async fn gen_test_upload_req( &self, @@ -1879,12 +1892,11 @@ mod test { var: taskprov::VdafTypeVar::Prio3Aes128Count, }, } - .get_encoded_with_param(&t.helper.global_config.taskprov_version); + .get_encoded_with_param(&t.helper.global_config.taskprov_version.unwrap()); let taskprov_id = crate::taskprov::compute_task_id( - t.helper.global_config.taskprov_version, + t.helper.global_config.taskprov_version.unwrap(), &taskprov_ext_payload, - ) - .unwrap(); + ); // Client: Send upload request to Leader. let hpke_config_list = [ diff --git a/daphne/src/taskprov.rs b/daphne/src/taskprov.rs index dedc98c99..0bd2034ff 100644 --- a/daphne/src/taskprov.rs +++ b/daphne/src/taskprov.rs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: BSD-3-Clause use crate::{ - fatal_error, hpke::HpkeConfig, messages::{ decode_base64url_vec, @@ -27,9 +26,6 @@ use url::Url; pub enum TaskprovVersion { #[serde(rename = "v02")] Draft02, - - #[serde(other)] - Unknown, } /// SHA-256 of "dap-taskprov" @@ -48,12 +44,9 @@ fn compute_task_id_draft02(serialized: &[u8]) -> TaskId { } /// Compute the task id of a serialized task config. -pub fn compute_task_id(version: TaskprovVersion, serialized: &[u8]) -> Result { +pub fn compute_task_id(version: TaskprovVersion, serialized: &[u8]) -> TaskId { match version { - TaskprovVersion::Draft02 => Ok(compute_task_id_draft02(serialized)), - TaskprovVersion::Unknown => Err(fatal_error!( - err = "attempted to resolve taskprov task with unknown version", - )), + TaskprovVersion::Draft02 => compute_task_id_draft02(serialized), } } @@ -70,7 +63,6 @@ pub(crate) fn extract_prk_from_verify_key_init( // time, so we compute it once. let value = match version { TaskprovVersion::Draft02 => &TASK_PROV_SALT_DRAFT02, - _ => panic!("unimplemented taskprov version"), }; Salt::new(HKDF_SHA256, value).extract(verify_key_init) } @@ -202,7 +194,7 @@ fn get_taskprov_task_config( return Ok(None); }; - if compute_task_id(taskprov_version, taskprov_data.as_ref())? != *task_id { + if compute_task_id(taskprov_version, taskprov_data.as_ref()) != *task_id { // Return unrecognizedTask following section 5.1 of the taskprov draft. return Err(DapAbort::UnrecognizedTask); } @@ -302,15 +294,8 @@ impl DapTaskConfig { impl ReportMetadata { /// Does this metatdata have a taskprov extension and does it match the specified id? pub fn is_taskprov(&self, version: TaskprovVersion, task_id: &TaskId) -> bool { - // Don't check for taskprov usage if we don't know the version. - if matches!(version, TaskprovVersion::Unknown) { - return false; - } - return self.extensions.iter().any(|x| match x { - Extension::Taskprov { payload } => { - *task_id == compute_task_id(version, payload).unwrap() - } + Extension::Taskprov { payload } => *task_id == compute_task_id(version, payload), _ => false, }); } @@ -393,7 +378,7 @@ mod test { let taskprov_task_config_data = taskprov_task_config.get_encoded_with_param(&taskprov_version); let taskprov_task_config_base64url = encode_base64url(&taskprov_task_config_data); - let task_id = compute_task_id(taskprov_version, &taskprov_task_config_data).unwrap(); + let task_id = compute_task_id(taskprov_version, &taskprov_task_config_data); let collector_hpke_config = HpkeReceiverConfig::gen(1, HpkeKemId::X25519HkdfSha256) .unwrap() .config; diff --git a/daphne/src/testing.rs b/daphne/src/testing.rs index e52ce8ed3..b9cadc0f6 100644 --- a/daphne/src/testing.rs +++ b/daphne/src/testing.rs @@ -41,1475 +41,1534 @@ use std::{ }; use url::Url; -#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub(crate) enum MetaAggregationJobIdOwned { - Draft02(Draft02AggregationJobId), - Draft05(AggregationJobId), -} - -impl From<&MetaAggregationJobId<'_>> for MetaAggregationJobIdOwned { - fn from(agg_job_id: &MetaAggregationJobId<'_>) -> Self { - match agg_job_id { - MetaAggregationJobId::Draft02(agg_job_id) => { - Self::Draft02(agg_job_id.clone().into_owned()) - } - MetaAggregationJobId::Draft05(agg_job_id) => { - Self::Draft05(agg_job_id.clone().into_owned()) - } - } - } -} - -#[derive(Eq, Hash, PartialEq)] -pub(crate) enum DapBatchBucketOwned { - FixedSize { batch_id: BatchId }, - TimeInterval { batch_window: Time }, -} - -impl From for PartialBatchSelector { - fn from(bucket: DapBatchBucketOwned) -> Self { - match bucket { - DapBatchBucketOwned::FixedSize { batch_id } => Self::FixedSizeByBatchId { batch_id }, - DapBatchBucketOwned::TimeInterval { .. } => Self::TimeInterval, - } - } -} - -impl<'a> DapBatchBucket<'a> { - // TODO(cjpatton) Figure out how to use `ToOwned` properly. The lifetime parameter causes - // confusion for the compiler for implementing `Borrow`. The goal is to avoid cloning the - // bucket each time we need to check if it exists in the set. - pub(crate) fn to_owned_bucket(&self) -> DapBatchBucketOwned { - match self { - Self::FixedSize { batch_id } => DapBatchBucketOwned::FixedSize { - batch_id: (*batch_id).clone(), - }, - Self::TimeInterval { batch_window } => DapBatchBucketOwned::TimeInterval { - batch_window: *batch_window, - }, - } - } -} +/// Scaffolding for testing the aggregation flow. +pub struct AggregationJobTest { + // task parameters + pub(crate) task_id: TaskId, + pub(crate) task_config: DapTaskConfig, + pub(crate) leader_hpke_receiver_config: HpkeReceiverConfig, + pub(crate) helper_hpke_receiver_config: HpkeReceiverConfig, + pub(crate) client_hpke_config_list: Vec, + pub(crate) collector_hpke_receiver_config: HpkeReceiverConfig, -pub(crate) struct MockAggregatorReportSelector(pub(crate) TaskId); + // aggregation job ID + pub(crate) agg_job_id: MetaAggregationJobId<'static>, -#[derive(Default)] -pub(crate) struct MockAuditLog(AtomicU32); + // the current time + pub(crate) now: Time, -impl MockAuditLog { - #[cfg(test)] - pub(crate) fn invocations(&self) -> u32 { - self.0.load(Ordering::Relaxed) - } + // operational parameters + #[allow(dead_code)] + pub(crate) prometheus_registry: prometheus::Registry, + pub(crate) leader_metrics: DaphneMetrics, + pub(crate) helper_metrics: DaphneMetrics, + pub(crate) leader_reports_processed: Arc>>, + pub(crate) helper_reports_processed: Arc>>, } -impl AuditLog for MockAuditLog { - fn on_aggregation_job( +// NOTE(cjpatton) This implementation of the report initializer is not feature complete. Since +// [`AggrregationJobTest`], is only used to test the aggregation flow, features that are not +// directly relevant to the tests aren't implemented. +#[async_trait(?Send)] +impl DapReportInitializer for AggregationJobTest { + async fn initialize_reports<'req>( &self, - _host: &str, + is_leader: bool, _task_id: &TaskId, - _task_config: &DapTaskConfig, - _report_count: u64, - _action: AggregationJobAuditAction, - ) { - self.0.fetch_add(1, Ordering::Relaxed); - } -} - -pub(crate) struct MockAggregator { - pub(crate) global_config: DapGlobalConfig, - pub(crate) tasks: Arc>>, - pub(crate) hpke_receiver_config_list: Vec, - pub(crate) leader_token: BearerToken, - pub(crate) collector_token: Option, // Not set by Helper - pub(crate) report_store: Arc>>, - pub(crate) leader_state_store: Arc>>, - pub(crate) helper_state_store: Arc>>, - pub(crate) agg_store: Arc>>>, - pub(crate) collector_hpke_config: HpkeConfig, - pub(crate) metrics: DaphneMetrics, - pub(crate) audit_log: MockAuditLog, - - // taskprov - pub(crate) taskprov_vdaf_verify_key_init: [u8; 32], - pub(crate) taskprov_leader_token: BearerToken, - pub(crate) taskprov_collector_token: Option, // Not set by Helper + task_config: &DapTaskConfig, + _part_batch_sel: &PartialBatchSelector, + consumed_reports: Vec>, + ) -> Result>, DapError> { + let mut reports_processed = if is_leader { + self.leader_reports_processed.lock().unwrap() + } else { + self.helper_reports_processed.lock().unwrap() + }; - // Leader: Reference to peer. Used to simulate HTTP requests from Leader to Helper, i.e., - // implement `DapLeader::send_http_post()` for `MockAggregator`. Not set by the Helper. - pub(crate) peer: Option>, + Ok(consumed_reports + .into_iter() + .map(|consumed| { + if reports_processed.contains(&consumed.metadata().id) { + Ok(EarlyReportStateInitialized::Rejected { + metadata: Cow::Owned(consumed.metadata().clone()), + failure: TransitionFailure::ReportReplayed, + }) + } else { + reports_processed.insert(consumed.metadata().id.clone()); + EarlyReportStateInitialized::initialize( + is_leader, + &task_config.vdaf_verify_key, + &task_config.vdaf, + consumed, + ) + } + }) + .collect::, _>>()?) + } } -impl MockAggregator { - /// Conducts checks on a received report to see whether: - /// 1) the report falls into a batch that has been already collected, or - /// 2) the report has been submitted by the client in the past. - async fn check_report_early_fail( - &self, - task_id: &TaskId, - bucket: &DapBatchBucketOwned, - metadata: &ReportMetadata, - ) -> Option { - // Check AggStateStore to see whether the report is part of a batch that has already - // been collected. - let mut guard = self.agg_store.lock().expect("agg_store: failed to lock"); - let agg_store = guard.entry(task_id.clone()).or_default(); - if matches!(agg_store.get(bucket), Some(inner_agg_store) if inner_agg_store.collected) { - return Some(TransitionFailure::BatchCollected); - } +impl AggregationJobTest { + /// Create an aggregation job test with the given VDAF config, HPKE KEM algorithm, DAP protocol + /// version. The KEM algorithm is used to generate an HPKE config for each party. + pub fn new(vdaf: &VdafConfig, kem_id: HpkeKemId, version: DapVersion) -> Self { + let mut rng = thread_rng(); + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + let task_id = TaskId(rng.gen()); + let agg_job_id = MetaAggregationJobId::gen_for_version(&version); + let vdaf_verify_key = vdaf.gen_verify_key(); + let leader_hpke_receiver_config = HpkeReceiverConfig::gen(rng.gen(), kem_id).unwrap(); + let helper_hpke_receiver_config = HpkeReceiverConfig::gen(rng.gen(), kem_id).unwrap(); + let collector_hpke_receiver_config = HpkeReceiverConfig::gen(rng.gen(), kem_id).unwrap(); + let leader_hpke_config = leader_hpke_receiver_config.clone().config; + let helper_hpke_config = helper_hpke_receiver_config.clone().config; + let collector_hpke_config = collector_hpke_receiver_config.clone().config; + let prometheus_registry = prometheus::Registry::new(); + let leader_metrics = + DaphneMetrics::register(&prometheus_registry, Some("test_leader")).unwrap(); + let helper_metrics = + DaphneMetrics::register(&prometheus_registry, Some("test_helper")).unwrap(); - // Check whether the same report has been submitted in the past. - let mut guard = self - .report_store - .lock() - .expect("report_store: failed to lock"); - let report_store = guard.entry(task_id.clone()).or_default(); - if report_store.processed.contains(&metadata.id) { - return Some(TransitionFailure::ReportReplayed); + Self { + now, + task_id, + agg_job_id, + leader_hpke_receiver_config, + helper_hpke_receiver_config, + client_hpke_config_list: vec![leader_hpke_config, helper_hpke_config], + collector_hpke_receiver_config, + task_config: DapTaskConfig { + version, + leader_url: Url::parse("http://leader.com").unwrap(), + helper_url: Url::parse("https://helper.org").unwrap(), + time_precision: 500, + expiration: now + 500, + min_batch_size: 10, + query: DapQueryConfig::TimeInterval, + vdaf: vdaf.clone(), + vdaf_verify_key, + collector_hpke_config, + taskprov: false, + }, + prometheus_registry, + leader_metrics, + helper_metrics, + leader_reports_processed: Default::default(), + helper_reports_processed: Default::default(), } - - None } - fn get_hpke_receiver_config_for(&self, hpke_config_id: u8) -> Option<&HpkeReceiverConfig> { - self.hpke_receiver_config_list - .iter() - .find(|&hpke_receiver_config| hpke_config_id == hpke_receiver_config.config.id) + /// For each measurement, generate a report for the given task. + /// + /// Panics if a measurement is incompatible with the given VDAF. + pub fn produce_reports(&self, measurements: Vec) -> Vec { + let mut reports = Vec::with_capacity(measurements.len()); + + for measurement in measurements.into_iter() { + reports.push( + self.task_config + .vdaf + .produce_report( + &self.client_hpke_config_list, + self.now, + &self.task_id, + measurement, + self.task_config.version, + ) + .unwrap(), + ); + } + reports } - /// Assign the report to a bucket. + /// Leader: Produce AggregationJobInitReq. /// - /// TODO(cjpatton) Figure out if we can avoid returning and owned thing here. - async fn assign_report_to_bucket( + /// Panics if the Leader aborts. + pub async fn produce_agg_job_init_req( &self, - report: &Report, - task_id: &TaskId, - ) -> Option { - let mut rng = thread_rng(); - let task_config = self - .get_task_config_for(Cow::Borrowed(task_id)) + reports: Vec, + ) -> DapLeaderTransition { + let metrics = self + .leader_metrics + .with_host(self.task_config.leader_url.host_str().unwrap()); + self.task_config + .vdaf + .produce_agg_job_init_req( + &self.leader_hpke_receiver_config, + self, + &self.task_id, + &self.task_config, + &self.agg_job_id, + &PartialBatchSelector::TimeInterval, + reports, + &metrics, + ) .await .unwrap() - .expect("tasks: unrecognized task"); + } - match task_config.query { - // For fixed-size queries, the bucket corresponds to a single batch. - DapQueryConfig::FixedSize { .. } => { - let mut guard = self - .leader_state_store - .lock() - .expect("leader_state_store: failed to lock"); - let leader_state_store = guard.entry(task_id.clone()).or_default(); - - // Assign the report to the first unsaturated batch. - for (batch_id, report_count) in leader_state_store.batch_queue.iter_mut() { - if *report_count < task_config.min_batch_size { - *report_count += 1; - return Some(DapBatchBucketOwned::FixedSize { - batch_id: batch_id.clone(), - }); - } - } - - // No unsaturated batch exists, so create a new batch. - let batch_id = BatchId(rng.gen()); - leader_state_store - .batch_queue - .push_back((batch_id.clone(), 1)); - Some(DapBatchBucketOwned::FixedSize { batch_id }) - } - - // For time-interval queries, the bucket is the batch window computed by truncating the - // report timestamp. - DapQueryConfig::TimeInterval => Some(DapBatchBucketOwned::TimeInterval { - batch_window: task_config.quantized_time_lower_bound(report.report_metadata.time), - }), - } - } - - /// Return the ID of the batch currently being filled with reports. Panics unless the task is - /// configured for fixed-size queries. - pub(crate) fn current_batch_id( + /// Helper: Handle AggregationJobInitReq, produce first AggregationJobResp. + /// + /// Panics if the Helper aborts. + pub async fn handle_agg_job_init_req( &self, - task_id: &TaskId, - task_config: &DapTaskConfig, - ) -> Option { - // Calling current_batch() is only well-defined for fixed-size tasks. - assert_matches!(task_config.query, DapQueryConfig::FixedSize { .. }); - - let guard = self - .leader_state_store - .lock() - .expect("leader_state_store: failed to lock"); - let leader_state_store = guard - .get(task_id) - .expect("leader_state_store: unrecognized task"); - - leader_state_store - .batch_queue - .front() - .cloned() // TODO(cjpatton) Avoid clone by returning MutexGuard - .map(|(batch_id, _report_count)| batch_id) - } - - pub(crate) async fn unchecked_get_task_config(&self, task_id: &TaskId) -> DapTaskConfig { - self.get_task_config_for(Cow::Borrowed(task_id)) + agg_job_init_req: &AggregationJobInitReq, + ) -> DapHelperTransition { + let metrics = self + .helper_metrics + .with_host(self.task_config.helper_url.host_str().unwrap()); + self.task_config + .vdaf + .handle_agg_job_init_req( + &self.helper_hpke_receiver_config, + self, + &self.task_id, + &self.task_config, + agg_job_init_req, + &metrics, + ) .await - .expect("encountered unexpected error") - .expect("missing task config") + .unwrap() } -} - -#[async_trait(?Send)] -impl BearerTokenProvider for MockAggregator { - type WrappedBearerToken<'a> = &'a BearerToken; - async fn get_leader_bearer_token_for<'s>( - &'s self, - _task_id: &'s TaskId, - task_config: &DapTaskConfig, - ) -> Result>, DapError> { - if task_config.taskprov { - Ok(Some(&self.taskprov_leader_token)) - } else { - Ok(Some(&self.leader_token)) - } + /// Leader: Handle first AggregationJobResp, produce AggregationJobContinueReq. + /// + /// Panics if the Leader aborts. + pub fn handle_agg_job_resp( + &self, + leader_state: DapLeaderState, + agg_job_resp: AggregationJobResp, + ) -> DapLeaderTransition { + let metrics = self + .leader_metrics + .with_host(self.task_config.leader_url.host_str().unwrap()); + self.task_config + .vdaf + .handle_agg_job_resp( + &self.task_id, + &self.agg_job_id, + leader_state, + agg_job_resp, + self.task_config.version, + &metrics, + ) + .unwrap() } - async fn get_collector_bearer_token_for<'s>( - &'s self, - _task_id: &'s TaskId, - task_config: &DapTaskConfig, - ) -> Result>, DapError> { - if task_config.taskprov { - Ok(Some(self.taskprov_collector_token.as_ref().expect( - "MockAggregator not configured with taskprov collector token", - ))) - } else { - Ok(Some(self.collector_token.as_ref().expect( - "MockAggregator not configured with collector token", - ))) - } + /// Like [`handle_agg_job_resp`] but expect the Leader to abort. + pub fn handle_agg_job_resp_expect_err( + &self, + leader_state: DapLeaderState, + agg_job_resp: AggregationJobResp, + ) -> DapAbort { + let metrics = self + .leader_metrics + .with_host(self.task_config.leader_url.host_str().unwrap()); + self.task_config + .vdaf + .handle_agg_job_resp( + &self.task_id, + &self.agg_job_id, + leader_state, + agg_job_resp, + self.task_config.version, + &metrics, + ) + .expect_err("handle_agg_job_resp() succeeded; expected failure") } -} - -#[async_trait(?Send)] -impl HpkeDecrypter for MockAggregator { - type WrappedHpkeConfig<'a> = &'a HpkeConfig; - - async fn get_hpke_config_for<'s>( - &'s self, - _version: DapVersion, - task_id: Option<&TaskId>, - ) -> Result, DapError> { - if self.hpke_receiver_config_list.is_empty() { - return Err(fatal_error!(err = "empty HPKE receiver config list")); - } - - // Aggregators MAY abort if the HPKE config request does not specify a task ID. While not - // required for MockAggregator, we simulate this behavior for testing purposes. - // - // TODO(cjpatton) To make this clearer, have MockAggregator store a map from task IDs to - // HPKE receiver configs. - if task_id.is_none() { - return Err(DapError::Abort(DapAbort::MissingTaskId)); - } - // Always advertise the first HPKE config in the list. - Ok(&self.hpke_receiver_config_list[0].config) + /// Helper: Handle AggregationJobContinueReq, produce second AggregationJobResp. + /// + /// Panics if the Helper aborts. + pub fn handle_agg_job_cont_req( + &self, + helper_state: DapHelperState, + agg_job_cont_req: &AggregationJobContinueReq, + ) -> DapHelperTransition { + let metrics = self + .helper_metrics + .with_host(self.task_config.helper_url.host_str().unwrap()); + self.task_config + .vdaf + .handle_agg_job_cont_req( + &self.task_id, + &self.agg_job_id, + helper_state, + agg_job_cont_req, + &metrics, + ) + .unwrap() } - async fn can_hpke_decrypt(&self, _task_id: &TaskId, config_id: u8) -> Result { - Ok(self.get_hpke_receiver_config_for(config_id).is_some()) + /// Like [`handle_agg_job_cont_req`] but expect the Helper to abort. + pub fn handle_agg_job_cont_req_expect_err( + &self, + helper_state: DapHelperState, + agg_job_cont_req: &AggregationJobContinueReq, + ) -> DapAbort { + let metrics = self + .helper_metrics + .with_host(self.task_config.helper_url.host_str().unwrap()); + self.task_config + .vdaf + .handle_agg_job_cont_req( + &self.task_id, + &self.agg_job_id, + helper_state, + agg_job_cont_req, + &metrics, + ) + .expect_err("handle_agg_job_cont_req() succeeded; expected failure") } - async fn hpke_decrypt( + /// Leader: Handle the last AggregationJobResp. + /// + /// Panics if the Leader aborts. + pub fn handle_final_agg_job_resp( &self, - _task_id: &TaskId, - info: &[u8], - aad: &[u8], - ciphertext: &HpkeCiphertext, - ) -> Result, DapError> { - if let Some(hpke_receiver_config) = self.get_hpke_receiver_config_for(ciphertext.config_id) - { - Ok(hpke_receiver_config.decrypt(info, aad, &ciphertext.enc, &ciphertext.payload)?) - } else { - Err(DapError::Transition(TransitionFailure::HpkeUnknownConfigId)) - } + leader_uncommitted: DapLeaderUncommitted, + agg_job_resp: AggregationJobResp, + ) -> Vec { + let metrics = self + .leader_metrics + .with_host(self.task_config.leader_url.host_str().unwrap()); + self.task_config + .vdaf + .handle_final_agg_job_resp(leader_uncommitted, agg_job_resp, &metrics) + .unwrap() } -} -#[async_trait(?Send)] -impl DapAuthorizedSender for MockAggregator { - async fn authorize( + /// Produce the Leader's encrypted aggregate share. + pub fn produce_leader_encrypted_agg_share( &self, - task_id: &TaskId, - task_config: &DapTaskConfig, - media_type: &DapMediaType, - _payload: &[u8], - ) -> Result { - Ok(self - .authorize_with_bearer_token(task_id, task_config, media_type) - .await? - .clone()) + batch_selector: &BatchSelector, + agg_share: &DapAggregateShare, + ) -> HpkeCiphertext { + self.task_config + .vdaf + .produce_leader_encrypted_agg_share( + &self.task_config.collector_hpke_config, + &self.task_id, + batch_selector, + agg_share, + self.task_config.version, + ) + .unwrap() } -} -#[async_trait(?Send)] -impl DapReportInitializer for MockAggregator { - async fn initialize_reports<'req>( + /// Produce the Helper's encrypted aggregate share. + pub fn produce_helper_encrypted_agg_share( &self, - is_leader: bool, - task_id: &TaskId, - task_config: &DapTaskConfig, - part_batch_sel: &PartialBatchSelector, - consumed_reports: Vec>, - ) -> Result>, DapError> { - let span = task_config.batch_span_for_meta( - part_batch_sel, - consumed_reports.iter().filter(|report| report.is_ready()), - )?; - - let mut early_fails = HashMap::new(); - for (bucket, reports_consumed_per_bucket) in span.iter() { - for metadata in reports_consumed_per_bucket - .iter() - .map(|report| report.metadata()) - { - // Check whether Report has been collected or replayed. - if let Some(transition_failure) = self - .check_report_early_fail(task_id, &bucket.to_owned_bucket(), metadata) - .await - { - early_fails.insert(metadata.id.clone(), transition_failure); - }; - } - } - - Ok(consumed_reports - .into_iter() - .map(|consumed| { - if let Some(failure) = early_fails.get(&consumed.metadata().id) { - Ok(EarlyReportStateInitialized::Rejected { - metadata: Cow::Owned(consumed.metadata().clone()), - failure: *failure, - }) - } else { - EarlyReportStateInitialized::initialize( - is_leader, - &task_config.vdaf_verify_key, - &task_config.vdaf, - consumed, - ) - } - }) - .collect::, _>>()?) + batch_selector: &BatchSelector, + agg_share: &DapAggregateShare, + ) -> HpkeCiphertext { + self.task_config + .vdaf + .produce_helper_encrypted_agg_share( + &self.task_config.collector_hpke_config, + &self.task_id, + batch_selector, + agg_share, + self.task_config.version, + ) + .unwrap() } -} - -#[async_trait(?Send)] -impl DapAggregator for MockAggregator { - // The lifetimes on the traits ensure that we can return a reference to a task config stored by - // the DapAggregator. (See DaphneWorkerConfig for an example.) For simplicity, MockAggregator - // clones the task config as needed. - type WrappedDapTaskConfig<'a> = DapTaskConfig; - async fn unauthorized_reason( + /// Collector: Consume the aggregate shares. + pub async fn consume_encrypted_agg_shares( &self, - task_config: &DapTaskConfig, - req: &DapRequest, - ) -> Result, DapError> { - self.bearer_token_authorized(task_config, req).await + batch_selector: &BatchSelector, + report_count: u64, + enc_agg_shares: Vec, + ) -> DapAggregateResult { + self.task_config + .vdaf + .consume_encrypted_agg_shares( + &self.collector_hpke_receiver_config, + &self.task_id, + batch_selector, + report_count, + enc_agg_shares, + self.task_config.version, + ) + .await + .unwrap() } - fn get_global_config(&self) -> &DapGlobalConfig { - &self.global_config - } + /// Generate a set of reports, aggregate them, and unshard the result. + pub async fn roundtrip(&mut self, measurements: Vec) -> DapAggregateResult { + let batch_selector = BatchSelector::TimeInterval { + batch_interval: Interval { + start: self.now, + duration: 3600, + }, + }; - fn taskprov_vdaf_verify_key_init(&self) -> Option<&[u8; 32]> { - Some(&self.taskprov_vdaf_verify_key_init) - } + // Clients: Shard + let reports = self.produce_reports(measurements); - fn taskprov_collector_hpke_config(&self) -> Option<&HpkeConfig> { - Some(&self.collector_hpke_config) - } + // Aggregators: Preparation + let DapLeaderTransition::Continue(leader_state, agg_job_init_req) = + self.produce_agg_job_init_req(reports).await + else { + panic!("unexpected transition"); + }; + let DapHelperTransition::Continue(helper_state, agg_job_resp) = + self.handle_agg_job_init_req(&agg_job_init_req).await + else { + panic!("unexpected transition"); + }; + let got = DapHelperState::get_decoded(&self.task_config.vdaf, &helper_state.get_encoded()) + .expect("failed to decode helper state"); + assert_eq!(got, helper_state); - fn taskprov_opt_out_reason( - &self, - _task_config: &DapTaskConfig, - ) -> Result, DapError> { - // Always opt-in. - Ok(None) - } + let DapLeaderTransition::Uncommitted(uncommitted, agg_cont) = + self.handle_agg_job_resp(leader_state, agg_job_resp) + else { + panic!("unexpected transition"); + }; + let DapHelperTransition::Finish(helper_out_shares, agg_job_resp) = + self.handle_agg_job_cont_req(helper_state, &agg_cont) + else { + panic!("unexpected transition"); + }; + let leader_out_shares = self.handle_final_agg_job_resp(uncommitted, agg_job_resp); + let report_count = u64::try_from(leader_out_shares.len()).unwrap(); - async fn taskprov_put( - &self, - req: &DapRequest, - task_config: DapTaskConfig, - ) -> Result<(), DapError> { - let task_id = req.task_id().map_err(DapError::Abort)?; - let mut tasks = self.tasks.lock().expect("tasks: lock failed"); - tasks.deref_mut().insert(task_id.clone(), task_config); - Ok(()) - } + // Leader: Aggregation + let leader_agg_share = DapAggregateShare::try_from_out_shares(leader_out_shares).unwrap(); + let leader_encrypted_agg_share = + self.produce_leader_encrypted_agg_share(&batch_selector, &leader_agg_share); - async fn get_task_config_for<'s>( - &self, - task_id: Cow<'s, TaskId>, - ) -> Result>, DapError> { - let tasks = self.tasks.lock().expect("tasks: lock failed"); - Ok(tasks.get(task_id.as_ref()).cloned()) - } + // Helper: Aggregation + let helper_agg_share = DapAggregateShare::try_from_out_shares(helper_out_shares).unwrap(); + let helper_encrypted_agg_share = + self.produce_helper_encrypted_agg_share(&batch_selector, &helper_agg_share); - fn get_current_time(&self) -> Time { - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_secs() + // Collector: Unshard + self.consume_encrypted_agg_shares( + &batch_selector, + report_count, + vec![leader_encrypted_agg_share, helper_encrypted_agg_share], + ) + .await } +} - async fn is_batch_overlapping( - &self, - task_id: &TaskId, - batch_sel: &BatchSelector, - ) -> Result { - let task_config = self - .get_task_config_for(Cow::Borrowed(task_id)) - .await - .unwrap() - .expect("tasks: unrecognized task"); - let guard = self.agg_store.lock().expect("agg_store: failed to lock"); - let agg_store = if let Some(agg_store) = guard.get(task_id) { - agg_store - } else { - return Ok(false); - }; - - for bucket in task_config.batch_span_for_sel(batch_sel)? { - if let Some(inner_agg_store) = agg_store.get(&bucket.to_owned_bucket()) { - if inner_agg_store.collected { - return Ok(true); - } +// These are declarative macros which let us generate a test point for +// each DapVersion given a test which takes a version parameter. +// +// E.g. currently +// +// async_test_versions! { something } +// +// would generate async tests named +// +// something_draft02 +// +// and +// +// something_draft05 +// +// that called something(version) with the appropriate version. +// +// We use the "paste" crate to get a macro that can paste tokens and also +// fiddle case. +#[macro_export] +macro_rules! test_version { + ($fname:ident, $version:ident) => { + ::paste::paste! { + #[test] + fn [<$fname _ $version:lower>]() { + $fname ($crate::DapVersion::$version); } } + }; +} - Ok(false) - } - - async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result { - let guard = self.agg_store.lock().expect("agg_store: failed to lock"); - if let Some(agg_store) = guard.get(task_id) { - Ok(agg_store - .get(&DapBatchBucketOwned::FixedSize { - batch_id: batch_id.clone(), - }) - .is_some()) - } else { - Ok(false) - } - } - - async fn put_out_shares( - &self, - task_id: &TaskId, - task_config: &DapTaskConfig, - part_batch_sel: &PartialBatchSelector, - out_shares: Vec, - ) -> Result, DapError> { - let mut report_store_guard = self - .report_store - .lock() - .expect("report_store: failed to lock"); - let report_store = report_store_guard.entry(task_id.clone()).or_default(); - let mut agg_store_guard = self.agg_store.lock().expect("agg_store: failed to lock"); - let agg_store = agg_store_guard.entry(task_id.clone()).or_default(); - - let mut replayed = HashSet::new(); - for (bucket, out_shares) in task_config - .batch_span_for_out_shares(part_batch_sel, out_shares)? - .into_iter() - { - for out_share in out_shares.into_iter() { - if !report_store.processed.contains(&out_share.report_id) { - // Mark report processed. - report_store.processed.insert(out_share.report_id.clone()); +#[macro_export] +macro_rules! test_versions { + ($($fname:ident),*) => { + $( + $crate::test_version! { $fname, Draft02 } + $crate::test_version! { $fname, Draft05 } + )* + }; +} - // Add to aggregate share. - agg_store - .entry(bucket.to_owned_bucket()) - .or_default() - .agg_share - .merge(DapAggregateShare::try_from_out_shares([out_share])?)?; - } else { - replayed.insert(out_share.report_id); - } +#[macro_export] +macro_rules! async_test_version { + ($fname:ident, $version:ident) => { + ::paste::paste! { + #[tokio::test] + async fn [<$fname _ $version:lower>]() { + $fname ($crate::DapVersion::$version) . await; } } + }; +} - Ok(replayed) - } +#[macro_export] +macro_rules! async_test_versions { + ($($fname:ident),*) => { + $( + $crate::async_test_version! { $fname, Draft02 } + $crate::async_test_version! { $fname, Draft05 } + )* + }; +} - async fn get_agg_share( - &self, - task_id: &TaskId, - batch_sel: &BatchSelector, - ) -> Result { - let task_config = self - .get_task_config_for(Cow::Borrowed(task_id)) - .await - .unwrap() - .expect("tasks: unrecognized task"); - let mut guard = self.agg_store.lock().expect("agg_store: failed to lock"); - let agg_store = guard.entry(task_id.clone()).or_default(); +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub(crate) enum MetaAggregationJobIdOwned { + Draft02(Draft02AggregationJobId), + Draft05(AggregationJobId), +} - // Fetch aggregate shares. - let mut agg_share = DapAggregateShare::default(); - for bucket in task_config.batch_span_for_sel(batch_sel)? { - if let Some(inner_agg_store) = agg_store.get(&bucket.to_owned_bucket()) { - if inner_agg_store.collected { - return Err(DapError::Abort(DapAbort::batch_overlap(task_id, batch_sel))); - } else { - agg_share.merge(inner_agg_store.agg_share.clone())?; - } +impl From<&MetaAggregationJobId<'_>> for MetaAggregationJobIdOwned { + fn from(agg_job_id: &MetaAggregationJobId<'_>) -> Self { + match agg_job_id { + MetaAggregationJobId::Draft02(agg_job_id) => { + Self::Draft02(agg_job_id.clone().into_owned()) + } + MetaAggregationJobId::Draft05(agg_job_id) => { + Self::Draft05(agg_job_id.clone().into_owned()) } } - - Ok(agg_share) } +} - async fn mark_collected( - &self, - task_id: &TaskId, - batch_sel: &BatchSelector, - ) -> Result<(), DapError> { - let task_config = self.unchecked_get_task_config(task_id).await; - let mut guard = self.agg_store.lock().expect("agg_store: failed to lock"); - let agg_store = guard.entry(task_id.clone()).or_default(); +#[derive(Eq, Hash, PartialEq)] +pub enum DapBatchBucketOwned { + FixedSize { batch_id: BatchId }, + TimeInterval { batch_window: Time }, +} - for bucket in task_config.batch_span_for_sel(batch_sel)? { - if let Some(inner_agg_store) = agg_store.get_mut(&bucket.to_owned_bucket()) { - inner_agg_store.collected = true; - } +impl From for PartialBatchSelector { + fn from(bucket: DapBatchBucketOwned) -> Self { + match bucket { + DapBatchBucketOwned::FixedSize { batch_id } => Self::FixedSizeByBatchId { batch_id }, + DapBatchBucketOwned::TimeInterval { .. } => Self::TimeInterval, } - - Ok(()) } +} - async fn current_batch(&self, task_id: &TaskId) -> std::result::Result { - let task_config = self.unchecked_get_task_config(task_id).await; - if let Some(id) = self.current_batch_id(task_id, &task_config) { - Ok(id) - } else { - Err(DapError::Abort(DapAbort::BadRequest( - "unknown version".to_string(), - ))) +impl<'a> DapBatchBucket<'a> { + // TODO(cjpatton) Figure out how to use `ToOwned` properly. The lifetime parameter causes + // confusion for the compiler for implementing `Borrow`. The goal is to avoid cloning the + // bucket each time we need to check if it exists in the set. + pub(crate) fn to_owned_bucket(&self) -> DapBatchBucketOwned { + match self { + Self::FixedSize { batch_id } => DapBatchBucketOwned::FixedSize { + batch_id: (*batch_id).clone(), + }, + Self::TimeInterval { batch_window } => DapBatchBucketOwned::TimeInterval { + batch_window: *batch_window, + }, } } +} - fn metrics(&self) -> &DaphneMetrics { - &self.metrics - } +pub struct MockAggregatorReportSelector(pub(crate) TaskId); - fn audit_log(&self) -> &dyn AuditLog { - &self.audit_log +#[derive(Default)] +pub struct MockAuditLog(AtomicU32); + +impl MockAuditLog { + #[allow(dead_code)] + pub(crate) fn invocations(&self) -> u32 { + self.0.load(Ordering::Relaxed) } } -#[async_trait(?Send)] -impl DapHelper for MockAggregator { - async fn put_helper_state_if_not_exists( +impl AuditLog for MockAuditLog { + fn on_aggregation_job( &self, - task_id: &TaskId, - agg_job_id: &MetaAggregationJobId, - helper_state: &DapHelperState, - ) -> Result { - let helper_state_info = HelperStateInfo { - task_id: task_id.clone(), - agg_job_id_owned: agg_job_id.into(), - }; + _host: &str, + _task_id: &TaskId, + _task_config: &DapTaskConfig, + _report_count: u64, + _action: AggregationJobAuditAction, + ) { + self.0.fetch_add(1, Ordering::Relaxed); + } +} - let mut helper_state_store_mutex_guard = self - .helper_state_store - .lock() - .map_err(|e| fatal_error!(err = ?e))?; +pub struct MockAggregator { + pub global_config: DapGlobalConfig, + pub tasks: Arc>>, + pub hpke_receiver_config_list: Vec, + pub leader_token: BearerToken, + pub collector_token: Option, // Not set by Helper + pub report_store: Arc>>, + pub leader_state_store: Arc>>, + pub helper_state_store: Arc>>, + pub agg_store: Arc>>>, + pub collector_hpke_config: HpkeConfig, + pub metrics: DaphneMetrics, + pub audit_log: MockAuditLog, - let helper_state_store = helper_state_store_mutex_guard.deref_mut(); + // taskprov + pub taskprov_vdaf_verify_key_init: [u8; 32], + pub taskprov_leader_token: BearerToken, + pub taskprov_collector_token: Option, // Not set by Helper - if helper_state_store.contains_key(&helper_state_info) { - return Ok(false); - } + // Leader: Reference to peer. Used to simulate HTTP requests from Leader to Helper, i.e., + // implement `DapLeader::send_http_post()` for `MockAggregator`. Not set by the Helper. + pub peer: Option>, +} - // NOTE: This code is only correct for VDAFs with exactly one round of preparation. - // For VDAFs with more rounds, the helper state blob will need to be updated here. - helper_state_store.insert(helper_state_info, helper_state.clone()); +impl MockAggregator { + #[allow(clippy::too_many_arguments)] + pub fn new_helper( + tasks: impl IntoIterator, + hpke_receiver_config_list: impl IntoIterator, + global_config: DapGlobalConfig, + leader_token: BearerToken, + collector_hpke_config: HpkeConfig, + registry: &prometheus::Registry, + taskprov_vdaf_verify_key_init: [u8; 32], + taskprov_leader_token: BearerToken, + ) -> Self { + Self { + global_config, + tasks: Arc::new(Mutex::new(tasks.into_iter().collect())), + hpke_receiver_config_list: hpke_receiver_config_list.into_iter().collect(), + leader_token, + collector_token: None, + report_store: Default::default(), + leader_state_store: Default::default(), + helper_state_store: Default::default(), + agg_store: Default::default(), + collector_hpke_config, + metrics: DaphneMetrics::register(registry, Some("test_helper")).unwrap(), + audit_log: MockAuditLog::default(), + taskprov_vdaf_verify_key_init, + taskprov_leader_token, + taskprov_collector_token: None, + peer: None, + } + } - Ok(true) + #[allow(clippy::too_many_arguments)] + pub fn new_leader( + tasks: impl IntoIterator, + hpke_receiver_config_list: impl IntoIterator, + global_config: DapGlobalConfig, + leader_token: BearerToken, + collector_token: impl Into>, + collector_hpke_config: HpkeConfig, + registry: &prometheus::Registry, + taskprov_vdaf_verify_key_init: [u8; 32], + taskprov_leader_token: BearerToken, + taskprov_collector_token: impl Into>, + peer: impl Into>>, + ) -> Self { + Self { + global_config, + tasks: Arc::new(Mutex::new(tasks.into_iter().collect())), + hpke_receiver_config_list: hpke_receiver_config_list.into_iter().collect(), + leader_token, + collector_token: collector_token.into(), + report_store: Default::default(), + leader_state_store: Default::default(), + helper_state_store: Default::default(), + agg_store: Default::default(), + collector_hpke_config, + metrics: DaphneMetrics::register(registry, Some("test_leader")).unwrap(), + audit_log: MockAuditLog::default(), + taskprov_vdaf_verify_key_init, + taskprov_leader_token, + taskprov_collector_token: taskprov_collector_token.into(), + peer: peer.into(), + } } - async fn get_helper_state( + /// Conducts checks on a received report to see whether: + /// 1) the report falls into a batch that has been already collected, or + /// 2) the report has been submitted by the client in the past. + async fn check_report_early_fail( &self, task_id: &TaskId, - agg_job_id: &MetaAggregationJobId, - ) -> Result, DapError> { - let helper_state_info = HelperStateInfo { - task_id: task_id.clone(), - agg_job_id_owned: agg_job_id.into(), - }; + bucket: &DapBatchBucketOwned, + metadata: &ReportMetadata, + ) -> Option { + // Check AggStateStore to see whether the report is part of a batch that has already + // been collected. + let mut guard = self.agg_store.lock().expect("agg_store: failed to lock"); + let agg_store = guard.entry(task_id.clone()).or_default(); + if matches!(agg_store.get(bucket), Some(inner_agg_store) if inner_agg_store.collected) { + return Some(TransitionFailure::BatchCollected); + } - let mut helper_state_store_mutex_guard = self - .helper_state_store + // Check whether the same report has been submitted in the past. + let mut guard = self + .report_store .lock() - .map_err(|e| fatal_error!(err = ?e))?; - - let helper_state_store = helper_state_store_mutex_guard.deref_mut(); - - // NOTE: This code is only correct for VDAFs with exactly one round of preparation. - // For VDAFs with more rounds, the helper state blob will need to be updated here. - if helper_state_store.contains_key(&helper_state_info) { - let helper_state = helper_state_store.remove(&helper_state_info); - - return Ok(helper_state); + .expect("report_store: failed to lock"); + let report_store = guard.entry(task_id.clone()).or_default(); + if report_store.processed.contains(&metadata.id) { + return Some(TransitionFailure::ReportReplayed); } - Ok(None) + None } -} -#[async_trait(?Send)] -impl DapLeader for MockAggregator { - type ReportSelector = MockAggregatorReportSelector; + fn get_hpke_receiver_config_for(&self, hpke_config_id: u8) -> Option<&HpkeReceiverConfig> { + self.hpke_receiver_config_list + .iter() + .find(|&hpke_receiver_config| hpke_config_id == hpke_receiver_config.config.id) + } - async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError> { - let bucket = self - .assign_report_to_bucket(report, task_id) - .await - .expect("could not determine batch for report"); - - // Check whether Report has been collected or replayed. - if let Some(transition_failure) = self - .check_report_early_fail(task_id, &bucket, &report.report_metadata) - .await - { - return Err(DapError::Transition(transition_failure)); - }; - - // Store Report for future processing. - let mut guard = self - .report_store - .lock() - .expect("report_store: failed to lock"); - let queue = guard - .get_mut(task_id) - .expect("report_store: unrecognized task") - .pending - .entry(bucket) - .or_default(); - queue.push_back(report.clone()); - Ok(()) - } - - async fn get_reports( + /// Assign the report to a bucket. + /// + /// TODO(cjpatton) Figure out if we can avoid returning and owned thing here. + async fn assign_report_to_bucket( &self, - report_sel: &MockAggregatorReportSelector, - ) -> Result>>, DapError> { - let task_id = &report_sel.0; - let task_config = self.unchecked_get_task_config(task_id).await; - let mut guard = self - .report_store - .lock() - .expect("report_store: failed to lock"); - let report_store = guard.entry(task_id.clone()).or_default(); + report: &Report, + task_id: &TaskId, + ) -> Option { + let mut rng = thread_rng(); + let task_config = self + .get_task_config_for(Cow::Borrowed(task_id)) + .await + .unwrap() + .expect("tasks: unrecognized task"); - // For the task indicated by the report selector, choose a single report to aggregate. match task_config.query { - DapQueryConfig::TimeInterval { .. } => { - // Aggregate reports in any order. - let mut reports = Vec::new(); - for (_bucket, queue) in report_store.pending.iter_mut() { - if !queue.is_empty() { - reports.append(&mut queue.drain(..1).collect()); - break; - } - } - return Ok(HashMap::from([( - task_id.clone(), - HashMap::from([(PartialBatchSelector::TimeInterval, reports)]), - )])); - } + // For fixed-size queries, the bucket corresponds to a single batch. DapQueryConfig::FixedSize { .. } => { - // Drain the batch that is being filled. + let mut guard = self + .leader_state_store + .lock() + .expect("leader_state_store: failed to lock"); + let leader_state_store = guard.entry(task_id.clone()).or_default(); - let bucket = if let Some(batch_id) = self.current_batch_id(task_id, &task_config) { - DapBatchBucketOwned::FixedSize { batch_id } - } else { - return Ok(HashMap::default()); - }; + // Assign the report to the first unsaturated batch. + for (batch_id, report_count) in leader_state_store.batch_queue.iter_mut() { + if *report_count < task_config.min_batch_size { + *report_count += 1; + return Some(DapBatchBucketOwned::FixedSize { + batch_id: batch_id.clone(), + }); + } + } - let queue = report_store - .pending - .get_mut(&bucket) - .expect("report_store: unknown bucket"); - let reports = queue.drain(..1).collect(); - return Ok(HashMap::from([( - task_id.clone(), - HashMap::from([(bucket.into(), reports)]), - )])); + // No unsaturated batch exists, so create a new batch. + let batch_id = BatchId(rng.gen()); + leader_state_store + .batch_queue + .push_back((batch_id.clone(), 1)); + Some(DapBatchBucketOwned::FixedSize { batch_id }) } + + // For time-interval queries, the bucket is the batch window computed by truncating the + // report timestamp. + DapQueryConfig::TimeInterval => Some(DapBatchBucketOwned::TimeInterval { + batch_window: task_config.quantized_time_lower_bound(report.report_metadata.time), + }), } } - // Called after receiving a CollectReq from Collector. - async fn init_collect_job( + /// Return the ID of the batch currently being filled with reports. Panics unless the task is + /// configured for fixed-size queries. + pub(crate) fn current_batch_id( &self, task_id: &TaskId, - collect_job_id: &Option, - collect_req: &CollectionReq, - ) -> Result { - let mut rng = thread_rng(); - let task_config = self - .get_task_config_for(Cow::Borrowed(task_id)) - .await? - .ok_or_else(|| fatal_error!(err = "task not found"))?; + task_config: &DapTaskConfig, + ) -> Option { + // Calling current_batch() is only well-defined for fixed-size tasks. + assert_matches!(task_config.query, DapQueryConfig::FixedSize { .. }); - let mut leader_state_store_mutex_guard = self + let guard = self .leader_state_store .lock() - .map_err(|e| fatal_error!(err = ?e))?; - let leader_state_store = leader_state_store_mutex_guard.deref_mut(); - - // Construct a new Collect URI for this CollectReq. - let collect_id = collect_job_id - .as_ref() - .map_or_else(|| CollectionJobId(rng.gen()), |cid| cid.clone()); - let collect_uri = task_config - .leader_url - .join(&format!( - "collect/task/{}/req/{}", - task_id.to_base64url(), - collect_id.to_base64url(), - )) - .map_err(|e| fatal_error!(err = ?e))?; + .expect("leader_state_store: failed to lock"); + let leader_state_store = guard + .get(task_id) + .expect("leader_state_store: unrecognized task"); - // Store Collect ID and CollectReq into LeaderState. - let leader_state = leader_state_store.entry(task_id.clone()).or_default(); - leader_state.collect_ids.push_back(collect_id.clone()); - let collect_job_state = CollectJobState::Pending(collect_req.clone()); - leader_state - .collect_jobs - .insert(collect_id, collect_job_state); + leader_state_store + .batch_queue + .front() + .cloned() // TODO(cjpatton) Avoid clone by returning MutexGuard + .map(|(batch_id, _report_count)| batch_id) + } - Ok(collect_uri) + pub(crate) async fn unchecked_get_task_config(&self, task_id: &TaskId) -> DapTaskConfig { + self.get_task_config_for(Cow::Borrowed(task_id)) + .await + .expect("encountered unexpected error") + .expect("missing task config") } +} - // Called to retrieve completed CollectResp at the request of Collector. - async fn poll_collect_job( - &self, - task_id: &TaskId, - collect_id: &CollectionJobId, - ) -> Result { - let mut leader_state_store_mutex_guard = self - .leader_state_store - .lock() - .map_err(|e| fatal_error!(err = ?e))?; - let leader_state_store = leader_state_store_mutex_guard.deref_mut(); +#[async_trait(?Send)] +impl BearerTokenProvider for MockAggregator { + type WrappedBearerToken<'a> = &'a BearerToken; - let leader_state = leader_state_store - .get(task_id) - .ok_or_else(|| fatal_error!(err = "collect job not found for task_id", %task_id))?; - if let Some(collect_job_state) = leader_state.collect_jobs.get(collect_id) { - match collect_job_state { - CollectJobState::Pending(_) => Ok(DapCollectJob::Pending), - CollectJobState::Processed(resp) => Ok(DapCollectJob::Done(resp.clone())), - } + async fn get_leader_bearer_token_for<'s>( + &'s self, + _task_id: &'s TaskId, + task_config: &DapTaskConfig, + ) -> Result>, DapError> { + if task_config.taskprov { + Ok(Some(&self.taskprov_leader_token)) } else { - Ok(DapCollectJob::Unknown) + Ok(Some(&self.leader_token)) } } - // Called to retrieve pending CollectReq. - async fn get_pending_collect_jobs( - &self, - ) -> Result, DapError> { - let mut leader_state_store_mutex_guard = self - .leader_state_store - .lock() - .map_err(|e| fatal_error!(err = ?e))?; - let leader_state_store = leader_state_store_mutex_guard.deref_mut(); - - let mut res = Vec::new(); - for (task_id, leader_state) in leader_state_store.iter() { - // Iterate over collect IDs and copy them and their associated requests to the response. - for collect_id in leader_state.collect_ids.iter() { - if let CollectJobState::Pending(collect_req) = - leader_state.collect_jobs.get(collect_id).unwrap() - { - res.push((task_id.clone(), collect_id.clone(), collect_req.clone())); - } - } + async fn get_collector_bearer_token_for<'s>( + &'s self, + _task_id: &'s TaskId, + task_config: &DapTaskConfig, + ) -> Result>, DapError> { + if task_config.taskprov { + Ok(Some(self.taskprov_collector_token.as_ref().expect( + "MockAggregator not configured with taskprov collector token", + ))) + } else { + Ok(Some(self.collector_token.as_ref().expect( + "MockAggregator not configured with collector token", + ))) } - Ok(res) } +} - async fn finish_collect_job( - &self, - task_id: &TaskId, - collect_id: &CollectionJobId, - collect_resp: &Collection, - ) -> Result<(), DapError> { - let mut leader_state_store_mutex_guard = self - .leader_state_store - .lock() - .map_err(|e| fatal_error!(err = ?e))?; - let leader_state_store = leader_state_store_mutex_guard.deref_mut(); - - let leader_state = leader_state_store - .get_mut(task_id) - .ok_or_else(|| fatal_error!(err = "collect job not found for task_id", %task_id))?; - let collect_job = leader_state - .collect_jobs - .get_mut(collect_id) - .ok_or_else(|| fatal_error!(err = "collect job not found for collect_id", %task_id))?; +#[async_trait(?Send)] +impl HpkeDecrypter for MockAggregator { + type WrappedHpkeConfig<'a> = &'a HpkeConfig; - // Remove the batch from the batch queue. - if let PartialBatchSelector::FixedSizeByBatchId { ref batch_id } = - collect_resp.part_batch_sel - { - leader_state - .batch_queue - .retain(|(id, _report_count)| id != batch_id); + async fn get_hpke_config_for<'s>( + &'s self, + _version: DapVersion, + task_id: Option<&TaskId>, + ) -> Result, DapError> { + if self.hpke_receiver_config_list.is_empty() { + return Err(fatal_error!(err = "empty HPKE receiver config list")); } - match collect_job { - CollectJobState::Pending(_) => { - // Mark collect job as Processed. - *collect_job = CollectJobState::Processed(collect_resp.clone()); - - // Remove collect ID from queue. - let index = leader_state - .collect_ids - .iter() - .position(|r| r == collect_id) - .unwrap(); - leader_state.collect_ids.remove(index); - - Ok(()) - } - CollectJobState::Processed(_) => { - Err(fatal_error!(err = "tried to overwrite collect response")) - } + // Aggregators MAY abort if the HPKE config request does not specify a task ID. While not + // required for MockAggregator, we simulate this behavior for testing purposes. + // + // TODO(cjpatton) To make this clearer, have MockAggregator store a map from task IDs to + // HPKE receiver configs. + if task_id.is_none() { + return Err(DapError::Abort(DapAbort::MissingTaskId)); } + + // Always advertise the first HPKE config in the list. + Ok(&self.hpke_receiver_config_list[0].config) } - async fn send_http_post(&self, req: DapRequest) -> Result { - match req.media_type { - DapMediaType::AggregationJobInitReq | DapMediaType::AggregationJobContinueReq => { - Ok(self - .peer - .as_ref() - .expect("peer not configured") - .handle_agg_job_req(&req) - .await - .expect("peer aborted unexpectedly")) - } - DapMediaType::AggregateShareReq => Ok(self - .peer - .as_ref() - .expect("peer not configured") - .handle_agg_share_req(&req) - .await - .expect("peer aborted unexpectedly")), - _ => unreachable!("unhandled media type: {:?}", req.media_type), - } + async fn can_hpke_decrypt(&self, _task_id: &TaskId, config_id: u8) -> Result { + Ok(self.get_hpke_receiver_config_for(config_id).is_some()) } - async fn send_http_put(&self, req: DapRequest) -> Result { - if req.media_type == DapMediaType::AggregationJobInitReq { - Ok(self - .peer - .as_ref() - .expect("peer not configured") - .handle_agg_job_req(&req) - .await - .expect("peer aborted unexpectedly")) + async fn hpke_decrypt( + &self, + _task_id: &TaskId, + info: &[u8], + aad: &[u8], + ciphertext: &HpkeCiphertext, + ) -> Result, DapError> { + if let Some(hpke_receiver_config) = self.get_hpke_receiver_config_for(ciphertext.config_id) + { + Ok(hpke_receiver_config.decrypt(info, aad, &ciphertext.enc, &ciphertext.payload)?) } else { - unreachable!("unhandled media type: {:?}", req.media_type) + Err(DapError::Transition(TransitionFailure::HpkeUnknownConfigId)) } } } -/// Information associated to a certain helper state for a given task ID and aggregate job ID. -#[derive(Clone, Eq, Hash, PartialEq, Deserialize, Serialize)] -pub(crate) struct HelperStateInfo { - task_id: TaskId, - agg_job_id_owned: MetaAggregationJobIdOwned, +#[async_trait(?Send)] +impl DapAuthorizedSender for MockAggregator { + async fn authorize( + &self, + task_id: &TaskId, + task_config: &DapTaskConfig, + media_type: &DapMediaType, + _payload: &[u8], + ) -> Result { + Ok(self + .authorize_with_bearer_token(task_id, task_config, media_type) + .await? + .clone()) + } } -/// Stores the reports received from Clients. -#[derive(Default)] -pub(crate) struct ReportStore { - pub(crate) pending: HashMap>, - pub(crate) processed: HashSet, -} +#[async_trait(?Send)] +impl DapReportInitializer for MockAggregator { + async fn initialize_reports<'req>( + &self, + is_leader: bool, + task_id: &TaskId, + task_config: &DapTaskConfig, + part_batch_sel: &PartialBatchSelector, + consumed_reports: Vec>, + ) -> Result>, DapError> { + let span = task_config.batch_span_for_meta( + part_batch_sel, + consumed_reports.iter().filter(|report| report.is_ready()), + )?; -/// Stores the state of the collect job. -pub(crate) enum CollectJobState { - Pending(CollectionReq), - Processed(Collection), -} + let mut early_fails = HashMap::new(); + for (bucket, reports_consumed_per_bucket) in span.iter() { + for metadata in reports_consumed_per_bucket + .iter() + .map(|report| report.metadata()) + { + // Check whether Report has been collected or replayed. + if let Some(transition_failure) = self + .check_report_early_fail(task_id, &bucket.to_owned_bucket(), metadata) + .await + { + early_fails.insert(metadata.id.clone(), transition_failure); + }; + } + } -/// LeaderState keeps track of the following: -/// * Collect IDs in their order of arrival. -/// * The state of the collect job associated to the Collect ID. -#[derive(Default)] -pub(crate) struct LeaderState { - collect_ids: VecDeque, - collect_jobs: HashMap, - batch_queue: VecDeque<(BatchId, u64)>, // Batch ID, batch size + Ok(consumed_reports + .into_iter() + .map(|consumed| { + if let Some(failure) = early_fails.get(&consumed.metadata().id) { + Ok(EarlyReportStateInitialized::Rejected { + metadata: Cow::Owned(consumed.metadata().clone()), + failure: *failure, + }) + } else { + EarlyReportStateInitialized::initialize( + is_leader, + &task_config.vdaf_verify_key, + &task_config.vdaf, + consumed, + ) + } + }) + .collect::, _>>()?) + } } -/// AggStore keeps track of the following: -/// * Aggregate share -/// * Whether this aggregate share has been collected -#[derive(Default)] -pub(crate) struct AggStore { - pub(crate) agg_share: DapAggregateShare, - pub(crate) collected: bool, -} +#[async_trait(?Send)] +impl DapAggregator for MockAggregator { + // The lifetimes on the traits ensure that we can return a reference to a task config stored by + // the DapAggregator. (See DaphneWorkerConfig for an example.) For simplicity, MockAggregator + // clones the task config as needed. + type WrappedDapTaskConfig<'a> = DapTaskConfig; -// These are declarative macros which let us generate a test point for -// each DapVersion given a test which takes a version parameter. -// -// E.g. currently -// -// async_test_versions! { something } -// -// would generate async tests named -// -// something_draft02 -// -// and -// -// something_draft05 -// -// that called something(version) with the appropriate version. -// -// We use the "paste" crate to get a macro that can paste tokens and also -// fiddle case. + async fn unauthorized_reason( + &self, + task_config: &DapTaskConfig, + req: &DapRequest, + ) -> Result, DapError> { + self.bearer_token_authorized(task_config, req).await + } -#[macro_export] -macro_rules! test_version { - ($fname:ident, $version:ident) => { - paste! { - #[test] - fn [<$fname _ $version:lower>]() { - $fname (DapVersion::$version); - } - } - }; -} + fn get_global_config(&self) -> &DapGlobalConfig { + &self.global_config + } -#[macro_export] -macro_rules! test_versions { - ($($fname:ident),*) => { - $( - test_version! { $fname, Draft02 } - test_version! { $fname, Draft05 } - )* - }; -} + fn taskprov_vdaf_verify_key_init(&self) -> Option<&[u8; 32]> { + Some(&self.taskprov_vdaf_verify_key_init) + } -#[macro_export] -macro_rules! async_test_version { - ($fname:ident, $version:ident) => { - paste! { - #[tokio::test] - async fn [<$fname _ $version:lower>]() { - $fname (DapVersion::$version) . await; - } - } - }; -} + fn taskprov_collector_hpke_config(&self) -> Option<&HpkeConfig> { + Some(&self.collector_hpke_config) + } -#[macro_export] -macro_rules! async_test_versions { - ($($fname:ident),*) => { - $( - async_test_version! { $fname, Draft02 } - async_test_version! { $fname, Draft05 } - )* - }; -} + fn taskprov_opt_out_reason( + &self, + _task_config: &DapTaskConfig, + ) -> Result, DapError> { + // Always opt-in. + Ok(None) + } -/// Helper macro used by `assert_metrics_include`. -// -// TODO(cjpatton) Figure out how to bake this into `asssert_metrics_include` so that users don't -// have to import both macros. -#[cfg(test)] -#[macro_export] -macro_rules! assert_metrics_include_auxiliary_function { - ($set:expr, $k:tt: $v:expr,) => {{ - let line = format!("{} {}", $k, $v); - $set.insert(line); - }}; + async fn taskprov_put( + &self, + req: &DapRequest, + task_config: DapTaskConfig, + ) -> Result<(), DapError> { + let task_id = req.task_id().map_err(DapError::Abort)?; + let mut tasks = self.tasks.lock().expect("tasks: lock failed"); + tasks.deref_mut().insert(task_id.clone(), task_config); + Ok(()) + } - ($set:expr, $k:tt: $v:expr, $($ks:tt: $vs:expr),+,) => {{ - let line = format!("{} {}", $k, $v); - $set.insert(line); - assert_metrics_include_auxiliary_function!($set, $($ks: $vs),+,) - }} -} + async fn get_task_config_for<'s>( + &self, + task_id: Cow<'s, TaskId>, + ) -> Result>, DapError> { + let tasks = self.tasks.lock().expect("tasks: lock failed"); + Ok(tasks.get(task_id.as_ref()).cloned()) + } -/// Gather metrics from a registry and assert that a list of metrics are present and have the -/// correct value. For example: -/// ``` -/// let registry = prometheus::Registry::new(); -/// -/// // ... Register a metric called "report_counter" and use it. -/// -/// assert_metrics_include!(t.helper_prometheus_registry, { -/// r#"report_counter{status="aggregated"}"#: 23, -/// }); -/// ``` -#[cfg(test)] -#[macro_export] -macro_rules! assert_metrics_include { - ($registry:expr, {$($ks:tt: $vs:expr),+,}) => {{ - use prometheus::{Encoder, TextEncoder}; - use std::collections::HashSet; + fn get_current_time(&self) -> Time { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() + } - let mut want: HashSet = HashSet::new(); - assert_metrics_include_auxiliary_function!(&mut want, $($ks: $vs),+,); + async fn is_batch_overlapping( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result { + let task_config = self + .get_task_config_for(Cow::Borrowed(task_id)) + .await + .unwrap() + .expect("tasks: unrecognized task"); + let guard = self.agg_store.lock().expect("agg_store: failed to lock"); + let agg_store = if let Some(agg_store) = guard.get(task_id) { + agg_store + } else { + return Ok(false); + }; - // Encode the metrics and iterate over each line. For each line, if the line appears in the - // list of expected output lines, then remove it. - let mut got_buf = Vec::new(); - let encoder = TextEncoder::new(); - encoder.encode(&$registry.gather(), &mut got_buf).unwrap(); - let got_str = String::from_utf8(got_buf).unwrap(); - for line in got_str.split('\n') { - want.remove(line); + for bucket in task_config.batch_span_for_sel(batch_sel)? { + if let Some(inner_agg_store) = agg_store.get(&bucket.to_owned_bucket()) { + if inner_agg_store.collected { + return Ok(true); + } + } } - // The metrics contain the expected lines if the the set is now empty. - if !want.is_empty() { - panic!("unexpected metrics: got:\n{}\nmust contain:\n{}\n", - got_str, want.into_iter().collect::>().join("\n")); + Ok(false) + } + + async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result { + let guard = self.agg_store.lock().expect("agg_store: failed to lock"); + if let Some(agg_store) = guard.get(task_id) { + Ok(agg_store + .get(&DapBatchBucketOwned::FixedSize { + batch_id: batch_id.clone(), + }) + .is_some()) + } else { + Ok(false) } - }} -} + } -/// Scaffolding for testing the aggregation flow. -pub struct AggregationJobTest { - // task parameters - pub(crate) task_id: TaskId, - pub(crate) task_config: DapTaskConfig, - pub(crate) leader_hpke_receiver_config: HpkeReceiverConfig, - pub(crate) helper_hpke_receiver_config: HpkeReceiverConfig, - pub(crate) client_hpke_config_list: Vec, - pub(crate) collector_hpke_receiver_config: HpkeReceiverConfig, + async fn put_out_shares( + &self, + task_id: &TaskId, + task_config: &DapTaskConfig, + part_batch_sel: &PartialBatchSelector, + out_shares: Vec, + ) -> Result, DapError> { + let mut report_store_guard = self + .report_store + .lock() + .expect("report_store: failed to lock"); + let report_store = report_store_guard.entry(task_id.clone()).or_default(); + let mut agg_store_guard = self.agg_store.lock().expect("agg_store: failed to lock"); + let agg_store = agg_store_guard.entry(task_id.clone()).or_default(); - // aggregation job ID - pub(crate) agg_job_id: MetaAggregationJobId<'static>, + let mut replayed = HashSet::new(); + for (bucket, out_shares) in task_config + .batch_span_for_out_shares(part_batch_sel, out_shares)? + .into_iter() + { + for out_share in out_shares.into_iter() { + if !report_store.processed.contains(&out_share.report_id) { + // Mark report processed. + report_store.processed.insert(out_share.report_id.clone()); - // the current time - pub(crate) now: Time, + // Add to aggregate share. + agg_store + .entry(bucket.to_owned_bucket()) + .or_default() + .agg_share + .merge(DapAggregateShare::try_from_out_shares([out_share])?)?; + } else { + replayed.insert(out_share.report_id); + } + } + } - // operational parameters - #[allow(dead_code)] - pub(crate) prometheus_registry: prometheus::Registry, - pub(crate) leader_metrics: DaphneMetrics, - pub(crate) helper_metrics: DaphneMetrics, - pub(crate) leader_reports_processed: Arc>>, - pub(crate) helper_reports_processed: Arc>>, -} + Ok(replayed) + } -// NOTE(cjpatton) This implementation of the report initializer is not feature complete. Since -// [`AggrregationJobTest`], is only used to test the aggregation flow, features that are not -// directly relevant to the tests aren't implemented. -#[async_trait(?Send)] -impl DapReportInitializer for AggregationJobTest { - async fn initialize_reports<'req>( + async fn get_agg_share( &self, - is_leader: bool, - _task_id: &TaskId, - task_config: &DapTaskConfig, - _part_batch_sel: &PartialBatchSelector, - consumed_reports: Vec>, - ) -> Result>, DapError> { - let mut reports_processed = if is_leader { - self.leader_reports_processed.lock().unwrap() - } else { - self.helper_reports_processed.lock().unwrap() - }; + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result { + let task_config = self + .get_task_config_for(Cow::Borrowed(task_id)) + .await + .unwrap() + .expect("tasks: unrecognized task"); + let mut guard = self.agg_store.lock().expect("agg_store: failed to lock"); + let agg_store = guard.entry(task_id.clone()).or_default(); - Ok(consumed_reports - .into_iter() - .map(|consumed| { - if reports_processed.contains(&consumed.metadata().id) { - Ok(EarlyReportStateInitialized::Rejected { - metadata: Cow::Owned(consumed.metadata().clone()), - failure: TransitionFailure::ReportReplayed, - }) + // Fetch aggregate shares. + let mut agg_share = DapAggregateShare::default(); + for bucket in task_config.batch_span_for_sel(batch_sel)? { + if let Some(inner_agg_store) = agg_store.get(&bucket.to_owned_bucket()) { + if inner_agg_store.collected { + return Err(DapError::Abort(DapAbort::batch_overlap(task_id, batch_sel))); } else { - reports_processed.insert(consumed.metadata().id.clone()); - EarlyReportStateInitialized::initialize( - is_leader, - &task_config.vdaf_verify_key, - &task_config.vdaf, - consumed, - ) + agg_share.merge(inner_agg_store.agg_share.clone())?; } - }) - .collect::, _>>()?) + } + } + + Ok(agg_share) } -} -impl AggregationJobTest { - /// Create an aggregation job test with the given VDAF config, HPKE KEM algorithm, DAP protocol - /// version. The KEM algorithm is used to generate an HPKE config for each party. - pub fn new(vdaf: &VdafConfig, kem_id: HpkeKemId, version: DapVersion) -> Self { - let mut rng = thread_rng(); - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_secs(); - let task_id = TaskId(rng.gen()); - let agg_job_id = MetaAggregationJobId::gen_for_version(&version); - let vdaf_verify_key = vdaf.gen_verify_key(); - let leader_hpke_receiver_config = HpkeReceiverConfig::gen(rng.gen(), kem_id).unwrap(); - let helper_hpke_receiver_config = HpkeReceiverConfig::gen(rng.gen(), kem_id).unwrap(); - let collector_hpke_receiver_config = HpkeReceiverConfig::gen(rng.gen(), kem_id).unwrap(); - let leader_hpke_config = leader_hpke_receiver_config.clone().config; - let helper_hpke_config = helper_hpke_receiver_config.clone().config; - let collector_hpke_config = collector_hpke_receiver_config.clone().config; - let prometheus_registry = prometheus::Registry::new(); - let leader_metrics = - DaphneMetrics::register(&prometheus_registry, Some("test_leader")).unwrap(); - let helper_metrics = - DaphneMetrics::register(&prometheus_registry, Some("test_helper")).unwrap(); + async fn mark_collected( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result<(), DapError> { + let task_config = self.unchecked_get_task_config(task_id).await; + let mut guard = self.agg_store.lock().expect("agg_store: failed to lock"); + let agg_store = guard.entry(task_id.clone()).or_default(); - Self { - now, - task_id, - agg_job_id, - leader_hpke_receiver_config, - helper_hpke_receiver_config, - client_hpke_config_list: vec![leader_hpke_config, helper_hpke_config], - collector_hpke_receiver_config, - task_config: DapTaskConfig { - version, - leader_url: Url::parse("http://leader.com").unwrap(), - helper_url: Url::parse("https://helper.org").unwrap(), - time_precision: 500, - expiration: now + 500, - min_batch_size: 10, - query: DapQueryConfig::TimeInterval, - vdaf: vdaf.clone(), - vdaf_verify_key, - collector_hpke_config, - taskprov: false, - }, - prometheus_registry, - leader_metrics, - helper_metrics, - leader_reports_processed: Default::default(), - helper_reports_processed: Default::default(), + for bucket in task_config.batch_span_for_sel(batch_sel)? { + if let Some(inner_agg_store) = agg_store.get_mut(&bucket.to_owned_bucket()) { + inner_agg_store.collected = true; + } } - } - /// For each measurement, generate a report for the given task. - /// - /// Panics if a measurement is incompatible with the given VDAF. - pub fn produce_reports(&self, measurements: Vec) -> Vec { - let mut reports = Vec::with_capacity(measurements.len()); + Ok(()) + } - for measurement in measurements.into_iter() { - reports.push( - self.task_config - .vdaf - .produce_report( - &self.client_hpke_config_list, - self.now, - &self.task_id, - measurement, - self.task_config.version, - ) - .unwrap(), - ); + async fn current_batch(&self, task_id: &TaskId) -> std::result::Result { + let task_config = self.unchecked_get_task_config(task_id).await; + if let Some(id) = self.current_batch_id(task_id, &task_config) { + Ok(id) + } else { + Err(DapError::Abort(DapAbort::BadRequest( + "unknown version".to_string(), + ))) } - reports } - /// Leader: Produce AggregationJobInitReq. - /// - /// Panics if the Leader aborts. - pub async fn produce_agg_job_init_req( - &self, - reports: Vec, - ) -> DapLeaderTransition { - let metrics = self - .leader_metrics - .with_host(self.task_config.leader_url.host_str().unwrap()); - self.task_config - .vdaf - .produce_agg_job_init_req( - &self.leader_hpke_receiver_config, - self, - &self.task_id, - &self.task_config, - &self.agg_job_id, - &PartialBatchSelector::TimeInterval, - reports, - &metrics, - ) - .await - .unwrap() + fn metrics(&self) -> &DaphneMetrics { + &self.metrics } - /// Helper: Handle AggregationJobInitReq, produce first AggregationJobResp. - /// - /// Panics if the Helper aborts. - pub async fn handle_agg_job_init_req( - &self, - agg_job_init_req: &AggregationJobInitReq, - ) -> DapHelperTransition { - let metrics = self - .helper_metrics - .with_host(self.task_config.helper_url.host_str().unwrap()); - self.task_config - .vdaf - .handle_agg_job_init_req( - &self.helper_hpke_receiver_config, - self, - &self.task_id, - &self.task_config, - agg_job_init_req, - &metrics, - ) - .await - .unwrap() + fn audit_log(&self) -> &dyn AuditLog { + &self.audit_log } +} - /// Leader: Handle first AggregationJobResp, produce AggregationJobContinueReq. - /// - /// Panics if the Leader aborts. - pub fn handle_agg_job_resp( +#[async_trait(?Send)] +impl DapHelper for MockAggregator { + async fn put_helper_state_if_not_exists( &self, - leader_state: DapLeaderState, - agg_job_resp: AggregationJobResp, - ) -> DapLeaderTransition { - let metrics = self - .leader_metrics - .with_host(self.task_config.leader_url.host_str().unwrap()); - self.task_config - .vdaf - .handle_agg_job_resp( - &self.task_id, - &self.agg_job_id, - leader_state, - agg_job_resp, - self.task_config.version, - &metrics, - ) - .unwrap() - } + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, + helper_state: &DapHelperState, + ) -> Result { + let helper_state_info = HelperStateInfo { + task_id: task_id.clone(), + agg_job_id_owned: agg_job_id.into(), + }; - /// Like [`handle_agg_job_resp`] but expect the Leader to abort. - pub fn handle_agg_job_resp_expect_err( - &self, - leader_state: DapLeaderState, - agg_job_resp: AggregationJobResp, - ) -> DapAbort { - let metrics = self - .leader_metrics - .with_host(self.task_config.leader_url.host_str().unwrap()); - self.task_config - .vdaf - .handle_agg_job_resp( - &self.task_id, - &self.agg_job_id, - leader_state, - agg_job_resp, - self.task_config.version, - &metrics, - ) - .expect_err("handle_agg_job_resp() succeeded; expected failure") - } + let mut helper_state_store_mutex_guard = self + .helper_state_store + .lock() + .map_err(|e| fatal_error!(err = ?e))?; - /// Helper: Handle AggregationJobContinueReq, produce second AggregationJobResp. - /// - /// Panics if the Helper aborts. - pub fn handle_agg_job_cont_req( - &self, - helper_state: DapHelperState, - agg_job_cont_req: &AggregationJobContinueReq, - ) -> DapHelperTransition { - let metrics = self - .helper_metrics - .with_host(self.task_config.helper_url.host_str().unwrap()); - self.task_config - .vdaf - .handle_agg_job_cont_req( - &self.task_id, - &self.agg_job_id, - helper_state, - agg_job_cont_req, - &metrics, - ) - .unwrap() - } + let helper_state_store = helper_state_store_mutex_guard.deref_mut(); - /// Like [`handle_agg_job_cont_req`] but expect the Helper to abort. - pub fn handle_agg_job_cont_req_expect_err( - &self, - helper_state: DapHelperState, - agg_job_cont_req: &AggregationJobContinueReq, - ) -> DapAbort { - let metrics = self - .helper_metrics - .with_host(self.task_config.helper_url.host_str().unwrap()); - self.task_config - .vdaf - .handle_agg_job_cont_req( - &self.task_id, - &self.agg_job_id, - helper_state, - agg_job_cont_req, - &metrics, - ) - .expect_err("handle_agg_job_cont_req() succeeded; expected failure") - } + if helper_state_store.contains_key(&helper_state_info) { + return Ok(false); + } - /// Leader: Handle the last AggregationJobResp. - /// - /// Panics if the Leader aborts. - pub fn handle_final_agg_job_resp( - &self, - leader_uncommitted: DapLeaderUncommitted, - agg_job_resp: AggregationJobResp, - ) -> Vec { - let metrics = self - .leader_metrics - .with_host(self.task_config.leader_url.host_str().unwrap()); - self.task_config - .vdaf - .handle_final_agg_job_resp(leader_uncommitted, agg_job_resp, &metrics) - .unwrap() - } + // NOTE: This code is only correct for VDAFs with exactly one round of preparation. + // For VDAFs with more rounds, the helper state blob will need to be updated here. + helper_state_store.insert(helper_state_info, helper_state.clone()); - /// Produce the Leader's encrypted aggregate share. - pub fn produce_leader_encrypted_agg_share( - &self, - batch_selector: &BatchSelector, - agg_share: &DapAggregateShare, - ) -> HpkeCiphertext { - self.task_config - .vdaf - .produce_leader_encrypted_agg_share( - &self.task_config.collector_hpke_config, - &self.task_id, - batch_selector, - agg_share, - self.task_config.version, - ) - .unwrap() + Ok(true) } - /// Produce the Helper's encrypted aggregate share. - pub fn produce_helper_encrypted_agg_share( + async fn get_helper_state( &self, - batch_selector: &BatchSelector, - agg_share: &DapAggregateShare, - ) -> HpkeCiphertext { - self.task_config - .vdaf - .produce_helper_encrypted_agg_share( - &self.task_config.collector_hpke_config, - &self.task_id, - batch_selector, - agg_share, - self.task_config.version, - ) - .unwrap() + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, + ) -> Result, DapError> { + let helper_state_info = HelperStateInfo { + task_id: task_id.clone(), + agg_job_id_owned: agg_job_id.into(), + }; + + let mut helper_state_store_mutex_guard = self + .helper_state_store + .lock() + .map_err(|e| fatal_error!(err = ?e))?; + + let helper_state_store = helper_state_store_mutex_guard.deref_mut(); + + // NOTE: This code is only correct for VDAFs with exactly one round of preparation. + // For VDAFs with more rounds, the helper state blob will need to be updated here. + if helper_state_store.contains_key(&helper_state_info) { + let helper_state = helper_state_store.remove(&helper_state_info); + + return Ok(helper_state); + } + + Ok(None) } +} - /// Collector: Consume the aggregate shares. - pub async fn consume_encrypted_agg_shares( - &self, - batch_selector: &BatchSelector, - report_count: u64, - enc_agg_shares: Vec, - ) -> DapAggregateResult { - self.task_config - .vdaf - .consume_encrypted_agg_shares( - &self.collector_hpke_receiver_config, - &self.task_id, - batch_selector, - report_count, - enc_agg_shares, - self.task_config.version, - ) +#[async_trait(?Send)] +impl DapLeader for MockAggregator { + type ReportSelector = MockAggregatorReportSelector; + + async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError> { + let bucket = self + .assign_report_to_bucket(report, task_id) .await - .unwrap() - } + .expect("could not determine batch for report"); - /// Generate a set of reports, aggregate them, and unshard the result. - pub async fn roundtrip(&mut self, measurements: Vec) -> DapAggregateResult { - let batch_selector = BatchSelector::TimeInterval { - batch_interval: Interval { - start: self.now, - duration: 3600, - }, + // Check whether Report has been collected or replayed. + if let Some(transition_failure) = self + .check_report_early_fail(task_id, &bucket, &report.report_metadata) + .await + { + return Err(DapError::Transition(transition_failure)); }; - // Clients: Shard - let reports = self.produce_reports(measurements); + // Store Report for future processing. + let mut guard = self + .report_store + .lock() + .expect("report_store: failed to lock"); + let queue = guard + .get_mut(task_id) + .expect("report_store: unrecognized task") + .pending + .entry(bucket) + .or_default(); + queue.push_back(report.clone()); + Ok(()) + } - // Aggregators: Preparation - let DapLeaderTransition::Continue(leader_state, agg_job_init_req) = - self.produce_agg_job_init_req(reports).await - else { - panic!("unexpected transition"); - }; - let DapHelperTransition::Continue(helper_state, agg_job_resp) = - self.handle_agg_job_init_req(&agg_job_init_req).await - else { - panic!("unexpected transition"); - }; - let got = DapHelperState::get_decoded(&self.task_config.vdaf, &helper_state.get_encoded()) - .expect("failed to decode helper state"); - assert_eq!(got, helper_state); + async fn get_reports( + &self, + report_sel: &MockAggregatorReportSelector, + ) -> Result>>, DapError> { + let task_id = &report_sel.0; + let task_config = self.unchecked_get_task_config(task_id).await; + let mut guard = self + .report_store + .lock() + .expect("report_store: failed to lock"); + let report_store = guard.entry(task_id.clone()).or_default(); - let DapLeaderTransition::Uncommitted(uncommitted, agg_cont) = - self.handle_agg_job_resp(leader_state, agg_job_resp) - else { - panic!("unexpected transition"); - }; - let DapHelperTransition::Finish(helper_out_shares, agg_job_resp) = - self.handle_agg_job_cont_req(helper_state, &agg_cont) - else { - panic!("unexpected transition"); - }; - let leader_out_shares = self.handle_final_agg_job_resp(uncommitted, agg_job_resp); - let report_count = u64::try_from(leader_out_shares.len()).unwrap(); + // For the task indicated by the report selector, choose a single report to aggregate. + match task_config.query { + DapQueryConfig::TimeInterval { .. } => { + // Aggregate reports in any order. + let mut reports = Vec::new(); + for (_bucket, queue) in report_store.pending.iter_mut() { + if !queue.is_empty() { + reports.append(&mut queue.drain(..1).collect()); + break; + } + } + return Ok(HashMap::from([( + task_id.clone(), + HashMap::from([(PartialBatchSelector::TimeInterval, reports)]), + )])); + } + DapQueryConfig::FixedSize { .. } => { + // Drain the batch that is being filled. - // Leader: Aggregation - let leader_agg_share = DapAggregateShare::try_from_out_shares(leader_out_shares).unwrap(); - let leader_encrypted_agg_share = - self.produce_leader_encrypted_agg_share(&batch_selector, &leader_agg_share); + let bucket = if let Some(batch_id) = self.current_batch_id(task_id, &task_config) { + DapBatchBucketOwned::FixedSize { batch_id } + } else { + return Ok(HashMap::default()); + }; - // Helper: Aggregation - let helper_agg_share = DapAggregateShare::try_from_out_shares(helper_out_shares).unwrap(); - let helper_encrypted_agg_share = - self.produce_helper_encrypted_agg_share(&batch_selector, &helper_agg_share); + let queue = report_store + .pending + .get_mut(&bucket) + .expect("report_store: unknown bucket"); + let reports = queue.drain(..1).collect(); + return Ok(HashMap::from([( + task_id.clone(), + HashMap::from([(bucket.into(), reports)]), + )])); + } + } + } - // Collector: Unshard - self.consume_encrypted_agg_shares( - &batch_selector, - report_count, - vec![leader_encrypted_agg_share, helper_encrypted_agg_share], - ) - .await + // Called after receiving a CollectReq from Collector. + async fn init_collect_job( + &self, + task_id: &TaskId, + collect_job_id: &Option, + collect_req: &CollectionReq, + ) -> Result { + let mut rng = thread_rng(); + let task_config = self + .get_task_config_for(Cow::Borrowed(task_id)) + .await? + .ok_or_else(|| fatal_error!(err = "task not found"))?; + + let mut leader_state_store_mutex_guard = self + .leader_state_store + .lock() + .map_err(|e| fatal_error!(err = ?e))?; + let leader_state_store = leader_state_store_mutex_guard.deref_mut(); + + // Construct a new Collect URI for this CollectReq. + let collect_id = collect_job_id + .as_ref() + .map_or_else(|| CollectionJobId(rng.gen()), |cid| cid.clone()); + let collect_uri = task_config + .leader_url + .join(&format!( + "collect/task/{}/req/{}", + task_id.to_base64url(), + collect_id.to_base64url(), + )) + .map_err(|e| fatal_error!(err = ?e))?; + + // Store Collect ID and CollectReq into LeaderState. + let leader_state = leader_state_store.entry(task_id.clone()).or_default(); + leader_state.collect_ids.push_back(collect_id.clone()); + let collect_job_state = CollectJobState::Pending(collect_req.clone()); + leader_state + .collect_jobs + .insert(collect_id, collect_job_state); + + Ok(collect_uri) + } + + // Called to retrieve completed CollectResp at the request of Collector. + async fn poll_collect_job( + &self, + task_id: &TaskId, + collect_id: &CollectionJobId, + ) -> Result { + let mut leader_state_store_mutex_guard = self + .leader_state_store + .lock() + .map_err(|e| fatal_error!(err = ?e))?; + let leader_state_store = leader_state_store_mutex_guard.deref_mut(); + + let leader_state = leader_state_store + .get(task_id) + .ok_or_else(|| fatal_error!(err = "collect job not found for task_id", %task_id))?; + if let Some(collect_job_state) = leader_state.collect_jobs.get(collect_id) { + match collect_job_state { + CollectJobState::Pending(_) => Ok(DapCollectJob::Pending), + CollectJobState::Processed(resp) => Ok(DapCollectJob::Done(resp.clone())), + } + } else { + Ok(DapCollectJob::Unknown) + } + } + + // Called to retrieve pending CollectReq. + async fn get_pending_collect_jobs( + &self, + ) -> Result, DapError> { + let mut leader_state_store_mutex_guard = self + .leader_state_store + .lock() + .map_err(|e| fatal_error!(err = ?e))?; + let leader_state_store = leader_state_store_mutex_guard.deref_mut(); + + let mut res = Vec::new(); + for (task_id, leader_state) in leader_state_store.iter() { + // Iterate over collect IDs and copy them and their associated requests to the response. + for collect_id in leader_state.collect_ids.iter() { + if let CollectJobState::Pending(collect_req) = + leader_state.collect_jobs.get(collect_id).unwrap() + { + res.push((task_id.clone(), collect_id.clone(), collect_req.clone())); + } + } + } + Ok(res) + } + + async fn finish_collect_job( + &self, + task_id: &TaskId, + collect_id: &CollectionJobId, + collect_resp: &Collection, + ) -> Result<(), DapError> { + let mut leader_state_store_mutex_guard = self + .leader_state_store + .lock() + .map_err(|e| fatal_error!(err = ?e))?; + let leader_state_store = leader_state_store_mutex_guard.deref_mut(); + + let leader_state = leader_state_store + .get_mut(task_id) + .ok_or_else(|| fatal_error!(err = "collect job not found for task_id", %task_id))?; + let collect_job = leader_state + .collect_jobs + .get_mut(collect_id) + .ok_or_else(|| fatal_error!(err = "collect job not found for collect_id", %task_id))?; + + // Remove the batch from the batch queue. + if let PartialBatchSelector::FixedSizeByBatchId { ref batch_id } = + collect_resp.part_batch_sel + { + leader_state + .batch_queue + .retain(|(id, _report_count)| id != batch_id); + } + + match collect_job { + CollectJobState::Pending(_) => { + // Mark collect job as Processed. + *collect_job = CollectJobState::Processed(collect_resp.clone()); + + // Remove collect ID from queue. + let index = leader_state + .collect_ids + .iter() + .position(|r| r == collect_id) + .unwrap(); + leader_state.collect_ids.remove(index); + + Ok(()) + } + CollectJobState::Processed(_) => { + Err(fatal_error!(err = "tried to overwrite collect response")) + } + } } + + async fn send_http_post(&self, req: DapRequest) -> Result { + match req.media_type { + DapMediaType::AggregationJobInitReq | DapMediaType::AggregationJobContinueReq => { + Ok(self + .peer + .as_ref() + .expect("peer not configured") + .handle_agg_job_req(&req) + .await + .expect("peer aborted unexpectedly")) + } + DapMediaType::AggregateShareReq => Ok(self + .peer + .as_ref() + .expect("peer not configured") + .handle_agg_share_req(&req) + .await + .expect("peer aborted unexpectedly")), + _ => unreachable!("unhandled media type: {:?}", req.media_type), + } + } + + async fn send_http_put(&self, req: DapRequest) -> Result { + if req.media_type == DapMediaType::AggregationJobInitReq { + Ok(self + .peer + .as_ref() + .expect("peer not configured") + .handle_agg_job_req(&req) + .await + .expect("peer aborted unexpectedly")) + } else { + unreachable!("unhandled media type: {:?}", req.media_type) + } + } +} + +/// Information associated to a certain helper state for a given task ID and aggregate job ID. +#[derive(Clone, Eq, Hash, PartialEq, Deserialize, Serialize)] +pub struct HelperStateInfo { + task_id: TaskId, + agg_job_id_owned: MetaAggregationJobIdOwned, +} + +/// Stores the reports received from Clients. +#[derive(Default)] +pub struct ReportStore { + pub(crate) pending: HashMap>, + pub(crate) processed: HashSet, +} + +/// Stores the state of the collect job. +pub enum CollectJobState { + Pending(CollectionReq), + Processed(Collection), +} + +/// LeaderState keeps track of the following: +/// * Collect IDs in their order of arrival. +/// * The state of the collect job associated to the Collect ID. +#[derive(Default)] +pub struct LeaderState { + collect_ids: VecDeque, + collect_jobs: HashMap, + batch_queue: VecDeque<(BatchId, u64)>, // Batch ID, batch size +} + +/// AggStore keeps track of the following: +/// * Aggregate share +/// * Whether this aggregate share has been collected +#[derive(Default)] +pub struct AggStore { + pub(crate) agg_share: DapAggregateShare, + pub(crate) collected: bool, +} + +/// Helper macro used by `assert_metrics_include`. +#[macro_export] +macro_rules! assert_metrics_include_auxiliary_function { + ($set:expr, $k:tt: $v:expr,) => {{ + let line = format!("{} {}", $k, $v); + $set.insert(line); + }}; + + ($set:expr, $k:tt: $v:expr, $($ks:tt: $vs:expr),+,) => {{ + let line = format!("{} {}", $k, $v); + $set.insert(line); + assert_metrics_include_auxiliary_function!($set, $($ks: $vs),+,) + }} +} + +/// Gather metrics from a registry and assert that a list of metrics are present and have the +/// correct value. For example: +/// ```ignore +/// let registry = prometheus::Registry::new(); +/// +/// // ... Register a metric called "report_counter" and use it. +/// +/// assert_metrics_include!(t.helper_prometheus_registry, { +/// r#"report_counter{status="aggregated"}"#: 23, +/// }); +/// ``` +#[macro_export] +macro_rules! assert_metrics_include { + ($registry:expr, {$($ks:tt: $vs:expr),+,}) => {{ + use prometheus::{Encoder, TextEncoder}; + use std::collections::HashSet; + + let mut want: HashSet = HashSet::new(); + $crate::assert_metrics_include_auxiliary_function!(&mut want, $($ks: $vs),+,); + + // Encode the metrics and iterate over each line. For each line, if the line appears in the + // list of expected output lines, then remove it. + let mut got_buf = Vec::new(); + let encoder = TextEncoder::new(); + encoder.encode(&$registry.gather(), &mut got_buf).unwrap(); + let got_str = String::from_utf8(got_buf).unwrap(); + for line in got_str.split('\n') { + want.remove(line); + } + + // The metrics contain the expected lines if the the set is now empty. + if !want.is_empty() { + panic!("unexpected metrics: got:\n{}\nmust contain:\n{}\n", + got_str, want.into_iter().collect::>().join("\n")); + } + }} } diff --git a/daphne/src/vdaf/mod.rs b/daphne/src/vdaf/mod.rs index 681775c91..1278a9a36 100644 --- a/daphne/src/vdaf/mod.rs +++ b/daphne/src/vdaf/mod.rs @@ -1299,15 +1299,14 @@ fn produce_encrypted_agg_share( #[cfg(test)] mod test { use crate::{ - assert_metrics_include, assert_metrics_include_auxiliary_function, async_test_version, - async_test_versions, + assert_metrics_include, async_test_versions, error::DapAbort, hpke::{HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId}, messages::{ AggregationJobInitReq, BatchSelector, Interval, PartialBatchSelector, Report, ReportId, ReportShare, Transition, TransitionFailure, TransitionVar, }, - test_version, test_versions, + test_versions, testing::AggregationJobTest, DapAggregateResult, DapAggregateShare, DapError, DapHelperState, DapHelperTransition, DapLeaderState, DapLeaderTransition, DapLeaderUncommitted, DapMeasurement, DapOutputShare, @@ -1315,7 +1314,6 @@ mod test { }; use assert_matches::assert_matches; use hpke_rs::HpkePublicKey; - use paste::paste; use prio::{ codec::Encode, field::Field64, diff --git a/daphne/src/vdaf/prio2.rs b/daphne/src/vdaf/prio2.rs index e84f02527..25977c5e6 100644 --- a/daphne/src/vdaf/prio2.rs +++ b/daphne/src/vdaf/prio2.rs @@ -138,10 +138,9 @@ pub(crate) fn prio2_unshard>>( #[cfg(test)] mod test { use crate::{ - async_test_version, async_test_versions, hpke::HpkeKemId, testing::AggregationJobTest, - DapAggregateResult, DapMeasurement, DapVersion, VdafConfig, + async_test_versions, hpke::HpkeKemId, testing::AggregationJobTest, DapAggregateResult, + DapMeasurement, DapVersion, VdafConfig, }; - use paste::paste; async fn roundtrip(version: DapVersion) { let mut t = AggregationJobTest::new( diff --git a/daphne_worker/Cargo.toml b/daphne_worker/Cargo.toml index fefe45cd8..308970bc5 100644 --- a/daphne_worker/Cargo.toml +++ b/daphne_worker/Cargo.toml @@ -41,4 +41,5 @@ worker.workspace = true bincode = "1.3.3" [dev-dependencies] +daphne = { path = "../daphne", features = ["test-utils"] } paste.workspace = true diff --git a/daphne_worker/src/config.rs b/daphne_worker/src/config.rs index fb3085ee5..042c15071 100644 --- a/daphne_worker/src/config.rs +++ b/daphne_worker/src/config.rs @@ -202,7 +202,7 @@ impl DaphneWorkerConfig { trace!("DAP deployment override applied: {deployment:?}"); } - let taskprov = if global.allow_taskprov { + let taskprov = if global.taskprov_version.is_some() { let hpke_collector_config = serde_json::from_str( env.var("DAP_TASKPROV_HPKE_COLLECTOR_CONFIG")? .to_string() diff --git a/daphne_worker/src/durable/mod.rs b/daphne_worker/src/durable/mod.rs index dccb9c5a7..1aec25c60 100644 --- a/daphne_worker/src/durable/mod.rs +++ b/daphne_worker/src/durable/mod.rs @@ -673,16 +673,14 @@ fn create_span_from_request(req: &Request) -> tracing::Span { #[cfg(test)] mod test { - use super::{ durable_name_agg_store, durable_name_queue, durable_name_report_store, reports_pending::PendingReport, }; use daphne::{ messages::{BatchId, Report, ReportId, ReportMetadata, TaskId}, - test_version, test_versions, DapBatchBucket, DapVersion, + test_versions, DapBatchBucket, DapVersion, }; - use paste::paste; use prio::codec::{ParameterizedDecode, ParameterizedEncode}; use rand::prelude::*; diff --git a/daphne_worker/src/roles/mod.rs b/daphne_worker/src/roles/mod.rs index 0ab70ffcb..9844caf00 100644 --- a/daphne_worker/src/roles/mod.rs +++ b/daphne_worker/src/roles/mod.rs @@ -96,7 +96,7 @@ impl<'srv> BearerTokenProvider for DaphneWorker<'srv> { task_config: &DapTaskConfig, ) -> std::result::Result>, DapError> { if let Some(ref taskprov_config) = self.config().taskprov { - if self.get_global_config().allow_taskprov && task_config.taskprov { + if self.get_global_config().taskprov_version.is_some() && task_config.taskprov { return Ok(Some(BearerTokenKvPair::new( task_id, taskprov_config.leader_auth.as_ref(), @@ -115,7 +115,7 @@ impl<'srv> BearerTokenProvider for DaphneWorker<'srv> { task_config: &DapTaskConfig, ) -> std::result::Result>, DapError> { if let Some(ref taskprov_config) = self.config().taskprov { - if self.get_global_config().allow_taskprov && task_config.taskprov { + if self.get_global_config().taskprov_version.is_some() && task_config.taskprov { return Ok(Some(BearerTokenKvPair::new( task_id, taskprov_config diff --git a/daphne_worker_test/tests/e2e/e2e.rs b/daphne_worker_test/tests/e2e/e2e.rs index d08ad24ae..df396b3e6 100644 --- a/daphne_worker_test/tests/e2e/e2e.rs +++ b/daphne_worker_test/tests/e2e/e2e.rs @@ -4,7 +4,7 @@ //! End-to-end tests for daphne. use super::test_runner::{TestRunner, MIN_BATCH_SIZE, TIME_PRECISION}; use daphne::{ - async_test_versions, + async_test_version, async_test_versions, constants::DapMediaType, messages::{ taskprov::{ @@ -17,27 +17,12 @@ use daphne::{ DapAggregateResult, DapMeasurement, DapTaskConfig, DapVersion, }; use daphne_worker::DaphneWorkerReportSelector; -use paste::paste; use prio::codec::{ParameterizedDecode, ParameterizedEncode}; use rand::prelude::*; use serde::Deserialize; use serde_json::json; use std::cmp::{max, min}; -// Redefine async_test_version locally because we want a -// cfg_attr as well. -macro_rules! async_test_version { - ($fname:ident, $version:ident) => { - paste! { - #[tokio::test] - #[cfg_attr(not(feature = "test_e2e"), ignore)] - async fn [<$fname _ $version:lower>]() { - $fname (DapVersion::$version) . await; - } - } - }; -} - #[derive(Deserialize)] struct InternalTestEndpointForTaskResult { status: String, @@ -394,7 +379,7 @@ async fn leader_upload_taskprov() { }, }; let payload = taskprov_task_config.get_encoded_with_param(&TaskprovVersion::Draft02); - let task_id = compute_task_id(TaskprovVersion::Draft02, &payload).unwrap(); + let task_id = compute_task_id(TaskprovVersion::Draft02, &payload); let extensions = vec![Extension::Taskprov { payload }]; let report = t .task_config @@ -420,7 +405,7 @@ async fn leader_upload_taskprov() { let payload = taskprov_task_config.get_encoded_with_param(&TaskprovVersion::Draft02); let mut bad_payload = payload.clone(); bad_payload[0] = u8::wrapping_add(bad_payload[0], 1); - let task_id = compute_task_id(TaskprovVersion::Draft02, &bad_payload).unwrap(); + let task_id = compute_task_id(TaskprovVersion::Draft02, &bad_payload); let extensions = vec![Extension::Taskprov { payload }]; let report = t .task_config @@ -447,7 +432,7 @@ async fn leader_upload_taskprov() { // Generate and upload a report with two copies of the taskprov extension let payload = taskprov_task_config.get_encoded_with_param(&TaskprovVersion::Draft02); - let task_id = compute_task_id(TaskprovVersion::Draft02, &payload).unwrap(); + let task_id = compute_task_id(TaskprovVersion::Draft02, &payload); let extensions = vec![ Extension::Taskprov { payload: payload.clone(), @@ -501,7 +486,7 @@ async fn leader_upload_taskprov() { }, }; let payload = taskprov_task_config.get_encoded_with_param(&TaskprovVersion::Draft02); - let task_id = compute_task_id(TaskprovVersion::Draft02, &payload).unwrap(); + let task_id = compute_task_id(TaskprovVersion::Draft02, &payload); let extensions = vec![Extension::Taskprov { payload }]; let report = t .task_config @@ -1384,7 +1369,7 @@ async fn leader_collect_taskprov_ok(version: DapVersion) { }, }; let payload = taskprov_task_config.get_encoded_with_param(&TaskprovVersion::Draft02); - let task_id = compute_task_id(TaskprovVersion::Draft02, &payload).unwrap(); + let task_id = compute_task_id(TaskprovVersion::Draft02, &payload); let task_config = DapTaskConfig::try_from_taskprov( version, TaskprovVersion::Draft02, diff --git a/daphne_worker_test/tests/e2e/main.rs b/daphne_worker_test/tests/e2e/main.rs index e53bd5dde..99ecc99bd 100644 --- a/daphne_worker_test/tests/e2e/main.rs +++ b/daphne_worker_test/tests/e2e/main.rs @@ -1,2 +1,4 @@ +#[cfg(feature = "test_e2e")] mod e2e; +#[cfg(feature = "test_e2e")] mod test_runner; diff --git a/daphne_worker_test/tests/e2e/test_runner.rs b/daphne_worker_test/tests/e2e/test_runner.rs index 9d00e5255..212dc022a 100644 --- a/daphne_worker_test/tests/e2e/test_runner.rs +++ b/daphne_worker_test/tests/e2e/test_runner.rs @@ -119,8 +119,7 @@ impl TestRunner { min_batch_interval_start: 259200, max_batch_interval_end: 259200, supported_hpke_kems: vec![HpkeKemId::X25519HkdfSha256], - allow_taskprov: true, - taskprov_version: TaskprovVersion::Draft02, + taskprov_version: Some(TaskprovVersion::Draft02), }; let taskprov_vdaf_verify_key_init = hex::decode("b029a72fa327931a5cb643dcadcaafa098fcbfac07d990cb9e7c9a8675fafb18") diff --git a/daphne_worker_test/wrangler.toml b/daphne_worker_test/wrangler.toml index 7fde97f44..ed17bc046 100644 --- a/daphne_worker_test/wrangler.toml +++ b/daphne_worker_test/wrangler.toml @@ -33,7 +33,6 @@ DAP_GLOBAL_CONFIG = """{ "min_batch_interval_start": 259200, "max_batch_interval_end": 259200, "supported_hpke_kems": ["x25519_hkdf_sha256"], - "allow_taskprov": true, "taskprov_version": "v02" }""" DAP_PROCESSED_ALARM_SAFETY_INTERVAL = "300" diff --git a/docker/wrangler.toml b/docker/wrangler.toml index 74da5e624..e1c4d99ad 100644 --- a/docker/wrangler.toml +++ b/docker/wrangler.toml @@ -33,7 +33,6 @@ DAP_GLOBAL_CONFIG = """{ "min_batch_interval_start": 259200, "max_batch_interval_end": 259200, "supported_hpke_kems": ["x25519_hkdf_sha256"], - "allow_taskprov": true, "taskprov_version": "v02" }""" DAP_PROCESSED_ALARM_SAFETY_INTERVAL = "300"