diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index f54d339c..a12f9739 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -18,18 +18,14 @@ pub mod load_testing; -use crate::{ - deduce_dap_version_from_url, response_to_anyhow, test_durations::TestDurations, HttpClient, -}; +use crate::{deduce_dap_version_from_url, functions, test_durations::TestDurations, HttpClient}; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; use daphne::{ - constants::DapMediaType, - error::aborts::ProblemDetails, hpke::{HpkeConfig, HpkeKemId, HpkeReceiverConfig}, messages::{ - self, AggregateShareReq, AggregationJobId, AggregationJobResp, Base64Encode, BatchId, - BatchSelector, PartialBatchSelector, TaskId, + self, AggregateShareReq, AggregationJobId, Base64Encode, BatchId, BatchSelector, + PartialBatchSelector, TaskId, }, metrics::DaphneMetrics, roles::DapReportInitializer, @@ -39,9 +35,8 @@ use daphne::{ DapQueryConfig, DapTaskConfig, DapTaskParameters, DapVersion, EarlyReportStateConsumed, EarlyReportStateInitialized, ReplayProtection, }; -use daphne_service_utils::{bearer_token::BearerToken, http_headers}; +use daphne_service_utils::bearer_token::BearerToken; use futures::{future::OptionFuture, StreamExt, TryStreamExt}; -use prio::codec::{Decode, ParameterizedEncode}; use prometheus::{Encoder, HistogramVec, IntCounterVec, IntGaugeVec, TextEncoder}; use rand::{rngs, Rng}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; @@ -509,62 +504,30 @@ impl Test { .context("producing agg job init request")?; // Send AggregationJobInitReq. - let headers = construct_request_headers( - DapMediaType::AggregationJobInitReq.as_str_for_version(task_config.version), - taskprov_advertisement.as_deref(), - &self.bearer_token, - ) - .context("constructing request headers for AggregationJobInitReq")?; - let url = self.helper_url.join(&format!( - "tasks/{}/aggregation_jobs/{}", - task_id.to_base64url(), - agg_job_id.to_base64url() - ))?; - // wait for all agg jobs to be ready to fire. info!("Reports generated, waiting for other tasks..."); let _guard = load_control.wait().await; info!("Starting AggregationJobInitReq"); let start = Instant::now(); - let resp = send( - self.http_client - .put(url) - .body( - agg_job_init_req - .get_encoded_with_param(&task_config.version) - .unwrap(), - ) - .headers(headers), - ) - .await?; + let agg_job_resp = self + .http_client + .submit_aggregation_job_init_req( + self.helper_url.join(&format!( + "tasks/{}/aggregation_jobs/{}", + task_id.to_base64url(), + agg_job_id.to_base64url() + ))?, + agg_job_init_req, + task_config.version, + functions::helper::Options { + taskprov_advertisement: taskprov_advertisement.as_deref(), + bearer_token: self.bearer_token.as_ref(), + }, + ) + .await?; let duration = start.elapsed(); info!("Finished AggregationJobInitReq in {duration:#?}"); - if resp.status() == 400 { - let text = resp.text().await?; - let problem_details: ProblemDetails = - serde_json::from_str(&text).with_context(|| { - format!("400 Bad Request: failed to parse problem details document: {text:?}") - })?; - return Err(anyhow!("400 Bad Request: {problem_details:?}")); - } else if resp.status() == 500 { - return Err(anyhow::anyhow!( - "500 Internal Server Error: {}", - resp.text().await? - )); - } else if !resp.status().is_success() { - return Err(response_to_anyhow(resp).await) - .context("while running an AggregateInitReq"); - } - - // Handle AggregationJobResp.. - let agg_job_resp = AggregationJobResp::get_decoded( - &resp - .bytes() - .await - .context("transfering bytes from the AggregateInitReq")?, - ) - .with_context(|| "failed to parse response to AggregateInitReq from Helper")?; let agg_share_span = task_config.consume_agg_job_resp( task_id, agg_job_state, @@ -603,43 +566,21 @@ impl Test { // Send AggregateShareReq. info!("Starting AggregationJobInitReq"); let start = Instant::now(); - let headers = construct_request_headers( - DapMediaType::AggregateShareReq.as_str_for_version(version), - taskprov_advertisement, - &self.bearer_token, - )?; - let url = self.helper_url.join(&format!( - "tasks/{}/aggregate_shares", - task_id.to_base64url() - ))?; - let resp = send( - self.http_client - .post(url) - .body(agg_share_req.get_encoded_with_param(&version).unwrap()) - .headers(headers), - ) - .await?; - let duration = start.elapsed(); - info!("Finished AggregateShareReq in {duration:#?}"); - if resp.status() == 400 { - let problem_details: ProblemDetails = serde_json::from_slice( - &resp - .bytes() - .await - .context("transfering bytes for AggregateShareReq")?, + self.http_client + .get_aggregate_share( + self.helper_url.join(&format!( + "tasks/{}/aggregate_shares", + task_id.to_base64url() + ))?, + agg_share_req, + version, + functions::helper::Options { + taskprov_advertisement, + bearer_token: self.bearer_token.as_ref(), + }, ) - .with_context(|| "400 Bad Request: failed to parse problem details document")?; - return Err(anyhow!("400 Bad Request: {problem_details:?}")); - } else if resp.status() == 500 { - return Err(anyhow::anyhow!( - "500 Internal Server Error: {}", - resp.text().await? - )); - } else if !resp.status().is_success() { - return Err(response_to_anyhow(resp).await) - .context("while running an AggregateInitReq"); - } - Ok(duration) + .await?; + Ok(start.elapsed()) } pub async fn test_helper(&self, opt: &TestOptions) -> Result { @@ -794,60 +735,6 @@ impl DapReportInitializer for Test { } } -fn construct_request_headers<'a, M, T, B>( - media_type: M, - taskprov: T, - bearer_token: B, -) -> Result -where - M: Into>, - T: Into>, - B: Into>, -{ - let mut headers = reqwest::header::HeaderMap::new(); - if let Some(media_type) = media_type.into() { - headers.insert( - reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_str(media_type)?, - ); - } - if let Some(taskprov) = taskprov.into() { - headers.insert( - reqwest::header::HeaderName::from_static(http_headers::DAP_TASKPROV), - reqwest::header::HeaderValue::from_str(taskprov)?, - ); - } - if let Some(token) = bearer_token.into() { - headers.insert( - reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN), - reqwest::header::HeaderValue::from_str(token.as_ref())?, - ); - } - Ok(headers) -} - -async fn send(req: reqwest::RequestBuilder) -> reqwest::Result { - for i in 0..4 { - let resp = req.try_clone().unwrap().send().await; - match &resp { - Ok(r) if r.status() != reqwest::StatusCode::BAD_GATEWAY => { - return resp; - } - Ok(r) if r.status().is_client_error() => { - return resp; - } - Ok(_) => {} - Err(e) => { - tracing::error!("request failed: {e:?}"); - } - } - if i == 3 { - return resp; - } - } - unreachable!() -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Now(u64); pub fn now() -> Now { diff --git a/crates/dapf/src/cli_parsers.rs b/crates/dapf/src/cli_parsers.rs index 62550fa1..f364545f 100644 --- a/crates/dapf/src/cli_parsers.rs +++ b/crates/dapf/src/cli_parsers.rs @@ -1,7 +1,8 @@ -//! Human friendly parsers for common types of parameters to DAP functions. // Copyright (c) 2024 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause +//! Human friendly parsers for common types of parameters to DAP functions. + use std::{ fmt, io::{self, IsTerminal as _}, @@ -51,7 +52,7 @@ impl ValueEnum for DefaultVdafConfigs { } impl DefaultVdafConfigs { - fn into_vdaf(self) -> VdafConfig { + fn into_vdaf_config(self) -> VdafConfig { match self { Self::Prio2Dimension99k => VdafConfig::Prio2 { dimension: 99_992 }, Self::Prio3NumProofs2 => { @@ -115,9 +116,9 @@ impl FromStr for CliVdafConfig { } impl CliVdafConfig { - pub fn into_vdaf(self) -> VdafConfig { + pub fn into_vdaf_config(self) -> VdafConfig { match self { - Self::Default(d) => d.into_vdaf(), + Self::Default(d) => d.into_vdaf_config(), Self::Custom(v) => v, } } diff --git a/crates/dapf/src/functions/helper.rs b/crates/dapf/src/functions/helper.rs new file mode 100644 index 00000000..d3228574 --- /dev/null +++ b/crates/dapf/src/functions/helper.rs @@ -0,0 +1,137 @@ +// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use anyhow::{anyhow, Context as _}; +use daphne::{ + constants::DapMediaType, + error::aborts::ProblemDetails, + messages::{AggregateShareReq, AggregationJobInitReq, AggregationJobResp}, + DapVersion, +}; +use daphne_service_utils::{bearer_token::BearerToken, http_headers}; +use prio::codec::{Decode as _, ParameterizedEncode as _}; +use reqwest::header; +use url::Url; + +use crate::{response_to_anyhow, HttpClient}; + +impl HttpClient { + pub async fn submit_aggregation_job_init_req( + &self, + url: Url, + agg_job_init_req: AggregationJobInitReq, + version: DapVersion, + opts: Options<'_>, + ) -> anyhow::Result { + let resp = self + .put(url) + .body(agg_job_init_req.get_encoded_with_param(&version).unwrap()) + .headers(construct_request_headers( + DapMediaType::AggregationJobInitReq + .as_str_for_version(version) + .with_context(|| { + format!("AggregationJobInitReq media type is not defined for {version}") + })?, + opts, + )?) + .send() + .await + .context("sending AggregationJobInitReq")?; + if resp.status() == 400 { + let text = resp.text().await?; + let problem_details: ProblemDetails = + serde_json::from_str(&text).with_context(|| { + format!("400 Bad Request: failed to parse problem details document: {text:?}") + })?; + Err(anyhow!("400 Bad Request: {problem_details:?}")) + } else if resp.status() == 500 { + Err(anyhow::anyhow!( + "500 Internal Server Error: {}", + resp.text().await? + )) + } else if !resp.status().is_success() { + Err(response_to_anyhow(resp).await).context("while running an AggregationJobInitReq") + } else { + AggregationJobResp::get_decoded( + &resp + .bytes() + .await + .context("transfering bytes from the AggregateInitReq")?, + ) + .with_context(|| "failed to parse response to AggregateInitReq from Helper") + } + } + + pub async fn get_aggregate_share( + &self, + url: Url, + agg_share_req: AggregateShareReq, + version: DapVersion, + opts: Options<'_>, + ) -> anyhow::Result<()> { + let resp = self + .post(url) + .body(agg_share_req.get_encoded_with_param(&version).unwrap()) + .headers(construct_request_headers( + DapMediaType::AggregateShareReq + .as_str_for_version(version) + .with_context(|| { + format!("AggregateShareReq media type is not defined for {version}") + })?, + opts, + )?) + .send() + .await + .context("sending AggregateShareReq")?; + if resp.status() == 400 { + let problem_details: ProblemDetails = serde_json::from_slice( + &resp + .bytes() + .await + .context("transfering bytes for AggregateShareReq")?, + ) + .with_context(|| "400 Bad Request: failed to parse problem details document")?; + Err(anyhow!("400 Bad Request: {problem_details:?}")) + } else if resp.status() == 500 { + Err(anyhow!("500 Internal Server Error: {}", resp.text().await?)) + } else if !resp.status().is_success() { + Err(response_to_anyhow(resp).await).context("while running an AggregateShareReq") + } else { + Ok(()) + } + } +} + +#[derive(Default, Debug)] +pub struct Options<'s> { + pub taskprov_advertisement: Option<&'s str>, + pub bearer_token: Option<&'s BearerToken>, +} + +fn construct_request_headers( + media_type: &str, + options: Options<'_>, +) -> Result { + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_str(media_type)?, + ); + let Options { + taskprov_advertisement, + bearer_token, + } = options; + if let Some(taskprov) = taskprov_advertisement { + headers.insert( + const { header::HeaderName::from_static(http_headers::DAP_TASKPROV) }, + header::HeaderValue::from_str(taskprov)?, + ); + } + if let Some(token) = bearer_token { + headers.insert( + const { header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN) }, + header::HeaderValue::from_str(token.as_str())?, + ); + } + Ok(headers) +} diff --git a/crates/dapf/src/functions/mod.rs b/crates/dapf/src/functions/mod.rs index 9de7056a..a8caa120 100644 --- a/crates/dapf/src/functions/mod.rs +++ b/crates/dapf/src/functions/mod.rs @@ -1,6 +1,8 @@ -//! The various DAP functions dapf can perform. // Copyright (c) 2024 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause +//! The various DAP functions dapf can perform. + +pub mod helper; pub mod hpke; pub mod test_routes; diff --git a/crates/dapf/src/main.rs b/crates/dapf/src/main.rs index d81f7f1c..a81c22e6 100644 --- a/crates/dapf/src/main.rs +++ b/crates/dapf/src/main.rs @@ -418,7 +418,7 @@ async fn handle_leader_actions( let version = deduce_dap_version_from_url(&leader_url)?; // Generate a report for the measurement. let report = vdaf_config - .into_vdaf() + .into_vdaf_config() .produce_report( &[leader_hpke_config, helper_hpke_config], now, @@ -543,15 +543,17 @@ async fn handle_leader_actions( })?; let version = deduce_dap_version_from_url(&uri)?; let collect_resp = Collection::get_decoded_with_param(&version, &resp.bytes().await?)?; - let agg_res = vdaf_config.into_vdaf().consume_encrypted_agg_shares( - receiver, - &task_id.into(), - &batch_selector, - collect_resp.report_count, - &DapAggregationParam::Empty, - collect_resp.encrypted_agg_shares.to_vec(), - version, - )?; + let agg_res = vdaf_config + .into_vdaf_config() + .consume_encrypted_agg_shares( + receiver, + &task_id.into(), + &batch_selector, + collect_resp.report_count, + &DapAggregationParam::Empty, + collect_resp.encrypted_agg_shares.to_vec(), + version, + )?; print!("{}", serde_json::to_string(&agg_res)?); Ok(()) @@ -578,7 +580,7 @@ async fn handle_helper_actions( let t = dapf::acceptance::Test::from_env( helper_url, - vdaf_config.into_vdaf(), + vdaf_config.into_vdaf_config(), hpke_signing_certificate_path, http_client, load_control, @@ -620,7 +622,7 @@ async fn handle_helper_actions( } => { load_testing::execute_single_combination_from_env( helper_url, - vdaf_config.into_vdaf(), + vdaf_config.into_vdaf_config(), reports_per_batch, reports_per_agg_job, http_client, @@ -832,7 +834,7 @@ async fn handle_decode_actions(action: DecodeAction) -> anyhow::Result<()> { } }; let agg_shares = vdaf_config - .into_vdaf() + .into_vdaf_config() .consume_encrypted_agg_shares( &hpke_config, &task_id.into(), @@ -900,7 +902,7 @@ async fn handle_test_routes(action: TestAction, http_client: HttpClient) -> anyh expires_in_seconds: task_expiration, } => { let vdaf = use_or_request_from_user_or_default(vdaf, CliVdafConfig::default, "vdaf")? - .into_vdaf(); + .into_vdaf_config(); let vdaf_verify_key = encode_base64url(vdaf.gen_verify_key()); let CliDapQueryConfig(query) = use_or_request_from_user_or_default( query,