Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make active_work match records_per_batch #1316

Merged
merged 7 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/context/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ impl<'a, B> Batcher<'a, B> {
self.total_records = self.total_records.overwrite(total_records.into());
}

pub fn records_per_batch(&self) -> usize {
self.records_per_batch
}

fn batch_offset(&self, record_id: RecordId) -> usize {
let batch_index = usize::from(record_id) / self.records_per_batch;
batch_index
Expand Down
24 changes: 23 additions & 1 deletion ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,38 @@
pub struct DZKPUpgraded<'a> {
validator_inner: Weak<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
active_work: NonZeroUsize,
}

impl<'a> DZKPUpgraded<'a> {
pub(super) fn new(
validator_inner: &Arc<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
) -> Self {
let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch();
let active_work = if records_per_batch == 1 {
// If records_per_batch is 1, let active_work be anything. This only happens
// in tests; there shouldn't be a risk of deadlocks with one record per
// batch; and UnorderedReceiver capacity (which is set from active_work)
// must be at least two.
base_ctx.active_work()
} else {
// Adjust active_work to match records_per_batch. If it is less, we will
// certainly stall, since every record in the batch remains incomplete until
// the batch is validated. It is possible that it can be larger, but making
// it the same seems safer for now.
let active_work = NonZeroUsize::new(records_per_batch).unwrap();
tracing::debug!(
"Changed active_work from {} to {} to match batch size",
base_ctx.active_work().get(),

Check warning on line 55 in ipa-core/src/protocol/context/dzkp_malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/context/dzkp_malicious.rs#L54-L55

Added lines #L54 - L55 were not covered by tests
active_work,
);
active_work
};
Self {
validator_inner: Arc::downgrade(validator_inner),
base_ctx,
active_work,
}
}

Expand Down Expand Up @@ -130,7 +152,7 @@

impl<'a> SeqJoin for DZKPUpgraded<'a> {
fn active_work(&self) -> NonZeroUsize {
self.base_ctx.active_work()
self.active_work
}
}

Expand Down
248 changes: 230 additions & 18 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -825,35 +825,158 @@ mod tests {
};

use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec};
use futures::{StreamExt, TryStreamExt};
use futures::{stream, StreamExt, TryStreamExt};
use futures_util::stream::iter;
use proptest::{prop_compose, proptest, sample::select};
use rand::{thread_rng, Rng};
use proptest::{
prelude::{Just, Strategy},
prop_compose, prop_oneof, proptest,
test_runner::Config as ProptestConfig,
};
use rand::{distributions::Standard, prelude::Distribution};

use crate::{
error::Error,
ff::{boolean::Boolean, Fp61BitPrime},
ff::{
boolean::Boolean,
boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8},
Fp61BitPrime,
},
protocol::{
basics::SecureMul,
basics::{select, BooleanArrayMul, SecureMul},
context::{
dzkp_field::{DZKPCompatibleField, BLOCK_SIZE},
dzkp_validator::{
Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE,
},
Context, UpgradableContext, TEST_DZKP_STEPS,
Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext,
UpgradableContext, TEST_DZKP_STEPS,
},
Gate, RecordId,
},
rand::{thread_rng, Rng},
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue,
Vectorizable,
},
seq_join::{seq_join, SeqJoin},
seq_join::seq_join,
sharding::NotSharded,
test_fixture::{join3v, Reconstruct, Runner, TestWorld},
};

async fn test_select_semi_honest<V>()
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedSemiHonestContext<'a, NotSharded>>,
Standard: Distribution<V>,
{
let world = TestWorld::default();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that we probably want to use the same seed for rng and test world (it is supported via TestWorldConfig struct). That way we can make it reproducible if it ever fails

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an issue for this. It's definitely worth doing in general, but it doesn't seem all that important for this particular test case, where the input values should be unrelated to the behavior of the test.

let context = world.contexts();
let mut rng = thread_rng();

let bit = rng.gen::<Boolean>();
let a = rng.gen::<V>();
let b = rng.gen::<V>();

let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);

let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map(
|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1);
let sh_ctx = v.context();

let result = select(
sh_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;

v.validate().await?;

Ok::<_, Error>(result)
},
);

let [ab0, ab1, ab2] = join3v(futures).await;

let ab = [ab0, ab1, ab2].reconstruct();

assert_eq!(ab, if bit.into() { a } else { b });
}

#[tokio::test]
async fn dzkp_malicious() {
async fn select_semi_honest() {
test_select_semi_honest::<BA3>().await;
test_select_semi_honest::<BA8>().await;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth testing it for for weird types like BA3 and BA7 as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. I had BA4 but took it out because it's not boolean_vector! (not for any good reason other than we haven't needed it). But I think it is worth having a less-than-one-byte case, and maybe even adding a new BA type so we can cover the between-one-and-two-bytes case.

test_select_semi_honest::<BA16>().await;
test_select_semi_honest::<BA20>().await;
test_select_semi_honest::<BA32>().await;
test_select_semi_honest::<BA64>().await;
test_select_semi_honest::<BA256>().await;
}

async fn test_select_malicious<V>()
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>,
Standard: Distribution<V>,
{
let world = TestWorld::default();
let context = world.malicious_contexts();
let mut rng = thread_rng();

let bit = rng.gen::<Boolean>();
let a = rng.gen::<V>();
let b = rng.gen::<V>();

let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);

let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map(
|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1);
let m_ctx = v.context();

let result = select(
m_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;

v.validate().await?;

Ok::<_, Error>(result)
},
);

let [ab0, ab1, ab2] = join3v(futures).await;

let ab = [ab0, ab1, ab2].reconstruct();

assert_eq!(ab, if bit.into() { a } else { b });
}

#[tokio::test]
async fn select_malicious() {
test_select_malicious::<BA3>().await;
test_select_malicious::<BA8>().await;
test_select_malicious::<BA16>().await;
test_select_malicious::<BA20>().await;
test_select_malicious::<BA32>().await;
test_select_malicious::<BA64>().await;
test_select_malicious::<BA256>().await;
}

#[tokio::test]
async fn two_multiplies_malicious() {
const COUNT: usize = 32;
let mut rng = thread_rng();

Expand Down Expand Up @@ -913,9 +1036,54 @@ mod tests {
}
}

/// Similar to `test_select_malicious`, but operating on vectors
async fn multi_select_malicious<V>(count: usize, max_multiplications_per_gate: usize)
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>,
Standard: Distribution<V>,
{
let mut rng = thread_rng();

let bit: Vec<Boolean> = repeat_with(|| rng.gen::<Boolean>()).take(count).collect();
let a: Vec<V> = repeat_with(|| rng.gen()).take(count).collect();
let b: Vec<V> = repeat_with(|| rng.gen()).take(count).collect();

let [ab0, ab1, ab2]: [Vec<Replicated<V>>; 3] = TestWorld::default()
.malicious(
zip(bit.clone(), zip(a.clone(), b.clone())),
|ctx, inputs| async move {
let v = ctx
.set_total_records(count)
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate);
let m_ctx = v.context();

v.validated_seq_join(stream::iter(inputs).enumerate().map(
|(i, (bit_share, (a_share, b_share)))| {
let m_ctx = m_ctx.clone();
async move {
select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share)
.await
}
},
))
.try_collect()
.await
},
)
.await
.map(Result::unwrap);

let ab: Vec<V> = [ab0, ab1, ab2].reconstruct();

for i in 0..count {
assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] });
}
}

/// test for testing `validated_seq_join`
/// similar to `complex_circuit` in `validator.rs`
async fn complex_circuit_dzkp(
/// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment)
async fn chained_multiplies_dzkp(
count: usize,
max_multiplications_per_gate: usize,
) -> Result<(), Error> {
Expand Down Expand Up @@ -945,7 +1113,7 @@ mod tests {
.map(|(ctx, input_shares)| async move {
let v = ctx
.set_total_records(count - 1)
.dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get());
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate);
let m_ctx = v.context();

let m_results = v
Expand Down Expand Up @@ -1021,19 +1189,63 @@ mod tests {
Ok(())
}

fn record_count_strategy() -> impl Strategy<Value = usize> {
// The chained_multiplies test has count - 1 records, so 1 is not a valid input size.
// It is for multi_select though.
prop_oneof![2usize..=512, (1u32..=9).prop_map(|i| 1usize << i)]
}

fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy<Value = usize> {
let max_max_mults = record_count.min(128);
prop_oneof![
1usize..=max_max_mults,
(0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i)
]
}

prop_compose! {
fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) {
(1usize<<log_count, 1usize<<log_multiplication_amount)
fn batching()
(record_count in record_count_strategy())
(record_count in Just(record_count), max_mults in max_multiplications_per_gate_strategy(record_count))
-> (usize, usize)
{
(record_count, max_mults)
}
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){
let future = async {
let _ = complex_circuit_dzkp(count, multiplication_amount).await;
};
tokio::runtime::Runtime::new().unwrap().block_on(future);
fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) {
println!("record_count {record_count} batch {max_multiplications_per_gate}");
if record_count / max_multiplications_per_gate >= 192 {
// TODO: #1269, or even if we don't fix that, don't hardcode the limit.
println!("skipping config because batch count exceeds limit of 192");
}
// This condition is correct only for active_work = 16 and record size of 1 byte.
else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 {
// TODO: #1300, read_size | batch_size.
// Note: for active work < 2048, read size matches active work.

// Besides read_size | batch_size, there is also a constraint
// something like active_work > read_size + batch_size - 1.
println!("skipping config due to read_size vs. batch_size constraints");
} else {
tokio::runtime::Runtime::new().unwrap().block_on(async {
chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap();
/*
multi_select_malicious::<BA3>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA8>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA16>(record_count, max_multiplications_per_gate).await;
*/
multi_select_malicious::<BA20>(record_count, max_multiplications_per_gate).await;
/*
multi_select_malicious::<BA32>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA64>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA256>(record_count, max_multiplications_per_gate).await;
*/
});
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions ipa-core/src/protocol/context/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@

impl<'a, F: ExtendableField> Upgraded<'a, F> {
pub(super) fn new(batch: &Arc<MacBatcher<'a, F>>, ctx: Context<'a>) -> Self {
// The DZKP malicious context adjusts active_work to match records_per_batch.
// The MAC validator currently configures the batcher with records_per_batch =
// active_work. If the latter behavior changes, this code may need to be
// updated.
let records_per_batch = batch.lock().unwrap().records_per_batch();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to assert this to make sure we don't miss the misalignment between MAC and ZKP?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'm fine with that.

let active_work = ctx.active_work().get();
assert_eq!(
records_per_batch, active_work,
"Expect MAC validation batch size ({records_per_batch}) to match active work ({active_work})",

Check warning on line 189 in ipa-core/src/protocol/context/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/context/malicious.rs#L189

Added line #L189 was not covered by tests
);
Self {
batch: Arc::downgrade(batch),
base_ctx: ctx,
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/context/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> {

// TODO: Right now we set the batch work to be equal to active_work,
// but it does not need to be. We can make this configurable if needed.
let records_per_batch = ctx.active_work().get().min(total_records.get());
let records_per_batch = ctx.active_work().get();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reducing if larger than total_records may have been necessary with an earlier version of the batcher, but the current version should take care of that internally, so I removed it here. No relation to the rest of these changes though.


Self {
protocol_ctx: ctx.narrow(&Step::MaliciousProtocol),
Expand Down
4 changes: 1 addition & 3 deletions ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,7 @@ where
protocol: &Step::Attribute,
validate: &Step::AttributeValidate,
},
// The size of a single batch should not exceed the active work limit,
// otherwise it will stall
std::cmp::min(sh_ctx.active_work().get(), chunk_size),
chunk_size,
);
dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap());
let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?;
Expand Down
Loading