Skip to content

Commit

Permalink
Removing OnceCell from setup (#91)
Browse files Browse the repository at this point in the history
* Removing OnceCell
* Modifying initialize setup to work with distributed
  • Loading branch information
RogerTaule authored Nov 6, 2024
1 parent 34427c1 commit e949af1
Show file tree
Hide file tree
Showing 80 changed files with 1,346 additions and 1,145 deletions.
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ members = [
"transcript",
"util",
"pil2-components/lib/std/rs",
# "pil2-components/test/std/range_check/rs",
# "pil2-components/test/std/lookup/rs",
# "pil2-components/test/std/connection/rs",
# "pil2-components/test/std/permutation/rs",
# "pil2-components/test/simple/rs",
#"pil2-components/test/std/range_check/rs",
#"pil2-components/test/std/lookup/rs",
#"pil2-components/test/std/connection/rs",
#"pil2-components/test/std/permutation/rs",
#"pil2-components/test/simple/rs",
# whoever re-enables this, it has to work out of
# the box with `cargo check --workspace` or CI will
# break and dev experience will be bad since repo
Expand Down
22 changes: 12 additions & 10 deletions common/src/air_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub struct StepsParams {
pub airvalues: *mut c_void,
pub evals: *mut c_void,
pub xdivxsub: *mut c_void,
pub p_const_pols: *mut c_void,
pub p_const_tree: *mut c_void,
}

impl From<&StepsParams> for *mut c_void {
Expand Down Expand Up @@ -45,13 +47,13 @@ pub struct AirInstance<F> {

impl<F: Field> AirInstance<F> {
pub fn new(
setup_ctx: Arc<SetupCtx>,
setup_ctx: Arc<SetupCtx<F>>,
airgroup_id: usize,
air_id: usize,
air_segment_id: Option<usize>,
buffer: Vec<F>,
) -> Self {
let ps = setup_ctx.get_partial_setup(airgroup_id, air_id).expect("REASON");
let ps = setup_ctx.get_setup(airgroup_id, air_id);

AirInstance {
airgroup_id,
Expand All @@ -74,8 +76,8 @@ impl<F: Field> AirInstance<F> {
self.buffer.as_ptr() as *mut u8
}

pub fn set_airvalue(&mut self, setup_ctx: &SetupCtx, name: &str, value: F) {
let ps = setup_ctx.get_partial_setup(self.airgroup_id, self.air_id).expect("REASON");
pub fn set_airvalue(&mut self, setup_ctx: &SetupCtx<F>, name: &str, value: F) {
let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id);

let id = get_airval_id_by_name_c(ps.p_setup.p_stark_info, name);
if id == -1 {
Expand All @@ -86,8 +88,8 @@ impl<F: Field> AirInstance<F> {
self.set_airvalue_calculated(id as usize);
}

pub fn set_airvalue_ext(&mut self, setup_ctx: &SetupCtx, name: &str, value: Vec<F>) {
let ps = setup_ctx.get_partial_setup(self.airgroup_id, self.air_id).expect("REASON");
pub fn set_airvalue_ext(&mut self, setup_ctx: &SetupCtx<F>, name: &str, value: Vec<F>) {
let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id);

let id = get_airval_id_by_name_c(ps.p_setup.p_stark_info, name);
if id == -1 {
Expand All @@ -105,8 +107,8 @@ impl<F: Field> AirInstance<F> {
self.set_airvalue_calculated(id as usize);
}

pub fn set_airgroupvalue(&mut self, setup_ctx: &SetupCtx, name: &str, value: F) {
let ps = setup_ctx.get_partial_setup(self.airgroup_id, self.air_id).expect("REASON");
pub fn set_airgroupvalue(&mut self, setup_ctx: &SetupCtx<F>, name: &str, value: F) {
let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id);

let id = get_airgroupval_id_by_name_c(ps.p_setup.p_stark_info, name);
if id == -1 {
Expand All @@ -117,8 +119,8 @@ impl<F: Field> AirInstance<F> {
self.set_airgroupvalue_calculated(id as usize);
}

pub fn set_airgroupvalue_ext(&mut self, setup_ctx: &SetupCtx, name: &str, value: Vec<F>) {
let ps = setup_ctx.get_partial_setup(self.airgroup_id, self.air_id).expect("REASON");
pub fn set_airgroupvalue_ext(&mut self, setup_ctx: &SetupCtx<F>, name: &str, value: Vec<F>) {
let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id);

let id = get_airgroupval_id_by_name_c(ps.p_setup.p_stark_info, name);
if id == -1 {
Expand Down
4 changes: 2 additions & 2 deletions common/src/buffer_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use std::error::Error;

use crate::SetupCtx;

pub trait BufferAllocator: Send + Sync {
pub trait BufferAllocator<F>: Send + Sync {
// Returns the size of the buffer and the offsets for each stage
fn get_buffer_info(
&self,
sctx: &SetupCtx,
sctx: &SetupCtx<F>,
airgroup_id: usize,
air_id: usize,
) -> Result<(u64, Vec<u64>), Box<dyn Error>>;
Expand Down
22 changes: 11 additions & 11 deletions common/src/execution_ctx.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
use std::{path::PathBuf, sync::Arc};
use crate::{BufferAllocator, VerboseMode, DistributionCtx};
use crate::{BufferAllocator, DistributionCtx, VerboseMode};
use std::sync::RwLock;
#[allow(dead_code)]
/// Represents the context when executing a witness computer plugin
pub struct ExecutionCtx {
pub struct ExecutionCtx<F> {
pub rom_path: Option<PathBuf>,
/// If true, the plugin must generate the public outputs
pub public_output: bool,
pub buffer_allocator: Arc<dyn BufferAllocator>,
pub buffer_allocator: Arc<dyn BufferAllocator<F>>,
pub verbose_mode: VerboseMode,
pub dctx: RwLock<DistributionCtx>,
}

impl ExecutionCtx {
pub fn builder() -> ExecutionCtxBuilder {
impl<F> ExecutionCtx<F> {
pub fn builder() -> ExecutionCtxBuilder<F> {
ExecutionCtxBuilder::new()
}
}

pub struct ExecutionCtxBuilder {
pub struct ExecutionCtxBuilder<F> {
rom_path: Option<PathBuf>,
public_output: bool,
buffer_allocator: Option<Arc<dyn BufferAllocator>>,
buffer_allocator: Option<Arc<dyn BufferAllocator<F>>>,
verbose_mode: VerboseMode,
}

impl Default for ExecutionCtxBuilder {
impl<F> Default for ExecutionCtxBuilder<F> {
fn default() -> Self {
Self::new()
}
}

impl ExecutionCtxBuilder {
impl<F> ExecutionCtxBuilder<F> {
pub fn new() -> Self {
ExecutionCtxBuilder {
rom_path: None,
Expand All @@ -46,7 +46,7 @@ impl ExecutionCtxBuilder {
self
}

pub fn with_buffer_allocator(mut self, buffer_allocator: Arc<dyn BufferAllocator>) -> Self {
pub fn with_buffer_allocator(mut self, buffer_allocator: Arc<dyn BufferAllocator<F>>) -> Self {
self.buffer_allocator = Some(buffer_allocator);
self
}
Expand All @@ -56,7 +56,7 @@ impl ExecutionCtxBuilder {
self
}

pub fn build(self) -> ExecutionCtx {
pub fn build(self) -> ExecutionCtx<F> {
if self.buffer_allocator.is_none() {
panic!("Buffer allocator is required");
}
Expand Down
16 changes: 12 additions & 4 deletions common/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use std::os::raw::c_void;
use std::os::raw::c_char;
use std::sync::Arc;

use p3_field::Field;
use transcript::FFITranscript;

use crate::ProofCtx;
use crate::SetupCtx;

#[derive(Debug, PartialEq)]
pub enum ProverStatus {
Expand Down Expand Up @@ -54,16 +56,22 @@ pub struct ConstraintsResults {
pub constraints_info: *mut ConstraintInfo,
}

pub trait Prover<F> {
pub trait Prover<F: Field> {
fn build(&mut self, proof_ctx: Arc<ProofCtx<F>>);
fn free(&mut self);
fn new_transcript(&self) -> FFITranscript;
fn num_stages(&self) -> u32;
fn get_challenges(&self, stage_id: u32, proof_ctx: Arc<ProofCtx<F>>, transcript: &FFITranscript);
fn calculate_stage(&mut self, stage_id: u32, proof_ctx: Arc<ProofCtx<F>>);
fn calculate_stage(&mut self, stage_id: u32, setup_ctx: Arc<SetupCtx<F>>, proof_ctx: Arc<ProofCtx<F>>);
fn commit_stage(&mut self, stage_id: u32, proof_ctx: Arc<ProofCtx<F>>) -> ProverStatus;
fn calculate_xdivxsub(&mut self, proof_ctx: Arc<ProofCtx<F>>);
fn calculate_lev(&mut self, proof_ctx: Arc<ProofCtx<F>>);
fn opening_stage(&mut self, opening_id: u32, proof_ctx: Arc<ProofCtx<F>>) -> ProverStatus;
fn opening_stage(
&mut self,
opening_id: u32,
setup_ctx: Arc<SetupCtx<F>>,
proof_ctx: Arc<ProofCtx<F>>,
) -> ProverStatus;

fn get_buff_helper_size(&self) -> usize;
fn get_proof(&self) -> *mut c_void;
Expand All @@ -73,7 +81,7 @@ pub trait Prover<F> {
fn get_transcript_values(&self, stage: u64, proof_ctx: Arc<ProofCtx<F>>) -> Vec<F>;
fn get_transcript_values_u64(&self, stage: u64, proof_ctx: Arc<ProofCtx<F>>) -> Vec<u64>;
fn calculate_hash(&self, values: Vec<F>) -> Vec<F>;
fn verify_constraints(&self, proof_ctx: Arc<ProofCtx<F>>) -> Vec<ConstraintInfo>;
fn verify_constraints(&self, setup_ctx: Arc<SetupCtx<F>>, proof_ctx: Arc<ProofCtx<F>>) -> Vec<ConstraintInfo>;

fn get_proof_challenges(&self, global_steps: Vec<usize>, global_challenges: Vec<F>) -> Vec<F>;
}
126 changes: 82 additions & 44 deletions common/src/setup.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use std::mem::MaybeUninit;
use std::os::raw::c_void;
use std::path::PathBuf;
use std::sync::RwLock;

use proofman_starks_lib_c::{const_pols_new_c, const_pols_with_tree_new_c, expressions_bin_new_c, stark_info_new_c};
use proofman_starks_lib_c::{
get_const_tree_size_c, get_const_size_c, prover_helpers_new_c, expressions_bin_new_c, stark_info_new_c,
load_const_tree_c, load_const_pols_c, calculate_const_tree_c, stark_info_free_c, expressions_bin_free_c,
prover_helpers_free_c,
};

use crate::GlobalInfo;
use crate::ProofType;
Expand All @@ -11,7 +17,7 @@ use crate::ProofType;
pub struct SetupC {
pub p_stark_info: *mut c_void,
pub p_expressions_bin: *mut c_void,
pub p_const_pols: *mut c_void,
pub p_prover_helpers: *mut c_void,
}

unsafe impl Send for SetupC {}
Expand All @@ -23,16 +29,29 @@ impl From<&SetupC> for *mut c_void {
}
}

#[derive(Debug)]
pub struct ConstPols<F> {
pub const_pols: RwLock<Vec<MaybeUninit<F>>>,
}

impl<F> Default for ConstPols<F> {
fn default() -> Self {
Self { const_pols: RwLock::new(Vec::new()) }
}
}

/// Air instance context for managing air instances (traces)
#[derive(Debug, Clone)]
#[derive(Debug)]
#[allow(dead_code)]
pub struct Setup {
pub struct Setup<F> {
pub airgroup_id: usize,
pub air_id: usize,
pub p_setup: SetupC,
pub const_pols: ConstPols<F>,
pub const_tree: ConstPols<F>,
}

impl Setup {
impl<F> Setup<F> {
const MY_NAME: &'static str = "Setup";

pub fn new(global_info: &GlobalInfo, airgroup_id: usize, air_id: usize, setup_type: &ProofType) -> Self {
Expand All @@ -43,62 +62,81 @@ impl Setup {

let stark_info_path = setup_path.display().to_string() + ".starkinfo.json";
let expressions_bin_path = setup_path.display().to_string() + ".bin";
let const_pols_path = setup_path.display().to_string() + ".const";
let const_pols_tree_path = setup_path.display().to_string() + ".consttree";

let p_stark_info = stark_info_new_c(stark_info_path.as_str());
let p_expressions_bin = expressions_bin_new_c(expressions_bin_path.as_str(), false);

let p_const_pols = match PathBuf::from(&const_pols_tree_path).exists() {
true => const_pols_with_tree_new_c(const_pols_path.as_str(), const_pols_tree_path.as_str(), p_stark_info),
false => const_pols_new_c(const_pols_path.as_str(), p_stark_info, true),
};

Self { air_id, airgroup_id, p_setup: SetupC { p_stark_info, p_expressions_bin, p_const_pols } }
}
let (p_stark_info, p_expressions_bin, p_prover_helpers) =
if setup_type == &ProofType::Compressor && !global_info.get_air_has_compressor(airgroup_id, air_id) {
// If the condition is met, use None for each pointer
(std::ptr::null_mut(), std::ptr::null_mut(), std::ptr::null_mut())
} else {
// Otherwise, initialize the pointers with their respective values
let stark_info = stark_info_new_c(stark_info_path.as_str());
let expressions_bin = expressions_bin_new_c(expressions_bin_path.as_str(), false);
let prover_helpers = prover_helpers_new_c(stark_info);

pub fn new_partial(global_info: &GlobalInfo, airgroup_id: usize, air_id: usize, setup_type: &ProofType) -> Self {
let setup_path = global_info.get_air_setup_path(airgroup_id, air_id, setup_type);

let air_name = &global_info.airs[airgroup_id][air_id].name;
log::debug!("{} : ··· Loading setup for AIR {}", Self::MY_NAME, air_name);

let stark_info_path = setup_path.display().to_string() + ".starkinfo.json";
let expressions_bin_path = setup_path.display().to_string() + ".bin";
let p_stark_info = stark_info_new_c(stark_info_path.as_str());
let p_expressions_bin = expressions_bin_new_c(expressions_bin_path.as_str(), false);
(stark_info, expressions_bin, prover_helpers)
};

Self {
air_id,
airgroup_id,
p_setup: SetupC { p_stark_info, p_expressions_bin, p_const_pols: std::ptr::null_mut() },
p_setup: SetupC { p_stark_info, p_expressions_bin, p_prover_helpers },
const_pols: ConstPols::default(),
const_tree: ConstPols::default(),
}
}

pub fn load_const_pols(&mut self, global_info: &GlobalInfo, setup_type: &ProofType) {
if !self.p_setup.p_const_pols.is_null() {
return;
}
assert!(!self.p_setup.p_stark_info.is_null());
assert!(!self.p_setup.p_expressions_bin.is_null());
pub fn free(&self) {
stark_info_free_c(self.p_setup.p_stark_info);
expressions_bin_free_c(self.p_setup.p_expressions_bin);
prover_helpers_free_c(self.p_setup.p_prover_helpers);
}

let setup_path = global_info.get_air_setup_path(self.airgroup_id, self.air_id, setup_type);
pub fn load_const_pols(&self, global_info: &GlobalInfo, setup_type: &ProofType) {
let setup_path = match setup_type {
ProofType::Final => global_info.get_final_setup_path(),
_ => global_info.get_air_setup_path(self.airgroup_id, self.air_id, setup_type),
};

let air_name = &global_info.airs[self.airgroup_id][self.air_id].name;
log::debug!("{} : ··· Loading const pols for AIR {} of type {:?}", Self::MY_NAME, air_name, setup_type);

let const_pols_path = setup_path.display().to_string() + ".const";
let const_pols_tree_path = setup_path.display().to_string() + ".consttree";

let p_const_pols = match PathBuf::from(&const_pols_tree_path).exists() {
true => const_pols_with_tree_new_c(
const_pols_path.as_str(),
const_pols_tree_path.as_str(),
self.p_setup.p_stark_info,
),
false => const_pols_new_c(const_pols_path.as_str(), self.p_setup.p_stark_info, true),
let p_stark_info = self.p_setup.p_stark_info;

let const_size = get_const_size_c(p_stark_info) as usize;
let const_pols: Vec<MaybeUninit<F>> = Vec::with_capacity(const_size);

let p_const_pols_address = const_pols.as_ptr() as *mut c_void;
load_const_pols_c(p_const_pols_address, const_pols_path.as_str(), const_size as u64);
*self.const_pols.const_pols.write().unwrap() = const_pols;
}

pub fn load_const_pols_tree(&self, global_info: &GlobalInfo, setup_type: &ProofType, save_file: bool) {
let setup_path = match setup_type {
ProofType::Final => global_info.get_final_setup_path(),
_ => global_info.get_air_setup_path(self.airgroup_id, self.air_id, setup_type),
};

self.p_setup.p_const_pols = p_const_pols;
let air_name = &global_info.airs[self.airgroup_id][self.air_id].name;
log::debug!("{} : ··· Loading const tree for AIR {}", Self::MY_NAME, air_name);

let const_pols_tree_path = setup_path.display().to_string() + ".consttree";

let p_stark_info = self.p_setup.p_stark_info;

let const_tree_size = get_const_tree_size_c(p_stark_info) as usize;
let const_tree: Vec<MaybeUninit<F>> = Vec::with_capacity(const_tree_size);

let p_const_tree_address = const_tree.as_ptr() as *mut c_void;
if PathBuf::from(&const_pols_tree_path).exists() {
load_const_tree_c(p_const_tree_address, const_pols_tree_path.as_str(), const_tree_size as u64);
} else {
let const_pols = self.const_pols.const_pols.read().unwrap();
let p_const_pols_address = (*const_pols).as_ptr() as *mut c_void;
let tree_filename = if save_file { const_pols_tree_path.as_str() } else { "" };
calculate_const_tree_c(p_stark_info, p_const_pols_address, p_const_tree_address, tree_filename);
};
*self.const_tree.const_pols.write().unwrap() = const_tree;
}
}
Loading

0 comments on commit e949af1

Please sign in to comment.