Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wait to mark reports as aggregated until just before committing #374

Merged
merged 1 commit into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions daphne/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,20 +310,19 @@ impl DapTaskConfig {
self.quantized_time_lower_bound(time) + self.time_precision
}

/// Compute the "batch span" of a set of output shares and, for each buckent in the span,
/// aggregate the output shares into an aggregate share.
/// Compute the "batch span" of a set of output shares.
pub fn batch_span_for_out_shares<'a>(
&self,
part_batch_sel: &'a PartialBatchSelector,
out_shares: Vec<DapOutputShare>,
) -> Result<HashMap<DapBatchBucket<'a>, DapAggregateShare>, DapError> {
) -> Result<HashMap<DapBatchBucket<'a>, Vec<DapOutputShare>>, DapError> {
if !self.query.is_valid_part_batch_sel(part_batch_sel) {
return Err(fatal_error!(
err = "partial batch selector not compatible with task",
));
}

let mut span: HashMap<DapBatchBucket<'a>, DapAggregateShare> = HashMap::new();
let mut span: HashMap<DapBatchBucket<'a>, Vec<DapOutputShare>> = HashMap::new();
for out_share in out_shares.into_iter() {
let bucket = match part_batch_sel {
PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval {
Expand All @@ -334,14 +333,7 @@ impl DapTaskConfig {
}
};

let agg_share = span.entry(bucket).or_default();
agg_share.merge(DapAggregateShare {
report_count: 1,
min_time: out_share.time,
max_time: out_share.time,
checksum: out_share.checksum,
data: Some(out_share.data),
})?;
Copy link
Contributor Author

@cjpatton cjpatton Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't aggregate the output shares yet because we still may need to reject in case of replay.

span.entry(bucket).or_default().push(out_share);
}

Ok(span)
Expand Down Expand Up @@ -515,7 +507,8 @@ impl DapHelperState {
#[derive(Debug)]
/// An ouptut share produced by an Aggregator for a single report.
pub struct DapOutputShare {
pub(crate) time: u64, // Value from the report
pub report_id: ReportId, // Value from report
pub time: u64, // Value from the report
pub(crate) checksum: [u8; 32],
pub(crate) data: VdafAggregateShare,
}
Expand Down
14 changes: 9 additions & 5 deletions daphne/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2023 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use std::borrow::Cow;
use std::{borrow::Cow, collections::HashSet};

use async_trait::async_trait;
use prio::codec::Encode;
Expand All @@ -12,8 +12,8 @@ use crate::{
error::DapAbort,
hpke::{HpkeConfig, HpkeDecrypter},
messages::{
decode_base64url, BatchId, BatchSelector, HpkeConfigList, PartialBatchSelector, TaskId,
Time,
decode_base64url, BatchId, BatchSelector, HpkeConfigList, PartialBatchSelector, ReportId,
TaskId, Time,
},
metrics::{DaphneMetrics, DaphneRequestType},
vdaf::{EarlyReportStateConsumed, EarlyReportStateInitialized},
Expand Down Expand Up @@ -102,13 +102,17 @@ pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
/// (resp. Helper) in response to a CollectReq (resp. AggregateShareReq) for fixed-size tasks.
async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result<bool, DapError>;

/// Store a set of output shares.
/// Store a set of output shares and mark the corresponding reports as aggregated. Any reports
/// that were already aggregated are not committed.
///
/// TODO spec: Ensure the spec allows rejecting due to replay at this stage.
async fn put_out_shares(
&self,
task_id: &TaskId,
task_config: &DapTaskConfig,
part_batch_sel: &PartialBatchSelector,
out_shares: Vec<DapOutputShare>,
) -> Result<(), DapError>;
) -> Result<HashSet<ReportId>, DapError>;

/// Fetch the aggregate share for the given batch.
async fn get_agg_share(
Expand Down
21 changes: 18 additions & 3 deletions daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use crate::{
fatal_error,
messages::{
constant_time_eq, AggregateShare, AggregateShareReq, AggregationJobContinueReq,
AggregationJobInitReq, PartialBatchSelector, ReportMetadata, TaskId,
AggregationJobInitReq, PartialBatchSelector, ReportMetadata, TaskId, TransitionFailure,
TransitionVar,
},
metrics::DaphneRequestType,
DapError, DapHelperState, DapHelperTransition, DapRequest, DapResource, DapResponse,
Expand Down Expand Up @@ -245,10 +246,24 @@ pub trait DapHelper<S>: DapAggregator<S> {
DapHelperTransition::Continue(..) => {
return Err(fatal_error!(err = "unexpected transition (continued)").into());
}
DapHelperTransition::Finish(out_shares, agg_job_resp) => {
DapHelperTransition::Finish(out_shares, mut agg_job_resp) => {
let out_shares_count = u64::try_from(out_shares.len()).unwrap();
self.put_out_shares(task_id, &part_batch_sel, out_shares)
let replayed = self
.put_out_shares(task_id, task_config, &part_batch_sel, out_shares)
.await?;

// If there are multiple aggregation jobs in flight that contain the
// same report, then we may need to reject the report at this late stage.
if !replayed.is_empty() {
for transition in agg_job_resp.transitions.iter_mut() {
if replayed.contains(&transition.report_id) {
let failure = TransitionFailure::ReportReplayed;
transition.var = TransitionVar::Failed(failure);
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
}
}
}

(agg_job_resp, out_shares_count)
}
};
Expand Down
8 changes: 7 additions & 1 deletion daphne/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,13 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {
.vdaf
.handle_final_agg_job_resp(uncommited, agg_job_resp, &metrics)?;
let out_shares_count = out_shares.len() as u64;
self.put_out_shares(task_id, part_batch_sel, out_shares)

// At this point we're committed to aggregating the reports: if we do detect a report was
// replayed at this stage, then we may end up with a batch mismatch. However, this should
// only happen if there are multiple aggregation jobs in-flight that include the same
// report.
let _ = self
.put_out_shares(task_id, task_config, part_batch_sel, out_shares)
.await?;

metrics.report_inc_by("aggregated", out_shares_count);
Expand Down
46 changes: 27 additions & 19 deletions daphne/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,6 @@ impl DapReportInitializer for MockAggregator {
{
early_fails.insert(metadata.id.clone(), transition_failure);
};

// Mark report processed.
let mut guard = self
.report_store
.lock()
.expect("report_store: failed to lock");
let report_store = guard.entry(task_id.clone()).or_default();
report_store.processed.insert(metadata.id.clone());
}
}

Expand Down Expand Up @@ -520,25 +512,41 @@ impl DapAggregator<BearerToken> for MockAggregator {
async fn put_out_shares(
&self,
task_id: &TaskId,
task_config: &DapTaskConfig,
part_batch_sel: &PartialBatchSelector,
out_shares: Vec<DapOutputShare>,
) -> Result<(), DapError> {
let task_config = self
.get_task_config_for(Cow::Borrowed(task_id))
.await?
.ok_or_else(|| fatal_error!(err = "task not found"))?;
) -> Result<HashSet<ReportId>, DapError> {
let mut report_store_guard = self
.report_store
.lock()
.expect("report_store: failed to lock");
let report_store = report_store_guard.entry(task_id.clone()).or_default();
let mut agg_store_guard = self.agg_store.lock().expect("agg_store: failed to lock");
let agg_store = agg_store_guard.entry(task_id.clone()).or_default();

let mut guard = self.agg_store.lock().expect("agg_store: failed to lock");
let agg_store = guard.entry(task_id.clone()).or_default();
for (bucket, agg_share_delta) in task_config
let mut replayed = HashSet::new();
for (bucket, out_shares) in task_config
.batch_span_for_out_shares(part_batch_sel, out_shares)?
.into_iter()
{
let inner_agg_store = agg_store.entry(bucket.to_owned_bucket()).or_default();
inner_agg_store.agg_share.merge(agg_share_delta)?;
for out_share in out_shares.into_iter() {
if !report_store.processed.contains(&out_share.report_id) {
// Mark report processed.
report_store.processed.insert(out_share.report_id.clone());

// Add to aggregate share.
agg_store
.entry(bucket.to_owned_bucket())
.or_default()
.agg_share
.merge(DapAggregateShare::try_from_out_shares([out_share])?)?;
} else {
replayed.insert(out_share.report_id);
}
}
}

Ok(())
Ok(replayed)
}

async fn get_agg_share(
Expand Down
2 changes: 2 additions & 0 deletions daphne/src/vdaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ impl VdafConfig {

states.push((
DapOutputShare {
report_id: leader_report_id.clone(),
time: leader_time,
checksum: checksum.as_ref().try_into().unwrap(),
data,
Expand Down Expand Up @@ -1063,6 +1064,7 @@ impl VdafConfig {
);

out_shares.push(DapOutputShare {
report_id: helper_report_id.clone(),
time: helper_time,
checksum: checksum.as_ref().try_into().unwrap(),
data,
Expand Down
15 changes: 6 additions & 9 deletions daphne_worker/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use daphne::{
fatal_error,
hpke::{HpkeConfig, HpkeReceiverConfig},
messages::{
decode_base64url_vec, AggregationJobId, BatchId, CollectionJobId, ReportMetadata, TaskId,
decode_base64url_vec, AggregationJobId, BatchId, CollectionJobId, ReportId, TaskId, Time,
},
DapError, DapGlobalConfig, DapQueryConfig, DapRequest, DapResource, DapResponse, DapTaskConfig,
DapVersion, Prio3Config, VdafConfig,
Expand Down Expand Up @@ -345,17 +345,14 @@ impl DaphneWorkerConfig {
&self,
task_config: &DapTaskConfig,
task_id_hex: &str,
metadata: &ReportMetadata,
report_id: &ReportId,
report_time: Time,
) -> String {
let mut shard_seed = [0; 8];
PrgSha3::seed_stream(
&self.report_shard_key,
b"report shard",
metadata.id.as_ref(),
)
.fill(&mut shard_seed);
PrgSha3::seed_stream(&self.report_shard_key, b"report shard", report_id.as_ref())
.fill(&mut shard_seed);
let shard = u64::from_be_bytes(shard_seed) % self.report_shard_count;
let epoch = metadata.time - (metadata.time % self.global.report_storage_epoch_duration);
let epoch = report_time - (report_time % self.global.report_storage_epoch_duration);
durable_name_report_store(&task_config.version, task_id_hex, epoch, shard)
}
}
Expand Down
Loading
Loading