Skip to content

Commit

Permalink
Wait to mark reports as aggregated until just before committing
Browse files Browse the repository at this point in the history
Loads that require a lot of CPU time can trigger the Workers runtime to
reset itself, causing a 500 error. This is most likely to happen when
processing the AggregationJobInitReq, as this involves the expensive,
VDAF prep initialization step. Some reports may still get marked as
aggregated; if the peer retries the request, those reports will get
rejected.

This problem is so common that we need to make this request idempotent.
As a first step, wait to mark the reports as aggregated until just
before we have committed to aggregating them. In particular, instead of
doing this in `DapReportInitializer::initialize_reports()`, wait until
`DapAggregator::put_out_shares()`.
  • Loading branch information
cjpatton committed Jul 21, 2023
1 parent a86b8a5 commit 01b1bf9
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 115 deletions.
19 changes: 6 additions & 13 deletions daphne/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,20 +310,19 @@ impl DapTaskConfig {
self.quantized_time_lower_bound(time) + self.time_precision
}

/// Compute the "batch span" of a set of output shares and, for each buckent in the span,
/// aggregate the output shares into an aggregate share.
/// Compute the "batch span" of a set of output shares.
pub fn batch_span_for_out_shares<'a>(
&self,
part_batch_sel: &'a PartialBatchSelector,
out_shares: Vec<DapOutputShare>,
) -> Result<HashMap<DapBatchBucket<'a>, DapAggregateShare>, DapError> {
) -> Result<HashMap<DapBatchBucket<'a>, Vec<DapOutputShare>>, DapError> {
if !self.query.is_valid_part_batch_sel(part_batch_sel) {
return Err(fatal_error!(
err = "partial batch selector not compatible with task",
));
}

let mut span: HashMap<DapBatchBucket<'a>, DapAggregateShare> = HashMap::new();
let mut span: HashMap<DapBatchBucket<'a>, Vec<DapOutputShare>> = HashMap::new();
for out_share in out_shares.into_iter() {
let bucket = match part_batch_sel {
PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval {
Expand All @@ -334,14 +333,7 @@ impl DapTaskConfig {
}
};

let agg_share = span.entry(bucket).or_default();
agg_share.merge(DapAggregateShare {
report_count: 1,
min_time: out_share.time,
max_time: out_share.time,
checksum: out_share.checksum,
data: Some(out_share.data),
})?;
span.entry(bucket).or_default().push(out_share);
}

Ok(span)
Expand Down Expand Up @@ -515,7 +507,8 @@ impl DapHelperState {
#[derive(Debug)]
/// An ouptut share produced by an Aggregator for a single report.
pub struct DapOutputShare {
pub(crate) time: u64, // Value from the report
pub report_id: ReportId, // Value from report
pub time: u64, // Value from the report
pub(crate) checksum: [u8; 32],
pub(crate) data: VdafAggregateShare,
}
Expand Down
14 changes: 9 additions & 5 deletions daphne/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2023 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use std::borrow::Cow;
use std::{borrow::Cow, collections::HashSet};

use async_trait::async_trait;
use prio::codec::Encode;
Expand All @@ -12,8 +12,8 @@ use crate::{
error::DapAbort,
hpke::{HpkeConfig, HpkeDecrypter},
messages::{
decode_base64url, BatchId, BatchSelector, HpkeConfigList, PartialBatchSelector, TaskId,
Time,
decode_base64url, BatchId, BatchSelector, HpkeConfigList, PartialBatchSelector, ReportId,
TaskId, Time,
},
metrics::{DaphneMetrics, DaphneRequestType},
vdaf::{EarlyReportStateConsumed, EarlyReportStateInitialized},
Expand Down Expand Up @@ -102,13 +102,17 @@ pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
/// (resp. Helper) in response to a CollectReq (resp. AggregateShareReq) for fixed-size tasks.
async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result<bool, DapError>;

/// Store a set of output shares.
/// Store a set of output shares and mark the corresponding reports as aggregated. Any reports
/// that were already aggregated are not committed.
///
/// TODO spec: Ensure the spec allows rejecting due to replay at this stage.
async fn put_out_shares(
&self,
task_id: &TaskId,
task_config: &DapTaskConfig,
part_batch_sel: &PartialBatchSelector,
out_shares: Vec<DapOutputShare>,
) -> Result<(), DapError>;
) -> Result<HashSet<ReportId>, DapError>;

/// Fetch the aggregate share for the given batch.
async fn get_agg_share(
Expand Down
19 changes: 16 additions & 3 deletions daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use crate::{
fatal_error,
messages::{
constant_time_eq, AggregateShare, AggregateShareReq, AggregationJobContinueReq,
AggregationJobInitReq, PartialBatchSelector, ReportMetadata, TaskId,
AggregationJobInitReq, PartialBatchSelector, ReportMetadata, TaskId, TransitionFailure,
TransitionVar,
},
metrics::DaphneRequestType,
DapError, DapHelperState, DapHelperTransition, DapRequest, DapResource, DapResponse,
Expand Down Expand Up @@ -245,10 +246,22 @@ pub trait DapHelper<S>: DapAggregator<S> {
DapHelperTransition::Continue(..) => {
return Err(fatal_error!(err = "unexpected transition (continued)").into());
}
DapHelperTransition::Finish(out_shares, agg_job_resp) => {
DapHelperTransition::Finish(out_shares, mut agg_job_resp) => {
let out_shares_count = u64::try_from(out_shares.len()).unwrap();
self.put_out_shares(task_id, &part_batch_sel, out_shares)
let replayed = self
.put_out_shares(task_id, task_config, &part_batch_sel, out_shares)
.await?;

// If there are multiple aggregation jobs in flight that contain the
// same report, then we may need to reject the report at this late stage.
for transition in agg_job_resp.transitions.iter_mut() {
if replayed.contains(&transition.report_id) {
let failure = TransitionFailure::ReportReplayed;
transition.var = TransitionVar::Failed(failure);
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
}
}

(agg_job_resp, out_shares_count)
}
};
Expand Down
8 changes: 7 additions & 1 deletion daphne/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,13 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {
.vdaf
.handle_final_agg_job_resp(uncommited, agg_job_resp, &metrics)?;
let out_shares_count = out_shares.len() as u64;
self.put_out_shares(task_id, part_batch_sel, out_shares)

// At this point we're committed to aggregating the reports: if we do detect a report was
// replayed at this stage, then we may end up with a batch mismatch. However, this should
// only happen if there are multiple aggregation jobs in-flight that include the same
// report.
let _ = self
.put_out_shares(task_id, task_config, part_batch_sel, out_shares)
.await?;

metrics.report_inc_by("aggregated", out_shares_count);
Expand Down
46 changes: 27 additions & 19 deletions daphne/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,6 @@ impl DapReportInitializer for MockAggregator {
{
early_fails.insert(metadata.id.clone(), transition_failure);
};

// Mark report processed.
let mut guard = self
.report_store
.lock()
.expect("report_store: failed to lock");
let report_store = guard.entry(task_id.clone()).or_default();
report_store.processed.insert(metadata.id.clone());
}
}

Expand Down Expand Up @@ -520,25 +512,41 @@ impl DapAggregator<BearerToken> for MockAggregator {
async fn put_out_shares(
&self,
task_id: &TaskId,
task_config: &DapTaskConfig,
part_batch_sel: &PartialBatchSelector,
out_shares: Vec<DapOutputShare>,
) -> Result<(), DapError> {
let task_config = self
.get_task_config_for(Cow::Borrowed(task_id))
.await?
.ok_or_else(|| fatal_error!(err = "task not found"))?;
) -> Result<HashSet<ReportId>, DapError> {
let mut report_store_guard = self
.report_store
.lock()
.expect("report_store: failed to lock");
let report_store = report_store_guard.entry(task_id.clone()).or_default();
let mut agg_store_guard = self.agg_store.lock().expect("agg_store: failed to lock");
let agg_store = agg_store_guard.entry(task_id.clone()).or_default();

let mut guard = self.agg_store.lock().expect("agg_store: failed to lock");
let agg_store = guard.entry(task_id.clone()).or_default();
for (bucket, agg_share_delta) in task_config
let mut replayed = HashSet::new();
for (bucket, out_shares) in task_config
.batch_span_for_out_shares(part_batch_sel, out_shares)?
.into_iter()
{
let inner_agg_store = agg_store.entry(bucket.to_owned_bucket()).or_default();
inner_agg_store.agg_share.merge(agg_share_delta)?;
for out_share in out_shares.into_iter() {
if !report_store.processed.contains(&out_share.report_id) {
// Mark report processed.
report_store.processed.insert(out_share.report_id.clone());

// Add to aggregate share.
agg_store
.entry(bucket.to_owned_bucket())
.or_default()
.agg_share
.merge(DapAggregateShare::try_from_out_shares([out_share])?)?;
} else {
replayed.insert(out_share.report_id);
}
}
}

Ok(())
Ok(replayed)
}

async fn get_agg_share(
Expand Down
2 changes: 2 additions & 0 deletions daphne/src/vdaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ impl VdafConfig {

states.push((
DapOutputShare {
report_id: leader_report_id.clone(),
time: leader_time,
checksum: checksum.as_ref().try_into().unwrap(),
data,
Expand Down Expand Up @@ -1063,6 +1064,7 @@ impl VdafConfig {
);

out_shares.push(DapOutputShare {
report_id: helper_report_id.clone(),
time: helper_time,
checksum: checksum.as_ref().try_into().unwrap(),
data,
Expand Down
15 changes: 6 additions & 9 deletions daphne_worker/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use daphne::{
fatal_error,
hpke::{HpkeConfig, HpkeReceiverConfig},
messages::{
decode_base64url_vec, AggregationJobId, BatchId, CollectionJobId, ReportMetadata, TaskId,
decode_base64url_vec, AggregationJobId, BatchId, CollectionJobId, ReportId, TaskId, Time,
},
DapError, DapGlobalConfig, DapQueryConfig, DapRequest, DapResource, DapResponse, DapTaskConfig,
DapVersion, Prio3Config, VdafConfig,
Expand Down Expand Up @@ -345,17 +345,14 @@ impl DaphneWorkerConfig {
&self,
task_config: &DapTaskConfig,
task_id_hex: &str,
metadata: &ReportMetadata,
report_id: &ReportId,
report_time: Time,
) -> String {
let mut shard_seed = [0; 8];
PrgSha3::seed_stream(
&self.report_shard_key,
b"report shard",
metadata.id.as_ref(),
)
.fill(&mut shard_seed);
PrgSha3::seed_stream(&self.report_shard_key, b"report shard", report_id.as_ref())
.fill(&mut shard_seed);
let shard = u64::from_be_bytes(shard_seed) % self.report_shard_count;
let epoch = metadata.time - (metadata.time % self.global.report_storage_epoch_duration);
let epoch = report_time - (report_time % self.global.report_storage_epoch_duration);
durable_name_report_store(&task_config.version, task_id_hex, epoch, shard)
}
}
Expand Down
Loading

0 comments on commit 01b1bf9

Please sign in to comment.