From f3d1c5aa67e77a467dab5d4ce28e880754cae917 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 7 Aug 2024 10:12:21 -0700 Subject: [PATCH 1/2] Add `validate_record` API to upgraded contexts This API serves multiple purposes: ## We could never make the existing API work and kept discovering bugs associated with it. Previously, the validation process was separate, detached from actual protocol and validated everything at once. We saw multiple occurrences where reveal was called before sharings were validated. ## Previous approach did not integrate smoothly with vectorization Vectorization makes data processed in chunks. Going between different chunk sizes (one for data, one for validation) has proven to be challenging. The code written for that was hard to read. ## Validate record API The core of this proposal is to put `validate_record` API on the `UpgradedContext` that takes a `record_id` and blocks the execution until this record (and others in the same batch) has been validated. FWIW this is exactly how ZKP validator works now. This API allows to bring closer together MAC and ZKP validation. In addition to bringing this API, this change also updates all the uses of MAC validators and contexts to use it. The pros include: * Validate record now must be called per record basis, making protocols easier to review. One can see that validate call is there and it is right before the reveal. * Reveal can have special power and abort if the record being revealed hasn't been validated yet. Because `UpgradedContext` now can keep track of things validated, we can add this functionality later. * No chunk conversion required on the protocol side. They can be simply written w/o doing magic conversions * Validation can now be done in batches, transparently to the code calling `validate_record` and it integrates smoothly with `seq_join` (no need for special `validated_seq_join`) Downsides: * Tracking total records is difficult. We still don't have a good solution to some of our protocols where we need to have different total records per step, created inside `UpgradedContext` * `record_id` appears more and more in the API. Traits like `ShareKnownValue` must be updated to support it. --- ipa-core/src/protocol/basics/mul/malicious.rs | 20 +- ipa-core/src/protocol/basics/reshare.rs | 26 +- ipa-core/src/protocol/basics/reveal.rs | 40 ++- .../src/protocol/basics/share_known_value.rs | 42 +-- ipa-core/src/protocol/boolean/or.rs | 18 +- ipa-core/src/protocol/context/batcher.rs | 128 ++++++++ ipa-core/src/protocol/context/malicious.rs | 211 +++++-------- ipa-core/src/protocol/context/mod.rs | 41 ++- ipa-core/src/protocol/context/semi_honest.rs | 4 + ipa-core/src/protocol/context/validator.rs | 297 +++++++++++------- ipa-core/src/test_fixture/world.rs | 77 +++-- 11 files changed, 527 insertions(+), 377 deletions(-) create mode 100644 ipa-core/src/protocol/context/batcher.rs diff --git a/ipa-core/src/protocol/basics/mul/malicious.rs b/ipa-core/src/protocol/basics/mul/malicious.rs index 2cdba03a7..e55d855d6 100644 --- a/ipa-core/src/protocol/basics/mul/malicious.rs +++ b/ipa-core/src/protocol/basics/mul/malicious.rs @@ -130,7 +130,7 @@ where mod test { use crate::{ ff::Fp31, - protocol::{basics::SecureMul, context::Context, RecordId}, + protocol::basics::SecureMul, rand::{thread_rng, Rng}, test_fixture::{Reconstruct, Runner, TestWorld}, }; @@ -143,14 +143,16 @@ mod test { let a = rng.gen::(); let b = rng.gen::(); - let res = world - .upgraded_malicious((a, b), |ctx, (a, b)| async move { - a.multiply(&b, ctx.set_total_records(1), RecordId::from(0)) - .await - .unwrap() - }) - .await; + let res = + world + .upgraded_malicious( + vec![(a, b)].into_iter(), + |ctx, record_id, (a, b)| async move { + a.multiply(&b, ctx, record_id).await.unwrap() + }, + ) + .await; - assert_eq!(a * b, res.reconstruct()); + assert_eq!(a * b, res.reconstruct()[0]); } } diff --git a/ipa-core/src/protocol/basics/reshare.rs b/ipa-core/src/protocol/basics/reshare.rs index 20aac3e3b..cb33a8146 100644 --- a/ipa-core/src/protocol/basics/reshare.rs +++ b/ipa-core/src/protocol/basics/reshare.rs @@ -208,7 +208,9 @@ mod tests { helpers::{in_memory_config::MaliciousHelper, Role}, protocol::{ basics::Reshare, - context::{upgrade::Upgradable, Context, UpgradableContext, Validator}, + context::{ + upgrade::Upgradable, Context, UpgradableContext, UpgradedContext, Validator, + }, RecordId, }, rand::{thread_rng, Rng}, @@ -229,15 +231,15 @@ mod tests { for &role in Role::all() { let secret = thread_rng().gen::(); let new_shares = world - .upgraded_malicious(secret, |ctx, share| async move { - share - .reshare(ctx.set_total_records(1), RecordId::from(0), role) - .await - .unwrap() - }) + .upgraded_malicious( + vec![secret].into_iter(), + |ctx, record_id, share| async move { + share.reshare(ctx, record_id, role).await.unwrap() + }, + ) .await; - assert_eq!(secret, new_shares.reconstruct()); + assert_eq!(secret, new_shares.reconstruct()[0]); } } @@ -299,16 +301,16 @@ mod tests { world .malicious(a, |ctx, a| async move { - let v = ctx.validator(); - let m_ctx = v.context().set_total_records(1); + let v = ctx.set_total_records(1).validator(); + let m_ctx = v.context(); let m_a = a.upgrade(m_ctx.clone(), RecordId::FIRST).await.unwrap(); - let m_reshared_a = m_a + let _ = m_a .reshare(m_ctx.narrow(STEP), RecordId::FIRST, to_helper) .await .unwrap(); - match v.validate(m_reshared_a).await { + match m_ctx.validate_record(RecordId::FIRST).await { Ok(result) => panic!("Got a result {result:?}"), Err(err) => { assert!(matches!(err, Error::MaliciousSecurityCheckFailed)); diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index 99f55db21..f8b71cdff 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -399,7 +399,10 @@ mod tests { }, protocol::{ basics::{partial_reveal, reveal, Reveal}, - context::{upgrade::Upgradable, Context, UpgradableContext, Validator}, + context::{ + upgrade::Upgradable, validator::BatchValidator, Context, UpgradableContext, + Validator, + }, RecordId, }, rand::{thread_rng, Rng}, @@ -501,9 +504,12 @@ mod tests { let mut rng = thread_rng(); let world = TestWorld::default(); - let sh_ctx = world.malicious_contexts(); + let sh_ctx = world + .malicious_contexts() + .each_ref() + .map(|c| c.set_total_records(1)); let v = sh_ctx.map(UpgradableContext::validator); - let m_ctx = v.each_ref().map(|v| v.context().set_total_records(1)); + let m_ctx = v.each_ref().map(BatchValidator::context); let record_id = RecordId::from(0); let input: TestField = rng.gen(); @@ -537,9 +543,12 @@ mod tests { let world = TestWorld::default(); for &excluded in Role::all() { - let sh_ctx = world.malicious_contexts(); + let sh_ctx = world + .malicious_contexts() + .each_ref() + .map(|c| c.set_total_records(1)); let v = sh_ctx.map(UpgradableContext::validator); - let m_ctx = v.each_ref().map(|v| v.context().set_total_records(1)); + let m_ctx = v.each_ref().map(BatchValidator::context); let record_id = RecordId::from(0); let input: TestField = rng.gen(); @@ -579,7 +588,6 @@ mod tests { F: Field, S: SecretSharing + Reveal>::Array>, { - let ctx = ctx.set_total_records(1); let my_role = ctx.role(); let ctx = ctx.narrow(MALICIOUS_REVEAL_STEP); @@ -620,7 +628,10 @@ mod tests { let world = TestWorld::new_with(config); let input: Fp31 = rng.gen(); world - .upgraded_malicious(input, |ctx, share| do_malicious_reveal(ctx, partial, share)) + .upgraded_malicious( + vec![input].into_iter(), + |ctx, _record_id: RecordId, share| do_malicious_reveal(ctx, partial, share), + ) .await; }); } @@ -637,7 +648,10 @@ mod tests { let world = TestWorld::new_with(config); let input: Fp31 = rng.gen(); world - .upgraded_malicious(input, |ctx, share| do_malicious_reveal(ctx, partial, share)) + .upgraded_malicious( + vec![input].into_iter(), + |ctx, _record_id: RecordId, share| do_malicious_reveal(ctx, partial, share), + ) .await; }); } @@ -653,8 +667,12 @@ mod tests { let world = TestWorld::new_with(config); let input: Boolean = rng.gen(); + // ZKP malicious does not set the total records as `upgraded_malicious` + // something to think about how to bring them closer together. world - .dzkp_malicious(input, |ctx, share| do_malicious_reveal(ctx, partial, share)) + .dzkp_malicious(input, |ctx, share| { + do_malicious_reveal(ctx.set_total_records(1), partial, share) + }) .await; }); } @@ -671,7 +689,9 @@ mod tests { let world = TestWorld::new_with(config); let input: Boolean = rng.gen(); world - .dzkp_malicious(input, |ctx, share| do_malicious_reveal(ctx, partial, share)) + .dzkp_malicious(input, |ctx, share| { + do_malicious_reveal(ctx.set_total_records(1), partial, share) + }) .await; }); } diff --git a/ipa-core/src/protocol/basics/share_known_value.rs b/ipa-core/src/protocol/basics/share_known_value.rs index 00cd9ea46..052356e4f 100644 --- a/ipa-core/src/protocol/basics/share_known_value.rs +++ b/ipa-core/src/protocol/basics/share_known_value.rs @@ -1,12 +1,8 @@ use crate::{ helpers::Role, - protocol::context::{Context, UpgradedMaliciousContext}, + protocol::context::Context, secret_sharing::{ - replicated::{ - malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, - semi_honest::AdditiveShare as Replicated, - ReplicatedSecretSharing, - }, + replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, SharedValue, }, }; @@ -15,6 +11,11 @@ use crate::{ /// /// The context is only used to determine the helper role. It is not used for communication or PRSS, /// and it is not necessary to use a uniquely narrowed context. +/// +/// As of Aug 2024, this interface does not work for MAC malicious sharings as they +/// were defined before. Sharing known value requires `r` and it varies from one +/// record id to another. If we need to update this, [`Self::share_known_value`] needs +/// to have record id parameter. pub trait ShareKnownValue { fn share_known_value(ctx: &C, value: V) -> Self; } @@ -29,14 +30,6 @@ impl ShareKnownValue for Replicated { } } -impl<'a, F: ExtendableField> ShareKnownValue, F> - for MaliciousReplicated -{ - fn share_known_value(ctx: &UpgradedMaliciousContext<'a, F>, value: F) -> Self { - ctx.share_known_value(value) - } -} - #[cfg(all(test, unit_test))] mod tests { use rand::Rng; @@ -44,10 +37,7 @@ mod tests { use super::ShareKnownValue; use crate::{ ff::Fp31, - secret_sharing::replicated::{ - malicious::AdditiveShare as MaliciousReplicated, - semi_honest::AdditiveShare as Replicated, - }, + secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, test_fixture::{Reconstruct, Runner, TestWorld}, }; @@ -66,20 +56,4 @@ mod tests { .reconstruct(); assert_eq!(result, a); } - - #[tokio::test] - pub async fn malicious_share_known_values() { - let world = TestWorld::default(); - - let mut rng = rand::thread_rng(); - let a = rng.gen::(); - - let result = world - .upgraded_malicious((), |ctx, ()| async move { - MaliciousReplicated::::share_known_value(&ctx, a) - }) - .await - .reconstruct(); - assert_eq!(result, a); - } } diff --git a/ipa-core/src/protocol/boolean/or.rs b/ipa-core/src/protocol/boolean/or.rs index 2a350210e..c8aa611c9 100644 --- a/ipa-core/src/protocol/boolean/or.rs +++ b/ipa-core/src/protocol/boolean/or.rs @@ -93,20 +93,16 @@ mod tests { .await .reconstruct(); let m_result = world - .upgraded_malicious((a, b), |ctx, (a_share, b_share)| async move { - or( - ctx.set_total_records(1), - RecordId::from(0_u32), - &a_share, - &b_share, - ) - .await - .unwrap() - }) + .upgraded_malicious( + vec![(a, b)].into_iter(), + |ctx, record_id, (a_share, b_share)| async move { + or(ctx, record_id, &a_share, &b_share).await.unwrap() + }, + ) .await .reconstruct(); - assert_eq!(result, m_result); + assert_eq!(result, m_result[0]); result } diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs new file mode 100644 index 000000000..9e3c2b694 --- /dev/null +++ b/ipa-core/src/protocol/context/batcher.rs @@ -0,0 +1,128 @@ +use std::{ + collections::VecDeque, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use tokio::sync::Notify; + +use crate::{ + protocol::RecordId, + sync::{Arc, Mutex}, +}; + +pub enum Either { + Left(L), + Right(R), +} + +impl Either { + fn left(value: L) -> Self { + Self::Left(value) + } + + fn right(value: R) -> Self { + Self::Right(value) + } +} + +#[derive(Debug)] +pub struct BatchState { + pub(super) batch: B, + pub(super) notify: Arc, + records_per_batch: usize, + records: AtomicUsize, +} + +pub(super) struct Batcher<'a, B> { + batches: VecDeque>, + first_batch: usize, + records_per_batch: usize, + total_records: usize, + batch_constructor: Box B + Send + 'a>, +} + +impl<'a, B> Batcher<'a, B> { + pub fn new( + records_per_batch: usize, + total_records: usize, + batch_constructor: Box B + Send + 'a>, + ) -> Arc> { + Arc::new(Mutex::new(Self { + batches: VecDeque::new(), + first_batch: 0, + records_per_batch, + total_records, + batch_constructor, + })) + } + + fn batch_offset(&self, record_id: RecordId) -> usize { + let batch_idx = usize::from(record_id) / self.records_per_batch; + let Some(batch_offset) = batch_idx.checked_sub(self.first_batch) else { + panic!( + "Batches should be processed in order. Attempting to retrieve batch {batch_idx}. \ + The oldest active batch is batch {}.", + self.first_batch, + ) + }; + batch_offset + } + + fn get_batch_by_offset(&mut self, batch_offset: usize) -> &mut BatchState { + if self.batches.len() <= batch_offset { + self.batches.reserve(batch_offset - self.batches.len() + 1); + while self.batches.len() <= batch_offset { + let state = BatchState { + batch: (self.batch_constructor)(self.first_batch + batch_offset), + notify: Arc::new(Notify::new()), + records_per_batch: self.records_per_batch, + records: AtomicUsize::new(0), + }; + self.batches.push_back(state); + } + } + + &mut self.batches[batch_offset] + } + + pub fn get_batch(&mut self, record_id: RecordId) -> &mut BatchState { + self.get_batch_by_offset(self.batch_offset(record_id)) + } + + pub fn validate_record( + &mut self, + record_id: RecordId, + ) -> Either<(usize, BatchState), Arc> { + tracing::trace!("validate record {record_id}"); + let batch_offset = self.batch_offset(record_id); + let is_last = self.is_last(record_id); + let batch = self.get_batch_by_offset(batch_offset); + let prev_records = batch.records.fetch_add(1, Ordering::Relaxed); + if prev_records == batch.records_per_batch - 1 || is_last { + // I am not sure if this is okay, or if we need to tolerate batch validation requests + // arriving out of order. (If we do, I think we would still want to actually fulfill + // the validations in order.) + assert_eq!( + batch_offset, + 0, + "Batches should be processed in order. \ + Batch {idx} is ready for validation, but the first batch is {first}.", + idx = self.first_batch + batch_offset, + first = self.first_batch, + ); + tracing::info!( + "batch {} is ready for validation", + self.first_batch + batch_offset + ); + let batch = self.batches.pop_front().unwrap(); + self.first_batch += 1; + Either::left((self.first_batch + batch_offset, batch)) + } else { + Either::right(Arc::clone(&batch.notify)) + } + } + + fn is_last(&self, record_id: RecordId) -> bool { + self.total_records - 1 == usize::from(record_id) + } +} diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 43d52b5da..685961b45 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -9,19 +9,18 @@ use ipa_step::{Step, StepNarrow}; use crate::{ error::Error, - helpers::{ChannelId, Gateway, MpcMessage, MpcReceivingEnd, Role, SendingEnd, TotalRecords}, + helpers::{Gateway, MpcMessage, MpcReceivingEnd, Role, SendingEnd, TotalRecords}, protocol::{ - basics::{ - mul::{semi_honest_multiply, step::MaliciousMultiplyStep::RandomnessForValidation}, - ShareKnownValue, - }, + basics::mul::{semi_honest_multiply, step::MaliciousMultiplyStep::RandomnessForValidation}, context::{ + batcher::{Batcher, Either}, dzkp_malicious::DZKPUpgraded, dzkp_validator::{DZKPBatch, MaliciousDZKPValidator}, prss::InstrumentedIndexedSharedRandomness, step::UpgradeStep, upgrade::Upgradable, - validator::{Malicious as Validator, MaliciousAccumulator}, + validator, + validator::BatchValidator, Base, Context as ContextTrait, InstrumentedSequentialSharedRandomness, SpecialAccessToUpgradedContext, UpgradableContext, UpgradedContext, }, @@ -59,24 +58,6 @@ impl<'a> Context<'a> { } } - /// Upgrade this context to malicious using MACs. - /// `malicious_step` is the step that will be used for malicious protocol execution. - /// `upgrade_step` is the step that will be used for upgrading inputs - /// from `replicated::semi_honest::AdditiveShare` to `replicated::malicious::AdditiveShare`. - /// `accumulator` and `r_share` come from a `MaliciousValidator`. - #[must_use] - pub fn upgrade( - self, - malicious_step: &S, - accumulator: MaliciousAccumulator, - r_share: Replicated, - ) -> Upgraded<'a, F> - where - Gate: StepNarrow, - { - Upgraded::new(&self.inner, malicious_step, accumulator, r_share) - } - /// Upgrade this context to malicious using DZKPs /// `malicious_step` is the step that will be used for malicious protocol execution. /// `DZKPBatch` comes from a `MaliciousDZKPValidator`. @@ -153,10 +134,10 @@ impl<'a> super::Context for Context<'a> { } impl<'a> UpgradableContext for Context<'a> { - type Validator = Validator<'a, F>; + type Validator = BatchValidator<'a, F>; fn validator(self) -> Self::Validator { - Validator::new(self) + BatchValidator::new(self) } type DZKPValidator = MaliciousDZKPValidator<'a>; @@ -178,58 +159,26 @@ impl Debug for Context<'_> { } } +use crate::sync::{Mutex, Weak}; + +pub(super) type MacBatch<'a, F> = Mutex>>; + /// Represents protocol context in malicious setting, i.e. secure against one active adversary /// in 3 party MPC ring. #[derive(Clone)] pub struct Upgraded<'a, F: ExtendableField> { - /// TODO (alex): Arc is required here because of the `TestWorld` structure. Real world - /// may operate with raw references and be more efficient - inner: Arc>, - gate: Gate, - total_records: TotalRecords, + batch: Weak>, + base_ctx: Context<'a>, } impl<'a, F: ExtendableField> Upgraded<'a, F> { - pub(super) fn new( - source: &Base<'a>, - malicious_step: &S, - acc: MaliciousAccumulator, - r_share: Replicated, - ) -> Self - where - Gate: StepNarrow, - { + pub(super) fn new(batch: &Arc>, ctx: Context<'a>) -> Self { Self { - inner: UpgradedInner::new(source, acc, r_share), - gate: source.gate().narrow(malicious_step), - total_records: TotalRecords::Unspecified, + batch: Arc::downgrade(batch), + base_ctx: ctx, } } - // TODO: it can be made more efficient by impersonating malicious context as semi-honest - // it does not work as of today because of https://github.com/rust-lang/rust/issues/20400 - // while it is possible to define a struct that wraps a reference to malicious context - // and implement `Context` trait for it, implementing SecureMul and Reveal for Context - // is not. - // For the same reason, it is not possible to implement Context> - // for `MaliciousContext`. Deep clone is the only option. - fn as_base(&self) -> Base<'a> { - Base::new_complete( - self.inner.prss, - self.inner.gateway, - self.gate.clone(), - self.total_records, - NotSharded, - ) - } - - pub fn share_known_value(&self, value: F) -> MaliciousReplicated { - MaliciousReplicated::new( - Replicated::share_known_value(&self.clone().base_context(), value), - &self.inner.r_share * value.to_extended(), - ) - } - /// Take a secret sharing and add it to the running MAC that this context maintains (if any). pub fn accumulate_macs( self, @@ -239,30 +188,71 @@ impl<'a, F: ExtendableField> Upgraded<'a, F> { F: ExtendableFieldSimd, Replicated: FromPrss, { - self.inner - .accumulator - .accumulate_macs(&self.prss(), record_id, share); + self.with_batch(record_id, |v| { + v.accumulator + .accumulate_macs(&self.prss(), record_id, share); + }); } - /// It is intentionally not public, allows access to it only from within - /// this module - fn r_share(&self) -> &Replicated { - &self.inner.r_share + #[cfg(any(test, feature = "test-fixture"))] + #[must_use] + pub fn r(&self, record_id: RecordId) -> Replicated { + self.r_share(record_id) + } + + fn r_share(&self, record_id: RecordId) -> Replicated { + // its unfortunate, but carrying references across mutex boundaries is not possible + self.with_batch(record_id, |v| v.r_share().clone()) + } + + fn with_batch) -> T, T>( + &self, + record_id: RecordId, + action: C, + ) -> T { + let batcher = self.batch.upgrade().expect("Validator is active"); + + let mut batch = batcher.lock().unwrap(); + let state = batch.get_batch(record_id); + (action)(&mut state.batch) } } #[async_trait] impl<'a, F: ExtendableField> UpgradedContext for Upgraded<'a, F> { type Field = F; + + async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> { + let r = { + self.batch + .upgrade() + .expect("Validation batch is active") + .lock() + .unwrap() + .validate_record(record_id) + }; + match r { + Either::Left((_, batch)) => { + // TODO: fix naming (batch.batch) + batch.batch.validate().await?; + batch.notify.notify_waiters(); + Ok(()) + } + Either::Right(notify) => { + notify.notified().await; + Ok(()) + } + } + } } impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { fn role(&self) -> Role { - self.inner.gateway.role() + self.base_ctx.role() } fn gate(&self) -> &Gate { - &self.gate + self.base_ctx.gate() } fn narrow(&self, step: &S) -> Self @@ -270,28 +260,24 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { Gate: StepNarrow, { Self { - inner: Arc::clone(&self.inner), - gate: self.gate.narrow(step), - total_records: self.total_records, + base_ctx: self.base_ctx.narrow(step), + ..self.clone() } } fn set_total_records>(&self, total_records: T) -> Self { Self { - inner: Arc::clone(&self.inner), - gate: self.gate.clone(), - total_records: self.total_records.overwrite(total_records), + base_ctx: self.base_ctx.set_total_records(total_records), + ..self.clone() } } fn total_records(&self) -> TotalRecords { - self.total_records + self.base_ctx.total_records() } fn prss(&self) -> InstrumentedIndexedSharedRandomness<'_> { - let prss = self.inner.prss.indexed(self.gate()); - - InstrumentedIndexedSharedRandomness::new(prss, &self.gate, self.role()) + self.base_ctx.prss() } fn prss_rng( @@ -300,29 +286,21 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { InstrumentedSequentialSharedRandomness<'_>, InstrumentedSequentialSharedRandomness<'_>, ) { - let (left, right) = self.inner.prss.sequential(self.gate()); - ( - InstrumentedSequentialSharedRandomness::new(left, self.gate(), self.role()), - InstrumentedSequentialSharedRandomness::new(right, self.gate(), self.role()), - ) + self.base_ctx.prss_rng() } fn send_channel(&self, role: Role) -> SendingEnd { - self.inner - .gateway - .get_mpc_sender(&ChannelId::new(role, self.gate.clone()), self.total_records) + self.base_ctx.send_channel(role) } fn recv_channel(&self, role: Role) -> MpcReceivingEnd { - self.inner - .gateway - .get_mpc_receiver(&ChannelId::new(role, self.gate.clone())) + self.base_ctx.recv_channel(role) } } impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { fn active_work(&self) -> NonZeroUsize { - self.inner.gateway.config().active_work() + self.base_ctx.active_work() } } @@ -334,7 +312,7 @@ impl<'a, F: ExtendableField> SpecialAccessToUpgradedContext for Upgraded<'a, type Base = Base<'a>; fn base_context(self) -> Self::Base { - self.as_base() + self.base_ctx.inner } } @@ -343,31 +321,6 @@ impl Debug for Upgraded<'_, F> { write!(f, "MaliciousContext<{:?}>", type_name::()) } } -struct UpgradedInner<'a, F: ExtendableField> { - prss: &'a PrssEndpoint, - gateway: &'a Gateway, - accumulator: MaliciousAccumulator, - r_share: Replicated, -} - -impl<'a, F: ExtendableField> UpgradedInner<'a, F> { - fn new( - base_context: &Base<'a>, - accumulator: MaliciousAccumulator, - r_share: Replicated, - ) -> Arc { - Arc::new(UpgradedInner { - prss: base_context.inner.prss, - gateway: base_context.inner.gateway, - accumulator, - r_share, - }) - } - - fn accumulator(&self) -> &MaliciousAccumulator { - &self.accumulator - } -} /// Upgrading a semi-honest replicated share using malicious context produces /// a MAC-secured share with the same vectorization factor. @@ -399,14 +352,12 @@ where // let induced_share = self.induced(); // expand r to match the vectorization factor of induced share - let r = ctx.r_share().expand(); + let r = ctx.r_share(record_id).expand(); - let rx = semi_honest_multiply(ctx.as_base(), record_id, &induced_share, &r).await?; - let m = MaliciousReplicated::new(self, rx); let narrowed = ctx.narrow(&RandomnessForValidation); - let prss = narrowed.prss(); - let accumulator = narrowed.inner.accumulator(); - accumulator.accumulate_macs(&prss, record_id, &m); + let rx = semi_honest_multiply(ctx.base_context(), record_id, &induced_share, &r).await?; + let m = MaliciousReplicated::new(self, rx); + narrowed.accumulate_macs(record_id, &m); Ok(m) } diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index e07e63b07..ec0cc1fb1 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -8,6 +8,7 @@ pub mod semi_honest; pub mod step; pub mod upgrade; +mod batcher; /// Validators are not used in IPA v3 yet. Once we make use of MAC-based validation, /// this flag can be removed #[allow(dead_code)] @@ -114,6 +115,15 @@ pub trait UpgradableContext: Context { #[async_trait] pub trait UpgradedContext: Context { type Field: ExtendableField; + + /// This method blocks until `record_id` has been validated. Validation happens + /// in batches, this method will block each individual future until + /// the whole batch is validated. The code written this way is more concise + /// and easier to read + /// + /// Future improvement will combine this with [`Reveal`] to access + /// the value after validation. + async fn validate_record(&self, record_id: RecordId) -> Result<(), Error>; } pub trait SpecialAccessToUpgradedContext: UpgradedContext { @@ -539,23 +549,24 @@ mod tests { } /// Toy protocol to execute PRSS generation and send/receive logic - async fn toy_protocol(ctx: C, index: usize, share: &S) -> Replicated + async fn toy_protocol(ctx: C, index: I, share: &S) -> Replicated where F: Field + U128Conversions, Standard: Distribution, C: Context, S: ReplicatedLeftValue, + I: Into, { let ctx = ctx.narrow("metrics"); let (left_peer, right_peer) = ( ctx.role().peer(Direction::Left), ctx.role().peer(Direction::Right), ); - let record_id = RecordId::from(index); + let record_id = index.into(); let (l, r) = ctx.prss().generate_fields(record_id); let (seq_l, seq_r) = { - let ctx = ctx.narrow(&format!("seq-prss-{index}")); + let ctx = ctx.narrow(&format!("seq-prss-{record_id}")); let (mut left_rng, mut right_rng) = ctx.prss_rng(); // exercise both methods of `RngCore` trait @@ -643,7 +654,6 @@ mod tests { async fn malicious_metrics() { let world = TestWorld::new_with(TestWorldConfig::default().enable_metrics()); let input = vec![Fp31::truncate_from(0u128), Fp31::truncate_from(1u128)]; - let input_len = input.len(); let field_size = ::Size::USIZE; let metrics_step = world .gate() @@ -652,14 +662,8 @@ mod tests { .narrow("metrics"); let _result = world - .upgraded_malicious(input.clone().into_iter(), |ctx, a| async move { - let ctx = ctx.set_total_records(input_len); - join_all( - a.iter() - .enumerate() - .map(|(i, share)| toy_protocol(ctx.clone(), i, share)), - ) - .await; + .upgraded_malicious(input.clone().into_iter(), |ctx, record_id, a| async move { + let _ = toy_protocol(ctx.clone(), record_id, &a).await; a }) @@ -724,18 +728,11 @@ mod tests { world .malicious(input.into_iter(), |ctx, shares| async move { // upgrade shares two times using different contexts - let v = ctx.validator(); + let v = ctx.set_total_records(1).validator(); let ctx = v.context().narrow("step1"); - shares - .clone() - .upgrade(ctx.set_total_records(1), RecordId::FIRST) - .await - .unwrap(); + shares.clone().upgrade(ctx, RecordId::FIRST).await.unwrap(); let ctx = v.context().narrow("step2"); - shares - .upgrade(ctx.set_total_records(1), RecordId::FIRST) - .await - .unwrap(); + shares.upgrade(ctx, RecordId::FIRST).await.unwrap(); }) .await; } diff --git a/ipa-core/src/protocol/context/semi_honest.rs b/ipa-core/src/protocol/context/semi_honest.rs index 7a6b78c12..1be359879 100644 --- a/ipa-core/src/protocol/context/semi_honest.rs +++ b/ipa-core/src/protocol/context/semi_honest.rs @@ -265,6 +265,10 @@ impl<'a, B: ShardBinding, F: ExtendableField> SeqJoin for Upgraded<'a, B, F> { #[async_trait] impl<'a, B: ShardBinding, F: ExtendableField> UpgradedContext for Upgraded<'a, B, F> { type Field = F; + + async fn validate_record(&self, _record_id: RecordId) -> Result<(), Error> { + Ok(()) + } } impl<'a, B: ShardBinding, F: ExtendableField> SpecialAccessToUpgradedContext diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index f3846fa82..10405f21b 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -4,9 +4,6 @@ use std::{ marker::PhantomData, }; -use async_trait::async_trait; -use typenum::Const; - use crate::{ error::Error, ff::Field, @@ -14,6 +11,8 @@ use crate::{ protocol::{ basics::{check_zero::malicious_check_zero, malicious_reveal}, context::{ + batcher::Batcher, + malicious::MacBatch, step::{MaliciousProtocolStep as Step, ValidateStep}, Base, Context, MaliciousContext, UpgradedContext, UpgradedMaliciousContext, UpgradedSemiHonestContext, @@ -24,24 +23,22 @@ use crate::{ secret_sharing::{ replicated::{ malicious::{ - AdditiveShare as MaliciousReplicated, DowngradeMalicious, ExtendableField, - ExtendableFieldSimd, + AdditiveShare as MaliciousReplicated, ExtendableField, ExtendableFieldSimd, }, semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing, }, FieldSimd, SharedValue, }, + seq_join::SeqJoin, sharding::ShardBinding, - sync::{Arc, Mutex, Weak}, + sync::Arc, }; -#[async_trait] pub trait Validator { type Context: UpgradedContext; fn context(&self) -> Self::Context; - async fn validate(self, values: D) -> Result; } pub struct SemiHonest<'a, B: ShardBinding, F: ExtendableField> { @@ -58,18 +55,12 @@ impl<'a, B: ShardBinding, F: ExtendableField> SemiHonest<'a, B, F> { } } -#[async_trait] impl<'a, B: ShardBinding, F: ExtendableField> Validator for SemiHonest<'a, B, F> { type Context = UpgradedSemiHonestContext<'a, B, F>; fn context(&self) -> Self::Context { self.context.clone() } - - async fn validate(self, values: D) -> Result { - use crate::secret_sharing::replicated::malicious::ThisCodeIsAuthorizedToDowngradeFromMalicious; - Ok(values.downgrade().await.access_without_downgrade()) - } } impl Debug for SemiHonest<'_, B, F> { @@ -136,10 +127,14 @@ impl AccumulatorState { #[derive(Clone, Debug)] pub struct MaliciousAccumulator { - inner: Weak>>, + inner: AccumulatorState, } impl MaliciousAccumulator { + pub(super) fn u_and_w(&self) -> (F::ExtendedField, F::ExtendedField) { + (self.inner.u, self.inner.w) + } + fn compute_dot_product_contribution( a: &Replicated, b: &Replicated, @@ -160,7 +155,7 @@ impl MaliciousAccumulator { /// ## Panics /// Will panic if the mutex is poisoned pub fn accumulate_macs( - &self, + &mut self, prss: &I, record_id: RecordId, input: &MaliciousReplicated, @@ -191,35 +186,57 @@ impl MaliciousAccumulator { let w_contribution = Self::compute_dot_product_contribution(&random_constant, &induced_share); - let arc_mutex = self.inner.upgrade().unwrap(); - // LOCK BEGIN - let mut accumulator_state = arc_mutex.lock().unwrap(); + self.inner.u += u_contribution; + self.inner.w += w_contribution; + } +} + +/// Validates the upgraded shares in batches, similarly to +/// ZKP validator. It keeps a unique context per batch that carries +/// the `r` value and accumulator. All multiplications that occur +/// in that context, will use the associated `r` value. +/// +/// When batch is validated, `r` is revealed and can never be +/// used again. In fact, it gets out of scope after successful validation +/// so no code can get access to it. +pub struct BatchValidator<'a, F: ExtendableField> { + batches_ref: Arc>, + protocol_ctx: MaliciousContext<'a>, +} + +impl<'a, F: ExtendableField> BatchValidator<'a, F> { + /// Create a new validator for malicious context. + /// + /// ## Panics + /// If total records is not set. + #[must_use] + pub fn new(ctx: MaliciousContext<'a>) -> Self { + assert!( + ctx.total_records().is_specified(), + "Total records must be specified before creating the validator" + ); + let total_records = ctx.total_records().count().unwrap(); + let records_per_batch = ctx.active_work().get().min(total_records); - accumulator_state.u += u_contribution; - accumulator_state.w += w_contribution; - // LOCK END + Self { + protocol_ctx: ctx.narrow(&Step::MaliciousProtocol), + batches_ref: Batcher::new( + records_per_batch, + total_records, + Box::new(move |batch_index| Malicious::new(ctx.clone(), batch_index)), + ), + } } } pub struct Malicious<'a, F: ExtendableField> { r_share: Replicated, - u_and_w: Arc>>, - protocol_ctx: UpgradedMaliciousContext<'a, F>, + pub(super) accumulator: MaliciousAccumulator, validate_ctx: Base<'a>, + offset: usize, } -#[async_trait] -impl<'a, F> Validator for Malicious<'a, F> -where - F: ExtendableField, -{ - type Context = UpgradedMaliciousContext<'a, F>; - - /// Get a copy of the context that can be used for malicious protocol execution. - fn context(&self) -> Self::Context { - self.protocol_ctx.clone() - } - +impl Malicious<'_, F> { /// ## Errors /// If the two information theoretic MACs are not equal (after multiplying by `r`), this indicates that one of the parties /// must have launched an additive attack. At this point the honest parties should abort the protocol. This method throws an @@ -229,7 +246,7 @@ where /// ## Panics /// Will panic if the mutex is poisoned #[tracing::instrument(name = "validate", skip_all, fields(gate = %self.validate_ctx.gate().as_ref()))] - async fn validate(self, values: D) -> Result { + pub(crate) async fn validate(self) -> Result<(), Error> { // send our `u_i+1` value to the helper on the right let (u_share, w_share) = self.propagate_u_and_w().await?; @@ -237,52 +254,77 @@ where let narrow_ctx = self .validate_ctx .narrow(&ValidateStep::RevealR) - .set_total_records(TotalRecords::ONE); + .set_total_records(TotalRecords::Indeterminate); let r = ::ExtendedField::from_array( - &malicious_reveal(narrow_ctx, RecordId::FIRST, None, &self.r_share) - .await? - .expect("full reveal should always return a value"), + &malicious_reveal( + narrow_ctx, + Self::reveal_check_zero_record(self.offset), + None, + &self.r_share, + ) + .await? + .expect("full reveal should always return a value"), ); let t = u_share - &(w_share * r); let check_zero_ctx = self .validate_ctx .narrow(&ValidateStep::CheckZero) - .set_total_records(TotalRecords::ONE); - let is_valid = malicious_check_zero(check_zero_ctx, RecordId::FIRST, &t).await?; + .set_total_records(TotalRecords::Indeterminate); + let is_valid = malicious_check_zero( + check_zero_ctx, + Self::reveal_check_zero_record(self.offset), + &t, + ) + .await?; if is_valid { // Yes, we're allowed to downgrade here. - use crate::secret_sharing::replicated::malicious::ThisCodeIsAuthorizedToDowngradeFromMalicious; - Ok(values.downgrade().await.access_without_downgrade()) + + Ok(()) } else { Err(Error::MaliciousSecurityCheckFailed) } } } +impl<'a, F> Validator for BatchValidator<'a, F> +where + F: ExtendableField, +{ + type Context = UpgradedMaliciousContext<'a, F>; + + fn context(&self) -> Self::Context { + UpgradedMaliciousContext::new(&self.batches_ref, self.protocol_ctx.clone()) + } +} + impl<'a, F: ExtendableField> Malicious<'a, F> { #[must_use] #[allow(clippy::needless_pass_by_value)] - pub fn new(ctx: MaliciousContext<'a>) -> Self { + pub fn new(ctx: MaliciousContext<'a>, offset: usize) -> Self { + // Each invocation requires 3 calls to PRSS to generate the state. + // Validation occurs in batches and `offset` indicates which batch + // we're in right now. + const TOTAL_CALLS_TO_PRSS: usize = 3; + // Use the current step in the context for initialization. - let r_share: Replicated = ctx.prss().generate(RecordId::FIRST); + let r_share: Replicated = ctx + .prss() + .generate(Self::r_share_record(offset, TOTAL_CALLS_TO_PRSS)); let prss = ctx.prss(); - let u: F::ExtendedField = prss.zero(RecordId::FIRST + 1); - let w: F::ExtendedField = prss.zero(RecordId::FIRST + 2); + let u: F::ExtendedField = prss.zero(Self::u_record(offset, TOTAL_CALLS_TO_PRSS)); + let w: F::ExtendedField = prss.zero(Self::w_record(offset, TOTAL_CALLS_TO_PRSS)); let state = AccumulatorState::new(u, w); - let u_and_w = Arc::new(Mutex::new(state)); - let accumulator = MaliciousAccumulator:: { - inner: Arc::downgrade(&u_and_w), - }; + let accumulator = MaliciousAccumulator:: { inner: state }; let validate_ctx = ctx.narrow(&Step::Validate).validator_context(); - let protocol_ctx = ctx.upgrade(&Step::MaliciousProtocol, accumulator, r_share.clone()); + Self { r_share, - u_and_w, - protocol_ctx, + accumulator, validate_ctx, + offset, } } @@ -295,31 +337,47 @@ impl<'a, F: ExtendableField> Malicious<'a, F> { &self, ) -> Result<(Replicated, Replicated), Error> { use futures::future::try_join; + const TOTAL_SEND: usize = 2; let propagate_ctx = self .validate_ctx .narrow(&ValidateStep::PropagateUAndW) - .set_total_records(Const::<2>); + .set_total_records(TotalRecords::Indeterminate); let helper_right = propagate_ctx.send_channel(propagate_ctx.role().peer(Direction::Right)); let helper_left = propagate_ctx.recv_channel(propagate_ctx.role().peer(Direction::Left)); - let (u_local, w_local) = { - let state = self.u_and_w.lock().unwrap(); - (state.u, state.w) - }; + let (u_local, w_local) = self.accumulator.u_and_w(); + let (u_record, w_record) = ( + Self::u_record(self.offset, TOTAL_SEND), + Self::w_record(self.offset, TOTAL_SEND), + ); + try_join( - helper_right.send(RecordId::FIRST, u_local), - helper_right.send(RecordId::FIRST + 1, w_local), - ) - .await?; - let (u_left, w_left): (F::ExtendedField, F::ExtendedField) = try_join( - helper_left.receive(RecordId::FIRST), - helper_left.receive(RecordId::FIRST + 1), + helper_right.send(u_record, u_local), + helper_right.send(w_record, w_local), ) .await?; + let (u_left, w_left): (F::ExtendedField, F::ExtendedField) = + try_join(helper_left.receive(u_record), helper_left.receive(w_record)).await?; let u_share = Replicated::new(u_left, u_local); let w_share = Replicated::new(w_left, w_local); Ok((u_share, w_share)) } + + fn u_record(offset: usize, total: usize) -> RecordId { + RecordId::from(total * offset) + } + + fn w_record(offset: usize, total: usize) -> RecordId { + RecordId::from(total * offset + 1) + } + + fn r_share_record(offset: usize, total: usize) -> RecordId { + RecordId::from(total * offset + 2) + } + + fn reveal_check_zero_record(offset: usize) -> RecordId { + RecordId::from(offset) + } } impl Debug for Malicious<'_, F> { @@ -338,7 +396,10 @@ mod tests { helpers::Role, protocol::{ basics::SecureMul, - context::{upgrade::Upgradable, validator::Validator, Context, UpgradableContext}, + context::{ + upgrade::Upgradable, validator::Validator, Context, UpgradableContext, + UpgradedContext, + }, RecordId, }, rand::{thread_rng, Rng}, @@ -381,22 +442,21 @@ mod tests { let futures = zip(context, zip(a_shares, b_shares)).map(|(ctx, (a_share, b_share))| async move { - let v = ctx.validator(); + let v = ctx.set_total_records(1).validator(); let m_ctx = v.context(); let (a_malicious, b_malicious) = (a_share, b_share) - .upgrade(m_ctx.set_total_records(1), RecordId::FIRST) + .upgrade(m_ctx.clone(), RecordId::FIRST) .await .unwrap(); let m_result = a_malicious - .multiply(&b_malicious, m_ctx.set_total_records(1), RecordId::from(0)) + .multiply(&b_malicious, m_ctx.clone(), RecordId::from(0)) .await?; // Save some cloned values so that we can check them. - let r_share = v.r_share().clone(); - let result = v.validate(m_result.clone()).await?; - assert_eq!(&result, m_result.x().access_without_downgrade()); + let r_share = m_ctx.r(RecordId::FIRST); + m_ctx.validate_record(RecordId::FIRST).await?; Ok::<_, Error>((m_result, r_share)) }); @@ -426,12 +486,12 @@ mod tests { let result = world .malicious(a, |ctx, a| async move { + let ctx = ctx.set_total_records(1); let v = ctx.validator(); - let m = a - .upgrade(v.context().set_total_records(1), RecordId::FIRST) - .await - .unwrap(); - v.validate(m).await.unwrap() + let m = a.upgrade(v.context(), RecordId::FIRST).await.unwrap(); + v.context().validate_record(RecordId::FIRST).await.unwrap(); + + m.access_without_downgrade() }) .await; assert_eq!(a, result.reconstruct()); @@ -453,12 +513,10 @@ mod tests { } else { a }; + let ctx = ctx.set_total_records(1); let v = ctx.validator(); - let m = a - .upgrade(v.context().set_total_records(1), RecordId::FIRST) - .await - .unwrap(); - match v.validate(m).await { + let _ = a.upgrade(v.context(), RecordId::FIRST).await.unwrap(); + match v.context().validate_record(RecordId::FIRST).await { Ok(result) => panic!("Got a result {result:?}"), Err(err) => assert!(matches!(err, Error::MaliciousSecurityCheckFailed)), } @@ -511,64 +569,59 @@ mod tests { .into_iter() .zip([h1_shares, h2_shares, h3_shares]) .map(|(ctx, input_shares)| async move { + let ctx = ctx.set_total_records(COUNT - 1); let v = ctx.validator(); let m_ctx = v.context(); - let m_input = input_shares - .upgrade(m_ctx.set_total_records(1), RecordId::FIRST) - .await - .unwrap(); - let m_results = m_ctx .try_join( zip( - repeat(m_ctx.set_total_records(COUNT - 1)).enumerate(), - zip(m_input.iter(), m_input.iter().skip(1)), + repeat(m_ctx.clone()).enumerate(), + zip(input_shares.iter(), input_shares.iter().skip(1)), ) - .map( - |((i, ctx), (a_malicious, b_malicious))| async move { - a_malicious - .multiply(b_malicious, ctx, RecordId::from(i)) - .await - }, - ), + .map(|((i, ctx), (a, b))| async move { + let record_id = RecordId::from(i); + let (a_malicious, b_malicious) = (a.clone(), b.clone()) + .upgrade(ctx.clone(), record_id) + .await?; + let m_result = a_malicious + .multiply(&b_malicious, ctx.clone(), RecordId::from(i)) + .await; + + let r_share = ctx.r(RecordId::from(i)); + ctx.validate_record(record_id).await?; + + Ok::<_, Error>((m_result?, r_share)) + }), ) .await?; - let r_share = v.r_share().clone(); - let results = v.validate(m_results.clone()).await?; - assert_eq!( - results.iter().collect::>(), - m_results - .iter() - .map(|x| x.x().access_without_downgrade()) - .collect::>() - ); - Ok::<_, Error>((m_results, r_share)) + Ok::<_, Error>(m_results) }); let processed_outputs = join3v(futures).await; - let r = [ - &processed_outputs[0].1, - &processed_outputs[1].1, - &processed_outputs[2].1, - ] - .reconstruct(); - for i in 0..99 { let x1 = original_inputs[i]; let x2 = original_inputs[i + 1]; + let x1_times_x2 = [ - processed_outputs[0].0[i].x().access_without_downgrade(), - processed_outputs[1].0[i].x().access_without_downgrade(), - processed_outputs[2].0[i].x().access_without_downgrade(), + processed_outputs[0][i].0.x().access_without_downgrade(), + processed_outputs[1][i].0.x().access_without_downgrade(), + processed_outputs[2][i].0.x().access_without_downgrade(), ] .reconstruct(); let r_times_x1_times_x2 = [ - processed_outputs[0].0[i].rx(), - processed_outputs[1].0[i].rx(), - processed_outputs[2].0[i].rx(), + processed_outputs[0][i].0.rx(), + processed_outputs[1][i].0.rx(), + processed_outputs[2][i].0.rx(), + ] + .reconstruct(); + + let r = [ + processed_outputs[0][i].1.clone(), + processed_outputs[1][i].1.clone(), + processed_outputs[2][i].1.clone(), ] .reconstruct(); diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 113f639ba..2a6b479cb 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -1,6 +1,8 @@ // We have quite a bit of code that is only used when descriptive-gate is enabled. #![allow(dead_code)] -use std::{array::from_fn, borrow::Borrow, fmt::Debug, io::stdout, iter::zip, marker::PhantomData}; +use std::{ + array::from_fn, borrow::Borrow, fmt::Debug, io::stdout, iter, iter::zip, marker::PhantomData, +}; use async_trait::async_trait; use futures::{future::join_all, stream::FuturesOrdered, Future, StreamExt}; @@ -21,14 +23,16 @@ use crate::{ context::{ dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, DZKPUpgradedMaliciousContext, MaliciousContext, SemiHonestContext, - ShardedSemiHonestContext, UpgradableContext, UpgradedMaliciousContext, + ShardedSemiHonestContext, UpgradableContext, UpgradedContext, UpgradedMaliciousContext, UpgradedSemiHonestContext, Validator, }, prss::Endpoint as PrssEndpoint, Gate, QueryId, RecordId, }, secret_sharing::{ - replicated::malicious::{DowngradeMalicious, ExtendableField}, + replicated::malicious::{ + DowngradeMalicious, ExtendableField, ThisCodeIsAuthorizedToDowngradeFromMalicious, + }, IntoShares, }, sharding::{NotSharded, ShardBinding, ShardIndex, Sharded}, @@ -406,14 +410,14 @@ pub trait Runner { &'a self, input: I, helper_fn: H, - ) -> [O; 3] + ) -> [Vec; 3] where F: ExtendableField, - I: IntoShares + Send + 'static, + I: IntoShares> + Send + 'static, A: Send + 'static + Upgradable, Output = M>, O: Send + Debug, M: Send + 'static, - H: Fn(UpgradedMaliciousContext<'a, F>, M) -> R + Send + Sync, + H: Fn(UpgradedMaliciousContext<'a, F>, RecordId, M) -> R + Send + Sync, R: Future + Send, P: DowngradeMalicious + Clone + Send + Debug, [P; 3]: ValidateMalicious, @@ -506,14 +510,14 @@ impl Runner> &'a self, _input: I, _helper_fn: H, - ) -> [O; 3] + ) -> [Vec; 3] where F: ExtendableField, - I: IntoShares + Send + 'static, + I: IntoShares> + Send + 'static, A: Send + 'static + Upgradable, Output = M>, O: Send + Debug, M: Send + 'static, - H: Fn(UpgradedMaliciousContext<'a, F>, M) -> R + Send + Sync, + H: Fn(UpgradedMaliciousContext<'a, F>, RecordId, M) -> R + Send + Sync, R: Future + Send, P: DowngradeMalicious + Clone + Send + Debug, [P; 3]: ValidateMalicious, @@ -599,35 +603,51 @@ impl Runner for TestWorld { &'a self, input: I, helper_fn: H, - ) -> [O; 3] + ) -> [Vec; 3] where F: ExtendableField, - I: IntoShares + Send + 'static, + I: IntoShares> + Send + 'static, A: Send + 'static + Upgradable, Output = M>, O: Send + Debug, M: Send + 'static, - H: Fn(UpgradedMaliciousContext<'a, F>, M) -> R + Send + Sync, + H: Fn(UpgradedMaliciousContext<'a, F>, RecordId, M) -> R + Send + Sync, R: Future + Send, P: DowngradeMalicious + Clone + Send + Debug, [P; 3]: ValidateMalicious, Standard: Distribution, { + // Closure is Copy, so we don't need to fight rustc convincing it + // that it is ok to use `helper_fn` in `malicious` closure. + #[allow(clippy::redundant_closure)] + let helper_fn = |ctx, record_id, m_share| helper_fn(ctx, record_id, m_share); + let (m_results, r_shares, output) = split_array_of_tuples( - self.malicious(input, |ctx, share| async { - let v = ctx.validator(); + self.malicious(input, |ctx, shares| async move { + let ctx = ctx.set_total_records( + TotalRecords::specified(shares.len()).expect("Non-empty input"), + ); + let v = ctx.validator::(); let m_ctx = v.context(); - let m_share = share - .upgrade( - m_ctx.set_total_records(TotalRecords::specified(1).unwrap()), - RecordId::FIRST, - ) - .await - .unwrap(); - let m_result = helper_fn(m_ctx, m_share).await; - let m_result_clone = m_result.clone(); - let r_share = v.r_share().clone(); - let output = v.validate(m_result_clone).await.unwrap(); - (m_result, r_share, output) + let r_share = m_ctx.clone().r(RecordId::FIRST).clone(); + let m_shares: Vec<_> = + join_all(zip(shares, iter::repeat(m_ctx.clone())).enumerate().map( + |(i, (share, m_ctx))| async move { + let record_id = RecordId::from(i); + let m_share = share.upgrade(m_ctx.clone(), record_id).await.unwrap(); + let m_result = helper_fn(m_ctx.clone(), record_id, m_share).await; + m_ctx.validate_record(record_id).await.unwrap(); + + ( + m_result.clone(), + m_result.downgrade().await.access_without_downgrade(), + ) + }, + )) + .await; + + let (m_results, outputs): (Vec<_>, Vec<_>) = m_shares.into_iter().unzip(); + + (m_results, r_share, outputs) }) .await, ); @@ -635,7 +655,10 @@ impl Runner for TestWorld { // Sanity check that rx = r * x at the output (it should not be possible // for this to fail if the distributed validation protocol passed). let r = r_shares.reconstruct(); - m_results.validate(r); + let [h1_r, h2_r, h3_r] = m_results; + for (h1, (h2, h3)) in zip(h1_r, zip(h2_r, h3_r)) { + [h1, h2, h3].validate(r); + } output } From 322e3772d1ab39e7fe24a25017def0c4a32ce1ed Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sun, 11 Aug 2024 22:54:49 -0700 Subject: [PATCH 2/2] Feedback --- ipa-core/src/protocol/context/malicious.rs | 5 ++++- ipa-core/src/protocol/context/validator.rs | 16 +++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 685961b45..564b7e1d6 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -194,14 +194,17 @@ impl<'a, F: ExtendableField> Upgraded<'a, F> { }); } + /// `TestWorld` malicious methods require access to r share to perform validation. + /// This method allows such access only in non-prod code. #[cfg(any(test, feature = "test-fixture"))] #[must_use] pub fn r(&self, record_id: RecordId) -> Replicated { self.r_share(record_id) } + /// It is intentionally not public, allows access to it only from within + /// this module fn r_share(&self, record_id: RecordId) -> Replicated { - // its unfortunate, but carrying references across mutex boundaries is not possible self.with_batch(record_id, |v| v.r_share().clone()) } diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index 10405f21b..6e5a4de6f 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -211,11 +211,12 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> { /// If total records is not set. #[must_use] pub fn new(ctx: MaliciousContext<'a>) -> Self { - assert!( - ctx.total_records().is_specified(), - "Total records must be specified before creating the validator" - ); - let total_records = ctx.total_records().count().unwrap(); + let Some(total_records) = ctx.total_records().count() else { + panic!("Total records must be specified before creating the validator"); + }; + + // TODO: Right now we set the batch work to be equal to active_work, + // but it does not need to be. We can make this configurable if needed. let records_per_batch = ctx.active_work().get().min(total_records); Self { @@ -254,6 +255,11 @@ impl Malicious<'_, F> { let narrow_ctx = self .validate_ctx .narrow(&ValidateStep::RevealR) + // TODO: propagate_u_and_w, RevealR and CheckZero all use indeterminate record count + // to communicate data right away. We could make it better if we had support from + // compact gate infrastructure to override batch size per step. All of the steps + // above require batch size to be set to 1, but we know the total number of records + // sent through these channels (total_records / batch_size) .set_total_records(TotalRecords::Indeterminate); let r = ::ExtendedField::from_array( &malicious_reveal(