diff --git a/src/ff/galois_field.rs b/src/ff/galois_field.rs index f8a99c3f5..0c3814d27 100644 --- a/src/ff/galois_field.rs +++ b/src/ff/galois_field.rs @@ -5,7 +5,7 @@ use std::{ use bitvec::prelude::{bitarr, BitArr, Lsb0}; use generic_array::GenericArray; -use typenum::{Unsigned, U1, U3, U4, U5}; +use typenum::{Unsigned, U1, U2, U3, U4, U5}; use crate::{ ff::{Field, Serializable}, @@ -25,6 +25,7 @@ pub trait GaloisField: // Bit store type definitions type U8_1 = BitArr!(for 8, in u8, Lsb0); +type U8_2 = BitArr!(for 9, in u8, Lsb0); type U8_3 = BitArr!(for 24, in u8, Lsb0); type U8_4 = BitArr!(for 32, in u8, Lsb0); type U8_5 = BitArr!(for 40, in u8, Lsb0); @@ -33,6 +34,10 @@ impl Block for U8_1 { type Size = U1; } +impl Block for U8_2 { + type Size = U2; +} + impl Block for U8_3 { type Size = U3; } @@ -575,6 +580,16 @@ bit_array_impl!( 0b1_0001_1011_u128 ); +bit_array_impl!( + bit_array_9, + Gf9Bit, + U8_2, + 9, + bitarr!(const u8, Lsb0; 1, 0, 0, 0, 0, 0, 0, 0, 0), + // x^9 + x^4 + x^3 + x + 1 + 0b10_0001_1011_u128 +); + bit_array_impl!( bit_array_5, Gf5Bit, diff --git a/src/ff/mod.rs b/src/ff/mod.rs index e324a4a72..5f18196c7 100644 --- a/src/ff/mod.rs +++ b/src/ff/mod.rs @@ -9,7 +9,9 @@ mod prime_field; use std::ops::{Add, AddAssign, Sub, SubAssign}; pub use field::{Field, FieldType}; -pub use galois_field::{GaloisField, Gf2, Gf20Bit, Gf32Bit, Gf3Bit, Gf40Bit, Gf5Bit, Gf8Bit}; +pub use galois_field::{ + GaloisField, Gf2, Gf20Bit, Gf32Bit, Gf3Bit, Gf40Bit, Gf5Bit, Gf8Bit, Gf9Bit, +}; use generic_array::{ArrayLength, GenericArray}; #[cfg(any(test, feature = "weak-field"))] pub use prime_field::Fp31; diff --git a/src/protocol/prf_sharding/bucket.rs b/src/protocol/prf_sharding/bucket.rs index 6349f91a6..d2ad77a11 100644 --- a/src/protocol/prf_sharding/bucket.rs +++ b/src/protocol/prf_sharding/bucket.rs @@ -1,17 +1,42 @@ use embed_doc_image::embed_doc_image; +use ipa_macros::Step; use crate::{ error::Error, ff::{GaloisField, PrimeField, Serializable}, protocol::{ - basics::SecureMul, context::UpgradedContext, prf_sharding::BinaryTreeDepthStep, - step::BitOpStep, RecordId, + basics::SecureMul, context::UpgradedContext, prf_sharding::BinaryTreeDepthStep, RecordId, }, secret_sharing::{ replicated::malicious::ExtendableField, BitDecomposed, Linear as LinearSecretSharing, }, }; +#[derive(Step)] +pub enum BucketStep { + #[dynamic(256)] + Bit(usize), +} + +impl TryFrom for BucketStep { + type Error = String; + + fn try_from(v: u32) -> Result { + let val = usize::try_from(v); + let val = match val { + Ok(val) => Self::Bit(val), + Err(error) => panic!("{error:?}"), + }; + Ok(val) + } +} + +impl From for BucketStep { + fn from(v: usize) -> Self { + Self::Bit(v) + } +} + #[embed_doc_image("tree-aggregation", "images/tree_aggregation.png")] /// This function moves a single value to a correct bucket using tree aggregation approach /// @@ -53,8 +78,8 @@ where BK::BITS ); assert!( - breakdown_count <= 128, - "Our step implementation (BitOpStep) cannot go past 64" + breakdown_count <= 512, + "Our step implementation (BucketStep) cannot go past 256" ); let mut row_contribution = vec![value; breakdown_count]; @@ -69,7 +94,7 @@ where let mut futures = Vec::with_capacity(breakdown_count / step); for (i, tree_index) in (0..breakdown_count).step_by(step).enumerate() { - let bit_c = depth_c.narrow(&BitOpStep::from(i)); + let bit_c = depth_c.narrow(&BucketStep::from(i)); if robust || tree_index + span < breakdown_count { futures.push(row_contribution[tree_index].multiply(bit_of_bdkey, bit_c, record_id)); @@ -96,7 +121,7 @@ pub mod tests { use rand::thread_rng; use crate::{ - ff::{Field, Fp32BitPrime, Gf5Bit, Gf8Bit}, + ff::{Field, Fp32BitPrime, Gf8Bit, Gf9Bit}, protocol::{ context::{Context, UpgradableContext, Validator}, prf_sharding::bucket::move_single_value_to_bucket, @@ -108,12 +133,12 @@ pub mod tests { test_fixture::{get_bits, Reconstruct, Runner, TestWorld}, }; - const MAX_BREAKDOWN_COUNT: usize = 1 << Gf5Bit::BITS; + const MAX_BREAKDOWN_COUNT: usize = 256; const VALUE: u32 = 10; async fn move_to_bucket(count: usize, breakdown_key: usize, robust: bool) -> Vec { let breakdown_key_bits = - get_bits::(breakdown_key.try_into().unwrap(), Gf5Bit::BITS); + get_bits::(breakdown_key.try_into().unwrap(), Gf8Bit::BITS); let value = Fp32BitPrime::truncate_from(VALUE); TestWorld::default() @@ -122,7 +147,7 @@ pub mod tests { |ctx, (breakdown_key_share, value_share)| async move { let validator = ctx.validator(); let ctx = validator.context(); - move_single_value_to_bucket::( + move_single_value_to_bucket::( ctx.set_total_records(1), RecordId::from(0), breakdown_key_share, @@ -207,7 +232,7 @@ pub mod tests { #[should_panic] fn move_out_of_range_too_many_buckets_steps() { run(move || async move { - let breakdown_key_bits = get_bits::(0, Gf8Bit::BITS); + let breakdown_key_bits = get_bits::(0, Gf9Bit::BITS); let value = Fp32BitPrime::truncate_from(VALUE); _ = TestWorld::default() @@ -216,12 +241,12 @@ pub mod tests { |ctx, (breakdown_key_share, value_share)| async move { let validator = ctx.validator(); let ctx = validator.context(); - move_single_value_to_bucket::( + move_single_value_to_bucket::( ctx.set_total_records(1), RecordId::from(0), breakdown_key_share, value_share, - 129, + 513, false, ) .await diff --git a/src/protocol/prf_sharding/mod.rs b/src/protocol/prf_sharding/mod.rs index 74505c09a..a0e7909fa 100644 --- a/src/protocol/prf_sharding/mod.rs +++ b/src/protocol/prf_sharding/mod.rs @@ -759,6 +759,13 @@ where F: PrimeField + ExtendableField, { let num_records = user_level_attributions.len(); + + // in case no attributable conversion is found, return 0. + // as anyways the helpers know that no attributions resulted. + if num_records == 0 { + return Ok(vec![S::ZERO; 1 << BK::BITS]); + } + let (bk_vec, tv_vec): (Vec<_>, Vec<_>) = user_level_attributions .into_iter() .map(|row| { @@ -1244,4 +1251,27 @@ pub mod tests { assert_eq!(result, &expected); }); } + + #[test] + fn semi_honest_aggregation_empty_input() { + run(|| async move { + let world = TestWorld::default(); + + let records: Vec = vec![]; + + let expected = [0_u128; 32]; + + let result: Vec<_> = world + .semi_honest(records.into_iter(), |ctx, input_rows| async move { + let validator = ctx.validator(); + let ctx = validator.context(); + do_aggregation::<_, Gf5Bit, Gf3Bit, Fp32BitPrime, _>(ctx, input_rows) + .await + .unwrap() + }) + .await + .reconstruct(); + assert_eq!(result, &expected); + }); + } }