-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make active_work match records_per_batch #1316
Changes from 5 commits
b628bf2
a4c6f03
1a6e388
706dcbe
c5c0ccf
966160f
d2512f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -825,35 +825,158 @@ mod tests { | |
}; | ||
|
||
use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; | ||
use futures::{StreamExt, TryStreamExt}; | ||
use futures::{stream, StreamExt, TryStreamExt}; | ||
use futures_util::stream::iter; | ||
use proptest::{prop_compose, proptest, sample::select}; | ||
use rand::{thread_rng, Rng}; | ||
use proptest::{ | ||
prelude::{Just, Strategy}, | ||
prop_compose, prop_oneof, proptest, | ||
test_runner::Config as ProptestConfig, | ||
}; | ||
use rand::{distributions::Standard, prelude::Distribution}; | ||
|
||
use crate::{ | ||
error::Error, | ||
ff::{boolean::Boolean, Fp61BitPrime}, | ||
ff::{ | ||
boolean::Boolean, | ||
boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8}, | ||
Fp61BitPrime, | ||
}, | ||
protocol::{ | ||
basics::SecureMul, | ||
basics::{select, BooleanArrayMul, SecureMul}, | ||
context::{ | ||
dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, | ||
dzkp_validator::{ | ||
Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, | ||
}, | ||
Context, UpgradableContext, TEST_DZKP_STEPS, | ||
Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, | ||
UpgradableContext, TEST_DZKP_STEPS, | ||
}, | ||
Gate, RecordId, | ||
}, | ||
rand::{thread_rng, Rng}, | ||
secret_sharing::{ | ||
replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, | ||
Vectorizable, | ||
}, | ||
seq_join::{seq_join, SeqJoin}, | ||
seq_join::seq_join, | ||
sharding::NotSharded, | ||
test_fixture::{join3v, Reconstruct, Runner, TestWorld}, | ||
}; | ||
|
||
async fn test_select_semi_honest<V>() | ||
where | ||
V: BooleanArray, | ||
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedSemiHonestContext<'a, NotSharded>>, | ||
Standard: Distribution<V>, | ||
{ | ||
let world = TestWorld::default(); | ||
let context = world.contexts(); | ||
let mut rng = thread_rng(); | ||
|
||
let bit = rng.gen::<Boolean>(); | ||
let a = rng.gen::<V>(); | ||
let b = rng.gen::<V>(); | ||
|
||
let bit_shares = bit.share_with(&mut rng); | ||
let a_shares = a.share_with(&mut rng); | ||
let b_shares = b.share_with(&mut rng); | ||
|
||
let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( | ||
|(ctx, (bit_share, (a_share, b_share)))| async move { | ||
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); | ||
let sh_ctx = v.context(); | ||
|
||
let result = select( | ||
sh_ctx.set_total_records(1), | ||
RecordId::from(0), | ||
&bit_share, | ||
&a_share, | ||
&b_share, | ||
) | ||
.await?; | ||
|
||
v.validate().await?; | ||
|
||
Ok::<_, Error>(result) | ||
}, | ||
); | ||
|
||
let [ab0, ab1, ab2] = join3v(futures).await; | ||
|
||
let ab = [ab0, ab1, ab2].reconstruct(); | ||
|
||
assert_eq!(ab, if bit.into() { a } else { b }); | ||
} | ||
|
||
#[tokio::test] | ||
async fn dzkp_malicious() { | ||
async fn select_semi_honest() { | ||
test_select_semi_honest::<BA3>().await; | ||
test_select_semi_honest::<BA8>().await; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it worth testing it for for weird types like BA3 and BA7 as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think so. I had BA4 but took it out because it's not |
||
test_select_semi_honest::<BA16>().await; | ||
test_select_semi_honest::<BA20>().await; | ||
test_select_semi_honest::<BA32>().await; | ||
test_select_semi_honest::<BA64>().await; | ||
test_select_semi_honest::<BA256>().await; | ||
} | ||
|
||
async fn test_select_malicious<V>() | ||
where | ||
V: BooleanArray, | ||
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>, | ||
Standard: Distribution<V>, | ||
{ | ||
let world = TestWorld::default(); | ||
let context = world.malicious_contexts(); | ||
let mut rng = thread_rng(); | ||
|
||
let bit = rng.gen::<Boolean>(); | ||
let a = rng.gen::<V>(); | ||
let b = rng.gen::<V>(); | ||
|
||
let bit_shares = bit.share_with(&mut rng); | ||
let a_shares = a.share_with(&mut rng); | ||
let b_shares = b.share_with(&mut rng); | ||
|
||
let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( | ||
|(ctx, (bit_share, (a_share, b_share)))| async move { | ||
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); | ||
let m_ctx = v.context(); | ||
|
||
let result = select( | ||
m_ctx.set_total_records(1), | ||
RecordId::from(0), | ||
&bit_share, | ||
&a_share, | ||
&b_share, | ||
) | ||
.await?; | ||
|
||
v.validate().await?; | ||
|
||
Ok::<_, Error>(result) | ||
}, | ||
); | ||
|
||
let [ab0, ab1, ab2] = join3v(futures).await; | ||
|
||
let ab = [ab0, ab1, ab2].reconstruct(); | ||
|
||
assert_eq!(ab, if bit.into() { a } else { b }); | ||
} | ||
|
||
#[tokio::test] | ||
async fn select_malicious() { | ||
test_select_malicious::<BA3>().await; | ||
test_select_malicious::<BA8>().await; | ||
test_select_malicious::<BA16>().await; | ||
test_select_malicious::<BA20>().await; | ||
test_select_malicious::<BA32>().await; | ||
test_select_malicious::<BA64>().await; | ||
test_select_malicious::<BA256>().await; | ||
} | ||
|
||
#[tokio::test] | ||
async fn two_multiplies_malicious() { | ||
const COUNT: usize = 32; | ||
let mut rng = thread_rng(); | ||
|
||
|
@@ -913,9 +1036,54 @@ mod tests { | |
} | ||
} | ||
|
||
/// Similar to `test_select_malicious`, but operating on vectors | ||
async fn multi_select_malicious<V>(count: usize, max_multiplications_per_gate: usize) | ||
where | ||
V: BooleanArray, | ||
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>, | ||
Standard: Distribution<V>, | ||
{ | ||
let mut rng = thread_rng(); | ||
|
||
let bit: Vec<Boolean> = repeat_with(|| rng.gen::<Boolean>()).take(count).collect(); | ||
let a: Vec<V> = repeat_with(|| rng.gen()).take(count).collect(); | ||
let b: Vec<V> = repeat_with(|| rng.gen()).take(count).collect(); | ||
|
||
let [ab0, ab1, ab2]: [Vec<Replicated<V>>; 3] = TestWorld::default() | ||
.malicious( | ||
zip(bit.clone(), zip(a.clone(), b.clone())), | ||
|ctx, inputs| async move { | ||
let v = ctx | ||
.set_total_records(count) | ||
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); | ||
let m_ctx = v.context(); | ||
|
||
v.validated_seq_join(stream::iter(inputs).enumerate().map( | ||
|(i, (bit_share, (a_share, b_share)))| { | ||
let m_ctx = m_ctx.clone(); | ||
async move { | ||
select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) | ||
.await | ||
} | ||
}, | ||
)) | ||
.try_collect() | ||
.await | ||
}, | ||
) | ||
.await | ||
.map(Result::unwrap); | ||
|
||
let ab: Vec<V> = [ab0, ab1, ab2].reconstruct(); | ||
|
||
for i in 0..count { | ||
assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] }); | ||
} | ||
} | ||
|
||
/// test for testing `validated_seq_join` | ||
/// similar to `complex_circuit` in `validator.rs` | ||
async fn complex_circuit_dzkp( | ||
/// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment) | ||
async fn chained_multiplies_dzkp( | ||
count: usize, | ||
max_multiplications_per_gate: usize, | ||
) -> Result<(), Error> { | ||
|
@@ -945,7 +1113,7 @@ mod tests { | |
.map(|(ctx, input_shares)| async move { | ||
let v = ctx | ||
.set_total_records(count - 1) | ||
.dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get()); | ||
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); | ||
let m_ctx = v.context(); | ||
|
||
let m_results = v | ||
|
@@ -1021,19 +1189,63 @@ mod tests { | |
Ok(()) | ||
} | ||
|
||
fn record_count_strategy() -> impl Strategy<Value = usize> { | ||
// The chained_multiplies test has count - 1 records, so 1 is not a valid input size. | ||
// It is for multi_select though. | ||
prop_oneof![2usize..=512, (1u32..=9).prop_map(|i| 1usize << i)] | ||
} | ||
|
||
fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy<Value = usize> { | ||
let max_max_mults = record_count.min(128); | ||
prop_oneof![ | ||
1usize..=max_max_mults, | ||
(0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i) | ||
] | ||
} | ||
|
||
prop_compose! { | ||
fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { | ||
(1usize<<log_count, 1usize<<log_multiplication_amount) | ||
fn batching() | ||
(record_count in record_count_strategy()) | ||
(record_count in Just(record_count), max_mults in max_multiplications_per_gate_strategy(record_count)) | ||
-> (usize, usize) | ||
{ | ||
(record_count, max_mults) | ||
} | ||
} | ||
|
||
proptest! { | ||
#![proptest_config(ProptestConfig::with_cases(50))] | ||
#[test] | ||
fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){ | ||
let future = async { | ||
let _ = complex_circuit_dzkp(count, multiplication_amount).await; | ||
}; | ||
tokio::runtime::Runtime::new().unwrap().block_on(future); | ||
fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) { | ||
println!("record_count {record_count} batch {max_multiplications_per_gate}"); | ||
if record_count / max_multiplications_per_gate >= 192 { | ||
// TODO: #1269, or even if we don't fix that, don't hardcode the limit. | ||
println!("skipping config because batch count exceeds limit of 192"); | ||
} | ||
// This condition is correct only for active_work = 16 and record size of 1 byte. | ||
else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 { | ||
// TODO: #1300, read_size | batch_size. | ||
// Note: for active work < 2048, read size matches active work. | ||
|
||
// Besides read_size | batch_size, there is also a constraint | ||
// something like active_work > read_size + batch_size - 1. | ||
println!("skipping config due to read_size vs. batch_size constraints"); | ||
} else { | ||
tokio::runtime::Runtime::new().unwrap().block_on(async { | ||
chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); | ||
/* | ||
multi_select_malicious::<BA3>(record_count, max_multiplications_per_gate).await; | ||
multi_select_malicious::<BA8>(record_count, max_multiplications_per_gate).await; | ||
multi_select_malicious::<BA16>(record_count, max_multiplications_per_gate).await; | ||
*/ | ||
multi_select_malicious::<BA20>(record_count, max_multiplications_per_gate).await; | ||
/* | ||
multi_select_malicious::<BA32>(record_count, max_multiplications_per_gate).await; | ||
multi_select_malicious::<BA64>(record_count, max_multiplications_per_gate).await; | ||
multi_select_malicious::<BA256>(record_count, max_multiplications_per_gate).await; | ||
*/ | ||
}); | ||
} | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,6 +178,16 @@ | |
|
||
impl<'a, F: ExtendableField> Upgraded<'a, F> { | ||
pub(super) fn new(batch: &Arc<MacBatcher<'a, F>>, ctx: Context<'a>) -> Self { | ||
// The DZKP malicious context adjusts active_work to match records_per_batch. | ||
// The MAC validator currently configures the batcher with records_per_batch = | ||
// active_work. If the latter behavior changes, this code may need to be | ||
// updated. | ||
let records_per_batch = batch.lock().unwrap().records_per_batch(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be better to assert this to make sure we don't miss the misalignment between MAC and ZKP? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I'm fine with that. |
||
let active_work = ctx.active_work().get(); | ||
assert_eq!( | ||
records_per_batch, active_work, | ||
"Expect MAC validation batch size ({records_per_batch}) to match active work ({active_work})", | ||
); | ||
Self { | ||
batch: Arc::downgrade(batch), | ||
base_ctx: ctx, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -217,7 +217,7 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> { | |
|
||
// TODO: Right now we set the batch work to be equal to active_work, | ||
// but it does not need to be. We can make this configurable if needed. | ||
let records_per_batch = ctx.active_work().get().min(total_records.get()); | ||
let records_per_batch = ctx.active_work().get(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reducing if larger than |
||
|
||
Self { | ||
protocol_ctx: ctx.narrow(&Step::MaliciousProtocol), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized that we probably want to use the same seed for rng and test world (it is supported via
TestWorldConfig
struct). That way we can make it reproducible if it ever failsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made an issue for this. It's definitely worth doing in general, but it doesn't seem all that important for this particular test case, where the input values should be unrelated to the behavior of the test.