diff --git a/common/src/air_instance.rs b/common/src/air_instance.rs index 51e1178a..257a209b 100644 --- a/common/src/air_instance.rs +++ b/common/src/air_instance.rs @@ -33,6 +33,7 @@ pub struct AirInstance { pub air_segment_id: Option, pub air_instance_id: Option, pub idx: Option, + pub global_idx: Option, pub buffer: Vec, pub airgroup_values: Vec, pub airvalues: Vec, @@ -58,6 +59,7 @@ impl AirInstance { air_segment_id, air_instance_id: None, idx: None, + global_idx: None, buffer, airgroup_values: vec![F::zero(); get_n_airgroupvals_c(ps.p_setup.p_stark_info) as usize * 3], airvalues: vec![F::zero(); get_n_airvals_c(ps.p_setup.p_stark_info) as usize * 3], diff --git a/common/src/air_instances_repository.rs b/common/src/air_instances_repository.rs index ef0e1895..15e1a616 100644 --- a/common/src/air_instances_repository.rs +++ b/common/src/air_instances_repository.rs @@ -23,13 +23,18 @@ impl AirInstancesRepository { } } - pub fn add_air_instance(&self, mut air_instance: AirInstance) { + pub fn add_air_instance(&self, mut air_instance: AirInstance, global_idx: Option) { let mut air_instances = self.air_instances.write().unwrap(); let n_air_instances = air_instances.len(); let mut air_instances_counts = self.air_instances_counts.write().unwrap(); let instance_id = air_instances_counts.entry((air_instance.airgroup_id, air_instance.air_id)).or_insert(0); air_instance.set_air_instance_id(*instance_id, n_air_instances); + if global_idx.is_some() { + air_instance.global_idx = global_idx; + } else { + air_instance.global_idx = Some(n_air_instances); + } *instance_id += 1; air_instances.push(air_instance); } @@ -38,17 +43,11 @@ impl AirInstancesRepository { let air_instances = self.air_instances.read().unwrap(); let mut indices = Vec::new(); - #[cfg(feature = "distributed")] - let mut segment_ids = Vec::new(); for (index, air_instance) in air_instances.iter().enumerate() { if air_instance.airgroup_id == airgroup_id { indices.push(index); - #[cfg(feature = "distributed")] - segment_ids.push(air_instance.air_segment_id.unwrap_or(0)); } } - #[cfg(feature = "distributed")] - indices.sort_by(|a, b| segment_ids[*a].cmp(&segment_ids[*b])); indices } diff --git a/common/src/distribution_ctx.rs b/common/src/distribution_ctx.rs index 7164caaf..0aba3006 100644 --- a/common/src/distribution_ctx.rs +++ b/common/src/distribution_ctx.rs @@ -1,9 +1,13 @@ +use std::collections::HashMap; +use std::collections::BTreeMap; #[cfg(feature = "distributed")] -use mpi::collective::CommunicatorCollectives; -#[cfg(feature = "distributed")] -use mpi::traits::Communicator; +use mpi::traits::*; #[cfg(feature = "distributed")] use mpi::environment::Universe; +#[cfg(feature = "distributed")] +use mpi::collective::CommunicatorCollectives; +#[cfg(feature = "distributed")] +use mpi::datatype::PartitionMut; /// Represents the context of distributed computing pub struct DistributionCtx { @@ -13,9 +17,18 @@ pub struct DistributionCtx { pub universe: Universe, #[cfg(feature = "distributed")] pub world: mpi::topology::SimpleCommunicator, - pub n_instances: i32, + pub n_instances: usize, pub my_instances: Vec, - pub instances: Vec<(usize, usize)>, + pub instances: Vec<(usize, usize)>, //group_id, air_id + pub instances_owner: Vec<(usize, usize)>, //owner_rank, owner_instance_idx + pub owners_count: Vec, + pub owners_weight: Vec, + #[cfg(feature = "distributed")] + pub roots_gatherv_count: Vec, + #[cfg(feature = "distributed")] + pub roots_gatherv_displ: Vec, + pub my_groups: Vec>, + pub my_air_groups: Vec>, } impl DistributionCtx { @@ -24,19 +37,39 @@ impl DistributionCtx { { let (universe, _threading) = mpi::initialize_with_threading(mpi::Threading::Multiple).unwrap(); let world = universe.world(); + let rank = world.rank(); + let n_processes = world.size(); DistributionCtx { - rank: world.rank(), - n_processes: world.size(), + rank, + n_processes, universe, world, n_instances: 0, my_instances: Vec::new(), instances: Vec::new(), + instances_owner: Vec::new(), + owners_count: vec![0; n_processes as usize], + owners_weight: vec![0; n_processes as usize], + roots_gatherv_count: vec![0; n_processes as usize], + roots_gatherv_displ: vec![0; n_processes as usize], + my_groups: Vec::new(), + my_air_groups: Vec::new(), } } #[cfg(not(feature = "distributed"))] { - DistributionCtx { rank: 0, n_processes: 1, n_instances: 0, my_instances: Vec::new(), instances: Vec::new() } + DistributionCtx { + rank: 0, + n_processes: 1, + n_instances: 0, + my_instances: Vec::new(), + instances: Vec::new(), + instances_owner: Vec::new(), + owners_count: vec![0; 1], + owners_weight: vec![0; 1], + my_groups: Vec::new(), + my_air_groups: Vec::new(), + } } } @@ -48,11 +81,6 @@ impl DistributionCtx { } } - #[inline] - pub fn is_master(&self) -> bool { - self.rank == 0 - } - #[inline] pub fn is_distributed(&self) -> bool { self.n_processes > 1 @@ -60,16 +88,178 @@ impl DistributionCtx { #[inline] pub fn is_my_instance(&self, instance_idx: usize) -> bool { - instance_idx % self.n_processes as usize == self.rank as usize + self.owner(instance_idx) == self.rank as usize + } + + #[inline] + pub fn owner(&self, instance_idx: usize) -> usize { + self.instances_owner[instance_idx].0 } #[inline] - pub fn add_instance(&mut self, airgroup_id: usize, air_id: usize, instance_idx: usize, _size: usize) { + pub fn add_instance(&mut self, airgroup_id: usize, air_id: usize, weight: usize) -> (bool, usize) { + let mut is_mine = false; + let owner = self.n_instances % self.n_processes as usize; + self.instances.push((airgroup_id, air_id)); + self.instances_owner.push((owner, self.owners_count[owner] as usize)); + self.owners_count[owner] += 1; + self.owners_weight[owner] += weight as u64; + + if owner == self.rank as usize { + self.my_instances.push(self.n_instances); + is_mine = true; + } self.n_instances += 1; - if self.is_my_instance(instance_idx) { - self.my_instances.push(instance_idx); + (is_mine, self.n_instances - 1) + } + + pub fn close(&mut self) { + let mut group_indices: BTreeMap> = BTreeMap::new(); + + // Calculate the partial sums of owners_count + #[cfg(feature = "distributed")] + { + let mut total_instances = 0; + for i in 0..self.n_processes as usize { + self.roots_gatherv_displ[i] = total_instances; + self.roots_gatherv_count[i] = self.owners_count[i] * 4; + total_instances += self.roots_gatherv_count[i]; + } + } + + // Populate the HashMap based on group_id and buffer positions + for (idx, &(group_id, _)) in self.instances.iter().enumerate() { + #[cfg(feature = "distributed")] + let pos_buffer = + self.roots_gatherv_displ[self.instances_owner[idx].0] as usize + self.instances_owner[idx].1 * 4; + #[cfg(not(feature = "distributed"))] + let pos_buffer = idx * 4; + group_indices.entry(group_id).or_default().push(pos_buffer); + } + + // Flatten the HashMap into a single vector for my_groups + for (_, indices) in group_indices { + self.my_groups.push(indices); + } + + // Create my eval groups + let mut my_air_groups_indices: HashMap<(usize, usize), Vec> = HashMap::new(); + for (loc_idx, glob_idx) in self.my_instances.iter().enumerate() { + let instance_idx = self.instances[*glob_idx]; + my_air_groups_indices.entry(instance_idx).or_default().push(loc_idx); + } + + // Flatten the HashMap into a single vector for my_air_groups + for (_, indices) in my_air_groups_indices { + self.my_air_groups.push(indices); + } + } + + pub fn distribute_roots(&self, roots: Vec) -> Vec { + #[cfg(feature = "distributed")] + { + let mut all_roots: Vec = vec![0; 4 * self.n_instances]; + let counts = &self.roots_gatherv_count; + let displs = &self.roots_gatherv_displ; + + let mut partitioned_all_roots = PartitionMut::new(&mut all_roots, counts.as_slice(), displs.as_slice()); + + self.world.all_gather_varcount_into(&roots, &mut partitioned_all_roots); + + all_roots + } + #[cfg(not(feature = "distributed"))] + { + roots + } + } + + pub fn distribute_multiplicity(&self, _multiplicity: &mut [u64], _owner: usize) { + #[cfg(feature = "distributed")] + { + //assert that I can operate with u32 + assert!(_multiplicity.len() < std::u32::MAX as usize); + + if _owner != self.rank as usize { + //pack multiplicities in a sparce vector + let mut packed_multiplicity = Vec::new(); + packed_multiplicity.push(0 as u32); //this will be the counter + for (idx, &m) in _multiplicity.iter().enumerate() { + if m != 0 { + assert!(m < std::u32::MAX as u64); + packed_multiplicity.push(idx as u32); + packed_multiplicity.push(m as u32); + packed_multiplicity[0] += 2; + } + } + self.world.process_at_rank(_owner as i32).send(&packed_multiplicity[..]); + } else { + let mut packed_multiplicity: Vec = vec![0; _multiplicity.len() * 2 + 1]; + for i in 0..self.n_processes { + if i != _owner as i32 { + self.world.process_at_rank(i).receive_into(&mut packed_multiplicity); + for j in (1..packed_multiplicity[0]).step_by(2) { + let idx = packed_multiplicity[j as usize] as usize; + let m = packed_multiplicity[j as usize + 1] as u64; + _multiplicity[idx] += m; + } + } + } + } + } + } + + pub fn distribute_multiplicities(&self, _multiplicities: &mut [Vec], _owner: usize) { + #[cfg(feature = "distributed")] + { + // Ensure that each multiplicity vector can be operated with u32 + let mut buff_size = 0; + for multiplicity in _multiplicities.iter() { + assert!(multiplicity.len() < std::u32::MAX as usize); + buff_size += multiplicity.len() + 1; + } + + let n_columns = _multiplicities.len(); + if _owner != self.rank as usize { + // Pack multiplicities in a sparse vector + let mut packed_multiplicities = vec![0u32; n_columns]; + for (col_idx, multiplicity) in _multiplicities.iter().enumerate() { + for (idx, &m) in multiplicity.iter().enumerate() { + if m != 0 { + assert!(m < std::u32::MAX as u64); + packed_multiplicities[col_idx] += 1; + packed_multiplicities.push(idx as u32); + packed_multiplicities.push(m as u32); + } + } + } + self.world.process_at_rank(_owner as i32).send(&packed_multiplicities[..]); + } else { + let mut packed_multiplicities: Vec = vec![0; buff_size * 2]; + for i in 0..self.n_processes { + if i != _owner as i32 { + self.world.process_at_rank(i).receive_into(&mut packed_multiplicities); + + // Read counters + let mut counters = vec![0usize; n_columns]; + for col_idx in 0..n_columns { + counters[col_idx] = packed_multiplicities[col_idx] as usize; + } + + // Unpack multiplicities + let mut idx = n_columns; + for col_idx in 0..n_columns { + for _ in 0..counters[col_idx] { + let row_idx = packed_multiplicities[idx] as usize; + let m = packed_multiplicities[idx + 1] as u64; + _multiplicities[col_idx][row_idx] += m; + idx += 2; + } + } + } + } + } } - self.instances.push((airgroup_id, air_id)); } } diff --git a/examples/fibonacci-square/src/fibonacci.rs b/examples/fibonacci-square/src/fibonacci.rs index 863823e7..33e9a559 100644 --- a/examples/fibonacci-square/src/fibonacci.rs +++ b/examples/fibonacci-square/src/fibonacci.rs @@ -82,7 +82,11 @@ impl FibonacciSquare { air_instance.set_airvalue(&sctx, "FibonacciSquare.fibo2", F::from_canonical_u64(2)); air_instance.set_airvalue_ext(&sctx, "FibonacciSquare.fibo3", vec![F::from_canonical_u64(5); 3]); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(FIBONACCI_SQUARE_AIRGROUP_ID, FIBONACCI_SQUARE_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } Ok(b) } diff --git a/examples/fibonacci-square/src/module.rs b/examples/fibonacci-square/src/module.rs index 25be501f..933b1e71 100644 --- a/examples/fibonacci-square/src/module.rs +++ b/examples/fibonacci-square/src/module.rs @@ -79,7 +79,11 @@ impl Module } let air_instance = AirInstance::new(sctx.clone(), MODULE_AIRGROUP_ID, MODULE_AIR_IDS[0], Some(0), buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + + let (is_myne, gid) = ectx.dctx.write().unwrap().add_instance(MODULE_AIRGROUP_ID, MODULE_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } self.std_lib.unregister_predecessor(pctx, None); } diff --git a/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs b/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs index 48d47c99..f080d0f5 100644 --- a/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs +++ b/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs @@ -94,25 +94,66 @@ impl SpecifiedRanges { pub fn drain_inputs(&self) { let mut inputs = self.inputs.lock().unwrap(); let drained_inputs = inputs.drain(..).collect(); + let pctx = self.wcm.get_pctx(); + let sctx = self.wcm.get_sctx(); + let ectx = self.wcm.get_ectx(); // Perform the last update self.update_multiplicity(drained_inputs); - // Set the multiplicity columns as done - let hints = self.hints.lock().unwrap(); + let mut dctx: std::sync::RwLockWriteGuard<'_, proofman_common::DistributionCtx> = ectx.dctx.write().unwrap(); - let air_instance_repo = &self.wcm.get_pctx().air_instance_repo; - let air_instance_id = air_instance_repo.find_air_instances(self.airgroup_id, self.air_id)[0]; - let mut air_instance_rw = air_instance_repo.air_instances.write().unwrap(); - let air_instance = &mut air_instance_rw[air_instance_id]; + let (is_myne, global_idx) = dctx.add_instance(self.airgroup_id, self.air_id, 1); - let mul_columns = &*self.mul_columns.lock().unwrap(); + let mut multiplicities = self + .mul_columns + .lock() + .unwrap() + .iter() + .map(|column| match column { + HintFieldValue::Column(values) => { + values.iter().map(|x| x.as_canonical_biguint().to_u64().unwrap()).collect::>() + } + _ => panic!("Multiplicities must be columns"), + }) + .collect::>>(); + let owner = dctx.owner(global_idx); + + dctx.distribute_multiplicities(&mut multiplicities, owner); + + if is_myne { + // Set the multiplicity columns as done + let hints = self.hints.lock().unwrap(); + + let air_instance_repo = &self.wcm.get_pctx().air_instance_repo; + let instance: Vec = air_instance_repo.find_air_instances(self.airgroup_id, self.air_id); + let air_instance_id = if !instance.is_empty() { + air_instance_repo.find_air_instances(self.airgroup_id, self.air_id)[0] + } else { + // create instance + let (buffer_size, _) = + ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, self.airgroup_id, self.air_id).unwrap(); + let buffer: Vec = create_buffer_fast(buffer_size as usize); + let air_instance = AirInstance::new(sctx.clone(), self.airgroup_id, self.air_id, None, buffer); + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + pctx.air_instance_repo.air_instances.read().unwrap().len() - 1 + }; + let mut air_instance_rw = air_instance_repo.air_instances.write().unwrap(); + let air_instance = &mut air_instance_rw[air_instance_id]; + + let mul_columns_2 = multiplicities + .iter() + .map(|multiplicities| { + HintFieldValue::Column(multiplicities.iter().map(|x| F::from_canonical_u64(*x)).collect::>()) + }) + .collect::>>(); + + for (index, hint) in hints[1..].iter().enumerate() { + set_hint_field(&self.wcm.get_sctx(), air_instance, *hint, "reference", &mul_columns_2[index]); + } - for (index, hint) in hints[1..].iter().enumerate() { - set_hint_field(&self.wcm.get_sctx(), air_instance, *hint, "reference", &mul_columns[index]); + log::trace!("{}: ··· Drained inputs for AIR '{}'", Self::MY_NAME, "SpecifiedRanges"); } - - log::trace!("{}: ··· Drained inputs for AIR '{}'", Self::MY_NAME, "SpecifiedRanges"); } fn update_multiplicity(&self, drained_inputs: Vec<(Range, F, F)>) { @@ -289,7 +330,8 @@ impl WitnessComponent for SpecifiedRanges { *self.num_rows.lock().unwrap() = num_rows.as_canonical_biguint().to_usize().unwrap(); - pctx.air_instance_repo.add_air_instance(air_instance); + //pctx.air_instance_repo.add_air_instance(air_instance); + // note: there is room for simplification here } fn calculate_witness( diff --git a/pil2-components/lib/std/rs/src/range_check/u16air.rs b/pil2-components/lib/std/rs/src/range_check/u16air.rs index a8622173..a45b5eb9 100644 --- a/pil2-components/lib/std/rs/src/range_check/u16air.rs +++ b/pil2-components/lib/std/rs/src/range_check/u16air.rs @@ -87,20 +87,58 @@ impl U16Air { pub fn drain_inputs(&self) { let mut inputs = self.inputs.lock().unwrap(); let drained_inputs = inputs.drain(..).collect(); + let pctx = self.wcm.get_pctx(); + let sctx = self.wcm.get_sctx(); + let ectx = self.wcm.get_ectx(); // Perform the last update self.update_multiplicity(drained_inputs); - let air_instance_repo = &self.wcm.get_pctx().air_instance_repo; - let air_instance_id = air_instance_repo.find_air_instances(self.airgroup_id, self.air_id)[0]; + let mut dctx: std::sync::RwLockWriteGuard<'_, proofman_common::DistributionCtx> = ectx.dctx.write().unwrap(); - let mut air_instance_rw = air_instance_repo.air_instances.write().unwrap(); - let air_instance = &mut air_instance_rw[air_instance_id]; - - let mul_column = &*self.mul_column.lock().unwrap(); - set_hint_field(&self.wcm.get_sctx(), air_instance, self.hint.load(Ordering::Acquire), "reference", mul_column); + let (is_myne, global_idx) = dctx.add_instance(self.airgroup_id, self.air_id, 1); + let mut multiplicity = match &*self.mul_column.lock().unwrap() { + HintFieldValue::Column(values) => { + values.iter().map(|x| x.as_canonical_biguint().to_u64().unwrap()).collect::>() + } + _ => panic!("Multiplicities must be a column"), + }; - log::trace!("{}: ··· Drained inputs for AIR '{}'", Self::MY_NAME, "U16Air"); + let owner = dctx.owner(global_idx); + dctx.distribute_multiplicity(&mut multiplicity, owner); + + if is_myne { + let air_instance_repo = &self.wcm.get_pctx().air_instance_repo; + let instance: Vec = air_instance_repo.find_air_instances(self.airgroup_id, self.air_id); + let air_instance_id = if !instance.is_empty() { + air_instance_repo.find_air_instances(self.airgroup_id, self.air_id)[0] + } else { + // create instance + let (buffer_size, _) = + ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, self.airgroup_id, self.air_id).unwrap(); + let buffer: Vec = create_buffer_fast(buffer_size as usize); + let air_instance = AirInstance::new(sctx.clone(), self.airgroup_id, self.air_id, None, buffer); + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + pctx.air_instance_repo.air_instances.read().unwrap().len() - 1 + }; + + let mut air_instance_rw = air_instance_repo.air_instances.write().unwrap(); + let air_instance = &mut air_instance_rw[air_instance_id]; + + // copy multiplicitis back to mul_column + let mul_column_2 = + HintFieldValue::Column(multiplicity.iter().map(|x| F::from_canonical_u64(*x)).collect::>()); + + set_hint_field( + &self.wcm.get_sctx(), + air_instance, + self.hint.load(Ordering::Acquire), + "reference", + &mul_column_2, + ); + + log::trace!("{}: ··· Drained inputs for AIR '{}'", Self::MY_NAME, "U16Air"); + } } fn update_multiplicity(&self, drained_inputs: Vec<(F, F)>) { @@ -145,7 +183,8 @@ impl WitnessComponent for U16Air { HintFieldOptions::dest_with_zeros(), ); - pctx.air_instance_repo.add_air_instance(air_instance); + //pctx.air_instance_repo.add_air_instance(air_instance); + //note: there is room for simplification here } fn calculate_witness( diff --git a/pil2-components/lib/std/rs/src/range_check/u8air.rs b/pil2-components/lib/std/rs/src/range_check/u8air.rs index ab49dbe6..105d2f23 100644 --- a/pil2-components/lib/std/rs/src/range_check/u8air.rs +++ b/pil2-components/lib/std/rs/src/range_check/u8air.rs @@ -86,20 +86,58 @@ impl U8Air { pub fn drain_inputs(&self) { let mut inputs = self.inputs.lock().unwrap(); let drained_inputs = inputs.drain(..).collect(); + let pctx = self.wcm.get_pctx(); + let sctx = self.wcm.get_sctx(); + let ectx = self.wcm.get_ectx(); // Perform the last update self.update_multiplicity(drained_inputs); - let air_instance_repo = &self.wcm.get_pctx().air_instance_repo; - let air_instance_id = air_instance_repo.find_air_instances(self.airgroup_id, self.air_id)[0]; + let mut dctx: std::sync::RwLockWriteGuard<'_, proofman_common::DistributionCtx> = ectx.dctx.write().unwrap(); - let mut air_instance_rw = air_instance_repo.air_instances.write().unwrap(); - let air_instance = &mut air_instance_rw[air_instance_id]; - - let mul_column = &*self.mul_column.lock().unwrap(); - set_hint_field(&self.wcm.get_sctx(), air_instance, self.hint.load(Ordering::Acquire), "reference", mul_column); + let (is_myne, global_idx) = dctx.add_instance(self.airgroup_id, self.air_id, 1); + let mut multiplicity = match &*self.mul_column.lock().unwrap() { + HintFieldValue::Column(values) => { + values.iter().map(|x| x.as_canonical_biguint().to_u64().unwrap()).collect::>() + } + _ => panic!("Multiplicities must be a column"), + }; - log::trace!("{}: ··· Drained inputs for AIR '{}'", Self::MY_NAME, "U8Air"); + let owner = dctx.owner(global_idx); + dctx.distribute_multiplicity(&mut multiplicity, owner); + + if is_myne { + let air_instance_repo = &self.wcm.get_pctx().air_instance_repo; + let instance: Vec = air_instance_repo.find_air_instances(self.airgroup_id, self.air_id); + let air_instance_id = if !instance.is_empty() { + air_instance_repo.find_air_instances(self.airgroup_id, self.air_id)[0] + } else { + // create instance + let (buffer_size, _) = + ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, self.airgroup_id, self.air_id).unwrap(); + let buffer: Vec = create_buffer_fast(buffer_size as usize); + let air_instance = AirInstance::new(sctx.clone(), self.airgroup_id, self.air_id, None, buffer); + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + pctx.air_instance_repo.air_instances.read().unwrap().len() - 1 + }; + + let mut air_instance_rw = air_instance_repo.air_instances.write().unwrap(); + let air_instance = &mut air_instance_rw[air_instance_id]; + + // copy multiplicitis back to mul_column + let mul_column_2 = + HintFieldValue::Column(multiplicity.iter().map(|x| F::from_canonical_u64(*x)).collect::>()); + + set_hint_field( + &self.wcm.get_sctx(), + air_instance, + self.hint.load(Ordering::Acquire), + "reference", + &mul_column_2, + ); + + log::trace!("{}: ··· Drained inputs for AIR '{}'", Self::MY_NAME, "U8Air"); + } } fn update_multiplicity(&self, drained_inputs: Vec<(F, F)>) { @@ -143,7 +181,8 @@ impl WitnessComponent for U8Air { HintFieldOptions::dest_with_zeros(), ); - pctx.air_instance_repo.add_air_instance(air_instance); + //pctx.air_instance_repo.add_air_instance(air_instance); + //note: there is room for simplification heres } fn calculate_witness( diff --git a/pil2-components/test/simple/rs/src/simple_left.rs b/pil2-components/test/simple/rs/src/simple_left.rs index b18be75f..9a84756b 100644 --- a/pil2-components/test/simple/rs/src/simple_left.rs +++ b/pil2-components/test/simple/rs/src/simple_left.rs @@ -33,7 +33,12 @@ where let buffer = vec![F::zero(); buffer_size as usize]; let air_instance = AirInstance::new(sctx.clone(), SIMPLE_AIRGROUP_ID, SIMPLE_LEFT_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(SIMPLE_AIRGROUP_ID, SIMPLE_LEFT_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/simple/rs/src/simple_right.rs b/pil2-components/test/simple/rs/src/simple_right.rs index 0bd2a666..d89aafad 100644 --- a/pil2-components/test/simple/rs/src/simple_right.rs +++ b/pil2-components/test/simple/rs/src/simple_right.rs @@ -33,7 +33,11 @@ where let buffer = vec![F::zero(); buffer_size as usize]; let air_instance = AirInstance::new(sctx.clone(), SIMPLE_AIRGROUP_ID, SIMPLE_RIGHT_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(SIMPLE_AIRGROUP_ID, SIMPLE_RIGHT_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/connection/rs/src/connection1.rs b/pil2-components/test/std/connection/rs/src/connection1.rs index d68a3818..80c06f4c 100644 --- a/pil2-components/test/std/connection/rs/src/connection1.rs +++ b/pil2-components/test/std/connection/rs/src/connection1.rs @@ -38,7 +38,11 @@ where let air_instance = AirInstance::new(sctx.clone(), CONNECTION_AIRGROUP_ID, CONNECTION_1_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(CONNECTION_AIRGROUP_ID, CONNECTION_1_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/connection/rs/src/connection2.rs b/pil2-components/test/std/connection/rs/src/connection2.rs index e063fee5..4933e3cc 100644 --- a/pil2-components/test/std/connection/rs/src/connection2.rs +++ b/pil2-components/test/std/connection/rs/src/connection2.rs @@ -38,7 +38,11 @@ where let air_instance = AirInstance::new(sctx.clone(), CONNECTION_AIRGROUP_ID, CONNECTION_2_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(CONNECTION_AIRGROUP_ID, CONNECTION_2_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/connection/rs/src/connection_new.rs b/pil2-components/test/std/connection/rs/src/connection_new.rs index 5130bb1a..4e30a4b6 100644 --- a/pil2-components/test/std/connection/rs/src/connection_new.rs +++ b/pil2-components/test/std/connection/rs/src/connection_new.rs @@ -38,7 +38,11 @@ where let air_instance = AirInstance::new(sctx.clone(), CONNECTION_AIRGROUP_ID, CONNECTION_NEW_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(CONNECTION_AIRGROUP_ID, CONNECTION_NEW_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/lookup/rs/src/lookup0.rs b/pil2-components/test/std/lookup/rs/src/lookup0.rs index bcb2f5c8..03032bde 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup0.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup0.rs @@ -35,7 +35,11 @@ where let air_instance = AirInstance::new(sctx.clone(), LOOKUP_AIRGROUP_ID, LOOKUP_0_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(LOOKUP_AIRGROUP_ID, LOOKUP_0_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/lookup/rs/src/lookup1.rs b/pil2-components/test/std/lookup/rs/src/lookup1.rs index 7a53bc0d..9ca0f901 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup1.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup1.rs @@ -35,7 +35,11 @@ where let air_instance = AirInstance::new(sctx.clone(), LOOKUP_AIRGROUP_ID, LOOKUP_1_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(LOOKUP_AIRGROUP_ID, LOOKUP_1_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/lookup/rs/src/lookup2_12.rs b/pil2-components/test/std/lookup/rs/src/lookup2_12.rs index 52a779c9..565f8c06 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup2_12.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup2_12.rs @@ -35,7 +35,11 @@ where let air_instance = AirInstance::new(sctx.clone(), LOOKUP_AIRGROUP_ID, LOOKUP_2_12_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(LOOKUP_AIRGROUP_ID, LOOKUP_2_12_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/lookup/rs/src/lookup2_13.rs b/pil2-components/test/std/lookup/rs/src/lookup2_13.rs index 00911fc7..ff59dc63 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup2_13.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup2_13.rs @@ -35,7 +35,11 @@ where let air_instance = AirInstance::new(sctx.clone(), LOOKUP_AIRGROUP_ID, LOOKUP_2_13_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(LOOKUP_AIRGROUP_ID, LOOKUP_2_13_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/lookup/rs/src/lookup2_15.rs b/pil2-components/test/std/lookup/rs/src/lookup2_15.rs index 183c1b7d..d9b5f5c6 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup2_15.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup2_15.rs @@ -35,7 +35,11 @@ where let air_instance = AirInstance::new(sctx.clone(), LOOKUP_AIRGROUP_ID, LOOKUP_2_15_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(LOOKUP_AIRGROUP_ID, LOOKUP_2_15_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/lookup/rs/src/lookup3.rs b/pil2-components/test/std/lookup/rs/src/lookup3.rs index 6865fc82..c68f2332 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup3.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup3.rs @@ -31,7 +31,11 @@ impl Lookup3 { let air_instance = AirInstance::new(sctx.clone(), LOOKUP_AIRGROUP_ID, LOOKUP_3_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(LOOKUP_AIRGROUP_ID, LOOKUP_3_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/permutation/rs/src/permutation1_6.rs b/pil2-components/test/std/permutation/rs/src/permutation1_6.rs index c35e1539..a59397aa 100644 --- a/pil2-components/test/std/permutation/rs/src/permutation1_6.rs +++ b/pil2-components/test/std/permutation/rs/src/permutation1_6.rs @@ -38,13 +38,21 @@ where let air_instance = AirInstance::new(sctx.clone(), PERMUTATION_AIRGROUP_ID, PERMUTATION_1_6_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(PERMUTATION_AIRGROUP_ID, PERMUTATION_1_6_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } let buffer = vec![F::zero(); buffer_size as usize]; let air_instance = AirInstance::new(sctx.clone(), PERMUTATION_AIRGROUP_ID, PERMUTATION_1_6_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(PERMUTATION_AIRGROUP_ID, PERMUTATION_1_6_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/permutation/rs/src/permutation1_7.rs b/pil2-components/test/std/permutation/rs/src/permutation1_7.rs index 61934b74..7dfe97c0 100644 --- a/pil2-components/test/std/permutation/rs/src/permutation1_7.rs +++ b/pil2-components/test/std/permutation/rs/src/permutation1_7.rs @@ -38,7 +38,11 @@ where let air_instance = AirInstance::new(sctx.clone(), PERMUTATION_AIRGROUP_ID, PERMUTATION_1_7_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(PERMUTATION_AIRGROUP_ID, PERMUTATION_1_7_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/permutation/rs/src/permutation1_8.rs b/pil2-components/test/std/permutation/rs/src/permutation1_8.rs index 8ded471a..6dff0018 100644 --- a/pil2-components/test/std/permutation/rs/src/permutation1_8.rs +++ b/pil2-components/test/std/permutation/rs/src/permutation1_8.rs @@ -38,7 +38,11 @@ where let air_instance = AirInstance::new(sctx.clone(), PERMUTATION_AIRGROUP_ID, PERMUTATION_1_8_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(PERMUTATION_AIRGROUP_ID, PERMUTATION_1_8_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/permutation/rs/src/permutation2.rs b/pil2-components/test/std/permutation/rs/src/permutation2.rs index 19e753e7..1ac9f372 100644 --- a/pil2-components/test/std/permutation/rs/src/permutation2.rs +++ b/pil2-components/test/std/permutation/rs/src/permutation2.rs @@ -34,7 +34,11 @@ impl Permutation2 { let air_instance = AirInstance::new(sctx.clone(), PERMUTATION_AIRGROUP_ID, PERMUTATION_2_6_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(PERMUTATION_AIRGROUP_ID, PERMUTATION_2_6_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/range_check/rs/src/multi_range_check1.rs b/pil2-components/test/std/range_check/rs/src/multi_range_check1.rs index ce434237..08e1fd9a 100644 --- a/pil2-components/test/std/range_check/rs/src/multi_range_check1.rs +++ b/pil2-components/test/std/range_check/rs/src/multi_range_check1.rs @@ -52,7 +52,11 @@ where None, buffer, ); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(MULTI_RANGE_CHECK_1_AIRGROUP_ID, MULTI_RANGE_CHECK_1_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/range_check/rs/src/multi_range_check2.rs b/pil2-components/test/std/range_check/rs/src/multi_range_check2.rs index 7ab0b3c8..e9a44065 100644 --- a/pil2-components/test/std/range_check/rs/src/multi_range_check2.rs +++ b/pil2-components/test/std/range_check/rs/src/multi_range_check2.rs @@ -52,7 +52,11 @@ where None, buffer, ); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(MULTI_RANGE_CHECK_2_AIRGROUP_ID, MULTI_RANGE_CHECK_2_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/range_check/rs/src/range_check1.rs b/pil2-components/test/std/range_check/rs/src/range_check1.rs index 30a9c778..4986f271 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check1.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check1.rs @@ -43,7 +43,11 @@ where let air_instance = AirInstance::new(sctx.clone(), RANGE_CHECK_1_AIRGROUP_ID, RANGE_CHECK_1_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(RANGE_CHECK_1_AIRGROUP_ID, RANGE_CHECK_1_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/range_check/rs/src/range_check2.rs b/pil2-components/test/std/range_check/rs/src/range_check2.rs index e20783c8..280a69e9 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check2.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check2.rs @@ -43,7 +43,11 @@ where let air_instance = AirInstance::new(sctx.clone(), RANGE_CHECK_2_AIRGROUP_ID, RANGE_CHECK_2_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(RANGE_CHECK_2_AIRGROUP_ID, RANGE_CHECK_2_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/range_check/rs/src/range_check3.rs b/pil2-components/test/std/range_check/rs/src/range_check3.rs index 737e6781..40d61512 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check3.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check3.rs @@ -43,7 +43,11 @@ where let air_instance = AirInstance::new(sctx.clone(), RANGE_CHECK_3_AIRGROUP_ID, RANGE_CHECK_3_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(RANGE_CHECK_3_AIRGROUP_ID, RANGE_CHECK_3_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/range_check/rs/src/range_check4.rs b/pil2-components/test/std/range_check/rs/src/range_check4.rs index c4c8b590..1ce05fde 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check4.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check4.rs @@ -44,7 +44,11 @@ where let air_instance = AirInstance::new(sctx.clone(), RANGE_CHECK_4_AIRGROUP_ID, RANGE_CHECK_4_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(RANGE_CHECK_4_AIRGROUP_ID, RANGE_CHECK_4_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/range_check/rs/src/range_check_dynamic1.rs b/pil2-components/test/std/range_check/rs/src/range_check_dynamic1.rs index 952d8ca3..e5ab9d4c 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check_dynamic1.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check_dynamic1.rs @@ -52,7 +52,11 @@ where None, buffer, ); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(RANGE_CHECK_DYNAMIC_1_AIRGROUP_ID, RANGE_CHECK_DYNAMIC_1_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/range_check/rs/src/range_check_dynamic2.rs b/pil2-components/test/std/range_check/rs/src/range_check_dynamic2.rs index c317fde2..e94d3868 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check_dynamic2.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check_dynamic2.rs @@ -53,7 +53,11 @@ where None, buffer, ); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(RANGE_CHECK_DYNAMIC_2_AIRGROUP_ID, RANGE_CHECK_DYNAMIC_2_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/pil2-components/test/std/range_check/rs/src/range_check_mix.rs b/pil2-components/test/std/range_check/rs/src/range_check_mix.rs index c0fa5341..7494c0f2 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check_mix.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check_mix.rs @@ -48,7 +48,11 @@ where let air_instance = AirInstance::new(sctx.clone(), RANGE_CHECK_MIX_AIRGROUP_ID, RANGE_CHECK_MIX_AIR_IDS[0], None, buffer); - pctx.air_instance_repo.add_air_instance(air_instance); + let (is_myne, gid) = + ectx.dctx.write().unwrap().add_instance(RANGE_CHECK_MIX_AIRGROUP_ID, RANGE_CHECK_MIX_AIR_IDS[0], 1); + if is_myne { + pctx.air_instance_repo.add_air_instance(air_instance, Some(gid)); + } } } diff --git a/proofman/src/proofman.rs b/proofman/src/proofman.rs index cae9e38e..92b357f6 100644 --- a/proofman/src/proofman.rs +++ b/proofman/src/proofman.rs @@ -3,6 +3,7 @@ use log::{info, trace}; use p3_field::Field; use stark::{StarkBufferAllocator, StarkProver}; use proofman_starks_lib_c::{save_challenges_c, save_publics_c}; +use core::panic; use std::fs; use std::error::Error; use std::mem::MaybeUninit; @@ -23,9 +24,6 @@ use std::os::raw::c_void; use proofman_util::{timer_start_info, timer_start_debug, timer_stop_and_log_info, timer_stop_and_log_debug}; -#[cfg(feature = "distributed")] -use mpi::collective::CommunicatorCollectives; - pub struct ProofMan { _phantom: std::marker::PhantomData, } @@ -53,7 +51,6 @@ impl ProofMan { &output_dir_path, options.verify_constraints, )?; - let buffer_allocator: Arc = Arc::new(StarkBufferAllocator::new(proving_key_path.clone())); let ectx = ExecutionCtx::builder() .with_rom_path(rom_path) @@ -61,7 +58,6 @@ impl ProofMan { .with_verbose_mode(options.verbose_mode) .build(); let ectx = Arc::new(ectx); - // Load the witness computation dynamic library let library = unsafe { Library::new(&witness_lib_path)? }; @@ -76,12 +72,15 @@ impl ProofMan { Self::initialize_witness(&mut witness_lib, pctx.clone(), ectx.clone(), sctx.clone()); witness_lib.calculate_witness(1, pctx.clone(), ectx.clone(), sctx.clone()); - if ectx.dctx.read().unwrap().is_master() { + let mut dctx = ectx.dctx.write().unwrap(); + dctx.close(); + if dctx.rank == 0 { Self::print_summary(pctx.clone()); } + drop(dctx); let mut provers: Vec>> = Vec::new(); - let n_provers: usize = Self::initialize_provers(sctx.clone(), &mut provers, pctx.clone(), ectx.clone()); + Self::initialize_provers(sctx.clone(), &mut provers, pctx.clone(), ectx.clone()); if provers.is_empty() { return Err("No instances found".into()); @@ -93,7 +92,6 @@ impl ProofMan { let num_commit_stages = pctx.global_info.n_challenges.len() as u32; for stage in 1..=num_commit_stages { Self::get_challenges(stage, &mut provers, pctx.clone(), &transcript); - if stage != 1 { witness_lib.calculate_witness(stage, pctx.clone(), ectx.clone(), sctx.clone()); } @@ -112,7 +110,6 @@ impl ProofMan { ectx.clone(), &mut transcript, options.verify_constraints, - n_provers, ); } } @@ -135,11 +132,10 @@ impl ProofMan { ectx.clone(), &mut transcript, false, - n_provers, ); // Compute openings - Self::opening_stages(&mut provers, pctx.clone(), sctx.clone(), ectx.clone(), &mut transcript, n_provers); + Self::opening_stages(&mut provers, pctx.clone(), ectx.clone(), &mut transcript); //Generate proves_out let proves_out = Self::finalize_proof(&mut provers, pctx.clone(), output_dir_path.to_string_lossy().as_ref()); @@ -235,18 +231,10 @@ impl ProofMan { provers: &mut Vec>>, pctx: Arc>, _ectx: Arc, - ) -> usize { + ) { timer_start_debug!(INITIALIZE_PROVERS); - let mut cont = 0; info!("{}: Initializing provers", Self::MY_NAME); for air_instance in pctx.air_instance_repo.air_instances.read().unwrap().iter() { - cont += 1; - #[cfg(feature = "distributed")] - let segment_idx = air_instance.air_segment_id.unwrap_or(0); // Only for main proof - #[cfg(feature = "distributed")] - if segment_idx as i32 % _ectx.dctx.read().unwrap().n_processes != _ectx.dctx.read().unwrap().rank { - continue; - } let air_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; log::debug!("{}: Initializing prover for air instance {}", Self::MY_NAME, air_name); let prover = Box::new(StarkProver::new( @@ -276,7 +264,6 @@ impl ProofMan { *pctx.buff_helper.buff_helper.write().unwrap() = buff_helper; timer_stop_and_log_debug!(INITIALIZE_PROVERS); - cont } pub fn calculate_stage(stage: u32, provers: &mut [Box>], proof_ctx: Arc>) { @@ -353,7 +340,6 @@ impl ProofMan { ectx: Arc, transcript: &mut FFITranscript, verify_constraints: bool, - n_provers: usize, ) { if stage == 1 { let public_inputs_guard = pctx.public_inputs.inputs.read().unwrap(); @@ -363,88 +349,43 @@ impl ProofMan { } let dctx = ectx.dctx.read().unwrap(); - let is_distributed = dctx.is_distributed(); - let n_processes = dctx.n_processes; - drop(dctx); - if !is_distributed { - info!("{}: Calculating challenges", Self::MY_NAME); - let airgroups = pctx.global_info.air_groups.clone(); - for (airgroup_id, _airgroup) in airgroups.iter().enumerate() { - if verify_constraints { - let dummy_elements = [F::zero(), F::one(), F::two(), F::neg_one()]; - transcript.add_elements(dummy_elements.as_ptr() as *mut c_void, 4); - } else { - let airgroup_instances = pctx.air_instance_repo.find_airgroup_instances(airgroup_id); - - if !airgroup_instances.is_empty() { - let mut values = Vec::new(); - for prover_idx in airgroup_instances.iter() { - let value = provers[*prover_idx].get_transcript_values(stage as u64, pctx.clone()); - values.push(value); - } - if !values.is_empty() { - let value = Self::hash_b_tree(&*provers[airgroup_instances[0]], values.clone()); - transcript.add_elements(value.as_ptr() as *mut c_void, value.len()); - } - } - } + // calculate my roots + let mut roots: Vec = vec![0; 4 * provers.len()]; + for (i, prover) in provers.iter_mut().enumerate() { + // Important we need the roots in u64 in order to distribute them + let values = prover.get_transcript_values_u64(stage as u64, pctx.clone()); + if values.is_empty() { + panic!("No transcript values found for prover {}", i); } - } else { - let size = n_processes; - // max number of roots - let max_roots = (n_provers as i32 + size - 1) / size; - - // calculate my roots - let mut roots: Vec = vec![0; 4 * max_roots as usize]; - for (i, prover) in provers.iter_mut().enumerate() { - //prover.get_root(stage as u64, pctx.clone(), &mut roots[i * 4..(i + 1) * 4]); - let values = prover.get_transcript_values_u64(stage as u64, pctx.clone()); - if values.is_empty() { - panic!("No transcript values found for prover {}", i); + roots[i * 4..(i + 1) * 4].copy_from_slice(&values) + } + // get all roots + let all_roots = dctx.distribute_roots(roots); + + // add challenges to transcript in order + for group_idxs in dctx.my_groups.iter() { + if verify_constraints { + let dummy_elements = [F::zero(), F::one(), F::two(), F::neg_one()]; + transcript.add_elements(dummy_elements.as_ptr() as *mut c_void, 4); + } else { + let mut values = Vec::new(); + for idx in group_idxs.iter() { + let value = vec![ + F::from_wrapped_u64(all_roots[*idx]), + F::from_wrapped_u64(all_roots[*idx + 1]), + F::from_wrapped_u64(all_roots[*idx + 2]), + F::from_wrapped_u64(all_roots[*idx + 3]), + ]; + values.push(value); } - roots[i * 4..(i + 1) * 4].copy_from_slice(&values) - } - - // Use all ghater - let all_roots: Vec = vec![0; 4 * max_roots as usize * size as usize]; - #[cfg(feature = "distributed")] - ectx.dctx.read().unwrap().world.all_gather_into(&roots, &mut all_roots); - - // add challenges to transcript - let airgroups = pctx.global_info.air_groups.clone(); - for (airgroup_id, _airgroup) in airgroups.iter().enumerate() { - if verify_constraints { - let dummy_elements = [F::zero(), F::one(), F::two(), F::neg_one()]; - transcript.add_elements(dummy_elements.as_ptr() as *mut c_void, 4); - } else { - let airgroup_instances = pctx.air_instance_repo.find_airgroup_instances(airgroup_id); - if !airgroup_instances.is_empty() { - let mut values: Vec> = Vec::new(); - for air_idx in airgroup_instances.iter() { - let mut value = Vec::new(); - let air_instance = &pctx.air_instance_repo.air_instances.read().unwrap()[*air_idx]; - let segment_idx = air_instance.air_segment_id.unwrap_or(0); // Only for main proof - let root_rank = segment_idx % size as usize; - let root_idx = segment_idx / size as usize; - let root_ptr = &all_roots[root_rank * 4 * max_roots as usize + root_idx * 4 - ..root_rank * 4 * max_roots as usize + root_idx * 4 + 4]; - - value.push(F::from_wrapped_u64(root_ptr[0])); - value.push(F::from_wrapped_u64(root_ptr[1])); - value.push(F::from_wrapped_u64(root_ptr[2])); - value.push(F::from_wrapped_u64(root_ptr[3])); - - values.push(value); - } - if !values.is_empty() { - let value = Self::hash_b_tree(&*provers[0], values); - transcript.add_elements(value.as_ptr() as *mut c_void, value.len()); - } - } + if !values.is_empty() { + let value = Self::hash_b_tree(&*provers[0], values); + transcript.add_elements(value.as_ptr() as *mut c_void, value.len()); } } } + drop(dctx); } fn get_challenges( @@ -459,95 +400,35 @@ impl ProofMan { pub fn opening_stages( provers: &mut [Box>], pctx: Arc>, - sctx: Arc, ectx: Arc, transcript: &mut FFITranscript, - n_provers: usize, ) { let num_commit_stages = pctx.global_info.n_challenges.len() as u32; let dctx = ectx.dctx.read().unwrap(); - let size = dctx.n_processes; - let rank = dctx.rank; - let is_distributed = dctx.is_distributed(); - drop(dctx); // Calculate evals - Self::get_challenges(num_commit_stages + 2, provers, pctx.clone(), transcript); timer_start_debug!(CALCULATING_EVALS); - info!("{}: Calculating evals", Self::MY_NAME); - for (airgroup_id, airgroup) in sctx.get_setup_airs().iter().enumerate() { - for air_id in airgroup.iter() { - let air_instances_idx: Vec = pctx.air_instance_repo.find_air_instances(airgroup_id, *air_id); - if !air_instances_idx.is_empty() { - if is_distributed { - let mut is_first = true; - for idx in air_instances_idx { - let segment_idx = - &pctx.air_instance_repo.air_instances.read().unwrap()[idx].air_segment_id.unwrap(); - if *segment_idx as i32 % size == rank { - let loc_idx = segment_idx / size as usize; - if is_first { - provers[loc_idx].calculate_lev(pctx.clone()); - is_first = false; - } - provers[loc_idx].opening_stage(1, pctx.clone()); - } - } - } else { - provers[air_instances_idx[0]].calculate_lev(pctx.clone()); - for idx in air_instances_idx { - provers[idx].opening_stage(1, pctx.clone()); - } - } - } + for group_idx in dctx.my_air_groups.iter() { + provers[group_idx[0]].calculate_lev(pctx.clone()); + for idx in group_idx.iter() { + provers[*idx].opening_stage(1, pctx.clone()); } } timer_stop_and_log_debug!(CALCULATING_EVALS); - Self::calculate_challenges( - num_commit_stages + 2, - provers, - pctx.clone(), - ectx.clone(), - transcript, - false, - n_provers, - ); + Self::calculate_challenges(num_commit_stages + 2, provers, pctx.clone(), ectx.clone(), transcript, false); // Calculate fri polynomial Self::get_challenges(pctx.global_info.n_challenges.len() as u32 + 3, provers, pctx.clone(), transcript); info!("{}: Calculating FRI Polynomials", Self::MY_NAME); timer_start_debug!(CALCULATING_FRI_POLINOMIAL); - - let is_distributed = ectx.dctx.read().unwrap().is_distributed(); - for (airgroup_id, airgroup) in sctx.get_setup_airs().iter().enumerate() { - for air_id in airgroup.iter() { - let air_instances_idx: Vec = pctx.air_instance_repo.find_air_instances(airgroup_id, *air_id); - if !air_instances_idx.is_empty() { - if is_distributed { - let mut is_first = true; - for idx in air_instances_idx { - let segment_idx = - &pctx.air_instance_repo.air_instances.read().unwrap()[idx].air_segment_id.unwrap(); - if *segment_idx as i32 % size == rank { - let loc_idx = segment_idx / size as usize; - if is_first { - provers[loc_idx].calculate_xdivxsub(pctx.clone()); - is_first = false; - } - provers[loc_idx].opening_stage(2, pctx.clone()); - } - } - } else { - provers[air_instances_idx[0]].calculate_xdivxsub(pctx.clone()); - - for idx in air_instances_idx { - provers[idx].opening_stage(2, pctx.clone()); - } - } - } + for group_idx in dctx.my_air_groups.iter() { + provers[group_idx[0]].calculate_xdivxsub(pctx.clone()); + for idx in group_idx.iter() { + provers[*idx].opening_stage(2, pctx.clone()); } } timer_stop_and_log_debug!(CALCULATING_FRI_POLINOMIAL); + drop(dctx); let global_steps_fri: Vec = pctx.global_info.steps_fri.iter().map(|step| step.n_bits).collect(); let num_opening_stages = global_steps_fri.len() as u32; @@ -588,7 +469,6 @@ impl ProofMan { ectx.clone(), transcript, false, - n_provers, ); } timer_stop_and_log_debug!(CALCULATING_FRI_STEP); diff --git a/provers/stark/src/stark_prover.rs b/provers/stark/src/stark_prover.rs index 1d681289..67c8effd 100644 --- a/provers/stark/src/stark_prover.rs +++ b/provers/stark/src/stark_prover.rs @@ -1,3 +1,4 @@ +use core::panic; use std::error::Error; use std::fs::File; use std::io::Read; @@ -390,11 +391,11 @@ impl Prover for StarkProver { } fn get_transcript_values_u64(&self, stage: u64, proof_ctx: Arc>) -> Vec { - let p_stark = self.p_stark; + let p_stark: *mut std::ffi::c_void = self.p_stark; let air_name = &proof_ctx.global_info.airs[self.airgroup_id][self.air_id].name; - let mut value: Vec = vec![Goldilocks::zero(); self.n_field_elements]; + let mut value = vec![Goldilocks::zero(); self.n_field_elements]; if stage <= (Self::num_stages(self) + 1) as u64 { let (n_airvals_stage, indexes): (usize, Vec) = self .stark_info @@ -566,17 +567,7 @@ impl Prover for StarkProver { } fn get_zkin_proof(&self, proof_ctx: Arc>, output_dir: &str) -> *mut c_void { - #[cfg(not(feature = "distributed"))] - let idx = self.prover_idx; - #[cfg(feature = "distributed")] - let idx; - #[cfg(feature = "distributed")] - { - let segment_id: &usize = - &proof_ctx.air_instance_repo.air_instances.read().unwrap()[self.prover_idx].air_segment_id.unwrap(); - idx = *segment_id; - } - + let gidx = proof_ctx.air_instance_repo.air_instances.read().unwrap()[self.prover_idx].global_idx.unwrap(); let public_inputs_guard = proof_ctx.public_inputs.inputs.read().unwrap(); let public_inputs = (*public_inputs_guard).as_ptr() as *mut c_void; @@ -587,7 +578,7 @@ impl Prover for StarkProver { let global_info_file: &str = global_info_path.to_str().unwrap(); fri_proof_get_zkinproof_c( - idx as u64, + gidx as u64, self.p_proof.unwrap(), public_inputs, challenges,