Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make active_work match records_per_batch #1316

Merged
merged 7 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion ipa-core/src/protocol/context/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ impl<'a, B> Batcher<'a, B> {
self.total_records = self.total_records.overwrite(total_records.into());
}

pub fn records_per_batch(&self) -> usize {
self.records_per_batch
}

fn batch_offset(&self, record_id: RecordId) -> usize {
let batch_index = usize::from(record_id) / self.records_per_batch;
batch_index
Expand All @@ -110,7 +114,7 @@ impl<'a, B> Batcher<'a, B> {
while self.batches.len() <= batch_offset {
let (validation_result, _) = watch::channel::<bool>(false);
let state = BatchState {
batch: (self.batch_constructor)(self.first_batch + batch_offset),
batch: (self.batch_constructor)(self.first_batch + self.batches.len()),
validation_result,
pending_count: 0,
pending_records: bitvec![0; self.records_per_batch],
Expand Down Expand Up @@ -292,6 +296,23 @@ mod tests {
);
}

#[test]
fn makes_batches_out_of_order() {
// Regression test for a bug where, when adding batches i..j to fill in a gap in
// the batch deque prior to out-of-order requested batch j, the batcher passed
// batch index `j` to the constructor for all of them, as opposed to the correct
// sequence of indices i..=j.

let batcher = Batcher::new(1, 2, Box::new(std::convert::identity));
let mut batcher = batcher.lock().unwrap();

batcher.get_batch(RecordId::from(1));
batcher.get_batch(RecordId::from(0));

assert_eq!(batcher.get_batch(RecordId::from(0)).batch, 0);
assert_eq!(batcher.get_batch(RecordId::from(1)).batch, 1);
}

#[tokio::test]
async fn validates_batches() {
let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new()));
Expand Down
24 changes: 23 additions & 1 deletion ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,38 @@
pub struct DZKPUpgraded<'a> {
validator_inner: Weak<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
active_work: NonZeroUsize,
}

impl<'a> DZKPUpgraded<'a> {
pub(super) fn new(
validator_inner: &Arc<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
) -> Self {
let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch();
let active_work = if records_per_batch == 1 {
// If records_per_batch is 1, let active_work be anything. This only happens
// in tests; there shouldn't be a risk of deadlocks with one record per
// batch; and UnorderedReceiver capacity (which is set from active_work)
// must be at least two.
base_ctx.active_work()
} else {
// Adjust active_work to match records_per_batch. If it is less, we will
// certainly stall, since every record in the batch remains incomplete until
// the batch is validated. It is possible that it can be larger, but making
// it the same seems safer for now.
let active_work = NonZeroUsize::new(records_per_batch).unwrap();
tracing::debug!(
"Changed active_work from {} to {} to match batch size",
base_ctx.active_work().get(),

Check warning on line 55 in ipa-core/src/protocol/context/dzkp_malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/context/dzkp_malicious.rs#L54-L55

Added lines #L54 - L55 were not covered by tests
active_work,
);
active_work
};
Self {
validator_inner: Arc::downgrade(validator_inner),
base_ctx,
active_work,
}
}

Expand Down Expand Up @@ -130,7 +152,7 @@

impl<'a> SeqJoin for DZKPUpgraded<'a> {
fn active_work(&self) -> NonZeroUsize {
self.base_ctx.active_work()
self.active_work
}
}

Expand Down
248 changes: 230 additions & 18 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -825,35 +825,158 @@ mod tests {
};

use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec};
use futures::{StreamExt, TryStreamExt};
use futures::{stream, StreamExt, TryStreamExt};
use futures_util::stream::iter;
use proptest::{prop_compose, proptest, sample::select};
use rand::{thread_rng, Rng};
use proptest::{
prelude::{Just, Strategy},
prop_compose, prop_oneof, proptest,
test_runner::Config as ProptestConfig,
};
use rand::{distributions::Standard, prelude::Distribution};

use crate::{
error::Error,
ff::{boolean::Boolean, Fp61BitPrime},
ff::{
boolean::Boolean,
boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8},
Fp61BitPrime,
},
protocol::{
basics::SecureMul,
basics::{select, BooleanArrayMul, SecureMul},
context::{
dzkp_field::{DZKPCompatibleField, BLOCK_SIZE},
dzkp_validator::{
Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE,
},
Context, UpgradableContext, TEST_DZKP_STEPS,
Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext,
UpgradableContext, TEST_DZKP_STEPS,
},
Gate, RecordId,
},
rand::{thread_rng, Rng},
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue,
Vectorizable,
},
seq_join::{seq_join, SeqJoin},
seq_join::seq_join,
sharding::NotSharded,
test_fixture::{join3v, Reconstruct, Runner, TestWorld},
};

async fn test_select_semi_honest<V>()
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedSemiHonestContext<'a, NotSharded>>,
Standard: Distribution<V>,
{
let world = TestWorld::default();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that we probably want to use the same seed for rng and test world (it is supported via TestWorldConfig struct). That way we can make it reproducible if it ever fails

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an issue for this. It's definitely worth doing in general, but it doesn't seem all that important for this particular test case, where the input values should be unrelated to the behavior of the test.

let context = world.contexts();
let mut rng = thread_rng();

let bit = rng.gen::<Boolean>();
let a = rng.gen::<V>();
let b = rng.gen::<V>();

let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);

let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map(
|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1);
let sh_ctx = v.context();

let result = select(
sh_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;

v.validate().await?;

Ok::<_, Error>(result)
},
);

let [ab0, ab1, ab2] = join3v(futures).await;

let ab = [ab0, ab1, ab2].reconstruct();

assert_eq!(ab, if bit.into() { a } else { b });
}

#[tokio::test]
async fn dzkp_malicious() {
async fn select_semi_honest() {
test_select_semi_honest::<BA3>().await;
test_select_semi_honest::<BA8>().await;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth testing it for for weird types like BA3 and BA7 as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. I had BA4 but took it out because it's not boolean_vector! (not for any good reason other than we haven't needed it). But I think it is worth having a less-than-one-byte case, and maybe even adding a new BA type so we can cover the between-one-and-two-bytes case.

test_select_semi_honest::<BA16>().await;
test_select_semi_honest::<BA20>().await;
test_select_semi_honest::<BA32>().await;
test_select_semi_honest::<BA64>().await;
test_select_semi_honest::<BA256>().await;
}

async fn test_select_malicious<V>()
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>,
Standard: Distribution<V>,
{
let world = TestWorld::default();
let context = world.malicious_contexts();
let mut rng = thread_rng();

let bit = rng.gen::<Boolean>();
let a = rng.gen::<V>();
let b = rng.gen::<V>();

let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);

let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map(
|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1);
let m_ctx = v.context();

let result = select(
m_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;

v.validate().await?;

Ok::<_, Error>(result)
},
);

let [ab0, ab1, ab2] = join3v(futures).await;

let ab = [ab0, ab1, ab2].reconstruct();

assert_eq!(ab, if bit.into() { a } else { b });
}

#[tokio::test]
async fn select_malicious() {
test_select_malicious::<BA3>().await;
test_select_malicious::<BA8>().await;
test_select_malicious::<BA16>().await;
test_select_malicious::<BA20>().await;
test_select_malicious::<BA32>().await;
test_select_malicious::<BA64>().await;
test_select_malicious::<BA256>().await;
}

#[tokio::test]
async fn two_multiplies_malicious() {
const COUNT: usize = 32;
let mut rng = thread_rng();

Expand Down Expand Up @@ -913,9 +1036,54 @@ mod tests {
}
}

/// Similar to `test_select_malicious`, but operating on vectors
async fn multi_select_malicious<V>(count: usize, max_multiplications_per_gate: usize)
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>,
Standard: Distribution<V>,
{
let mut rng = thread_rng();

let bit: Vec<Boolean> = repeat_with(|| rng.gen::<Boolean>()).take(count).collect();
let a: Vec<V> = repeat_with(|| rng.gen()).take(count).collect();
let b: Vec<V> = repeat_with(|| rng.gen()).take(count).collect();

let [ab0, ab1, ab2]: [Vec<Replicated<V>>; 3] = TestWorld::default()
.malicious(
zip(bit.clone(), zip(a.clone(), b.clone())),
|ctx, inputs| async move {
let v = ctx
.set_total_records(count)
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate);
let m_ctx = v.context();

v.validated_seq_join(stream::iter(inputs).enumerate().map(
|(i, (bit_share, (a_share, b_share)))| {
let m_ctx = m_ctx.clone();
async move {
select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share)
.await
}
},
))
.try_collect()
.await
},
)
.await
.map(Result::unwrap);

let ab: Vec<V> = [ab0, ab1, ab2].reconstruct();

for i in 0..count {
assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] });
}
}

/// test for testing `validated_seq_join`
/// similar to `complex_circuit` in `validator.rs`
async fn complex_circuit_dzkp(
/// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment)
async fn chained_multiplies_dzkp(
count: usize,
max_multiplications_per_gate: usize,
) -> Result<(), Error> {
Expand Down Expand Up @@ -945,7 +1113,7 @@ mod tests {
.map(|(ctx, input_shares)| async move {
let v = ctx
.set_total_records(count - 1)
.dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get());
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate);
let m_ctx = v.context();

let m_results = v
Expand Down Expand Up @@ -1021,19 +1189,63 @@ mod tests {
Ok(())
}

fn record_count_strategy() -> impl Strategy<Value = usize> {
// The chained_multiplies test has count - 1 records, so 1 is not a valid input size.
// It is for multi_select though.
prop_oneof![2usize..=512, (1u32..=9).prop_map(|i| 1usize << i)]
}

fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy<Value = usize> {
let max_max_mults = record_count.min(128);
prop_oneof![
1usize..=max_max_mults,
(0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i)
]
}

prop_compose! {
fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) {
(1usize<<log_count, 1usize<<log_multiplication_amount)
fn batching()
(record_count in record_count_strategy())
(record_count in Just(record_count), max_mults in max_multiplications_per_gate_strategy(record_count))
-> (usize, usize)
{
(record_count, max_mults)
}
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){
let future = async {
let _ = complex_circuit_dzkp(count, multiplication_amount).await;
};
tokio::runtime::Runtime::new().unwrap().block_on(future);
fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) {
println!("record_count {record_count} batch {max_multiplications_per_gate}");
if record_count / max_multiplications_per_gate >= 192 {
// TODO: #1269, or even if we don't fix that, don't hardcode the limit.
println!("skipping config because batch count exceeds limit of 192");
}
// This condition is correct only for active_work = 16 and record size of 1 byte.
else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 {
// TODO: #1300, read_size | batch_size.
// Note: for active work < 2048, read size matches active work.

// Besides read_size | batch_size, there is also a constraint
// something like active_work > read_size + batch_size - 1.
println!("skipping config due to read_size vs. batch_size constraints");
} else {
tokio::runtime::Runtime::new().unwrap().block_on(async {
chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap();
/*
multi_select_malicious::<BA3>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA8>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA16>(record_count, max_multiplications_per_gate).await;
*/
multi_select_malicious::<BA20>(record_count, max_multiplications_per_gate).await;
/*
multi_select_malicious::<BA32>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA64>(record_count, max_multiplications_per_gate).await;
multi_select_malicious::<BA256>(record_count, max_multiplications_per_gate).await;
*/
});
}
}
}

Expand Down
Loading