From e0e1fbccf332b31e778aa92ab77c153ab560a637 Mon Sep 17 00:00:00 2001 From: Hannah Davis Date: Wed, 8 May 2024 18:54:17 +0000 Subject: [PATCH] Create Mastic module with client implementation Implements client functionality for the Mastic protocol for weighted heavy-hitters and attribute-based metrics. --- src/flp.rs | 2 +- src/flp/szk.rs | 793 ++++++++++++++++++++++++++++++++++++++------- src/vdaf.rs | 9 + src/vdaf/mastic.rs | 539 ++++++++++++++++++++++++++++++ src/vidpf.rs | 145 ++++++++- 5 files changed, 1360 insertions(+), 128 deletions(-) create mode 100644 src/vdaf/mastic.rs diff --git a/src/flp.rs b/src/flp.rs index 62308bf89..707d7333e 100644 --- a/src/flp.rs +++ b/src/flp.rs @@ -57,7 +57,7 @@ use std::convert::TryFrom; use std::fmt::Debug; pub mod gadgets; -#[cfg(all(feature = "experimental", test))] +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] pub mod szk; pub mod types; diff --git a/src/flp/szk.rs b/src/flp/szk.rs index a5256269e..ef504204f 100644 --- a/src/flp/szk.rs +++ b/src/flp/szk.rs @@ -12,17 +12,22 @@ //! following a strategy similar to [`Prio3`](crate::vdaf::prio3::Prio3). use crate::{ - codec::{CodecError, Encode}, - field::{FftFriendlyFieldElement, FieldElement}, + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{decode_fieldvec, encode_fieldvec, FieldElement}, flp::{FlpError, Type}, prng::{Prng, PrngError}, vdaf::xof::{IntoFieldVec, Seed, Xof, XofTurboShake128}, }; -use std::{borrow::Cow, marker::PhantomData}; +#[cfg(test)] +use std::borrow::Cow; +use std::ops::BitAnd; +use std::{io::Cursor, marker::PhantomData}; +use subtle::{Choice, ConstantTimeEq}; // Domain separation tags const DST_PROVE_RANDOMNESS: u16 = 0; const DST_PROOF_SHARE: u16 = 1; +#[allow(dead_code)] const DST_QUERY_RANDOMNESS: u16 = 2; const DST_JOINT_RAND_SEED: u16 = 3; const DST_JOINT_RAND_PART: u16 = 4; @@ -57,51 +62,257 @@ pub enum SzkError { /// Contains an FLP proof share, and if joint randomness is needed, the blind /// used to derive it and the other party's joint randomness part. -#[derive(Clone)] -pub enum SzkProofShare { - /// Leader's proof share is uncompressed. The first Seed is a blind, second - /// is a joint randomness part. +#[derive(Debug, Clone)] +pub enum SzkProofShare { + /// Leader's proof share is uncompressed. Leader { + /// Share of an FLP proof, as a vector of Field elements. uncompressed_proof_share: Vec, - leader_blind_and_helper_joint_rand_part: Option<(Seed, Seed)>, + /// Set only if joint randomness is needed. The first Seed is a blind, second + /// is the helper's joint randomness part. + leader_blind_and_helper_joint_rand_part_opt: Option<(Seed, Seed)>, }, /// The Helper uses one seed for both its compressed proof share and as the blind for its joint /// randomness. Helper { + /// The Seed that acts both as the compressed proof share and, optionally, as the blind. proof_share_seed_and_blind: Seed, - leader_joint_rand_part: Option>, + /// The leader's joint randomness part, if needed. + leader_joint_rand_part_opt: Option>, }, } +impl PartialEq for SzkProofShare { + fn eq(&self, other: &SzkProofShare) -> bool { + bool::from(self.ct_eq(other)) + } +} + +impl ConstantTimeEq for SzkProofShare { + fn ct_eq(&self, other: &SzkProofShare) -> Choice { + match (self, other) { + ( + SzkProofShare::Leader { + uncompressed_proof_share: s_proof, + leader_blind_and_helper_joint_rand_part_opt: s_blind, + }, + SzkProofShare::Leader { + uncompressed_proof_share: o_proof, + leader_blind_and_helper_joint_rand_part_opt: o_blind, + }, + ) => s_proof[..] + .ct_eq(&o_proof[..]) + .bitand(option_tuple_ct_eq(s_blind, o_blind)), + ( + SzkProofShare::Helper { + proof_share_seed_and_blind: s_seed, + leader_joint_rand_part_opt: s_rand, + }, + SzkProofShare::Helper { + proof_share_seed_and_blind: o_seed, + leader_joint_rand_part_opt: o_rand, + }, + ) => s_seed.ct_eq(o_seed).bitand(option_ct_eq(s_rand, o_rand)), + _ => Choice::from(0), + } + } +} + +impl Encode for SzkProofShare { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + match self { + SzkProofShare::Leader { + uncompressed_proof_share, + leader_blind_and_helper_joint_rand_part_opt, + } => ( + encode_fieldvec(uncompressed_proof_share, bytes)?, + if let Some((blind, helper_joint_rand_part)) = + leader_blind_and_helper_joint_rand_part_opt + { + blind.encode(bytes)?; + helper_joint_rand_part.encode(bytes)?; + }, + ), + SzkProofShare::Helper { + proof_share_seed_and_blind, + leader_joint_rand_part_opt, + } => ( + proof_share_seed_and_blind.encode(bytes)?, + if let Some(leader_joint_rand_part) = leader_joint_rand_part_opt { + leader_joint_rand_part.encode(bytes)?; + }, + ), + }; + Ok(()) + } + + fn encoded_len(&self) -> Option { + match self { + SzkProofShare::Leader { + uncompressed_proof_share, + leader_blind_and_helper_joint_rand_part_opt, + } => Some( + uncompressed_proof_share.len() * F::ENCODED_SIZE + + if let Some((blind, helper_joint_rand_part)) = + leader_blind_and_helper_joint_rand_part_opt + { + blind.encoded_len()? + helper_joint_rand_part.encoded_len()? + } else { + 0 + }, + ), + SzkProofShare::Helper { + proof_share_seed_and_blind, + leader_joint_rand_part_opt, + } => Some( + proof_share_seed_and_blind.encoded_len()? + + if let Some(leader_joint_rand_part) = leader_joint_rand_part_opt { + leader_joint_rand_part.encoded_len()? + } else { + 0 + }, + ), + } + } +} + +impl ParameterizedDecode<(bool, usize, bool)> + for SzkProofShare +{ + fn decode_with_param( + (is_leader, proof_len, requires_joint_rand): &(bool, usize, bool), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + if *is_leader { + Ok(SzkProofShare::Leader { + uncompressed_proof_share: decode_fieldvec::(*proof_len, bytes)?, + leader_blind_and_helper_joint_rand_part_opt: if *requires_joint_rand { + Some(( + Seed::::decode(bytes)?, + Seed::::decode(bytes)?, + )) + } else { + None + }, + }) + } else { + Ok(SzkProofShare::Helper { + proof_share_seed_and_blind: Seed::::decode(bytes)?, + leader_joint_rand_part_opt: if *requires_joint_rand { + Some(Seed::::decode(bytes)?) + } else { + None + }, + }) + } + } +} + /// A tuple containing the state and messages produced by an SZK query. -#[derive(Clone)] -pub(crate) struct SzkQueryShare { - joint_rand_part: Option>, - verifier: SzkVerifier, +#[cfg(test)] +#[derive(Clone, Debug)] +pub struct SzkQueryShare { + joint_rand_part_opt: Option>, + flp_verifier: Vec, } /// The state that needs to be stored by an Szk verifier between query() and decide() -pub(crate) struct SzkQueryState { - joint_rand_seed: Option>, +pub type SzkQueryState = Option>; + +#[cfg(test)] +impl SzkQueryShare { + pub(crate) fn merge_verifiers( + mut leader_share: SzkQueryShare, + helper_share: SzkQueryShare, + ) -> SzkVerifier { + for (x, y) in leader_share + .flp_verifier + .iter_mut() + .zip(helper_share.flp_verifier) + { + *x += y; + } + SzkVerifier { + flp_verifier: leader_share.flp_verifier, + leader_joint_rand_part_opt: leader_share.joint_rand_part_opt, + helper_joint_rand_part_opt: helper_share.joint_rand_part_opt, + } + } } /// Verifier type for the SZK proof. -pub type SzkVerifier = Vec; +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SzkVerifier { + flp_verifier: Vec, + leader_joint_rand_part_opt: Option>, + helper_joint_rand_part_opt: Option>, +} + +impl Encode for SzkVerifier { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + encode_fieldvec(&self.flp_verifier, bytes)?; + if let Some(ref part) = self.leader_joint_rand_part_opt { + part.encode(bytes)? + }; + if let Some(ref part) = self.helper_joint_rand_part_opt { + part.encode(bytes)? + }; + Ok(()) + } + + fn encoded_len(&self) -> Option { + Some( + self.flp_verifier.len() * F::ENCODED_SIZE + + match self.leader_joint_rand_part_opt { + Some(ref part) => part.encoded_len()?, + None => 0, + } + + match self.helper_joint_rand_part_opt { + Some(ref part) => part.encoded_len()?, + None => 0, + }, + ) + } +} + +impl ParameterizedDecode<(bool, usize)> + for SzkVerifier +{ + fn decode_with_param( + (requires_joint_rand, verifier_len): &(bool, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + if *requires_joint_rand { + Ok(SzkVerifier { + flp_verifier: decode_fieldvec(*verifier_len, bytes)?, + leader_joint_rand_part_opt: Some(Seed::::decode(bytes)?), + helper_joint_rand_part_opt: Some(Seed::::decode(bytes)?), + }) + } else { + Ok(SzkVerifier { + flp_verifier: decode_fieldvec(*verifier_len, bytes)?, + leader_joint_rand_part_opt: None, + helper_joint_rand_part_opt: None, + }) + } + } +} /// Main struct encapsulating the shared zero-knowledge functionality. The type /// T is the underlying FLP proof system. P is the XOF used to derive all random /// coins (it should be indifferentiable from a random oracle for security.) +#[derive(Clone, Debug)] pub struct Szk where T: Type, P: Xof, { - typ: T, + /// The Type representing the specific FLP system used to prove validity of an input. + pub(crate) typ: T, algorithm_id: u32, phantom: PhantomData

, } -#[cfg(test)] impl Szk { /// Create an instance of [`Szk`] using [`XofTurboShake128`]. pub fn new_turboshake128(typ: T, algorithm_id: u32) -> Self { @@ -207,6 +418,7 @@ where .collect() } + #[cfg(test)] fn derive_query_rand(&self, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16]) -> Vec { let mut xof = P::init( verify_key, @@ -221,7 +433,15 @@ where self.typ.joint_rand_len() > 0 } - fn prove( + /// Used by a client to prove validity (according to an FLP system) of an input + /// that is both shared between the leader and helper + /// and encoded as a measurement. Has a precondition that leader_input_share + /// \+ helper_input_share = encoded_measurement. + /// leader_seed_opt should be set only if the underlying FLP system requires + /// joint randomness. + /// In this case, the helper uses the same seed to derive its proof share and + /// joint randomness. + pub(crate) fn prove( &self, leader_input_share: &[T::Field], helper_input_share: &[T::Field], @@ -236,8 +456,8 @@ where // leader its blinding seed and the helper's joint randomness part, and // pass the helper the leader's joint randomness part. (The seed used to // derive the helper's proof share is reused as the helper's blind.) - let (leader_blind_and_helper_joint_rand_part, leader_joint_rand_part, joint_rand) = - if let Some(leader_seed) = leader_seed_opt.clone() { + let (leader_blind_and_helper_joint_rand_part_opt, leader_joint_rand_part_opt, joint_rand) = + if let Some(leader_seed) = leader_seed_opt { let leader_joint_rand_part = self.derive_joint_rand_part(&leader_seed, leader_input_share, nonce)?; let helper_joint_rand_part = @@ -269,16 +489,17 @@ where // Construct the output messages. let leader_proof_share = SzkProofShare::Leader { uncompressed_proof_share: leader_proof_share, - leader_blind_and_helper_joint_rand_part, + leader_blind_and_helper_joint_rand_part_opt, }; let helper_proof_share = SzkProofShare::Helper { - proof_share_seed_and_blind: helper_seed.clone(), - leader_joint_rand_part, + proof_share_seed_and_blind: helper_seed, + leader_joint_rand_part_opt, }; Ok([leader_proof_share, helper_proof_share]) } - fn query( + #[cfg(test)] + pub(crate) fn query( &self, input_share: &[T::Field], proof_share: SzkProofShare, @@ -301,8 +522,8 @@ where let ((joint_rand_seed, joint_rand), host_joint_rand_part) = match proof_share { SzkProofShare::Leader { uncompressed_proof_share: _, - leader_blind_and_helper_joint_rand_part, - } => match leader_blind_and_helper_joint_rand_part { + leader_blind_and_helper_joint_rand_part_opt, + } => match leader_blind_and_helper_joint_rand_part_opt { Some((seed, helper_joint_rand_part)) => { match self.derive_joint_rand_part(&seed, input_share, nonce) { Ok(leader_joint_rand_part) => ( @@ -323,8 +544,8 @@ where }, SzkProofShare::Helper { proof_share_seed_and_blind, - leader_joint_rand_part, - } => match leader_joint_rand_part { + leader_joint_rand_part_opt, + } => match leader_joint_rand_part_opt { Some(leader_joint_rand_part) => match self.derive_joint_rand_part( &proof_share_seed_and_blind, input_share, @@ -363,33 +584,31 @@ where )?; Ok(( SzkQueryShare { - joint_rand_part, - verifier: verifier_share, + joint_rand_part_opt: joint_rand_part, + flp_verifier: verifier_share, }, - SzkQueryState { joint_rand_seed }, + joint_rand_seed, )) } /// Returns true if the verifier message indicates that the input from which /// it was generated is valid. - fn decide( + pub fn decide( &self, - verifier: &[T::Field], - leader_joint_rand_part_opt: Option>, - helper_joint_rand_part_opt: Option>, - joint_rand_seed_opt: Option>, + verifier: SzkVerifier, + query_state: SzkQueryState, ) -> Result { // Check if underlying FLP proof validates - let check_flp_proof = self.typ.decide(verifier)?; + let check_flp_proof = self.typ.decide(&verifier.flp_verifier)?; if !check_flp_proof { return Ok(false); } // Check that joint randomness was properly derived from both // aggregators' parts match ( - joint_rand_seed_opt, - leader_joint_rand_part_opt, - helper_joint_rand_part_opt, + query_state, + verifier.leader_joint_rand_part_opt, + verifier.helper_joint_rand_part_opt, ) { (Some(joint_rand_seed), Some(leader_joint_rand_part), Some(helper_joint_rand_part)) => { let expected_joint_rand_seed = @@ -404,12 +623,45 @@ where } } +#[inline] +fn option_ct_eq(left: &Option, right: &Option) -> Choice +where + T: ConstantTimeEq + Sized, +{ + match (left, right) { + (Some(left), Some(right)) => left.ct_eq(right), + (None, None) => Choice::from(1), + _ => Choice::from(0), + } +} + +// This function determines equality between two optional, constant-time comparable tuples. It +// short-circuits on the existence (but not contents) of the values -- a timing side-channel may +// reveal whether the values match on Some or None. +#[inline] +fn option_tuple_ct_eq(left: &Option<(T, T)>, right: &Option<(T, T)>) -> Choice +where + T: ConstantTimeEq + Sized, +{ + match (left, right) { + (Some((left_0, left_1)), Some((right_0, right_1))) => { + left_0.ct_eq(right_0).bitand(left_1.ct_eq(right_1)) + } + (None, None) => Choice::from(1), + _ => Choice::from(0), + } +} + +#[cfg(test)] mod tests { use super::*; - use crate::field::Field128 as TestField; - use crate::field::{random_vector, FieldElementWithInteger}; - use crate::flp::types::{Count, Sum}; - use crate::flp::Type; + use crate::{ + field::Field128, + field::{random_vector, FieldElementWithInteger}, + flp::gadgets::{Mul, ParallelSum}, + flp::types::{Count, Sum, SumVec}, + flp::Type, + }; use rand::{thread_rng, Rng}; fn generic_szk_test(typ: T, encoded_measurement: &[T::Field], valid: bool) { @@ -454,21 +706,8 @@ mod tests { .query(&helper_input_share, h_proof_share, &verify_key, &nonce) .unwrap(); - let mut verifier = l_query_share.clone().verifier; - - for (x, y) in verifier.iter_mut().zip(h_query_share.clone().verifier) { - *x += y; - } - let h_jr_part = h_query_share.clone().joint_rand_part; - let h_jr_seed = h_query_state.joint_rand_seed; - let l_jr_part = l_query_share.joint_rand_part; - let l_jr_seed = l_query_state.joint_rand_seed; - if let Ok(leader_decision) = szk_typ.decide( - &verifier, - l_jr_part.clone(), - h_jr_part.clone(), - l_jr_seed.clone(), - ) { + let verifier = SzkQueryShare::merge_verifiers(l_query_share.clone(), h_query_share.clone()); + if let Ok(leader_decision) = szk_typ.decide(verifier.clone(), l_query_state.clone()) { assert_eq!( leader_decision, valid, "Leader incorrectly determined validity", @@ -476,12 +715,7 @@ mod tests { } else { panic!("Leader failed during decision"); }; - if let Ok(helper_decision) = szk_typ.decide( - &verifier, - l_jr_part.clone(), - h_jr_part.clone(), - h_jr_seed.clone(), - ) { + if let Ok(helper_decision) = szk_typ.decide(verifier.clone(), h_query_state.clone()) { assert_eq!( helper_decision, valid, "Helper incorrectly determined validity", @@ -493,33 +727,22 @@ mod tests { //test mutated jr seed if szk_typ.has_joint_rand() { let joint_rand_seed_opt = Some(Seed::<16>::generate().unwrap()); - if let Ok(leader_decision) = szk_typ.decide( - &verifier, - l_jr_part.clone(), - h_jr_part.clone(), - joint_rand_seed_opt, - ) { + if let Ok(leader_decision) = szk_typ.decide(verifier, joint_rand_seed_opt.clone()) { assert!(!leader_decision, "Leader accepted wrong jr seed"); }; }; - //test mutated verifier - let mut verifier = l_query_share.verifier; - - for (x, y) in verifier.iter_mut().zip(h_query_share.clone().verifier) { - *x += y + T::Field::from( + // test mutated verifier + let mut mutated_query_share = l_query_share.clone(); + for x in mutated_query_share.flp_verifier.iter_mut() { + *x += T::Field::from( ::Integer::try_from(7).unwrap(), ); } - let leader_decision = szk_typ - .decide( - &verifier, - l_jr_part.clone(), - h_jr_part.clone(), - l_jr_seed.clone(), - ) - .unwrap(); + let verifier = SzkQueryShare::merge_verifiers(mutated_query_share, h_query_share.clone()); + + let leader_decision = szk_typ.decide(verifier, l_query_state.clone()).unwrap(); assert!(!leader_decision, "Leader validated after proof mutation"); // test mutated input share @@ -529,31 +752,21 @@ mod tests { let (mutated_query_share, mutated_query_state) = szk_typ .query(&mutated_input, l_proof_share.clone(), &verify_key, &nonce) .unwrap(); - let mut verifier = mutated_query_share.verifier; - for (x, y) in verifier.iter_mut().zip(h_query_share.clone().verifier) { - *x += y; - } + let verifier = SzkQueryShare::merge_verifiers(mutated_query_share, h_query_share.clone()); - let mutated_jr_seed = mutated_query_state.joint_rand_seed; - let mutated_jr_part = mutated_query_share.joint_rand_part; - if let Ok(leader_decision) = szk_typ.decide( - &verifier, - mutated_jr_part.clone(), - h_jr_part.clone(), - mutated_jr_seed, - ) { + if let Ok(leader_decision) = szk_typ.decide(verifier, mutated_query_state) { assert!(!leader_decision, "Leader validated after input mutation"); }; // test mutated proof share - let (mut mutated_proof, leader_blind_and_helper_joint_rand_part) = match l_proof_share { + let (mut mutated_proof, leader_blind_and_helper_joint_rand_part_opt) = match l_proof_share { SzkProofShare::Leader { uncompressed_proof_share, - leader_blind_and_helper_joint_rand_part, + leader_blind_and_helper_joint_rand_part_opt, } => ( uncompressed_proof_share.clone(), - leader_blind_and_helper_joint_rand_part, + leader_blind_and_helper_joint_rand_part_opt, ), _ => (vec![], None), }; @@ -561,7 +774,7 @@ mod tests { T::Field::from(::Integer::try_from(23).unwrap()); let mutated_proof_share = SzkProofShare::Leader { uncompressed_proof_share: mutated_proof, - leader_blind_and_helper_joint_rand_part, + leader_blind_and_helper_joint_rand_part_opt, }; let (l_query_share, l_query_state) = szk_typ .query( @@ -571,38 +784,384 @@ mod tests { &nonce, ) .unwrap(); - let mut verifier = l_query_share.verifier; + let verifier = SzkQueryShare::merge_verifiers(l_query_share, h_query_share.clone()); - for (x, y) in verifier.iter_mut().zip(h_query_share.clone().verifier) { - *x += y; - } - - let mutated_jr_seed = l_query_state.joint_rand_seed; - let mutated_jr_part = l_query_share.joint_rand_part; - if let Ok(leader_decision) = szk_typ.decide( - &verifier, - mutated_jr_part.clone(), - h_jr_part.clone(), - mutated_jr_seed, - ) { + if let Ok(leader_decision) = szk_typ.decide(verifier, l_query_state) { assert!(!leader_decision, "Leader validated after proof mutation"); }; } + #[test] + fn test_sum_proof_share_encode() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sum = Sum::::new(5).unwrap(); + let encoded_measurement = sum.encode_measurement(&9).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sum, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + assert_eq!( + l_proof_share.encoded_len().unwrap(), + l_proof_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_sumvec_proof_share_encode() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sumvec, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + assert_eq!( + l_proof_share.encoded_len().unwrap(), + l_proof_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_count_proof_share_encode() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let count = Count::::new(); + let encoded_measurement = count.encode_measurement(&true).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(count, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + assert_eq!( + l_proof_share.encoded_len().unwrap(), + l_proof_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_sum_leader_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sum = Sum::::new(5).unwrap(); + let encoded_measurement = sum.encode_measurement(&9).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sum, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + true, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = l_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(l_proof_share, decoded_proof_share); + } + + #[test] + fn test_sum_helper_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sum = Sum::::new(5).unwrap(); + let encoded_measurement = sum.encode_measurement(&9).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sum, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [_, h_proof_share] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + false, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = h_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(h_proof_share, decoded_proof_share); + } + + #[test] + fn test_count_leader_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let count = Count::::new(); + let encoded_measurement = count.encode_measurement(&true).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(count, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = None; + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + true, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = l_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(l_proof_share, decoded_proof_share); + } + + #[test] + fn test_count_helper_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let count = Count::::new(); + let encoded_measurement = count.encode_measurement(&true).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(count, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = None; + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [_, h_proof_share] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + false, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = h_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(h_proof_share, decoded_proof_share); + } + + #[test] + fn test_sumvec_leader_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sumvec, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + true, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = l_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(l_proof_share, decoded_proof_share); + } + + #[test] + fn test_sumvec_helper_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sumvec, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [_, h_proof_share] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + false, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = h_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(h_proof_share, decoded_proof_share); + } + #[test] fn test_sum() { - let sum = Sum::::new(5).unwrap(); + let sum = Sum::::new(5).unwrap(); - let five = TestField::from(5); + let five = Field128::from(5); let nine = sum.encode_measurement(&9).unwrap(); let bad_encoding = &vec![five; sum.input_len()]; generic_szk_test(sum.clone(), &nine, true); generic_szk_test(sum, bad_encoding, false); } + #[test] + fn test_sumvec() { + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + + let five = Field128::from(5); + let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); + let bad_encoding = &vec![five; sumvec.input_len()]; + generic_szk_test(sumvec.clone(), &encoded_measurement, true); + generic_szk_test(sumvec, bad_encoding, false); + } + #[test] fn test_count() { - let count = Count::::new(); + let count = Count::::new(); let encoded_true = count.encode_measurement(&true).unwrap(); generic_szk_test(count, &encoded_true, true); } diff --git a/src/vdaf.rs b/src/vdaf.rs index e5f4e14c5..8dc8b2fe1 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -8,6 +8,8 @@ #[cfg(feature = "experimental")] use crate::dp::DifferentialPrivacyStrategy; #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +use crate::flp::szk::SzkError; +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] use crate::idpf::IdpfError; #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] use crate::vidpf::VidpfError; @@ -46,6 +48,11 @@ pub enum VdafError { #[error("flp error: {0}")] Flp(#[from] FlpError), + /// SZK error. + #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] + #[error("Szk error: {0}")] + Szk(#[from] SzkError), + /// PRNG error. #[error("prng error: {0}")] Prng(#[from] PrngError), @@ -740,6 +747,8 @@ mod tests { #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] pub mod dummy; #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +pub mod mastic; +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] #[cfg_attr( docsrs, doc(cfg(all(feature = "crypto-dependencies", feature = "experimental"))) diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs new file mode 100644 index 000000000..7b8d63424 --- /dev/null +++ b/src/vdaf/mastic.rs @@ -0,0 +1,539 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of Mastic as specified in [[draft-mouris-cfrg-mastic-01]]. +//! +//! [draft-mouris-cfrg-mastic-01]: https://www.ietf.org/archive/id/draft-mouris-cfrg-mastic-01.html + +use crate::{ + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{decode_fieldvec, FieldElement}, + flp::{ + szk::{Szk, SzkProofShare}, + Type, + }, + vdaf::{ + poplar1::Poplar1AggregationParam, + xof::{Seed, Xof}, + AggregateShare, Client, OutputShare, Vdaf, VdafError, + }, + vidpf::{ + Vidpf, VidpfError, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId, VidpfWeight, + }, +}; + +use std::fmt::Debug; +use std::io::{Cursor, Read}; +use std::ops::BitAnd; +use subtle::{Choice, ConstantTimeEq}; + +/// The main struct implementing the Mastic VDAF. +/// Composed of a shared zero knowledge proof system and a verifiable incremental +/// distributed point function. +#[derive(Clone, Debug)] +pub struct Mastic +where + T: Type, + P: Xof, +{ + algorithm_id: u32, + szk: Szk, + pub(crate) vidpf: Vidpf, 16>, + /// The length of the private attribute associated with any input. + pub(crate) bits: usize, +} + +impl Mastic +where + T: Type, + P: Xof, +{ + /// Creates a new instance of Mastic, with a specific attribute length and weight type. + pub fn new( + algorithm_id: u32, + szk: Szk, + vidpf: Vidpf, 16>, + bits: usize, + ) -> Self { + Self { + algorithm_id, + szk, + vidpf, + bits, + } + } +} + +/// Mastic aggregation parameter. +/// +/// This includes the VIDPF tree level under evaluation and a set of prefixes to evaluate at that level. +#[derive(Clone, Debug)] +pub struct MasticAggregationParam { + /// aggregation parameter inherited from [`Poplar1`]: contains the level (attribute length) and a vector of attribute prefixes (IdpfInputs) + level_and_prefixes: Poplar1AggregationParam, + /// Flag indicating whether the VIDPF weight needs to be validated using SZK. + /// This flag must be set the first time any report is aggregated; however this may happen at any level of the tree. + require_check_flag: bool, +} + +impl Encode for MasticAggregationParam { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + self.level_and_prefixes.encode(bytes)?; + let require_check = if self.require_check_flag { 1u8 } else { 0u8 }; + require_check.encode(bytes)?; + Ok(()) + } + + fn encoded_len(&self) -> Option { + Some(self.level_and_prefixes.encoded_len()? + 1usize) + } +} + +impl Decode for MasticAggregationParam { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let level_and_prefixes = Poplar1AggregationParam::decode(bytes)?; + let require_check = u8::decode(bytes)?; + let require_check_flag = require_check != 0; + Ok(Self { + level_and_prefixes, + require_check_flag, + }) + } +} + +/// Mastic public share. +/// +/// Contains broadcast information shared between parties to support VIDPF correctness. +pub type MasticPublicShare = VidpfPublicShare; + +impl ParameterizedDecode> + for MasticPublicShare> +where + T: Type, + P: Xof, +{ + fn decode_with_param( + mastic: &Mastic, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + MasticPublicShare::>::decode_with_param( + &(mastic.bits, mastic.vidpf.weight_parameter), + bytes, + ) + } +} + +/// Mastic input share +/// +/// Message sent by the [`Client`] to each Aggregator during the Sharding phase. +#[derive(Clone, Debug)] +pub struct MasticInputShare { + /// VIDPF key share. + vidpf_key: VidpfKey, + + /// The proof share. + proof_share: SzkProofShare, +} + +impl Encode for MasticInputShare { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + bytes.extend_from_slice(&self.vidpf_key.value[..]); + self.proof_share.encode(bytes)?; + Ok(()) + } + + fn encoded_len(&self) -> Option { + Some(16 + self.proof_share.encoded_len()?) + } +} + +impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Mastic, usize)> + for MasticInputShare +where + T: Type, + P: Xof, +{ + fn decode_with_param( + (mastic, agg_id): &(&'a Mastic, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + if *agg_id > 1 { + return Err(CodecError::UnexpectedValue); + } + let mut value = [0; 16]; + bytes.read_exact(&mut value)?; + let vidpf_key = VidpfKey::new( + if *agg_id == 0 { + VidpfServerId::S0 + } else { + VidpfServerId::S1 + }, + value, + ); + + let proof_share = SzkProofShare::::decode_with_param( + &( + *agg_id == 0, + mastic.szk.typ.proof_len(), + mastic.szk.typ.joint_rand_len() != 0, + ), + bytes, + )?; + Ok(Self { + vidpf_key, + proof_share, + }) + } +} + +#[cfg(test)] +impl PartialEq for MasticInputShare { + fn eq(&self, other: &MasticInputShare) -> bool { + self.ct_eq(other).into() + } +} + +impl ConstantTimeEq for MasticInputShare { + fn ct_eq(&self, other: &MasticInputShare) -> Choice { + self.vidpf_key + .ct_eq(&other.vidpf_key) + .bitand(self.proof_share.ct_eq(&other.proof_share)) + } +} + +/// Mastic output share. +/// +/// Contains a flattened vector of VIDPF outputs: one for each prefix. +pub type MasticOutputShare = OutputShare; + +/// Mastic aggregate share. +/// +/// Contains a flattened vector of VIDPF outputs to be aggregated by Mastic aggregators +pub type MasticAggregateShare = AggregateShare; + +impl<'a, T, P, const SEED_SIZE: usize> + ParameterizedDecode<(&'a Mastic, &'a MasticAggregationParam)> + for MasticAggregateShare +where + T: Type, + P: Xof, +{ + fn decode_with_param( + decoding_parameter: &(&Mastic, &MasticAggregationParam), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + let (mastic, agg_param) = decoding_parameter; + let l = mastic + .vidpf + .weight_parameter + .checked_mul(agg_param.level_and_prefixes.prefixes().len()) + .ok_or_else(|| CodecError::Other("multiplication overflow".into()))?; + let result = decode_fieldvec(l, bytes)?; + Ok(AggregateShare(result)) + } +} + +impl<'a, T, P, const SEED_SIZE: usize> + ParameterizedDecode<(&'a Mastic, &'a MasticAggregationParam)> + for MasticOutputShare +where + T: Type, + P: Xof, +{ + fn decode_with_param( + decoding_parameter: &(&Mastic, &MasticAggregationParam), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + let (mastic, agg_param) = decoding_parameter; + let l = mastic + .vidpf + .weight_parameter + .checked_mul(agg_param.level_and_prefixes.prefixes().len()) + .ok_or_else(|| CodecError::Other("multiplication overflow".into()))?; + let result = decode_fieldvec(l, bytes)?; + Ok(OutputShare(result)) + } +} + +impl Vdaf for Mastic +where + T: Type, + P: Xof, +{ + type Measurement = (VidpfInput, T::Measurement); + type AggregateResult = T::AggregateResult; + type AggregationParam = MasticAggregationParam; + type PublicShare = MasticPublicShare>; + type InputShare = MasticInputShare; + type OutputShare = MasticOutputShare; + type AggregateShare = MasticAggregateShare; + + fn algorithm_id(&self) -> u32 { + self.algorithm_id + } + + fn num_aggregators(&self) -> usize { + 2 + } +} + +impl Mastic +where + T: Type, + P: Xof, +{ + fn shard_with_random( + &self, + measurement_attribute: &VidpfInput, + measurement_weight: &VidpfWeight, + nonce: &[u8; 16], + vidpf_keys: [VidpfKey; 2], + szk_random: [Seed; 2], + joint_random_opt: Option>, + ) -> Result<(::PublicShare, Vec<::InputShare>), VdafError> { + // Compute the measurement shares for each aggregator by generating VIDPF + // keys for the measurement and evaluating each of them. + let public_share = self.vidpf.gen_with_keys( + &vidpf_keys, + measurement_attribute, + measurement_weight, + nonce, + )?; + + let leader_measurement_share = + self.vidpf.eval_root(&vidpf_keys[0], &public_share, nonce)?; + let helper_measurement_share = + self.vidpf.eval_root(&vidpf_keys[1], &public_share, nonce)?; + + let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove( + leader_measurement_share.as_ref(), + helper_measurement_share.as_ref(), + measurement_weight.as_ref(), + szk_random, + joint_random_opt, + nonce, + )?; + let [leader_vidpf_key, helper_vidpf_key] = vidpf_keys; + let leader_share = MasticInputShare:: { + vidpf_key: leader_vidpf_key, + proof_share: leader_szk_proof_share, + }; + let helper_share = MasticInputShare:: { + vidpf_key: helper_vidpf_key, + proof_share: helper_szk_proof_share, + }; + Ok((public_share, vec![leader_share, helper_share])) + } + + fn encode_measurement( + &self, + measurement: &T::Measurement, + ) -> Result, VdafError> { + Ok(VidpfWeight::::from( + self.szk.typ.encode_measurement(measurement)?, + )) + } +} + +impl Client<16> for Mastic +where + T: Type, + P: Xof, +{ + fn shard( + &self, + (attribute, weight): &(VidpfInput, T::Measurement), + nonce: &[u8; 16], + ) -> Result<(Self::PublicShare, Vec), VdafError> { + if attribute.len() != self.bits { + return Err(VdafError::Vidpf(VidpfError::InvalidAttributeLength)); + } + + let vidpf_keys = [ + VidpfKey::gen(VidpfServerId::S0)?, + VidpfKey::gen(VidpfServerId::S1)?, + ]; + let joint_random_opt = if self.szk.has_joint_rand() { + Some(Seed::::generate()?) + } else { + None + }; + let szk_random = [ + Seed::::generate()?, + Seed::::generate()?, + ]; + + let encoded_measurement = self.encode_measurement(weight)?; + if encoded_measurement.as_ref().len() != self.vidpf.weight_parameter { + return Err(VdafError::Uncategorized( + "encoded_measurement is the wrong length".to_string(), + )); + } + self.shard_with_random( + attribute, + &encoded_measurement, + nonce, + vidpf_keys, + szk_random, + joint_random_opt, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::Field128; + use crate::flp::gadgets::{Mul, ParallelSum}; + use crate::flp::types::{Count, Sum, SumVec}; + use rand::{thread_rng, Rng}; + + const TEST_NONCE_SIZE: usize = 16; + + #[test] + fn test_mastic_shard_sum() { + let algorithm_id = 6; + let sum_typ = Sum::::new(5).unwrap(); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); + let (_public, _input_shares) = mastic.shard(&(first_input, 24u128), &nonce).unwrap(); + } + + #[test] + fn test_input_share_encode_sum() { + let algorithm_id = 6; + let sum_typ = Sum::::new(5).unwrap(); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); + let (_, input_shares) = mastic.shard(&(first_input, 26u128), &nonce).unwrap(); + let [leader_input_share, helper_input_share] = [&input_shares[0], &input_shares[1]]; + + assert_eq!( + leader_input_share.encoded_len().unwrap(), + leader_input_share.get_encoded().unwrap().len() + ); + assert_eq!( + helper_input_share.encoded_len().unwrap(), + helper_input_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_mastic_shard_count() { + let algorithm_id = 6; + let count = Count::::new(); + let szk = Szk::new_turboshake128(count, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(1); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (_public, _input_shares) = mastic.shard(&(first_input, true), &nonce).unwrap(); + } + + #[test] + fn test_mastic_shard_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let measurement = vec![1, 16, 0]; + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (_public, _input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + } + + #[test] + fn test_input_share_encode_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let measurement = vec![1, 16, 0]; + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (_public, input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let leader_input_share = &input_shares[0]; + let helper_input_share = &input_shares[1]; + + assert_eq!( + leader_input_share.encoded_len().unwrap(), + leader_input_share.get_encoded().unwrap().len() + ); + assert_eq!( + helper_input_share.encoded_len().unwrap(), + helper_input_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_input_share_roundtrip_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let measurement = vec![1, 16, 0]; + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (_public, input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let leader_input_share = &input_shares[0]; + let helper_input_share = &input_shares[1]; + + let encoded_input_share = leader_input_share.get_encoded().unwrap(); + let decoded_leader_input_share = + MasticInputShare::get_decoded_with_param(&(&mastic, 0), &encoded_input_share[..]) + .unwrap(); + assert_eq!(leader_input_share, &decoded_leader_input_share); + let encoded_input_share = helper_input_share.get_encoded().unwrap(); + let decoded_helper_input_share = + MasticInputShare::get_decoded_with_param(&(&mastic, 1), &encoded_input_share[..]) + .unwrap(); + assert_eq!(helper_input_share, &decoded_helper_input_share); + } +} diff --git a/src/vidpf.rs b/src/vidpf.rs index c8ba5db22..3ec8d1347 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -16,8 +16,9 @@ use core::{ use bitvec::field::BitField; use rand_core::RngCore; +use std::fmt::Debug; use std::io::Cursor; -use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable}; +use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq}; use crate::{ codec::{CodecError, Encode, ParameterizedDecode}, @@ -45,6 +46,11 @@ pub enum VidpfError { #[error("level index out of bounds")] IndexLevel, + /// Error when input attribute has too few or many bits to be a path in an initialized + /// VIDPF tree. + #[error("invalid attribute length")] + InvalidAttributeLength, + /// Error when weight's length mismatches the length in weight's parameter. #[error("invalid weight length")] InvalidWeightLength, @@ -58,12 +64,13 @@ pub enum VidpfError { pub type VidpfInput = IdpfInput; /// Represents the codomain of an incremental point function. -pub trait VidpfValue: IdpfValue + Clone {} +pub trait VidpfValue: IdpfValue + Clone + Debug + PartialEq + ConstantTimeEq {} +#[derive(Clone, Debug)] /// A VIDPF instance. pub struct Vidpf { /// Any parameters required to instantiate a weight value. - weight_parameter: W::ValueParameter, + pub(crate) weight_parameter: W::ValueParameter, } impl Vidpf { @@ -108,7 +115,7 @@ impl Vidpf { /// [`Vidpf::gen_with_keys`] works as the [`Vidpf::gen`] method, except that two different /// keys must be provided. - fn gen_with_keys( + pub(crate) fn gen_with_keys( &self, keys: &[VidpfKey; 2], input: &VidpfInput, @@ -206,6 +213,9 @@ impl Vidpf { let mut share = W::zero(&self.weight_parameter); let n = input.len(); + if n > public.cw.len() { + return Err(VidpfError::InvalidAttributeLength); + } for level in 0..n { (state, share) = self.eval_next(key.id, public, input, level, &state, nonce)?; } @@ -266,6 +276,20 @@ impl Vidpf { Ok((next_state, y)) } + pub(crate) fn eval_root( + &self, + key: &VidpfKey, + public_share: &VidpfPublicShare, + nonce: &[u8; NONCE_SIZE], + ) -> Result { + Ok(self + .eval(key, public_share, &VidpfInput::from_bools(&[false]), nonce)? + .share + + self + .eval(key, public_share, &VidpfInput::from_bools(&[true]), nonce)? + .share) + } + fn prg(seed: &VidpfSeed, nonce: &[u8]) -> VidpfPrgOutput { let mut rng = XofFixedKeyAes128::seed_stream(&Seed(*seed), VidpfDomainSepTag::PRG, nonce); @@ -339,6 +363,8 @@ impl Vidpf { } } +/// Vidpf domain separation tag +/// /// Contains the domain separation tags for invoking different oracles. struct VidpfDomainSepTag; impl VidpfDomainSepTag { @@ -348,10 +374,13 @@ impl VidpfDomainSepTag { const NODE_PROOF_ADJUST: &'static [u8] = b"NodeProofAdjust"; } +#[derive(Clone, Debug)] +/// Vidpf key +/// /// Private key of an aggregation server. pub struct VidpfKey { id: VidpfServerId, - value: [u8; 16], + pub(crate) value: [u8; 16], } impl VidpfKey { @@ -364,10 +393,32 @@ impl VidpfKey { getrandom::getrandom(&mut value)?; Ok(Self { id, value }) } + + pub(crate) fn new(id: VidpfServerId, value: [u8; 16]) -> Self { + Self { id, value } + } +} + +impl ConstantTimeEq for VidpfKey { + fn ct_eq(&self, other: &VidpfKey) -> Choice { + if self.id != other.id { + Choice::from(0) + } else { + self.value.ct_eq(&other.value) + } + } } +impl PartialEq for VidpfKey { + fn eq(&self, other: &VidpfKey) -> bool { + bool::from(self.ct_eq(other)) + } +} + +/// Vidpf server ID +/// /// Identifies the two aggregation servers. -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub(crate) enum VidpfServerId { /// S0 is the first server. S0, @@ -384,8 +435,10 @@ impl From for Choice { } } -/// Adjusts values of shares during the VIDPF evaluation. -#[derive(Debug)] +/// Vidpf correction word +/// +/// Adjusts values of shares during the VIDPF evaluation. +#[derive(Clone, Debug)] struct VidpfCorrectionWord { seed: VidpfSeed, left_control_bit: Choice, @@ -393,13 +446,73 @@ struct VidpfCorrectionWord { weight: W, } +impl ConstantTimeEq for VidpfCorrectionWord { + fn ct_eq(&self, other: &Self) -> Choice { + self.seed.ct_eq(&other.seed) + & self.left_control_bit.ct_eq(&other.left_control_bit) + & self.right_control_bit.ct_eq(&other.right_control_bit) + & self.weight.ct_eq(&other.weight) + } +} + +impl PartialEq for VidpfCorrectionWord +where + W: ConstantTimeEq, +{ + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Encode for VidpfCorrectionWord { + fn encode(&self, _bytes: &mut Vec) -> Result<(), CodecError> { + todo!(); + } + + fn encoded_len(&self) -> Option { + todo!(); + } +} + +impl ParameterizedDecode for VidpfCorrectionWord { + fn decode_with_param( + _decoding_parameter: &W::ValueParameter, + _bytes: &mut Cursor<&[u8]>, + ) -> Result { + todo!(); + } +} + +/// Vidpf public share +/// /// Common public information used by aggregation servers. -#[derive(Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct VidpfPublicShare { cw: Vec>, cs: Vec, } +impl Encode for VidpfPublicShare { + fn encode(&self, _bytes: &mut Vec) -> Result<(), CodecError> { + todo!() + } + + fn encoded_len(&self) -> Option { + todo!() + } +} + +impl ParameterizedDecode<(usize, W::ValueParameter)> for VidpfPublicShare { + fn decode_with_param( + (_bits, _weight_parameter): &(usize, W::ValueParameter), + _bytes: &mut Cursor<&[u8]>, + ) -> Result { + todo!() + } +} + +/// Vidpf evaluation state +/// /// Contains the values produced during input evaluation at a given level. pub struct VidpfEvalState { seed: VidpfSeed, @@ -454,7 +567,7 @@ struct VidpfPrgOutput { /// Represents an array of field elements that implements the [`VidpfValue`] trait. #[derive(Debug, PartialEq, Eq, Clone)] -pub struct VidpfWeight(Vec); +pub struct VidpfWeight(pub(crate) Vec); impl From> for VidpfWeight { fn from(value: Vec) -> Self { @@ -462,6 +575,12 @@ impl From> for VidpfWeight { } } +impl AsRef<[F]> for VidpfWeight { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + impl VidpfValue for VidpfWeight {} impl IdpfValue for VidpfWeight { @@ -549,6 +668,12 @@ impl Sub for VidpfWeight { } } +impl ConstantTimeEq for VidpfWeight { + fn ct_eq(&self, other: &Self) -> Choice { + self.0[..].ct_eq(&other.0[..]) + } +} + impl Encode for VidpfWeight { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { for e in &self.0 {