Skip to content

Commit

Permalink
Merge pull request #88 from 0xPolygonHermez/distributed_WC_merge
Browse files Browse the repository at this point in the history
Distributed prover
  • Loading branch information
rickb80 authored Oct 28, 2024
2 parents 3e90c1f + 236c49f commit 272603d
Show file tree
Hide file tree
Showing 34 changed files with 556 additions and 265 deletions.
2 changes: 2 additions & 0 deletions common/src/air_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct AirInstance<F> {
pub air_segment_id: Option<usize>,
pub air_instance_id: Option<usize>,
pub idx: Option<usize>,
pub global_idx: Option<usize>,
pub buffer: Vec<F>,
pub airgroup_values: Vec<F>,
pub airvalues: Vec<F>,
Expand All @@ -58,6 +59,7 @@ impl<F: Field> AirInstance<F> {
air_segment_id,
air_instance_id: None,
idx: None,
global_idx: None,
buffer,
airgroup_values: vec![F::zero(); get_n_airgroupvals_c(ps.p_setup.p_stark_info) as usize * 3],
airvalues: vec![F::zero(); get_n_airvals_c(ps.p_setup.p_stark_info) as usize * 3],
Expand Down
13 changes: 6 additions & 7 deletions common/src/air_instances_repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@ impl<F: Field> AirInstancesRepository<F> {
}
}

pub fn add_air_instance(&self, mut air_instance: AirInstance<F>) {
pub fn add_air_instance(&self, mut air_instance: AirInstance<F>, global_idx: Option<usize>) {
let mut air_instances = self.air_instances.write().unwrap();
let n_air_instances = air_instances.len();

let mut air_instances_counts = self.air_instances_counts.write().unwrap();
let instance_id = air_instances_counts.entry((air_instance.airgroup_id, air_instance.air_id)).or_insert(0);
air_instance.set_air_instance_id(*instance_id, n_air_instances);
if global_idx.is_some() {
air_instance.global_idx = global_idx;
} else {
air_instance.global_idx = Some(n_air_instances);
}
*instance_id += 1;
air_instances.push(air_instance);
}
Expand All @@ -38,17 +43,11 @@ impl<F: Field> AirInstancesRepository<F> {
let air_instances = self.air_instances.read().unwrap();

let mut indices = Vec::new();
#[cfg(feature = "distributed")]
let mut segment_ids = Vec::new();
for (index, air_instance) in air_instances.iter().enumerate() {
if air_instance.airgroup_id == airgroup_id {
indices.push(index);
#[cfg(feature = "distributed")]
segment_ids.push(air_instance.air_segment_id.unwrap_or(0));
}
}
#[cfg(feature = "distributed")]
indices.sort_by(|a, b| segment_ids[*a].cmp(&segment_ids[*b]));
indices
}

Expand Down
226 changes: 208 additions & 18 deletions common/src/distribution_ctx.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::collections::HashMap;
use std::collections::BTreeMap;
#[cfg(feature = "distributed")]
use mpi::collective::CommunicatorCollectives;
#[cfg(feature = "distributed")]
use mpi::traits::Communicator;
use mpi::traits::*;
#[cfg(feature = "distributed")]
use mpi::environment::Universe;
#[cfg(feature = "distributed")]
use mpi::collective::CommunicatorCollectives;
#[cfg(feature = "distributed")]
use mpi::datatype::PartitionMut;

/// Represents the context of distributed computing
pub struct DistributionCtx {
Expand All @@ -13,9 +17,18 @@ pub struct DistributionCtx {
pub universe: Universe,
#[cfg(feature = "distributed")]
pub world: mpi::topology::SimpleCommunicator,
pub n_instances: i32,
pub n_instances: usize,
pub my_instances: Vec<usize>,
pub instances: Vec<(usize, usize)>,
pub instances: Vec<(usize, usize)>, //group_id, air_id
pub instances_owner: Vec<(usize, usize)>, //owner_rank, owner_instance_idx
pub owners_count: Vec<i32>,
pub owners_weight: Vec<u64>,
#[cfg(feature = "distributed")]
pub roots_gatherv_count: Vec<i32>,
#[cfg(feature = "distributed")]
pub roots_gatherv_displ: Vec<i32>,
pub my_groups: Vec<Vec<usize>>,
pub my_air_groups: Vec<Vec<usize>>,
}

impl DistributionCtx {
Expand All @@ -24,19 +37,39 @@ impl DistributionCtx {
{
let (universe, _threading) = mpi::initialize_with_threading(mpi::Threading::Multiple).unwrap();
let world = universe.world();
let rank = world.rank();
let n_processes = world.size();
DistributionCtx {
rank: world.rank(),
n_processes: world.size(),
rank,
n_processes,
universe,
world,
n_instances: 0,
my_instances: Vec::new(),
instances: Vec::new(),
instances_owner: Vec::new(),
owners_count: vec![0; n_processes as usize],
owners_weight: vec![0; n_processes as usize],
roots_gatherv_count: vec![0; n_processes as usize],
roots_gatherv_displ: vec![0; n_processes as usize],
my_groups: Vec::new(),
my_air_groups: Vec::new(),
}
}
#[cfg(not(feature = "distributed"))]
{
DistributionCtx { rank: 0, n_processes: 1, n_instances: 0, my_instances: Vec::new(), instances: Vec::new() }
DistributionCtx {
rank: 0,
n_processes: 1,
n_instances: 0,
my_instances: Vec::new(),
instances: Vec::new(),
instances_owner: Vec::new(),
owners_count: vec![0; 1],
owners_weight: vec![0; 1],
my_groups: Vec::new(),
my_air_groups: Vec::new(),
}
}
}

Expand All @@ -48,28 +81,185 @@ impl DistributionCtx {
}
}

#[inline]
pub fn is_master(&self) -> bool {
self.rank == 0
}

#[inline]
pub fn is_distributed(&self) -> bool {
self.n_processes > 1
}

#[inline]
pub fn is_my_instance(&self, instance_idx: usize) -> bool {
instance_idx % self.n_processes as usize == self.rank as usize
self.owner(instance_idx) == self.rank as usize
}

#[inline]
pub fn owner(&self, instance_idx: usize) -> usize {
self.instances_owner[instance_idx].0
}

#[inline]
pub fn add_instance(&mut self, airgroup_id: usize, air_id: usize, instance_idx: usize, _size: usize) {
pub fn add_instance(&mut self, airgroup_id: usize, air_id: usize, weight: usize) -> (bool, usize) {
let mut is_mine = false;
let owner = self.n_instances % self.n_processes as usize;
self.instances.push((airgroup_id, air_id));
self.instances_owner.push((owner, self.owners_count[owner] as usize));
self.owners_count[owner] += 1;
self.owners_weight[owner] += weight as u64;

if owner == self.rank as usize {
self.my_instances.push(self.n_instances);
is_mine = true;
}
self.n_instances += 1;
if self.is_my_instance(instance_idx) {
self.my_instances.push(instance_idx);
(is_mine, self.n_instances - 1)
}

pub fn close(&mut self) {
let mut group_indices: BTreeMap<usize, Vec<usize>> = BTreeMap::new();

// Calculate the partial sums of owners_count
#[cfg(feature = "distributed")]
{
let mut total_instances = 0;
for i in 0..self.n_processes as usize {
self.roots_gatherv_displ[i] = total_instances;
self.roots_gatherv_count[i] = self.owners_count[i] * 4;
total_instances += self.roots_gatherv_count[i];
}
}

// Populate the HashMap based on group_id and buffer positions
for (idx, &(group_id, _)) in self.instances.iter().enumerate() {
#[cfg(feature = "distributed")]
let pos_buffer =
self.roots_gatherv_displ[self.instances_owner[idx].0] as usize + self.instances_owner[idx].1 * 4;
#[cfg(not(feature = "distributed"))]
let pos_buffer = idx * 4;
group_indices.entry(group_id).or_default().push(pos_buffer);
}

// Flatten the HashMap into a single vector for my_groups
for (_, indices) in group_indices {
self.my_groups.push(indices);
}

// Create my eval groups
let mut my_air_groups_indices: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
for (loc_idx, glob_idx) in self.my_instances.iter().enumerate() {
let instance_idx = self.instances[*glob_idx];
my_air_groups_indices.entry(instance_idx).or_default().push(loc_idx);
}

// Flatten the HashMap into a single vector for my_air_groups
for (_, indices) in my_air_groups_indices {
self.my_air_groups.push(indices);
}
}

pub fn distribute_roots(&self, roots: Vec<u64>) -> Vec<u64> {
#[cfg(feature = "distributed")]
{
let mut all_roots: Vec<u64> = vec![0; 4 * self.n_instances];
let counts = &self.roots_gatherv_count;
let displs = &self.roots_gatherv_displ;

let mut partitioned_all_roots = PartitionMut::new(&mut all_roots, counts.as_slice(), displs.as_slice());

self.world.all_gather_varcount_into(&roots, &mut partitioned_all_roots);

all_roots
}
#[cfg(not(feature = "distributed"))]
{
roots
}
}

pub fn distribute_multiplicity(&self, _multiplicity: &mut [u64], _owner: usize) {
#[cfg(feature = "distributed")]
{
//assert that I can operate with u32
assert!(_multiplicity.len() < std::u32::MAX as usize);

if _owner != self.rank as usize {
//pack multiplicities in a sparce vector
let mut packed_multiplicity = Vec::new();
packed_multiplicity.push(0 as u32); //this will be the counter
for (idx, &m) in _multiplicity.iter().enumerate() {
if m != 0 {
assert!(m < std::u32::MAX as u64);
packed_multiplicity.push(idx as u32);
packed_multiplicity.push(m as u32);
packed_multiplicity[0] += 2;
}
}
self.world.process_at_rank(_owner as i32).send(&packed_multiplicity[..]);
} else {
let mut packed_multiplicity: Vec<u32> = vec![0; _multiplicity.len() * 2 + 1];
for i in 0..self.n_processes {
if i != _owner as i32 {
self.world.process_at_rank(i).receive_into(&mut packed_multiplicity);
for j in (1..packed_multiplicity[0]).step_by(2) {
let idx = packed_multiplicity[j as usize] as usize;
let m = packed_multiplicity[j as usize + 1] as u64;
_multiplicity[idx] += m;
}
}
}
}
}
}

pub fn distribute_multiplicities(&self, _multiplicities: &mut [Vec<u64>], _owner: usize) {
#[cfg(feature = "distributed")]
{
// Ensure that each multiplicity vector can be operated with u32
let mut buff_size = 0;
for multiplicity in _multiplicities.iter() {
assert!(multiplicity.len() < std::u32::MAX as usize);
buff_size += multiplicity.len() + 1;
}

let n_columns = _multiplicities.len();
if _owner != self.rank as usize {
// Pack multiplicities in a sparse vector
let mut packed_multiplicities = vec![0u32; n_columns];
for (col_idx, multiplicity) in _multiplicities.iter().enumerate() {
for (idx, &m) in multiplicity.iter().enumerate() {
if m != 0 {
assert!(m < std::u32::MAX as u64);
packed_multiplicities[col_idx] += 1;
packed_multiplicities.push(idx as u32);
packed_multiplicities.push(m as u32);
}
}
}
self.world.process_at_rank(_owner as i32).send(&packed_multiplicities[..]);
} else {
let mut packed_multiplicities: Vec<u32> = vec![0; buff_size * 2];
for i in 0..self.n_processes {
if i != _owner as i32 {
self.world.process_at_rank(i).receive_into(&mut packed_multiplicities);

// Read counters
let mut counters = vec![0usize; n_columns];
for col_idx in 0..n_columns {
counters[col_idx] = packed_multiplicities[col_idx] as usize;
}

// Unpack multiplicities
let mut idx = n_columns;
for col_idx in 0..n_columns {
for _ in 0..counters[col_idx] {
let row_idx = packed_multiplicities[idx] as usize;
let m = packed_multiplicities[idx + 1] as u64;
_multiplicities[col_idx][row_idx] += m;
idx += 2;
}
}
}
}
}
}
self.instances.push((airgroup_id, air_id));
}
}

Expand Down
6 changes: 5 additions & 1 deletion examples/fibonacci-square/src/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ impl<F: PrimeField + Copy> FibonacciSquare<F> {
air_instance.set_airvalue(&sctx, "FibonacciSquare.fibo2", F::from_canonical_u64(2));
air_instance.set_airvalue_ext(&sctx, "FibonacciSquare.fibo3", vec![F::from_canonical_u64(5); 3]);

pctx.air_instance_repo.add_air_instance(air_instance);
let (is_myne, gid) =
ectx.dctx.write().unwrap().add_instance(FIBONACCI_SQUARE_AIRGROUP_ID, FIBONACCI_SQUARE_AIR_IDS[0], 1);
if is_myne {
pctx.air_instance_repo.add_air_instance(air_instance, Some(gid));
}

Ok(b)
}
Expand Down
6 changes: 5 additions & 1 deletion examples/fibonacci-square/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ impl<F: PrimeField + AbstractField + Clone + Copy + Default + 'static> Module<F>
}

let air_instance = AirInstance::new(sctx.clone(), MODULE_AIRGROUP_ID, MODULE_AIR_IDS[0], Some(0), buffer);
pctx.air_instance_repo.add_air_instance(air_instance);

let (is_myne, gid) = ectx.dctx.write().unwrap().add_instance(MODULE_AIRGROUP_ID, MODULE_AIR_IDS[0], 1);
if is_myne {
pctx.air_instance_repo.add_air_instance(air_instance, Some(gid));
}

self.std_lib.unregister_predecessor(pctx, None);
}
Expand Down
Loading

0 comments on commit 272603d

Please sign in to comment.