From e949af131790252f08c12f973ac79e048f517667 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roger=20Taul=C3=A9=20Buxadera?= <55488871+RogerTaule@users.noreply.github.com> Date: Wed, 6 Nov 2024 12:27:25 +0100 Subject: [PATCH] Removing OnceCell from setup (#91) * Removing OnceCell * Modifying initialize setup to work with distributed --- Cargo.toml | 10 +- common/src/air_instance.rs | 22 +- common/src/buffer_allocator.rs | 4 +- common/src/execution_ctx.rs | 22 +- common/src/prover.rs | 16 +- common/src/setup.rs | 126 ++++--- common/src/setup_ctx.rs | 141 ++++---- examples/fibonacci-square/src/fibonacci.rs | 10 +- .../fibonacci-square/src/fibonacci_lib.rs | 18 +- examples/fibonacci-square/src/module.rs | 8 +- hints/src/hints.rs | 113 +++--- pil2-components/lib/std/rs/src/decider.rs | 2 +- .../rs/src/range_check/specified_ranges.rs | 10 +- .../std/rs/src/range_check/std_range_check.rs | 12 +- .../lib/std/rs/src/range_check/u16air.rs | 10 +- .../lib/std/rs/src/range_check/u8air.rs | 10 +- pil2-components/lib/std/rs/src/std_prod.rs | 12 +- pil2-components/lib/std/rs/src/std_sum.rs | 12 +- .../test/simple/rs/src/simple_left.rs | 6 +- .../test/simple/rs/src/simple_lib.rs | 20 +- .../test/simple/rs/src/simple_right.rs | 6 +- .../test/std/connection/rs/src/connection1.rs | 6 +- .../test/std/connection/rs/src/connection2.rs | 6 +- .../std/connection/rs/src/connection_lib.rs | 22 +- .../std/connection/rs/src/connection_new.rs | 6 +- .../test/std/lookup/rs/src/lookup0.rs | 6 +- .../test/std/lookup/rs/src/lookup1.rs | 6 +- .../test/std/lookup/rs/src/lookup2_12.rs | 6 +- .../test/std/lookup/rs/src/lookup2_13.rs | 6 +- .../test/std/lookup/rs/src/lookup2_15.rs | 6 +- .../test/std/lookup/rs/src/lookup3.rs | 6 +- .../test/std/lookup/rs/src/lookup_lib.rs | 20 +- .../std/permutation/rs/src/permutation1_6.rs | 6 +- .../std/permutation/rs/src/permutation1_7.rs | 6 +- .../std/permutation/rs/src/permutation1_8.rs | 6 +- .../std/permutation/rs/src/permutation2.rs | 6 +- .../std/permutation/rs/src/permutation_lib.rs | 20 +- .../range_check/rs/src/multi_range_check1.rs | 6 +- .../range_check/rs/src/multi_range_check2.rs | 6 +- .../std/range_check/rs/src/range_check1.rs | 6 +- .../std/range_check/rs/src/range_check2.rs | 6 +- .../std/range_check/rs/src/range_check3.rs | 6 +- .../std/range_check/rs/src/range_check4.rs | 6 +- .../rs/src/range_check_dynamic1.rs | 6 +- .../rs/src/range_check_dynamic2.rs | 6 +- .../std/range_check/rs/src/range_check_lib.rs | 20 +- .../std/range_check/rs/src/range_check_mix.rs | 6 +- pil2-stark/lib/include/starks_lib.h | 27 +- pil2-stark/src/api/starks_api.cpp | 105 +++--- pil2-stark/src/api/starks_api.hpp | 27 +- pil2-stark/src/bctree/build_const_tree.cpp | 16 +- pil2-stark/src/starkpil/const_pols.hpp | 327 ++---------------- pil2-stark/src/starkpil/expressions_avx.hpp | 8 +- .../src/starkpil/expressions_avx512.hpp | 8 +- pil2-stark/src/starkpil/expressions_pack.hpp | 8 +- .../src/starkpil/gen_recursive_proof.hpp | 13 +- pil2-stark/src/starkpil/hints.hpp | 11 +- .../starkpil/merkleTree/merkleTreeBN128.cpp | 9 + .../starkpil/merkleTree/merkleTreeBN128.hpp | 2 + .../src/starkpil/merkleTree/merkleTreeGL.cpp | 16 + .../src/starkpil/merkleTree/merkleTreeGL.hpp | 2 + pil2-stark/src/starkpil/proof2zkinStark.cpp | 64 ++-- pil2-stark/src/starkpil/proof2zkinStark.hpp | 4 +- pil2-stark/src/starkpil/setup_ctx.hpp | 164 ++++++++- pil2-stark/src/starkpil/starks.cpp | 32 +- pil2-stark/src/starkpil/starks.hpp | 10 +- pil2-stark/src/starkpil/steps.hpp | 2 + pil2-stark/src/utils/utils.cpp | 72 ++++ pil2-stark/src/utils/utils.hpp | 4 + proofman/src/constraints.rs | 14 +- proofman/src/global_constraints.rs | 10 +- proofman/src/proofman.rs | 139 ++++++-- proofman/src/recursion.rs | 93 +++-- proofman/src/witness_component.rs | 6 +- proofman/src/witness_executor.rs | 2 +- proofman/src/witness_library.rs | 16 +- proofman/src/witness_manager.rs | 20 +- provers/stark/src/stark_prover.rs | 126 ++++--- provers/starks-lib-c/bindings_starks.rs | 105 +++--- provers/starks-lib-c/src/ffi_starks.rs | 225 +++++------- 80 files changed, 1346 insertions(+), 1145 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 02a3dd7c..3c798dc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,11 +11,11 @@ members = [ "transcript", "util", "pil2-components/lib/std/rs", - # "pil2-components/test/std/range_check/rs", - # "pil2-components/test/std/lookup/rs", - # "pil2-components/test/std/connection/rs", - # "pil2-components/test/std/permutation/rs", - # "pil2-components/test/simple/rs", + #"pil2-components/test/std/range_check/rs", + #"pil2-components/test/std/lookup/rs", + #"pil2-components/test/std/connection/rs", + #"pil2-components/test/std/permutation/rs", + #"pil2-components/test/simple/rs", # whoever re-enables this, it has to work out of # the box with `cargo check --workspace` or CI will # break and dev experience will be bad since repo diff --git a/common/src/air_instance.rs b/common/src/air_instance.rs index 257a209b..d0b119d4 100644 --- a/common/src/air_instance.rs +++ b/common/src/air_instance.rs @@ -16,6 +16,8 @@ pub struct StepsParams { pub airvalues: *mut c_void, pub evals: *mut c_void, pub xdivxsub: *mut c_void, + pub p_const_pols: *mut c_void, + pub p_const_tree: *mut c_void, } impl From<&StepsParams> for *mut c_void { @@ -45,13 +47,13 @@ pub struct AirInstance { impl AirInstance { pub fn new( - setup_ctx: Arc, + setup_ctx: Arc>, airgroup_id: usize, air_id: usize, air_segment_id: Option, buffer: Vec, ) -> Self { - let ps = setup_ctx.get_partial_setup(airgroup_id, air_id).expect("REASON"); + let ps = setup_ctx.get_setup(airgroup_id, air_id); AirInstance { airgroup_id, @@ -74,8 +76,8 @@ impl AirInstance { self.buffer.as_ptr() as *mut u8 } - pub fn set_airvalue(&mut self, setup_ctx: &SetupCtx, name: &str, value: F) { - let ps = setup_ctx.get_partial_setup(self.airgroup_id, self.air_id).expect("REASON"); + pub fn set_airvalue(&mut self, setup_ctx: &SetupCtx, name: &str, value: F) { + let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id); let id = get_airval_id_by_name_c(ps.p_setup.p_stark_info, name); if id == -1 { @@ -86,8 +88,8 @@ impl AirInstance { self.set_airvalue_calculated(id as usize); } - pub fn set_airvalue_ext(&mut self, setup_ctx: &SetupCtx, name: &str, value: Vec) { - let ps = setup_ctx.get_partial_setup(self.airgroup_id, self.air_id).expect("REASON"); + pub fn set_airvalue_ext(&mut self, setup_ctx: &SetupCtx, name: &str, value: Vec) { + let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id); let id = get_airval_id_by_name_c(ps.p_setup.p_stark_info, name); if id == -1 { @@ -105,8 +107,8 @@ impl AirInstance { self.set_airvalue_calculated(id as usize); } - pub fn set_airgroupvalue(&mut self, setup_ctx: &SetupCtx, name: &str, value: F) { - let ps = setup_ctx.get_partial_setup(self.airgroup_id, self.air_id).expect("REASON"); + pub fn set_airgroupvalue(&mut self, setup_ctx: &SetupCtx, name: &str, value: F) { + let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id); let id = get_airgroupval_id_by_name_c(ps.p_setup.p_stark_info, name); if id == -1 { @@ -117,8 +119,8 @@ impl AirInstance { self.set_airgroupvalue_calculated(id as usize); } - pub fn set_airgroupvalue_ext(&mut self, setup_ctx: &SetupCtx, name: &str, value: Vec) { - let ps = setup_ctx.get_partial_setup(self.airgroup_id, self.air_id).expect("REASON"); + pub fn set_airgroupvalue_ext(&mut self, setup_ctx: &SetupCtx, name: &str, value: Vec) { + let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id); let id = get_airgroupval_id_by_name_c(ps.p_setup.p_stark_info, name); if id == -1 { diff --git a/common/src/buffer_allocator.rs b/common/src/buffer_allocator.rs index 2a0f9d16..feb905ec 100644 --- a/common/src/buffer_allocator.rs +++ b/common/src/buffer_allocator.rs @@ -2,11 +2,11 @@ use std::error::Error; use crate::SetupCtx; -pub trait BufferAllocator: Send + Sync { +pub trait BufferAllocator: Send + Sync { // Returns the size of the buffer and the offsets for each stage fn get_buffer_info( &self, - sctx: &SetupCtx, + sctx: &SetupCtx, airgroup_id: usize, air_id: usize, ) -> Result<(u64, Vec), Box>; diff --git a/common/src/execution_ctx.rs b/common/src/execution_ctx.rs index 0446ccdd..6e23812b 100644 --- a/common/src/execution_ctx.rs +++ b/common/src/execution_ctx.rs @@ -1,37 +1,37 @@ use std::{path::PathBuf, sync::Arc}; -use crate::{BufferAllocator, VerboseMode, DistributionCtx}; +use crate::{BufferAllocator, DistributionCtx, VerboseMode}; use std::sync::RwLock; #[allow(dead_code)] /// Represents the context when executing a witness computer plugin -pub struct ExecutionCtx { +pub struct ExecutionCtx { pub rom_path: Option, /// If true, the plugin must generate the public outputs pub public_output: bool, - pub buffer_allocator: Arc, + pub buffer_allocator: Arc>, pub verbose_mode: VerboseMode, pub dctx: RwLock, } -impl ExecutionCtx { - pub fn builder() -> ExecutionCtxBuilder { +impl ExecutionCtx { + pub fn builder() -> ExecutionCtxBuilder { ExecutionCtxBuilder::new() } } -pub struct ExecutionCtxBuilder { +pub struct ExecutionCtxBuilder { rom_path: Option, public_output: bool, - buffer_allocator: Option>, + buffer_allocator: Option>>, verbose_mode: VerboseMode, } -impl Default for ExecutionCtxBuilder { +impl Default for ExecutionCtxBuilder { fn default() -> Self { Self::new() } } -impl ExecutionCtxBuilder { +impl ExecutionCtxBuilder { pub fn new() -> Self { ExecutionCtxBuilder { rom_path: None, @@ -46,7 +46,7 @@ impl ExecutionCtxBuilder { self } - pub fn with_buffer_allocator(mut self, buffer_allocator: Arc) -> Self { + pub fn with_buffer_allocator(mut self, buffer_allocator: Arc>) -> Self { self.buffer_allocator = Some(buffer_allocator); self } @@ -56,7 +56,7 @@ impl ExecutionCtxBuilder { self } - pub fn build(self) -> ExecutionCtx { + pub fn build(self) -> ExecutionCtx { if self.buffer_allocator.is_none() { panic!("Buffer allocator is required"); } diff --git a/common/src/prover.rs b/common/src/prover.rs index 8294e3d1..869aeb76 100644 --- a/common/src/prover.rs +++ b/common/src/prover.rs @@ -2,9 +2,11 @@ use std::os::raw::c_void; use std::os::raw::c_char; use std::sync::Arc; +use p3_field::Field; use transcript::FFITranscript; use crate::ProofCtx; +use crate::SetupCtx; #[derive(Debug, PartialEq)] pub enum ProverStatus { @@ -54,16 +56,22 @@ pub struct ConstraintsResults { pub constraints_info: *mut ConstraintInfo, } -pub trait Prover { +pub trait Prover { fn build(&mut self, proof_ctx: Arc>); + fn free(&mut self); fn new_transcript(&self) -> FFITranscript; fn num_stages(&self) -> u32; fn get_challenges(&self, stage_id: u32, proof_ctx: Arc>, transcript: &FFITranscript); - fn calculate_stage(&mut self, stage_id: u32, proof_ctx: Arc>); + fn calculate_stage(&mut self, stage_id: u32, setup_ctx: Arc>, proof_ctx: Arc>); fn commit_stage(&mut self, stage_id: u32, proof_ctx: Arc>) -> ProverStatus; fn calculate_xdivxsub(&mut self, proof_ctx: Arc>); fn calculate_lev(&mut self, proof_ctx: Arc>); - fn opening_stage(&mut self, opening_id: u32, proof_ctx: Arc>) -> ProverStatus; + fn opening_stage( + &mut self, + opening_id: u32, + setup_ctx: Arc>, + proof_ctx: Arc>, + ) -> ProverStatus; fn get_buff_helper_size(&self) -> usize; fn get_proof(&self) -> *mut c_void; @@ -73,7 +81,7 @@ pub trait Prover { fn get_transcript_values(&self, stage: u64, proof_ctx: Arc>) -> Vec; fn get_transcript_values_u64(&self, stage: u64, proof_ctx: Arc>) -> Vec; fn calculate_hash(&self, values: Vec) -> Vec; - fn verify_constraints(&self, proof_ctx: Arc>) -> Vec; + fn verify_constraints(&self, setup_ctx: Arc>, proof_ctx: Arc>) -> Vec; fn get_proof_challenges(&self, global_steps: Vec, global_challenges: Vec) -> Vec; } diff --git a/common/src/setup.rs b/common/src/setup.rs index dedcc5e5..993dd717 100644 --- a/common/src/setup.rs +++ b/common/src/setup.rs @@ -1,7 +1,13 @@ +use std::mem::MaybeUninit; use std::os::raw::c_void; use std::path::PathBuf; +use std::sync::RwLock; -use proofman_starks_lib_c::{const_pols_new_c, const_pols_with_tree_new_c, expressions_bin_new_c, stark_info_new_c}; +use proofman_starks_lib_c::{ + get_const_tree_size_c, get_const_size_c, prover_helpers_new_c, expressions_bin_new_c, stark_info_new_c, + load_const_tree_c, load_const_pols_c, calculate_const_tree_c, stark_info_free_c, expressions_bin_free_c, + prover_helpers_free_c, +}; use crate::GlobalInfo; use crate::ProofType; @@ -11,7 +17,7 @@ use crate::ProofType; pub struct SetupC { pub p_stark_info: *mut c_void, pub p_expressions_bin: *mut c_void, - pub p_const_pols: *mut c_void, + pub p_prover_helpers: *mut c_void, } unsafe impl Send for SetupC {} @@ -23,16 +29,29 @@ impl From<&SetupC> for *mut c_void { } } +#[derive(Debug)] +pub struct ConstPols { + pub const_pols: RwLock>>, +} + +impl Default for ConstPols { + fn default() -> Self { + Self { const_pols: RwLock::new(Vec::new()) } + } +} + /// Air instance context for managing air instances (traces) -#[derive(Debug, Clone)] +#[derive(Debug)] #[allow(dead_code)] -pub struct Setup { +pub struct Setup { pub airgroup_id: usize, pub air_id: usize, pub p_setup: SetupC, + pub const_pols: ConstPols, + pub const_tree: ConstPols, } -impl Setup { +impl Setup { const MY_NAME: &'static str = "Setup"; pub fn new(global_info: &GlobalInfo, airgroup_id: usize, air_id: usize, setup_type: &ProofType) -> Self { @@ -43,62 +62,81 @@ impl Setup { let stark_info_path = setup_path.display().to_string() + ".starkinfo.json"; let expressions_bin_path = setup_path.display().to_string() + ".bin"; - let const_pols_path = setup_path.display().to_string() + ".const"; - let const_pols_tree_path = setup_path.display().to_string() + ".consttree"; - - let p_stark_info = stark_info_new_c(stark_info_path.as_str()); - let p_expressions_bin = expressions_bin_new_c(expressions_bin_path.as_str(), false); - let p_const_pols = match PathBuf::from(&const_pols_tree_path).exists() { - true => const_pols_with_tree_new_c(const_pols_path.as_str(), const_pols_tree_path.as_str(), p_stark_info), - false => const_pols_new_c(const_pols_path.as_str(), p_stark_info, true), - }; - - Self { air_id, airgroup_id, p_setup: SetupC { p_stark_info, p_expressions_bin, p_const_pols } } - } + let (p_stark_info, p_expressions_bin, p_prover_helpers) = + if setup_type == &ProofType::Compressor && !global_info.get_air_has_compressor(airgroup_id, air_id) { + // If the condition is met, use None for each pointer + (std::ptr::null_mut(), std::ptr::null_mut(), std::ptr::null_mut()) + } else { + // Otherwise, initialize the pointers with their respective values + let stark_info = stark_info_new_c(stark_info_path.as_str()); + let expressions_bin = expressions_bin_new_c(expressions_bin_path.as_str(), false); + let prover_helpers = prover_helpers_new_c(stark_info); - pub fn new_partial(global_info: &GlobalInfo, airgroup_id: usize, air_id: usize, setup_type: &ProofType) -> Self { - let setup_path = global_info.get_air_setup_path(airgroup_id, air_id, setup_type); - - let air_name = &global_info.airs[airgroup_id][air_id].name; - log::debug!("{} : ··· Loading setup for AIR {}", Self::MY_NAME, air_name); - - let stark_info_path = setup_path.display().to_string() + ".starkinfo.json"; - let expressions_bin_path = setup_path.display().to_string() + ".bin"; - let p_stark_info = stark_info_new_c(stark_info_path.as_str()); - let p_expressions_bin = expressions_bin_new_c(expressions_bin_path.as_str(), false); + (stark_info, expressions_bin, prover_helpers) + }; Self { air_id, airgroup_id, - p_setup: SetupC { p_stark_info, p_expressions_bin, p_const_pols: std::ptr::null_mut() }, + p_setup: SetupC { p_stark_info, p_expressions_bin, p_prover_helpers }, + const_pols: ConstPols::default(), + const_tree: ConstPols::default(), } } - pub fn load_const_pols(&mut self, global_info: &GlobalInfo, setup_type: &ProofType) { - if !self.p_setup.p_const_pols.is_null() { - return; - } - assert!(!self.p_setup.p_stark_info.is_null()); - assert!(!self.p_setup.p_expressions_bin.is_null()); + pub fn free(&self) { + stark_info_free_c(self.p_setup.p_stark_info); + expressions_bin_free_c(self.p_setup.p_expressions_bin); + prover_helpers_free_c(self.p_setup.p_prover_helpers); + } - let setup_path = global_info.get_air_setup_path(self.airgroup_id, self.air_id, setup_type); + pub fn load_const_pols(&self, global_info: &GlobalInfo, setup_type: &ProofType) { + let setup_path = match setup_type { + ProofType::Final => global_info.get_final_setup_path(), + _ => global_info.get_air_setup_path(self.airgroup_id, self.air_id, setup_type), + }; let air_name = &global_info.airs[self.airgroup_id][self.air_id].name; log::debug!("{} : ··· Loading const pols for AIR {} of type {:?}", Self::MY_NAME, air_name, setup_type); let const_pols_path = setup_path.display().to_string() + ".const"; - let const_pols_tree_path = setup_path.display().to_string() + ".consttree"; - let p_const_pols = match PathBuf::from(&const_pols_tree_path).exists() { - true => const_pols_with_tree_new_c( - const_pols_path.as_str(), - const_pols_tree_path.as_str(), - self.p_setup.p_stark_info, - ), - false => const_pols_new_c(const_pols_path.as_str(), self.p_setup.p_stark_info, true), + let p_stark_info = self.p_setup.p_stark_info; + + let const_size = get_const_size_c(p_stark_info) as usize; + let const_pols: Vec> = Vec::with_capacity(const_size); + + let p_const_pols_address = const_pols.as_ptr() as *mut c_void; + load_const_pols_c(p_const_pols_address, const_pols_path.as_str(), const_size as u64); + *self.const_pols.const_pols.write().unwrap() = const_pols; + } + + pub fn load_const_pols_tree(&self, global_info: &GlobalInfo, setup_type: &ProofType, save_file: bool) { + let setup_path = match setup_type { + ProofType::Final => global_info.get_final_setup_path(), + _ => global_info.get_air_setup_path(self.airgroup_id, self.air_id, setup_type), }; - self.p_setup.p_const_pols = p_const_pols; + let air_name = &global_info.airs[self.airgroup_id][self.air_id].name; + log::debug!("{} : ··· Loading const tree for AIR {}", Self::MY_NAME, air_name); + + let const_pols_tree_path = setup_path.display().to_string() + ".consttree"; + + let p_stark_info = self.p_setup.p_stark_info; + + let const_tree_size = get_const_tree_size_c(p_stark_info) as usize; + let const_tree: Vec> = Vec::with_capacity(const_tree_size); + + let p_const_tree_address = const_tree.as_ptr() as *mut c_void; + if PathBuf::from(&const_pols_tree_path).exists() { + load_const_tree_c(p_const_tree_address, const_pols_tree_path.as_str(), const_tree_size as u64); + } else { + let const_pols = self.const_pols.const_pols.read().unwrap(); + let p_const_pols_address = (*const_pols).as_ptr() as *mut c_void; + let tree_filename = if save_file { const_pols_tree_path.as_str() } else { "" }; + calculate_const_tree_c(p_stark_info, p_const_pols_address, p_const_tree_address, tree_filename); + }; + *self.const_tree.const_pols.write().unwrap() = const_tree; } } diff --git a/common/src/setup_ctx.rs b/common/src/setup_ctx.rs index 6b6a8d2e..02410744 100644 --- a/common/src/setup_ctx.rs +++ b/common/src/setup_ctx.rs @@ -1,30 +1,61 @@ -use std::cell::OnceCell; use std::collections::HashMap; use std::ffi::c_void; +use std::sync::Arc; use proofman_starks_lib_c::expressions_bin_new_c; +use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use crate::GlobalInfo; use crate::Setup; use crate::ProofType; +pub struct SetupsVadcop { + pub sctx: Arc>, + pub sctx_compressor: Option>>, + pub sctx_recursive1: Option>>, + pub sctx_recursive2: Option>>, + pub sctx_final: Option>>, +} + +impl SetupsVadcop { + pub fn new(global_info: &GlobalInfo, aggregation: bool) -> Self { + if aggregation { + SetupsVadcop { + sctx: Arc::new(SetupCtx::new(global_info, &ProofType::Basic)), + sctx_compressor: Some(Arc::new(SetupCtx::new(global_info, &ProofType::Compressor))), + sctx_recursive1: Some(Arc::new(SetupCtx::new(global_info, &ProofType::Recursive1))), + sctx_recursive2: Some(Arc::new(SetupCtx::new(global_info, &ProofType::Recursive2))), + sctx_final: Some(Arc::new(SetupCtx::new(global_info, &ProofType::Final))), + } + } else { + SetupsVadcop { + sctx: Arc::new(SetupCtx::new(global_info, &ProofType::Basic)), + sctx_compressor: None, + sctx_recursive1: None, + sctx_recursive2: None, + sctx_final: None, + } + } + } +} + #[derive(Debug)] -pub struct SetupRepository { +pub struct SetupRepository { // We store the setup in two stages: a partial setup in the first cell and a full setup in the second cell. // This allows for loading only the partial setup when constant polynomials are not needed, improving performance. // In C++, same SetupCtx structure is used to store either the partial or full setup for each instance. // A full setup can be loaded in one or two steps: partial first, then full (which includes constant polynomial data). // Since the setup is referenced immutably in the repository, we use OnceCell for both the partial and full setups. - setups: HashMap<(usize, usize), (OnceCell, OnceCell)>, // (partial setup, full setup) - setup_airs: Vec>, + setups: HashMap<(usize, usize), Setup>, global_bin: Option<*mut c_void>, } -unsafe impl Send for SetupRepository {} -unsafe impl Sync for SetupRepository {} +unsafe impl Send for SetupRepository {} +unsafe impl Sync for SetupRepository {} -impl SetupRepository { +impl SetupRepository { pub fn new(global_info: &GlobalInfo, setup_type: &ProofType) -> Self { + timer_start_debug!(INITIALIZE_SETUPS); let mut setups = HashMap::new(); let global_bin = match setup_type == &ProofType::Basic { @@ -36,41 +67,35 @@ impl SetupRepository { false => None, }; - // Initialize Hashmao for each airgroup_id, air_id - let setup_airs = match setup_type != &ProofType::Final { - true => global_info - .airs - .iter() - .enumerate() - .map(|(airgroup_id, air_group)| { - let mut air_group_setups = Vec::new(); - air_group.iter().enumerate().for_each(|(air_id, _)| { - setups.insert((airgroup_id, air_id), (OnceCell::new(), OnceCell::new())); - air_group_setups.push(air_id); - }); - air_group_setups - }) - .collect::>>(), - false => { - let mut air_group_setups: Vec> = Vec::new(); - setups.insert((0, 0), (OnceCell::new(), OnceCell::new())); - air_group_setups.push(vec![0]); - air_group_setups + // Initialize Hashmap for each airgroup_id, air_id + if setup_type != &ProofType::Final { + for (airgroup_id, air_group) in global_info.airs.iter().enumerate() { + for (air_id, _) in air_group.iter().enumerate() { + setups.insert((airgroup_id, air_id), Setup::new(global_info, airgroup_id, air_id, setup_type)); + } } - }; + } else { + setups.insert((0, 0), Setup::new(global_info, 0, 0, setup_type)); + } + + timer_stop_and_log_debug!(INITIALIZE_SETUPS); + + Self { setups, global_bin } + } - Self { setups, setup_airs, global_bin } + pub fn free(&self) { + // TODO } } /// Air instance context for managing air instances (traces) #[allow(dead_code)] -pub struct SetupCtx { +pub struct SetupCtx { global_info: GlobalInfo, - setup_repository: SetupRepository, + setup_repository: SetupRepository, setup_type: ProofType, } -impl SetupCtx { +impl SetupCtx { pub fn new(global_info: &GlobalInfo, setup_type: &ProofType) -> Self { SetupCtx { setup_repository: SetupRepository::new(global_info, setup_type), @@ -79,52 +104,18 @@ impl SetupCtx { } } - pub fn get_setup(&self, airgroup_id: usize, air_id: usize) -> Result<&Setup, String> { - let setup = self - .setup_repository - .setups - .get(&(airgroup_id, air_id)) - .ok_or_else(|| format!("Setup not found for airgroup_id: {}, Air_id: {}", airgroup_id, air_id))?; - - if let Some(setup_ref) = setup.1.get() { - Ok(setup_ref) - } else if let Some(setup_ref) = setup.0.get() { - let mut new_setup = setup_ref.clone(); - new_setup.load_const_pols(&self.global_info, &self.setup_type); - setup.1.set(new_setup).unwrap(); - - Ok(setup.1.get().unwrap()) - } else { - let new_setup = Setup::new(&self.global_info, airgroup_id, air_id, &self.setup_type); - setup.1.set(new_setup).unwrap(); - - Ok(setup.1.get().unwrap()) - } - } - - pub fn get_partial_setup(&self, airgroup_id: usize, air_id: usize) -> Result<&Setup, String> { - let setup = self - .setup_repository - .setups - .get(&(airgroup_id, air_id)) - .ok_or_else(|| format!("Setup not found for airgroup_id: {}, Air_id: {}", airgroup_id, air_id))?; - - if setup.0.get().is_some() { - Ok(setup.0.get().unwrap()) - } else if setup.1.get().is_some() { - Ok(setup.1.get().unwrap()) - } else { - let new_setup = Setup::new_partial(&self.global_info, airgroup_id, air_id, &self.setup_type); - setup.0.set(new_setup).unwrap(); - - Ok(setup.0.get().unwrap()) + pub fn get_setup(&self, airgroup_id: usize, air_id: usize) -> &Setup { + match self.setup_repository.setups.get(&(airgroup_id, air_id)) { + Some(setup) => setup, + None => { + // Handle the error case as needed + log::error!("Setup not found for airgroup_id: {}, air_id: {}", airgroup_id, air_id); + // You might want to return a default value or panic + panic!("Setup not found"); // or return a default value if applicable + } } } - pub fn get_setup_airs(&self) -> Vec> { - self.setup_repository.setup_airs.clone() - } - pub fn get_global_bin(&self) -> *mut c_void { self.setup_repository.global_bin.unwrap() } diff --git a/examples/fibonacci-square/src/fibonacci.rs b/examples/fibonacci-square/src/fibonacci.rs index 34dba2c2..15aefe6b 100644 --- a/examples/fibonacci-square/src/fibonacci.rs +++ b/examples/fibonacci-square/src/fibonacci.rs @@ -22,7 +22,7 @@ impl FibonacciSquare { fibonacci } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // TODO: We should create the instance here and fill the trace in calculate witness!!! if let Err(e) = Self::calculate_trace(self, FIBONACCI_SQUARE_AIRGROUP_ID, FIBONACCI_SQUARE_AIR_IDS[0], pctx, ectx, sctx) @@ -36,8 +36,8 @@ impl FibonacciSquare { airgroup_id: usize, air_id: usize, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) -> Result> { log::debug!("{} ··· Starting witness computation stage {}", Self::MY_NAME, 1); @@ -101,8 +101,8 @@ impl WitnessComponent for FibonacciSquare { _stage: u32, _air_instance_id: Option, _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, + _ectx: Arc>, + _sctx: Arc>, ) { } } diff --git a/examples/fibonacci-square/src/fibonacci_lib.rs b/examples/fibonacci-square/src/fibonacci_lib.rs index 0fe6f90d..da92c149 100644 --- a/examples/fibonacci-square/src/fibonacci_lib.rs +++ b/examples/fibonacci-square/src/fibonacci_lib.rs @@ -1,7 +1,7 @@ use std::io::Read; use std::{fs::File, sync::Arc}; -use proofman_common::{initialize_logger, ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; +use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; use proofman::{WitnessLibrary, WitnessManager}; use pil_std_lib::Std; use p3_field::PrimeField; @@ -28,7 +28,7 @@ impl FibonacciWitness { } impl WitnessLibrary for FibonacciWitness { - fn start_proof(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn start_proof(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { let wcm = Arc::new(WitnessManager::new(pctx.clone(), ectx.clone(), sctx.clone())); let std_lib = Std::new(wcm.clone()); @@ -67,12 +67,18 @@ impl WitnessLibrary for FibonacciWitness { self.wcm.as_ref().unwrap().end_proof(); } - fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { self.fibonacci.as_ref().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); self.module.as_ref().unwrap().execute(pctx, ectx, sctx); } - fn calculate_witness(&mut self, stage: u32, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn calculate_witness( + &mut self, + stage: u32, + pctx: Arc>, + ectx: Arc>, + sctx: Arc>, + ) { self.wcm.as_ref().unwrap().calculate_witness(stage, pctx, ectx, sctx); } @@ -83,11 +89,9 @@ impl WitnessLibrary for FibonacciWitness { #[no_mangle] pub extern "Rust" fn init_library( - ectx: Arc, + _: Option, public_inputs_path: Option, ) -> Result>, Box> { - initialize_logger(ectx.verbose_mode); - let fibonacci_witness = FibonacciWitness::new(public_inputs_path); Ok(Box::new(fibonacci_witness)) } diff --git a/examples/fibonacci-square/src/module.rs b/examples/fibonacci-square/src/module.rs index 03918f19..5c8118e5 100644 --- a/examples/fibonacci-square/src/module.rs +++ b/examples/fibonacci-square/src/module.rs @@ -37,11 +37,11 @@ impl Module x_mod } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { self.calculate_trace(pctx, ectx, sctx); } - fn calculate_trace(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn calculate_trace(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { log::debug!("{} ··· Starting witness computation stage {}", Self::MY_NAME, 1); let pi: FibonacciSquarePublics = pctx.public_inputs.inputs.read().unwrap().as_slice().into(); @@ -100,8 +100,8 @@ impl WitnessComponent for Module { _stage: u32, _air_instance_id: Option, _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, + _ectx: Arc>, + _sctx: Arc>, ) { } } diff --git a/hints/src/hints.rs b/hints/src/hints.rs index bd889c89..f078365d 100644 --- a/hints/src/hints.rs +++ b/hints/src/hints.rs @@ -666,7 +666,7 @@ pub fn get_hint_ids_by_name(p_expressions_bin: *mut c_void, name: &str) -> Vec( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, proof_ctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, @@ -676,11 +676,14 @@ pub fn mul_hint_fields( hint_field_name2: &str, options2: HintFieldOptions, ) -> u64 { - let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON"); + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, public_inputs: public_inputs_ptr, @@ -689,6 +692,8 @@ pub fn mul_hint_fields( airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: air_instance.evals.as_ptr() as *mut c_void, xdivxsub: std::ptr::null_mut(), + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; mul_hint_fields_c( @@ -704,7 +709,7 @@ pub fn mul_hint_fields( } pub fn acc_hint_field( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, proof_ctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, @@ -712,11 +717,14 @@ pub fn acc_hint_field( hint_field_airgroupvalue: &str, hint_field_name: &str, ) -> (u64, u64) { - let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON"); + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, public_inputs: public_inputs_ptr, @@ -725,6 +733,8 @@ pub fn acc_hint_field( airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: air_instance.evals.as_ptr() as *mut c_void, xdivxsub: std::ptr::null_mut(), + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; let raw_ptr = acc_hint_field_c( @@ -745,7 +755,7 @@ pub fn acc_hint_field( #[allow(clippy::too_many_arguments)] pub fn acc_mul_hint_fields( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, proof_ctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, @@ -756,11 +766,14 @@ pub fn acc_mul_hint_fields( options1: HintFieldOptions, options2: HintFieldOptions, ) -> (u64, u64) { - let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON"); + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, public_inputs: public_inputs_ptr, @@ -769,6 +782,8 @@ pub fn acc_mul_hint_fields( airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: air_instance.evals.as_ptr() as *mut c_void, xdivxsub: std::ptr::null_mut(), + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; let raw_ptr = acc_mul_hint_fields_c( @@ -791,22 +806,21 @@ pub fn acc_mul_hint_fields( } pub fn get_hint_field( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, proof_ctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, ) -> HintFieldValue { - let setup = if options.dest { - setup_ctx.get_partial_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON") - } else { - setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON") - }; + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, public_inputs: public_inputs_ptr, @@ -815,6 +829,8 @@ pub fn get_hint_field( airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: air_instance.evals.as_ptr() as *mut c_void, xdivxsub: std::ptr::null_mut(), + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; let raw_ptr = get_hint_field_c( @@ -836,22 +852,21 @@ pub fn get_hint_field( } pub fn get_hint_field_a( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, proof_ctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, ) -> HintFieldValuesVec { - let setup = if options.dest { - setup_ctx.get_partial_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON") - } else { - setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON") - }; + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, public_inputs: public_inputs_ptr, @@ -860,6 +875,8 @@ pub fn get_hint_field_a( airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: air_instance.evals.as_ptr() as *mut c_void, xdivxsub: std::ptr::null_mut(), + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; let raw_ptr = get_hint_field_c( @@ -887,22 +904,21 @@ pub fn get_hint_field_a( } pub fn get_hint_field_m( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, proof_ctx: &ProofCtx, air_instance: &mut AirInstance, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, ) -> HintFieldValues { - let setup = if options.dest { - setup_ctx.get_partial_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON") - } else { - setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON") - }; + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, public_inputs: public_inputs_ptr, @@ -911,6 +927,8 @@ pub fn get_hint_field_m( airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: air_instance.evals.as_ptr() as *mut c_void, xdivxsub: std::ptr::null_mut(), + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; let raw_ptr = get_hint_field_c( @@ -943,14 +961,14 @@ pub fn get_hint_field_m( } pub fn get_hint_field_constant( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, airgroup_id: usize, air_id: usize, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, ) -> HintFieldValue { - let setup = setup_ctx.get_partial_setup(airgroup_id, air_id).expect("REASON"); + let setup = setup_ctx.get_setup(airgroup_id, air_id); let steps_params = StepsParams { buffer: std::ptr::null_mut(), @@ -960,6 +978,8 @@ pub fn get_hint_field_constant( airvalues: std::ptr::null_mut(), evals: std::ptr::null_mut(), xdivxsub: std::ptr::null_mut(), + p_const_pols: std::ptr::null_mut(), + p_const_tree: std::ptr::null_mut(), }; let raw_ptr = get_hint_field_c( @@ -981,14 +1001,14 @@ pub fn get_hint_field_constant( } pub fn get_hint_field_constant_a( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, airgroup_id: usize, air_id: usize, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, ) -> Vec> { - let setup = setup_ctx.get_partial_setup(airgroup_id, air_id).expect("REASON"); + let setup = setup_ctx.get_setup(airgroup_id, air_id); let steps_params = StepsParams { buffer: std::ptr::null_mut(), @@ -998,6 +1018,8 @@ pub fn get_hint_field_constant_a( airvalues: std::ptr::null_mut(), evals: std::ptr::null_mut(), xdivxsub: std::ptr::null_mut(), + p_const_pols: std::ptr::null_mut(), + p_const_tree: std::ptr::null_mut(), }; let raw_ptr = get_hint_field_c( @@ -1025,14 +1047,14 @@ pub fn get_hint_field_constant_a( } pub fn get_hint_field_constant_m( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, airgroup_id: usize, air_id: usize, hint_id: usize, hint_field_name: &str, options: HintFieldOptions, ) -> HintFieldValues { - let setup = setup_ctx.get_partial_setup(airgroup_id, air_id).expect("REASON"); + let setup = setup_ctx.get_setup(airgroup_id, air_id); let steps_params = StepsParams { buffer: std::ptr::null_mut(), @@ -1042,6 +1064,8 @@ pub fn get_hint_field_constant_m( airvalues: std::ptr::null_mut(), evals: std::ptr::null_mut(), xdivxsub: std::ptr::null_mut(), + p_const_pols: std::ptr::null_mut(), + p_const_tree: std::ptr::null_mut(), }; let raw_ptr = get_hint_field_c( @@ -1077,7 +1101,7 @@ pub fn get_hint_field_constant_m( } pub fn set_hint_field( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, air_instance: &mut AirInstance, hint_id: u64, hint_field_name: &str, @@ -1091,9 +1115,11 @@ pub fn set_hint_field( airvalues: std::ptr::null_mut(), evals: std::ptr::null_mut(), xdivxsub: std::ptr::null_mut(), + p_const_pols: std::ptr::null_mut(), + p_const_tree: std::ptr::null_mut(), }; - let setup = setup_ctx.get_partial_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON"); + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let values_ptr: *mut c_void = match values { HintFieldValue::Column(vec) => vec.as_ptr() as *mut c_void, @@ -1107,7 +1133,7 @@ pub fn set_hint_field( } pub fn set_hint_field_val( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, air_instance: &mut AirInstance, hint_id: u64, hint_field_name: &str, @@ -1121,9 +1147,11 @@ pub fn set_hint_field_val( airvalues: air_instance.airvalues.as_mut_ptr() as *mut c_void, evals: std::ptr::null_mut(), xdivxsub: std::ptr::null_mut(), + p_const_pols: std::ptr::null_mut(), + p_const_tree: std::ptr::null_mut(), }; - let setup = setup_ctx.get_partial_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON"); + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let mut value_array: Vec = Vec::new(); @@ -1146,13 +1174,13 @@ pub fn set_hint_field_val( } pub fn print_expression( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, air_instance: &mut AirInstance, expr: &HintFieldValue, first_print_value: u64, last_print_value: u64, ) { - let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON"); + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); match expr { HintFieldValue::Column(vec) => { @@ -1183,8 +1211,8 @@ pub fn print_expression( } } -pub fn print_row(setup_ctx: &SetupCtx, air_instance: &AirInstance, stage: u64, row: u64) { - let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON"); +pub fn print_row(setup_ctx: &SetupCtx, air_instance: &AirInstance, stage: u64, row: u64) { + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let buffer = air_instance.get_buffer_ptr() as *mut c_void; @@ -1192,7 +1220,7 @@ pub fn print_row(setup_ctx: &SetupCtx, air_instance: &AirInstance, } pub fn print_by_name( - setup_ctx: &SetupCtx, + setup_ctx: &SetupCtx, proof_ctx: Arc>, air_instance: &AirInstance, name: &str, @@ -1200,10 +1228,13 @@ pub fn print_by_name( first_print_value: u64, last_print_value: u64, ) -> Option> { - let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("REASON"); + let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, public_inputs: public_inputs_ptr, @@ -1212,6 +1243,8 @@ pub fn print_by_name( airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: std::ptr::null_mut(), xdivxsub: std::ptr::null_mut(), + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; let mut lengths_vec = lengths.unwrap_or_default(); diff --git a/pil2-components/lib/std/rs/src/decider.rs b/pil2-components/lib/std/rs/src/decider.rs index 7b5bd441..2fbd15fe 100644 --- a/pil2-components/lib/std/rs/src/decider.rs +++ b/pil2-components/lib/std/rs/src/decider.rs @@ -3,5 +3,5 @@ use std::sync::Arc; use proofman_common::{ProofCtx, SetupCtx}; pub trait Decider { - fn decide(&self, sctx: Arc, pctx: Arc>); + fn decide(&self, sctx: Arc>, pctx: Arc>); } diff --git a/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs b/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs index f080d0f5..60fa15f2 100644 --- a/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs +++ b/pil2-components/lib/std/rs/src/range_check/specified_ranges.rs @@ -182,11 +182,9 @@ impl SpecifiedRanges { } impl WitnessComponent for SpecifiedRanges { - fn start_proof(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn start_proof(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // Obtain info from the mul hints - let setup = sctx.get_partial_setup(self.airgroup_id, self.air_id).unwrap_or_else(|_| { - panic!("Setup not found for airgroup_id: {}, air_id: {}", self.airgroup_id, self.air_id) - }); + let setup = sctx.get_setup(self.airgroup_id, self.air_id); let specified_hints = get_hint_ids_by_name(setup.p_setup.p_expressions_bin, "specified_ranges"); let mut hints_guard = self.hints.lock().unwrap(); let mut ranges_guard = self.ranges.lock().unwrap(); @@ -339,8 +337,8 @@ impl WitnessComponent for SpecifiedRanges { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, + _ectx: Arc>, + _sctx: Arc>, ) { } } diff --git a/pil2-components/lib/std/rs/src/range_check/std_range_check.rs b/pil2-components/lib/std/rs/src/range_check/std_range_check.rs index 015a29bf..6a881d86 100644 --- a/pil2-components/lib/std/rs/src/range_check/std_range_check.rs +++ b/pil2-components/lib/std/rs/src/range_check/std_range_check.rs @@ -47,7 +47,7 @@ pub struct StdRangeCheck { } impl Decider for StdRangeCheck { - fn decide(&self, sctx: Arc, pctx: Arc>) { + fn decide(&self, sctx: Arc>, pctx: Arc>) { // Scan the pilout for airs that have rc-related hints let air_groups = pctx.pilout.air_groups(); @@ -56,7 +56,7 @@ impl Decider for StdRangeCheck { airs.iter().for_each(|air| { let airgroup_id = air.airgroup_id; let air_id = air.air_id; - let setup = sctx.get_partial_setup(airgroup_id, air_id).expect("REASON"); + let setup = sctx.get_setup(airgroup_id, air_id); // Obtain info from the range hints let rc_hints = get_hint_ids_by_name(setup.p_setup.p_expressions_bin, "range_def"); @@ -92,7 +92,7 @@ impl StdRangeCheck { std_range_check } - fn register_range(&self, sctx: Arc, airgroup_id: usize, air_id: usize, hint: u64) { + fn register_range(&self, sctx: Arc>, airgroup_id: usize, air_id: usize, hint: u64) { let predefined = get_hint_field_constant::( &sctx, airgroup_id, @@ -267,7 +267,7 @@ impl StdRangeCheck { } impl WitnessComponent for StdRangeCheck { - fn start_proof(&self, pctx: Arc>, _ectx: Arc, sctx: Arc) { + fn start_proof(&self, pctx: Arc>, _ectx: Arc>, sctx: Arc>) { self.decide(sctx, pctx); } @@ -276,8 +276,8 @@ impl WitnessComponent for StdRangeCheck { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, + _ectx: Arc>, + _sctx: Arc>, ) { // Nothing to do } diff --git a/pil2-components/lib/std/rs/src/range_check/u16air.rs b/pil2-components/lib/std/rs/src/range_check/u16air.rs index a45b5eb9..b580ad80 100644 --- a/pil2-components/lib/std/rs/src/range_check/u16air.rs +++ b/pil2-components/lib/std/rs/src/range_check/u16air.rs @@ -155,11 +155,9 @@ impl U16Air { } impl WitnessComponent for U16Air { - fn start_proof(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn start_proof(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // Obtain info from the mul hints - let setup = sctx.get_partial_setup(self.airgroup_id, self.air_id).unwrap_or_else(|_| { - panic!("Setup not found for airgroup_id: {}, air_id: {}", self.airgroup_id, self.air_id) - }); + let setup = sctx.get_setup(self.airgroup_id, self.air_id); let u16air_hints = get_hint_ids_by_name(setup.p_setup.p_expressions_bin, "u16air"); if !u16air_hints.is_empty() { self.hint.store(u16air_hints[0], Ordering::Release); @@ -192,8 +190,8 @@ impl WitnessComponent for U16Air { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, + _ectx: Arc>, + _sctx: Arc>, ) { } } diff --git a/pil2-components/lib/std/rs/src/range_check/u8air.rs b/pil2-components/lib/std/rs/src/range_check/u8air.rs index 105d2f23..30e2bf7f 100644 --- a/pil2-components/lib/std/rs/src/range_check/u8air.rs +++ b/pil2-components/lib/std/rs/src/range_check/u8air.rs @@ -153,11 +153,9 @@ impl U8Air { } impl WitnessComponent for U8Air { - fn start_proof(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn start_proof(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // Obtain info from the mul hints - let setup = sctx.get_partial_setup(self.airgroup_id, self.air_id).unwrap_or_else(|_| { - panic!("Setup not found for airgroup_id: {}, air_id: {}", self.airgroup_id, self.air_id) - }); + let setup = sctx.get_setup(self.airgroup_id, self.air_id); let u8air_hints = get_hint_ids_by_name(setup.p_setup.p_expressions_bin, "u8air"); if !u8air_hints.is_empty() { self.hint.store(u8air_hints[0], Ordering::Release); @@ -190,8 +188,8 @@ impl WitnessComponent for U8Air { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, + _ectx: Arc>, + _sctx: Arc>, ) { } } diff --git a/pil2-components/lib/std/rs/src/std_prod.rs b/pil2-components/lib/std/rs/src/std_prod.rs index 605c6c99..a899a47c 100644 --- a/pil2-components/lib/std/rs/src/std_prod.rs +++ b/pil2-components/lib/std/rs/src/std_prod.rs @@ -37,14 +37,14 @@ struct BusValue { type DebugData = Mutex>, BusValue>>>; // opid -> val -> BusValue impl Decider for StdProd { - fn decide(&self, sctx: Arc, pctx: Arc>) { + fn decide(&self, sctx: Arc>, pctx: Arc>) { // Scan the pilout for airs that have prod-related hints for airgroup in pctx.pilout.air_groups() { for air in airgroup.airs() { let airgroup_id = air.airgroup_id; let air_id = air.air_id; - let setup = sctx.get_partial_setup(airgroup_id, air_id).expect("REASON"); + let setup = sctx.get_setup(airgroup_id, air_id); let p_expressions_bin = setup.p_setup.p_expressions_bin; let gprod_hints = get_hint_ids_by_name(p_expressions_bin, "gprod_col"); @@ -76,7 +76,7 @@ impl StdProd { fn debug( &self, pctx: &ProofCtx, - sctx: &SetupCtx, + sctx: &SetupCtx, air_instance: &mut AirInstance, num_rows: usize, debug_hints_data: Vec, @@ -177,7 +177,7 @@ impl StdProd { } impl WitnessComponent for StdProd { - fn start_proof(&self, pctx: Arc>, _ectx: Arc, sctx: Arc) { + fn start_proof(&self, pctx: Arc>, _ectx: Arc>, sctx: Arc>) { self.decide(sctx, pctx); } @@ -186,8 +186,8 @@ impl WitnessComponent for StdProd { stage: u32, _air_instance: Option, pctx: Arc>, - _ectx: Arc, - sctx: Arc, + _ectx: Arc>, + sctx: Arc>, ) { if stage == 2 { let prod_airs = self.prod_airs.lock().unwrap(); diff --git a/pil2-components/lib/std/rs/src/std_sum.rs b/pil2-components/lib/std/rs/src/std_sum.rs index e4d40a66..04159e4c 100644 --- a/pil2-components/lib/std/rs/src/std_sum.rs +++ b/pil2-components/lib/std/rs/src/std_sum.rs @@ -39,7 +39,7 @@ struct BusValue { type DebugData = Mutex>, BusValue>>>; // opid -> val -> BusValue impl Decider for StdSum { - fn decide(&self, sctx: Arc, pctx: Arc>) { + fn decide(&self, sctx: Arc>, pctx: Arc>) { // Scan the pilout for airs that have sum-related hints let air_groups = pctx.pilout.air_groups(); let mut sum_airs_guard = self.sum_airs.lock().unwrap(); @@ -49,7 +49,7 @@ impl Decider for StdSum { let airgroup_id = air.airgroup_id; let air_id = air.air_id; - let setup = sctx.get_partial_setup(airgroup_id, air_id).expect("REASON"); + let setup = sctx.get_setup(airgroup_id, air_id); let p_expressions_bin = setup.p_setup.p_expressions_bin; let im_hints = get_hint_ids_by_name(p_expressions_bin, "im_col"); @@ -82,7 +82,7 @@ impl StdSum { fn debug( &self, pctx: &ProofCtx, - sctx: &SetupCtx, + sctx: &SetupCtx, air_instance: &mut AirInstance, num_rows: usize, debug_hints_data: Vec, @@ -189,7 +189,7 @@ impl StdSum { } impl WitnessComponent for StdSum { - fn start_proof(&self, pctx: Arc>, _ectx: Arc, sctx: Arc) { + fn start_proof(&self, pctx: Arc>, _ectx: Arc>, sctx: Arc>) { self.decide(sctx, pctx); } @@ -198,8 +198,8 @@ impl WitnessComponent for StdSum { stage: u32, _air_instance: Option, pctx: Arc>, - _ectx: Arc, - sctx: Arc, + _ectx: Arc>, + sctx: Arc>, ) { if stage == 2 { let sum_airs = self.sum_airs.lock().unwrap(); diff --git a/pil2-components/test/simple/rs/src/simple_left.rs b/pil2-components/test/simple/rs/src/simple_left.rs index fb5261a8..4d8e9850 100644 --- a/pil2-components/test/simple/rs/src/simple_left.rs +++ b/pil2-components/test/simple/rs/src/simple_left.rs @@ -26,7 +26,7 @@ where simple_left } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { let (buffer_size, _) = ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, SIMPLE_AIRGROUP_ID, SIMPLE_LEFT_AIR_IDS[0]).unwrap(); @@ -50,8 +50,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/simple/rs/src/simple_lib.rs b/pil2-components/test/simple/rs/src/simple_lib.rs index e29662d6..8ac49751 100644 --- a/pil2-components/test/simple/rs/src/simple_lib.rs +++ b/pil2-components/test/simple/rs/src/simple_lib.rs @@ -2,7 +2,7 @@ use std::{error::Error, path::PathBuf, sync::Arc}; use pil_std_lib::Std; use proofman::{WitnessLibrary, WitnessManager}; -use proofman_common::{initialize_logger, ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; +use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; use p3_field::PrimeField; use p3_goldilocks::Goldilocks; @@ -34,7 +34,7 @@ where Self { wcm: None, simple_left: None, simple_right: None, std_lib: None } } - pub fn initialize(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn initialize(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { let wcm = Arc::new(WitnessManager::new(pctx, ectx, sctx)); let std_lib = Std::new(wcm.clone()); @@ -52,7 +52,7 @@ impl WitnessLibrary for SimpleWitness where Standard: Distribution, { - fn start_proof(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn start_proof(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { self.initialize(pctx.clone(), ectx.clone(), sctx.clone()); self.wcm.as_ref().unwrap().start_proof(pctx, ectx, sctx); @@ -62,13 +62,19 @@ where self.wcm.as_ref().unwrap().end_proof(); } - fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // Execute those components that need to be executed self.simple_left.as_ref().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); self.simple_right.as_ref().unwrap().execute(pctx, ectx, sctx); } - fn calculate_witness(&mut self, stage: u32, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn calculate_witness( + &mut self, + stage: u32, + pctx: Arc>, + ectx: Arc>, + sctx: Arc>, + ) { self.wcm.as_ref().unwrap().calculate_witness(stage, pctx, ectx, sctx); } @@ -79,11 +85,9 @@ where #[no_mangle] pub extern "Rust" fn init_library( - ectx: Arc, + _: Option, _: Option, ) -> Result>, Box> { - initialize_logger(ectx.verbose_mode); - let simple_witness = SimpleWitness::new(); Ok(Box::new(simple_witness)) } diff --git a/pil2-components/test/simple/rs/src/simple_right.rs b/pil2-components/test/simple/rs/src/simple_right.rs index 1103043e..308c5c7e 100644 --- a/pil2-components/test/simple/rs/src/simple_right.rs +++ b/pil2-components/test/simple/rs/src/simple_right.rs @@ -26,7 +26,7 @@ where simple_right } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { let (buffer_size, _) = ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, SIMPLE_AIRGROUP_ID, SIMPLE_RIGHT_AIR_IDS[0]).unwrap(); @@ -49,8 +49,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let air_instances_vec = &mut pctx.air_instance_repo.air_instances.write().unwrap(); let air_instance = &mut air_instances_vec[air_instance_id.unwrap()]; diff --git a/pil2-components/test/std/connection/rs/src/connection1.rs b/pil2-components/test/std/connection/rs/src/connection1.rs index c2ca7ef5..5743e985 100644 --- a/pil2-components/test/std/connection/rs/src/connection1.rs +++ b/pil2-components/test/std/connection/rs/src/connection1.rs @@ -26,7 +26,7 @@ where connection1 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx .buffer_allocator @@ -55,8 +55,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/connection/rs/src/connection2.rs b/pil2-components/test/std/connection/rs/src/connection2.rs index f72b6ec5..a794c29b 100644 --- a/pil2-components/test/std/connection/rs/src/connection2.rs +++ b/pil2-components/test/std/connection/rs/src/connection2.rs @@ -26,7 +26,7 @@ where connection2 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx .buffer_allocator @@ -55,8 +55,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/connection/rs/src/connection_lib.rs b/pil2-components/test/std/connection/rs/src/connection_lib.rs index 47f0a511..0499075c 100644 --- a/pil2-components/test/std/connection/rs/src/connection_lib.rs +++ b/pil2-components/test/std/connection/rs/src/connection_lib.rs @@ -2,7 +2,7 @@ use std::{error::Error, path::PathBuf, sync::Arc}; use pil_std_lib::Std; use proofman::{WitnessLibrary, WitnessManager}; -use proofman_common::{initialize_logger, ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; +use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; use p3_field::PrimeField; use p3_goldilocks::Goldilocks; @@ -35,7 +35,7 @@ where ConnectionWitness { wcm: None, connection1: None, connection2: None, connection_new: None, std_lib: None } } - pub fn initialize(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn initialize(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { let wcm = Arc::new(WitnessManager::new(pctx, ectx, sctx)); let std_lib = Std::new(wcm.clone()); @@ -55,7 +55,7 @@ impl WitnessLibrary for ConnectionWitness where Standard: Distribution, { - fn start_proof(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn start_proof(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { self.initialize(pctx.clone(), ectx.clone(), sctx.clone()); self.wcm.as_ref().unwrap().start_proof(pctx, ectx, sctx); @@ -65,14 +65,20 @@ where self.wcm.as_ref().unwrap().end_proof(); } - fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // Execute those components that need to be executed self.connection1.as_ref().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); self.connection2.as_ref().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); self.connection_new.as_ref().unwrap().execute(pctx, ectx, sctx); } - fn calculate_witness(&mut self, stage: u32, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn calculate_witness( + &mut self, + stage: u32, + pctx: Arc>, + ectx: Arc>, + sctx: Arc>, + ) { self.wcm.as_ref().unwrap().calculate_witness(stage, pctx, ectx, sctx); } @@ -83,12 +89,10 @@ where #[no_mangle] pub extern "Rust" fn init_library( - ectx: Arc, + _: Option, _: Option, ) -> Result>, Box> { - initialize_logger(ectx.verbose_mode); - - let connection_witness = ConnectionWitness::new(); + let connection_witness: ConnectionWitness = ConnectionWitness::new(); Ok(Box::new(connection_witness)) } diff --git a/pil2-components/test/std/connection/rs/src/connection_new.rs b/pil2-components/test/std/connection/rs/src/connection_new.rs index 8aabb3aa..3c84d344 100644 --- a/pil2-components/test/std/connection/rs/src/connection_new.rs +++ b/pil2-components/test/std/connection/rs/src/connection_new.rs @@ -26,7 +26,7 @@ where connection_new } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx .buffer_allocator @@ -55,8 +55,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/lookup/rs/src/lookup0.rs b/pil2-components/test/std/lookup/rs/src/lookup0.rs index dade77bf..850f24fd 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup0.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup0.rs @@ -26,7 +26,7 @@ where lookup0 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, LOOKUP_AIRGROUP_ID, LOOKUP_0_AIR_IDS[0]).unwrap(); @@ -51,8 +51,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/lookup/rs/src/lookup1.rs b/pil2-components/test/std/lookup/rs/src/lookup1.rs index 9ce735ac..a9698776 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup1.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup1.rs @@ -26,7 +26,7 @@ where lookup1 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, LOOKUP_AIRGROUP_ID, LOOKUP_1_AIR_IDS[0]).unwrap(); @@ -51,8 +51,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/lookup/rs/src/lookup2_12.rs b/pil2-components/test/std/lookup/rs/src/lookup2_12.rs index 92596ebc..7e36cf28 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup2_12.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup2_12.rs @@ -26,7 +26,7 @@ where lookup2_12 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, LOOKUP_AIRGROUP_ID, LOOKUP_2_12_AIR_IDS[0]).unwrap(); @@ -51,8 +51,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/lookup/rs/src/lookup2_13.rs b/pil2-components/test/std/lookup/rs/src/lookup2_13.rs index ef96b4d5..417c09d8 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup2_13.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup2_13.rs @@ -26,7 +26,7 @@ where lookup2_13 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, LOOKUP_AIRGROUP_ID, LOOKUP_2_13_AIR_IDS[0]).unwrap(); @@ -51,8 +51,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/lookup/rs/src/lookup2_15.rs b/pil2-components/test/std/lookup/rs/src/lookup2_15.rs index c1df19bf..345fffca 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup2_15.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup2_15.rs @@ -26,7 +26,7 @@ where lookup2_15 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, LOOKUP_AIRGROUP_ID, LOOKUP_2_15_AIR_IDS[0]).unwrap(); @@ -51,8 +51,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/lookup/rs/src/lookup3.rs b/pil2-components/test/std/lookup/rs/src/lookup3.rs index e8826a6f..0e34859b 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup3.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup3.rs @@ -22,7 +22,7 @@ impl Lookup3 { lookup3 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx.buffer_allocator.as_ref().get_buffer_info(&sctx, LOOKUP_AIRGROUP_ID, LOOKUP_3_AIR_IDS[0]).unwrap(); @@ -44,8 +44,8 @@ impl WitnessComponent for Lookup3 { stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let air_instances_vec = &mut pctx.air_instance_repo.air_instances.write().unwrap(); let air_instance = &mut air_instances_vec[air_instance_id.unwrap()]; diff --git a/pil2-components/test/std/lookup/rs/src/lookup_lib.rs b/pil2-components/test/std/lookup/rs/src/lookup_lib.rs index 2d92a580..57e6176d 100644 --- a/pil2-components/test/std/lookup/rs/src/lookup_lib.rs +++ b/pil2-components/test/std/lookup/rs/src/lookup_lib.rs @@ -2,7 +2,7 @@ use std::{error::Error, path::PathBuf, sync::Arc}; use pil_std_lib::Std; use proofman::{WitnessLibrary, WitnessManager}; -use proofman_common::{initialize_logger, ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; +use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; use p3_field::PrimeField; use p3_goldilocks::Goldilocks; @@ -47,7 +47,7 @@ where } } - pub fn initialize(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn initialize(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { let wcm = Arc::new(WitnessManager::new(pctx, ectx, sctx)); let std_lib = Std::new(wcm.clone()); @@ -73,7 +73,7 @@ impl WitnessLibrary for LookupWitness where Standard: Distribution, { - fn start_proof(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn start_proof(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { self.initialize(pctx.clone(), ectx.clone(), sctx.clone()); self.wcm.as_ref().unwrap().start_proof(pctx, ectx, sctx); @@ -83,7 +83,7 @@ where self.wcm.as_ref().unwrap().end_proof(); } - fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // Execute those components that need to be executed self.lookup0.as_ref().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); self.lookup1.as_ref().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); @@ -93,7 +93,13 @@ where self.lookup3.as_ref().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); } - fn calculate_witness(&mut self, stage: u32, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn calculate_witness( + &mut self, + stage: u32, + pctx: Arc>, + ectx: Arc>, + sctx: Arc>, + ) { self.wcm.as_ref().unwrap().calculate_witness(stage, pctx, ectx, sctx); } @@ -104,11 +110,9 @@ where #[no_mangle] pub extern "Rust" fn init_library( - ectx: Arc, + _: Option, _: Option, ) -> Result>, Box> { - initialize_logger(ectx.verbose_mode); - let lookup_witness = LookupWitness::new(); Ok(Box::new(lookup_witness)) } diff --git a/pil2-components/test/std/permutation/rs/src/permutation1_6.rs b/pil2-components/test/std/permutation/rs/src/permutation1_6.rs index fcfc4da1..1f01445d 100644 --- a/pil2-components/test/std/permutation/rs/src/permutation1_6.rs +++ b/pil2-components/test/std/permutation/rs/src/permutation1_6.rs @@ -26,7 +26,7 @@ where permutation1_6 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // Add two instances of this air, so that 2**6 + 2**6 = 2**7 to fit with permutation2 let (buffer_size, _) = ectx .buffer_allocator @@ -65,8 +65,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/permutation/rs/src/permutation1_7.rs b/pil2-components/test/std/permutation/rs/src/permutation1_7.rs index f5b49665..92f3e250 100644 --- a/pil2-components/test/std/permutation/rs/src/permutation1_7.rs +++ b/pil2-components/test/std/permutation/rs/src/permutation1_7.rs @@ -26,7 +26,7 @@ where permutation1_7 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx .buffer_allocator @@ -55,8 +55,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/permutation/rs/src/permutation1_8.rs b/pil2-components/test/std/permutation/rs/src/permutation1_8.rs index 585aa5f8..983af11f 100644 --- a/pil2-components/test/std/permutation/rs/src/permutation1_8.rs +++ b/pil2-components/test/std/permutation/rs/src/permutation1_8.rs @@ -26,7 +26,7 @@ where permutation1_8 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx .buffer_allocator @@ -55,8 +55,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/permutation/rs/src/permutation2.rs b/pil2-components/test/std/permutation/rs/src/permutation2.rs index 2e5c86b0..f60374a3 100644 --- a/pil2-components/test/std/permutation/rs/src/permutation2.rs +++ b/pil2-components/test/std/permutation/rs/src/permutation2.rs @@ -22,7 +22,7 @@ impl Permutation2 { permutation2 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of each air let (buffer_size, _) = ectx .buffer_allocator @@ -48,8 +48,8 @@ impl WitnessComponent for Permutation2 { stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let air_instances_vec = &mut pctx.air_instance_repo.air_instances.write().unwrap(); let air_instance = &mut air_instances_vec[air_instance_id.unwrap()]; diff --git a/pil2-components/test/std/permutation/rs/src/permutation_lib.rs b/pil2-components/test/std/permutation/rs/src/permutation_lib.rs index 2e12bd44..dea67507 100644 --- a/pil2-components/test/std/permutation/rs/src/permutation_lib.rs +++ b/pil2-components/test/std/permutation/rs/src/permutation_lib.rs @@ -2,7 +2,7 @@ use std::{error::Error, path::PathBuf, sync::Arc}; use pil_std_lib::Std; use proofman::{WitnessLibrary, WitnessManager}; -use proofman_common::{initialize_logger, ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; +use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; use p3_field::PrimeField; use p3_goldilocks::Goldilocks; @@ -43,7 +43,7 @@ where } } - pub fn initialize(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn initialize(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { let wcm = Arc::new(WitnessManager::new(pctx, ectx, sctx)); let std_lib = Std::new(wcm.clone()); @@ -65,7 +65,7 @@ impl WitnessLibrary for PermutationWitness where Standard: Distribution, { - fn start_proof(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn start_proof(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { self.initialize(pctx.clone(), ectx.clone(), sctx.clone()); self.wcm.as_ref().unwrap().start_proof(pctx, ectx, sctx); @@ -75,7 +75,7 @@ where self.wcm.as_ref().unwrap().end_proof(); } - fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // Execute those components that need to be executed self.permutation1_6.as_ref().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); self.permutation1_7.as_ref().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); @@ -83,7 +83,13 @@ where self.permutation2.as_ref().unwrap().execute(pctx, ectx, sctx); } - fn calculate_witness(&mut self, stage: u32, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn calculate_witness( + &mut self, + stage: u32, + pctx: Arc>, + ectx: Arc>, + sctx: Arc>, + ) { self.wcm.as_ref().unwrap().calculate_witness(stage, pctx, ectx, sctx); } @@ -94,11 +100,9 @@ where #[no_mangle] pub extern "Rust" fn init_library( - ectx: Arc, + _: Option, _: Option, ) -> Result>, Box> { - initialize_logger(ectx.verbose_mode); - let permutation_witness = PermutationWitness::new(); Ok(Box::new(permutation_witness)) } diff --git a/pil2-components/test/std/range_check/rs/src/multi_range_check1.rs b/pil2-components/test/std/range_check/rs/src/multi_range_check1.rs index bf6d8833..4a5528c6 100644 --- a/pil2-components/test/std/range_check/rs/src/multi_range_check1.rs +++ b/pil2-components/test/std/range_check/rs/src/multi_range_check1.rs @@ -35,7 +35,7 @@ where multi_range_check1 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of the air let (buffer_size, _) = ectx .buffer_allocator @@ -69,8 +69,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/range_check/rs/src/multi_range_check2.rs b/pil2-components/test/std/range_check/rs/src/multi_range_check2.rs index cdc9227b..49117a46 100644 --- a/pil2-components/test/std/range_check/rs/src/multi_range_check2.rs +++ b/pil2-components/test/std/range_check/rs/src/multi_range_check2.rs @@ -35,7 +35,7 @@ where multi_range_check2 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of the air let (buffer_size, _) = ectx .buffer_allocator @@ -69,8 +69,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/range_check/rs/src/range_check1.rs b/pil2-components/test/std/range_check/rs/src/range_check1.rs index 70582a34..3caeca37 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check1.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check1.rs @@ -31,7 +31,7 @@ where range_check1 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of the air let (buffer_size, _) = ectx .buffer_allocator @@ -60,8 +60,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/range_check/rs/src/range_check2.rs b/pil2-components/test/std/range_check/rs/src/range_check2.rs index 037fc07e..04b33976 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check2.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check2.rs @@ -31,7 +31,7 @@ where range_check1 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of the air let (buffer_size, _) = ectx .buffer_allocator @@ -60,8 +60,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/range_check/rs/src/range_check3.rs b/pil2-components/test/std/range_check/rs/src/range_check3.rs index 7d7c36f0..4cac5a9f 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check3.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check3.rs @@ -31,7 +31,7 @@ where range_check1 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of the air let (buffer_size, _) = ectx .buffer_allocator @@ -60,8 +60,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/range_check/rs/src/range_check4.rs b/pil2-components/test/std/range_check/rs/src/range_check4.rs index bc2d7fb2..bd4eeb6f 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check4.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check4.rs @@ -32,7 +32,7 @@ where range_check4 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of the air let (buffer_size, _) = ectx .buffer_allocator @@ -61,8 +61,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/range_check/rs/src/range_check_dynamic1.rs b/pil2-components/test/std/range_check/rs/src/range_check_dynamic1.rs index c6b33bb4..bd74a7dd 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check_dynamic1.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check_dynamic1.rs @@ -35,7 +35,7 @@ where range_check_dynamic1 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of the air let (buffer_size, _) = ectx .buffer_allocator @@ -72,8 +72,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/range_check/rs/src/range_check_dynamic2.rs b/pil2-components/test/std/range_check/rs/src/range_check_dynamic2.rs index 87e38049..4678c1ff 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check_dynamic2.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check_dynamic2.rs @@ -36,7 +36,7 @@ where range_check_dynamic2 } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of the air let (buffer_size, _) = ectx .buffer_allocator @@ -73,8 +73,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-components/test/std/range_check/rs/src/range_check_lib.rs b/pil2-components/test/std/range_check/rs/src/range_check_lib.rs index 3171bb44..341e532e 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check_lib.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check_lib.rs @@ -2,7 +2,7 @@ use std::{cell::OnceCell, error::Error, path::PathBuf, sync::Arc}; use pil_std_lib::Std; use proofman::{WitnessLibrary, WitnessManager}; -use proofman_common::{initialize_logger, ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; +use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; use p3_field::PrimeField; use p3_goldilocks::Goldilocks; @@ -56,7 +56,7 @@ where } } - fn initialize(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn initialize(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { let wcm = Arc::new(WitnessManager::new(pctx, ectx, sctx)); let std_lib = Std::new(wcm.clone()); @@ -88,7 +88,7 @@ impl WitnessLibrary for RangeCheckWitness where Standard: Distribution, { - fn start_proof(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn start_proof(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { self.initialize(pctx.clone(), ectx.clone(), sctx.clone()); self.wcm.get().unwrap().start_proof(pctx, ectx, sctx); @@ -98,7 +98,7 @@ where self.wcm.get().unwrap().end_proof(); } - fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // Execute those components that need to be executed self.range_check1.get().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); self.range_check2.get().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); @@ -111,7 +111,13 @@ where self.range_check_mix.get().unwrap().execute(pctx.clone(), ectx.clone(), sctx.clone()); } - fn calculate_witness(&mut self, stage: u32, pctx: Arc>, ectx: Arc, sctx: Arc) { + fn calculate_witness( + &mut self, + stage: u32, + pctx: Arc>, + ectx: Arc>, + sctx: Arc>, + ) { self.wcm.get().unwrap().calculate_witness(stage, pctx, ectx, sctx); } @@ -122,11 +128,9 @@ where #[no_mangle] pub extern "Rust" fn init_library( - ectx: Arc, + _: Option, _: Option, ) -> Result>, Box> { - initialize_logger(ectx.verbose_mode); - let range_check_witness = RangeCheckWitness::new(); Ok(Box::new(range_check_witness)) } diff --git a/pil2-components/test/std/range_check/rs/src/range_check_mix.rs b/pil2-components/test/std/range_check/rs/src/range_check_mix.rs index 055a0029..42708038 100644 --- a/pil2-components/test/std/range_check/rs/src/range_check_mix.rs +++ b/pil2-components/test/std/range_check/rs/src/range_check_mix.rs @@ -36,7 +36,7 @@ where range_check_mix } - pub fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { // For simplicity, add a single instance of the air let (buffer_size, _) = ectx .buffer_allocator @@ -65,8 +65,8 @@ where stage: u32, air_instance_id: Option, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { let mut rng = rand::thread_rng(); diff --git a/pil2-stark/lib/include/starks_lib.h b/pil2-stark/lib/include/starks_lib.h index d4a39cd0..8afdf5d6 100644 --- a/pil2-stark/lib/include/starks_lib.h +++ b/pil2-stark/lib/include/starks_lib.h @@ -20,9 +20,7 @@ // SetupCtx // ======================================================================================== - void *setup_ctx_new(void* p_stark_info, void* p_expression_bin, void* p_const_pols); void *get_hint_ids_by_name(void *p_expression_bin, char* hintName); - void setup_ctx_free(void *pSetupCtx); // Stark Info // ======================================================================================== @@ -38,13 +36,18 @@ int64_t get_airgroupvalue_id_by_name(void *pStarkInfo, char* airValueName); void stark_info_free(void *pStarkInfo); + // Prover Helpers + // ======================================================================================== + void *prover_helpers_new(void *pStarkInfo); + void prover_helpers_free(void *pProverHelpers); + // Const Pols // ======================================================================================== - void *const_pols_new(char* filename, void *pStarkInfo, bool calculate_tree); - void *const_pols_with_tree_new(char* filename, char* treeFilename, void *pStarkInfo); - void load_const_tree(void *pConstPols, void *pStarkInfo, char *treeFilename); - void calculate_const_tree(void *pConstPols, void *pStarkInfo); - void const_pols_free(void *pConstPols); + void load_const_tree(void *pConstTree, char *treeFilename, uint64_t constTreeSize); + void load_const_pols(void *pConstPols, char *constFilename, uint64_t constSize); + uint64_t get_const_tree_size(void *pStarkInfo); + uint64_t get_const_size(void *pStarkInfo); + void calculate_const_tree(void *pStarkInfo, void *pConstPolsAddress, void *pConstTree, char *treeFilename); // Expressions Bin // ======================================================================================== @@ -61,13 +64,13 @@ // Starks // ======================================================================================== - void *starks_new(void *pSetupCtx); + void *starks_new(void *pSetupCtx, void *pConstTree); void starks_free(void *pStarks); void treesGL_get_root(void *pStarks, uint64_t index, void *root); void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub); - void *get_fri_pol(void *pSetupCtx, void *buffer); + void *get_fri_pol(void *pStarkInfo, void *buffer); void calculate_fri_polynomial(void *pStarks, void* stepsParams); void calculate_quotient_polynomial(void *pStarks, void* stepsParams); @@ -76,11 +79,10 @@ void commit_stage(void *pStarks, uint32_t elementType, uint64_t step, void *buffer, void *pProof, void *pBuffHelper); void compute_lev(void *pStarks, void *xiChallenge, void* LEv); - void compute_evals(void *pStarks, void *buffer, void *LEv, void *evals, void *pProof); + void compute_evals(void *pStarks, void *params, void *LEv, void *pProof); void calculate_hash(void *pStarks, void *pHhash, void *pBuffer, uint64_t nElements); - void set_const_tree(void *pStarks, void *pConstPols); // MerkleTree // ================================================================================= @@ -123,9 +125,8 @@ // Recursive proof // ================================================================================= - void *gen_recursive_proof(void *pSetupCtx, void* pAddress, void* pPublicInputs, char *proof_file); + void *gen_recursive_proof(void *pSetupCtx, char* globalInfoFile, uint64_t airgroupId, void* pAddress, void *pConstPols, void *pConstTree, void* pPublicInputs, char *proof_file); void *get_zkin_ptr(char *zkin_file); - void *public2zkin(void *pZkin, void* pPublics, char* globalInfoFile, uint64_t airgroupId, bool isAggregated); void *add_recursive2_verkey(void *pZkin, char* recursive2VerKeyFilename); void *join_zkin_recursive2(char* globalInfoFile, uint64_t airgroupId, void* pPublics, void* pChallenges, void *zkin1, void *zkin2, void *starkInfoRecursive2); void *join_zkin_final(void* pPublics, void *pProofValues, void* pChallenges, char* globalInfoFile, void **zkinRecursive2, void **starkInfoRecursive2); diff --git a/pil2-stark/src/api/starks_api.cpp b/pil2-stark/src/api/starks_api.cpp index 167fdcc1..e29fcfa2 100644 --- a/pil2-stark/src/api/starks_api.cpp +++ b/pil2-stark/src/api/starks_api.cpp @@ -7,6 +7,7 @@ #include "gen_recursive_proof.hpp" #include "logger.hpp" #include +#include "setup_ctx.hpp" #include using json = nlohmann::json; @@ -137,11 +138,6 @@ void fri_proof_free(void *pFriProof) // SetupCtx // ======================================================================================== -void *setup_ctx_new(void* p_stark_info, void* p_expression_bin, void* p_const_pols) { - SetupCtx *setupCtx = new SetupCtx(*(StarkInfo*)p_stark_info, *(ExpressionsBin*)p_expression_bin, *(ConstPols *)p_const_pols); - return setupCtx; -} - void* get_hint_ids_by_name(void *p_expression_bin, char* hintName) { ExpressionsBin *expressionsBin = (ExpressionsBin*)p_expression_bin; @@ -150,11 +146,6 @@ void* get_hint_ids_by_name(void *p_expression_bin, char* hintName) return new VecU64Result(hintIds); } -void setup_ctx_free(void *pSetupCtx) { - SetupCtx *setupCtx = (SetupCtx *)pSetupCtx; - delete setupCtx; -} - // StarkInfo // ======================================================================================== void *stark_info_new(char *filename) @@ -218,38 +209,47 @@ void stark_info_free(void *pStarkInfo) delete starkInfo; } -// Const Pols +// Prover Helpers // ======================================================================================== -void *const_pols_new(char* filename, void *pStarkInfo, bool calculate_tree) -{ - auto const_pols = new ConstPols(*(StarkInfo *)pStarkInfo, filename, calculate_tree); - - return const_pols; +void *prover_helpers_new(void *pStarkInfo) { + auto prover_helpers = new ProverHelpers(*(StarkInfo *)pStarkInfo); + return prover_helpers; } -void *const_pols_with_tree_new(char* filename, char* treeFilename, void *pStarkInfo) -{ - auto const_pols = new ConstPols(*(StarkInfo *)pStarkInfo, filename, treeFilename); +void prover_helpers_free(void *pProverHelpers) { + auto proverHelpers = (ProverHelpers *)pProverHelpers; + delete proverHelpers; +}; - return const_pols; -} +// Const Pols +// ======================================================================================== +void load_const_tree(void *pConstTree, char *treeFilename, uint64_t constTreeSize) { + ConstTree constTree; + constTree.loadConstTree((Goldilocks::Element *)pConstTree, treeFilename, constTreeSize * sizeof(Goldilocks::Element)); +}; -void load_const_tree(void *pConstPols, void *pStarkInfo, char *treeFilename) { - ConstPols *constPols = (ConstPols *)pConstPols; - constPols->loadConstTree(*(StarkInfo *)pStarkInfo, treeFilename); -} +void load_const_pols(void *pConstPols, char *constFilename, uint64_t constSize) { + ConstTree constTree; + constTree.loadConstPols((Goldilocks::Element *)pConstPols, constFilename, constSize * sizeof(Goldilocks::Element)); +}; -void calculate_const_tree(void *pConstPols, void *pStarkInfo) { - ConstPols *constPols = (ConstPols *)pConstPols; - constPols->calculateConstTree(*(StarkInfo *)pStarkInfo); -} +uint64_t get_const_tree_size(void *pStarkInfo) { + ConstTree constTree; + return constTree.getConstTreeSizeGL(*(StarkInfo *)pStarkInfo); +}; -void const_pols_free(void *pConstPols) -{ - auto constPols = (ConstPols *)pConstPols; - delete constPols; +uint64_t get_const_size(void *pStarkInfo) { + auto starkInfo = *(StarkInfo *)pStarkInfo; + uint64_t N = 1 << starkInfo.starkStruct.nBits; + return N * starkInfo.nConstants; } + +void calculate_const_tree(void *pStarkInfo, void *pConstPolsAddress, void *pConstTreeAddress, char *treeFilename) { + ConstTree constTree; + constTree.calculateConstTree(*(StarkInfo *)pStarkInfo, (Goldilocks::Element *)pConstPolsAddress, (Goldilocks::Element *)pConstTreeAddress, treeFilename); +}; + // Expressions Bin // ======================================================================================== void *expressions_bin_new(char* filename, bool global) @@ -294,9 +294,9 @@ uint64_t set_hint_field(void *pSetupCtx, void* params, void *values, uint64_t hi // Starks // ======================================================================================== -void *starks_new(void *pSetupCtx) +void *starks_new(void *pSetupCtx, void* pConstTree) { - return new Starks(*(SetupCtx *)pSetupCtx); + return new Starks(*(SetupCtx *)pSetupCtx, (Goldilocks::Element*) pConstTree); } void starks_free(void *pStarks) @@ -350,10 +350,10 @@ void compute_lev(void *pStarks, void *xiChallenge, void* LEv) { starks->computeLEv((Goldilocks::Element *)xiChallenge, (Goldilocks::Element *)LEv); } -void compute_evals(void *pStarks, void *buffer, void *LEv, void *evals, void *pProof) +void compute_evals(void *pStarks, void *params, void *LEv, void *pProof) { Starks *starks = (Starks *)pStarks; - starks->computeEvals((Goldilocks::Element *)buffer, (Goldilocks::Element *)LEv, (Goldilocks::Element *)evals, *(FRIProof *)pProof); + starks->computeEvals(*(StepsParams *)params, (Goldilocks::Element *)LEv, *(FRIProof *)pProof); } void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub) @@ -362,12 +362,12 @@ void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub) starks->calculateXDivXSub((Goldilocks::Element *)xiChallenge, (Goldilocks::Element *)xDivXSub); } -void *get_fri_pol(void *pSetupCtx, void *buffer) +void *get_fri_pol(void *pStarkInfo, void *buffer) { - SetupCtx setupCtx = *(SetupCtx *)pSetupCtx; + StarkInfo starkInfo = *(StarkInfo *)pStarkInfo; auto pols = (Goldilocks::Element *)buffer; - return &pols[setupCtx.starkInfo.mapOffsets[std::make_pair("f", true)]]; + return &pols[starkInfo.mapOffsets[std::make_pair("f", true)]]; } void calculate_hash(void *pStarks, void *pHhash, void *pBuffer, uint64_t nElements) @@ -376,12 +376,6 @@ void calculate_hash(void *pStarks, void *pHhash, void *pBuffer, uint64_t nElemen starks->calculateHash((Goldilocks::Element *)pHhash, (Goldilocks::Element *)pBuffer, nElements); } -void set_const_tree(void *pStarks, void *pConstPols) -{ - Starks *starks = (Starks *)pStarks; - starks->setConstTree(*(ConstPols *)pConstPols); -} - // MerkleTree // ================================================================================= void *merkle_tree_new(uint64_t height, uint64_t width, uint64_t arity, bool custom) { @@ -528,8 +522,11 @@ void print_row(void *pSetupCtx, void *buffer, uint64_t stage, uint64_t row) { // Recursive proof // ================================================================================= -void *gen_recursive_proof(void *pSetupCtx, void* pAddress, void* pPublicInputs, char* proof_file) { - return genRecursiveProof(*(SetupCtx *)pSetupCtx, (Goldilocks::Element *)pAddress, (Goldilocks::Element *)pPublicInputs, string(proof_file)); +void *gen_recursive_proof(void *pSetupCtx, char* globalInfoFile, uint64_t airgroupId, void* pAddress, void *pConstPols, void *pConstTree, void* pPublicInputs, char* proof_file) { + json globalInfo; + file2json(globalInfoFile, globalInfo); + + return genRecursiveProof(*(SetupCtx *)pSetupCtx, globalInfo, airgroupId, (Goldilocks::Element *)pAddress, (Goldilocks::Element *)pConstPols, (Goldilocks::Element *)pConstTree, (Goldilocks::Element *)pPublicInputs, string(proof_file)); } void *get_zkin_ptr(char *zkin_file) { @@ -539,14 +536,6 @@ void *get_zkin_ptr(char *zkin_file) { return (void *) new nlohmann::ordered_json(zkin); } -void *public2zkin(void *pZkin, void* pPublics, char* globalInfoFile, uint64_t airgroupId, bool isAggregated) { - json globalInfo; - file2json(globalInfoFile, globalInfo); - - nlohmann::ordered_json zkin = *(nlohmann::ordered_json*) pZkin; - return publics2zkin(zkin, (Goldilocks::Element *)pPublics, globalInfo, airgroupId, isAggregated); -} - void *add_recursive2_verkey(void *pZkin, char* recursive2VerKeyFilename) { json recursive2VerkeyJson; file2json(recursive2VerKeyFilename, recursive2VerkeyJson); @@ -557,8 +546,8 @@ void *add_recursive2_verkey(void *pZkin, char* recursive2VerKeyFilename) { recursive2Verkey[i] = Goldilocks::fromU64(recursive2VerkeyJson[i]); } - nlohmann::ordered_json zkin = *(nlohmann::ordered_json*) pZkin; - return addRecursive2VerKey(zkin, recursive2Verkey); + ordered_json zkin = addRecursive2VerKey(*(nlohmann::ordered_json*) pZkin, recursive2Verkey); + return (void *) new nlohmann::ordered_json(zkin); } void *join_zkin_recursive2(char* globalInfoFile, uint64_t airgroupId, void* pPublics, void* pChallenges, void *zkin1, void *zkin2, void *starkInfoRecursive2) { diff --git a/pil2-stark/src/api/starks_api.hpp b/pil2-stark/src/api/starks_api.hpp index d4a39cd0..8afdf5d6 100644 --- a/pil2-stark/src/api/starks_api.hpp +++ b/pil2-stark/src/api/starks_api.hpp @@ -20,9 +20,7 @@ // SetupCtx // ======================================================================================== - void *setup_ctx_new(void* p_stark_info, void* p_expression_bin, void* p_const_pols); void *get_hint_ids_by_name(void *p_expression_bin, char* hintName); - void setup_ctx_free(void *pSetupCtx); // Stark Info // ======================================================================================== @@ -38,13 +36,18 @@ int64_t get_airgroupvalue_id_by_name(void *pStarkInfo, char* airValueName); void stark_info_free(void *pStarkInfo); + // Prover Helpers + // ======================================================================================== + void *prover_helpers_new(void *pStarkInfo); + void prover_helpers_free(void *pProverHelpers); + // Const Pols // ======================================================================================== - void *const_pols_new(char* filename, void *pStarkInfo, bool calculate_tree); - void *const_pols_with_tree_new(char* filename, char* treeFilename, void *pStarkInfo); - void load_const_tree(void *pConstPols, void *pStarkInfo, char *treeFilename); - void calculate_const_tree(void *pConstPols, void *pStarkInfo); - void const_pols_free(void *pConstPols); + void load_const_tree(void *pConstTree, char *treeFilename, uint64_t constTreeSize); + void load_const_pols(void *pConstPols, char *constFilename, uint64_t constSize); + uint64_t get_const_tree_size(void *pStarkInfo); + uint64_t get_const_size(void *pStarkInfo); + void calculate_const_tree(void *pStarkInfo, void *pConstPolsAddress, void *pConstTree, char *treeFilename); // Expressions Bin // ======================================================================================== @@ -61,13 +64,13 @@ // Starks // ======================================================================================== - void *starks_new(void *pSetupCtx); + void *starks_new(void *pSetupCtx, void *pConstTree); void starks_free(void *pStarks); void treesGL_get_root(void *pStarks, uint64_t index, void *root); void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub); - void *get_fri_pol(void *pSetupCtx, void *buffer); + void *get_fri_pol(void *pStarkInfo, void *buffer); void calculate_fri_polynomial(void *pStarks, void* stepsParams); void calculate_quotient_polynomial(void *pStarks, void* stepsParams); @@ -76,11 +79,10 @@ void commit_stage(void *pStarks, uint32_t elementType, uint64_t step, void *buffer, void *pProof, void *pBuffHelper); void compute_lev(void *pStarks, void *xiChallenge, void* LEv); - void compute_evals(void *pStarks, void *buffer, void *LEv, void *evals, void *pProof); + void compute_evals(void *pStarks, void *params, void *LEv, void *pProof); void calculate_hash(void *pStarks, void *pHhash, void *pBuffer, uint64_t nElements); - void set_const_tree(void *pStarks, void *pConstPols); // MerkleTree // ================================================================================= @@ -123,9 +125,8 @@ // Recursive proof // ================================================================================= - void *gen_recursive_proof(void *pSetupCtx, void* pAddress, void* pPublicInputs, char *proof_file); + void *gen_recursive_proof(void *pSetupCtx, char* globalInfoFile, uint64_t airgroupId, void* pAddress, void *pConstPols, void *pConstTree, void* pPublicInputs, char *proof_file); void *get_zkin_ptr(char *zkin_file); - void *public2zkin(void *pZkin, void* pPublics, char* globalInfoFile, uint64_t airgroupId, bool isAggregated); void *add_recursive2_verkey(void *pZkin, char* recursive2VerKeyFilename); void *join_zkin_recursive2(char* globalInfoFile, uint64_t airgroupId, void* pPublics, void* pChallenges, void *zkin1, void *zkin2, void *starkInfoRecursive2); void *join_zkin_final(void* pPublics, void *pProofValues, void* pChallenges, char* globalInfoFile, void **zkinRecursive2, void **starkInfoRecursive2); diff --git a/pil2-stark/src/bctree/build_const_tree.cpp b/pil2-stark/src/bctree/build_const_tree.cpp index c7dc7a8c..2c0d0008 100644 --- a/pil2-stark/src/bctree/build_const_tree.cpp +++ b/pil2-stark/src/bctree/build_const_tree.cpp @@ -63,12 +63,7 @@ void buildConstTree(const string constFile, const string starkInfoFile, const st // ConstTree if(constTreeFile != "") { - ofstream fw(constTreeFile.c_str(), std::fstream::out | std::fstream::binary); - fw.write((const char *)&(nPols), sizeof(uint64_t)); - fw.write((const char *)&(NExtended), sizeof(uint64_t)); - fw.write((const char *)pConstPolsExt, nPols * NExtended * sizeof(Goldilocks::Element)); - fw.write((const char *)mt.nodes, mt.numNodes * sizeof(Goldilocks::Element)); - fw.close(); + mt.writeFile(constTreeFile); } TimerStopAndLog(GENERATING_FILES); @@ -95,12 +90,7 @@ void buildConstTree(const string constFile, const string starkInfoFile, const st // ConstTree if(constTreeFile != "") { - std::ofstream fw(constTreeFile.c_str(), std::fstream::out | std::fstream::binary); - fw.write((const char *)&(mt.width), sizeof(mt.width)); - fw.write((const char *)&(mt.height), sizeof(mt.height)); - fw.write((const char *)mt.source, nPols * NExtended * sizeof(Goldilocks::Element)); - fw.write((const char *)mt.nodes, mt.numNodes * sizeof(RawFr::Element)); - fw.close(); + mt.writeFile(constTreeFile); } TimerStopAndLog(GENERATING_FILES); } else { @@ -110,4 +100,4 @@ void buildConstTree(const string constFile, const string starkInfoFile, const st free(pConstPolsExt); TimerStopAndLog(BUILD_CONST_TREE); -} +} \ No newline at end of file diff --git a/pil2-stark/src/starkpil/const_pols.hpp b/pil2-stark/src/starkpil/const_pols.hpp index 21e7620b..4158bedb 100644 --- a/pil2-stark/src/starkpil/const_pols.hpp +++ b/pil2-stark/src/starkpil/const_pols.hpp @@ -12,164 +12,11 @@ #include "merkleTreeBN128.hpp" #include "merkleTreeGL.hpp" - -class ConstPols -{ +class ConstTree { public: - Goldilocks::Element *pConstPolsAddress = nullptr; - Goldilocks::Element *pConstPolsAddressExtended; - Goldilocks::Element *pConstTreeAddress = nullptr; - Goldilocks::Element *zi = nullptr; - Goldilocks::Element *S = nullptr; - Goldilocks::Element *x = nullptr; - Goldilocks::Element *x_n = nullptr; // Needed for PIL1 compatibility - Goldilocks::Element *x_2ns = nullptr; // Needed for PIL1 compatibility - - ConstPols(StarkInfo& starkInfo, std::string constPolsFile, bool calculateTree = true) { - - loadConstPols(starkInfo, constPolsFile); - - if(calculateTree) { - calculateConstTree(starkInfo); - } - - computeZerofier(starkInfo); - - computeX(starkInfo); - - computeConnectionsX(starkInfo); // Needed for PIL1 compatibility - } - - ConstPols(StarkInfo& starkInfo, std::string constPolsFile, std::string constTreeFile) { - - loadConstPols(starkInfo, constPolsFile); - - loadConstTree(starkInfo, constTreeFile); - - computeZerofier(starkInfo); - - computeX(starkInfo); - - computeConnectionsX(starkInfo); // Needed for PIL1 compatibility - } - - // For verification only - ConstPols(StarkInfo& starkInfo, Goldilocks::Element* z, Goldilocks::Element* constVals) { - pConstPolsAddress = (Goldilocks::Element *)malloc(starkInfo.nConstants * starkInfo.starkStruct.nQueries * sizeof(Goldilocks::Element)); - for(uint64_t i = 0; i < starkInfo.nConstants * starkInfo.starkStruct.nQueries; ++i) { - pConstPolsAddress[i] = constVals[i]; - } - - - zi = new Goldilocks::Element[starkInfo.boundaries.size() * FIELD_EXTENSION]; - - Goldilocks::Element one[3] = {Goldilocks::one(), Goldilocks::zero(), Goldilocks::zero()}; - - Goldilocks::Element xN[3] = {Goldilocks::one(), Goldilocks::zero(), Goldilocks::zero()}; - for(uint64_t i = 0; i < uint64_t(1 << starkInfo.starkStruct.nBits); ++i) { - Goldilocks3::mul((Goldilocks3::Element *)xN, (Goldilocks3::Element *)xN, (Goldilocks3::Element *)z); - } - - Goldilocks::Element zN[3] = { xN[0] - Goldilocks::one(), xN[1], xN[2]}; - Goldilocks::Element zNInv[3]; - Goldilocks3::inv((Goldilocks3::Element *)zNInv, (Goldilocks3::Element *)zN); - std::memcpy(&zi[0], zNInv, FIELD_EXTENSION * sizeof(Goldilocks::Element)); - - for(uint64_t i = 1; i < starkInfo.boundaries.size(); ++i) { - Boundary boundary = starkInfo.boundaries[i]; - if(boundary.name == "firstRow") { - Goldilocks::Element zi_[3]; - Goldilocks3::sub((Goldilocks3::Element &)zi_[0], (Goldilocks3::Element &)z[0], (Goldilocks3::Element &)one[0]); - Goldilocks3::inv((Goldilocks3::Element *)zi_, (Goldilocks3::Element *)zi_); - Goldilocks3::mul((Goldilocks3::Element *)zi_, (Goldilocks3::Element *)zi_, (Goldilocks3::Element *)zN); - std::memcpy(&zi[i*FIELD_EXTENSION], zi_, FIELD_EXTENSION * sizeof(Goldilocks::Element)); - } else if(boundary.name == "lastRow") { - Goldilocks::Element root = Goldilocks::one(); - for(uint64_t i = 0; i < uint64_t(1 << starkInfo.starkStruct.nBits) - 1; ++i) { - root = root * Goldilocks::w(starkInfo.starkStruct.nBits); - } - Goldilocks::Element zi_[3]; - Goldilocks3::sub((Goldilocks3::Element &)zi_[0], (Goldilocks3::Element &)z[0], (Goldilocks3::Element &)root); - Goldilocks3::inv((Goldilocks3::Element *)zi_, (Goldilocks3::Element *)zi_); - Goldilocks3::mul((Goldilocks3::Element *)zi_, (Goldilocks3::Element *)zi_, (Goldilocks3::Element *)zN); - std::memcpy(&zi[i*FIELD_EXTENSION], zi_, FIELD_EXTENSION * sizeof(Goldilocks::Element)); - } else if(boundary.name == "everyRow") { - uint64_t nRoots = boundary.offsetMin + boundary.offsetMax; - Goldilocks::Element roots[nRoots]; - Goldilocks::Element zi_[3] = { Goldilocks::one(), Goldilocks::zero(), Goldilocks::zero()}; - for(uint64_t i = 0; i < boundary.offsetMin; ++i) { - roots[i] = Goldilocks::one(); - for(uint64_t j = 0; j < i; ++j) { - roots[i] = roots[i] * Goldilocks::w(starkInfo.starkStruct.nBits); - } - Goldilocks::Element aux[3]; - Goldilocks3::sub((Goldilocks3::Element &)aux[0], (Goldilocks3::Element &)z[0], (Goldilocks3::Element &)roots[i]); - Goldilocks3::mul((Goldilocks3::Element *)zi_, (Goldilocks3::Element *)zi_, (Goldilocks3::Element *)aux); - } + ConstTree () {}; - for(uint64_t i = 0; i < boundary.offsetMax; ++i) { - roots[i + boundary.offsetMin] = Goldilocks::one(); - for(uint64_t j = 0; j < (uint64_t(1 << starkInfo.starkStruct.nBits) - i - 1); ++j) { - roots[i + boundary.offsetMin] = roots[i + boundary.offsetMin] * Goldilocks::w(starkInfo.starkStruct.nBits); - } - Goldilocks::Element aux[3]; - Goldilocks3::sub((Goldilocks3::Element &)aux[0], (Goldilocks3::Element &)z[0], (Goldilocks3::Element &)roots[i + boundary.offsetMin]); - Goldilocks3::mul((Goldilocks3::Element *)zi_, (Goldilocks3::Element *)zi_, (Goldilocks3::Element *)aux); - } - - std::memcpy(&zi[i*FIELD_EXTENSION], zi_, FIELD_EXTENSION * sizeof(Goldilocks::Element)); - } - } - - x_n = new Goldilocks::Element[FIELD_EXTENSION]; - x_n[0] = z[0]; - x_n[1] = z[1]; - x_n[2] = z[2]; - }; - - void calculateConstTree(StarkInfo& starkInfo) { - pConstTreeAddress = (Goldilocks::Element *)malloc(getConstTreeSize(starkInfo)); - if(pConstTreeAddress == NULL) - { - zklog.error("Starks::Starks() failed to allocate pConstTreeAddress"); - exitProcess(); - } - pConstPolsAddressExtended = &pConstTreeAddress[2]; - - uint64_t merkleTreeArity = starkInfo.starkStruct.verificationHashType == std::string("BN128") ? starkInfo.starkStruct.merkleTreeArity : 2; - uint64_t merkleTreeCustom = starkInfo.starkStruct.verificationHashType == std::string("BN128") ? starkInfo.starkStruct.merkleTreeCustom : true; - uint64_t N = 1 << starkInfo.starkStruct.nBits; - uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; - NTT_Goldilocks ntt(N); - ntt.extendPol((Goldilocks::Element *)pConstPolsAddressExtended, (Goldilocks::Element *)pConstPolsAddress, NExtended, N, starkInfo.nConstants); - MerkleTreeGL mt(merkleTreeArity, merkleTreeCustom, NExtended, starkInfo.nConstants, (Goldilocks::Element *)pConstPolsAddressExtended); - mt.merkelize(); - - pConstTreeAddress[0] = Goldilocks::fromU64(starkInfo.nConstants); - pConstTreeAddress[1] = Goldilocks::fromU64(NExtended); - memcpy(&pConstTreeAddress[2 + starkInfo.nConstants * NExtended], mt.nodes, mt.numNodes * sizeof(Goldilocks::Element)); - } - - void loadConstTree(StarkInfo& starkInfo, std::string constTreeFile) { - uint64_t constTreeSizeBytes = getConstTreeSize(starkInfo); - - pConstTreeAddress = (Goldilocks::Element *)loadFileParallel(constTreeFile, constTreeSizeBytes); - - pConstPolsAddressExtended = &pConstTreeAddress[2]; - } - - void loadConstPols(StarkInfo& starkInfo, std::string constPolsFile) { - // Allocate an area of memory, mapped to file, to read all the constant polynomials, - // and create them using the allocated address - - uint64_t N = 1 << starkInfo.starkStruct.nBits; - uint64_t constPolsSize = starkInfo.nConstants * sizeof(Goldilocks::Element) * N; - - pConstPolsAddress = (Goldilocks::Element *)loadFileParallel(constPolsFile, constPolsSize); - } - - uint64_t getConstTreeSize(StarkInfo& starkInfo) - { + uint64_t getNumNodes(StarkInfo& starkInfo) { uint64_t merkleTreeArity = starkInfo.starkStruct.verificationHashType == std::string("BN128") ? starkInfo.starkStruct.merkleTreeArity : 2; uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; uint n_tmp = NExtended; @@ -190,160 +37,52 @@ class ConstPols } } - uint64_t elementSize = starkInfo.starkStruct.verificationHashType == std::string("BN128") ? sizeof(RawFr::Element) : sizeof(Goldilocks::Element); - uint64_t numElements = NExtended * starkInfo.nConstants * sizeof(Goldilocks::Element); - uint64_t nFieldElements = starkInfo.starkStruct.verificationHashType == std::string("BN128") ? 1 : HASH_SIZE; - uint64_t total = numElements + acc * nFieldElements * elementSize; - if(starkInfo.starkStruct.verificationHashType == std::string("BN128")) { - total += 16; // HEADER - } else { - total += merkleTreeArity * elementSize; - } - return total; - - }; + return acc; + } - void computeZerofier(StarkInfo& starkInfo) { - uint64_t N = 1 << starkInfo.starkStruct.nBits; + uint64_t getConstTreeSizeBytesBN128B(StarkInfo& starkInfo) + { uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; - zi = new Goldilocks::Element[starkInfo.boundaries.size() * NExtended]; - - for(uint64_t i = 0; i < starkInfo.boundaries.size(); ++i) { - Boundary boundary = starkInfo.boundaries[i]; - if(boundary.name == "everyRow") { - buildZHInv(starkInfo); - } else if(boundary.name == "firstRow") { - buildOneRowZerofierInv(starkInfo, i, 0); - } else if(boundary.name == "lastRow") { - buildOneRowZerofierInv(starkInfo, i, N); - } else if(boundary.name == "everyRow") { - buildFrameZerofierInv(starkInfo, i, boundary.offsetMin, boundary.offsetMax); - } - } + uint64_t acc = getNumNodes(starkInfo); + return 16 + (NExtended * starkInfo.nConstants) * sizeof(Goldilocks::Element) + acc * sizeof(RawFr::Element); } - void computeConnectionsX(StarkInfo& starkInfo) { - uint64_t N = 1 << starkInfo.starkStruct.nBits; + uint64_t getConstTreeSizeGL(StarkInfo& starkInfo) + { uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; - x_n = new Goldilocks::Element[N]; - Goldilocks::Element xx = Goldilocks::one(); - for (uint64_t i = 0; i < N; i++) - { - x_n[i] = xx; - Goldilocks::mul(xx, xx, Goldilocks::w(starkInfo.starkStruct.nBits)); - } - xx = Goldilocks::shift(); - x_2ns = new Goldilocks::Element[NExtended]; - for (uint64_t i = 0; i < NExtended; i++) - { - x_2ns[i] = xx; - Goldilocks::mul(xx, xx, Goldilocks::w(starkInfo.starkStruct.nBitsExt)); - } + uint64_t acc = getNumNodes(starkInfo); + return 2 + (NExtended * starkInfo.nConstants) + acc * HASH_SIZE; } - void computeX(StarkInfo& starkInfo) { + Goldilocks::Element* calculateConstTree(StarkInfo& starkInfo, Goldilocks::Element *pConstPolsAddress, Goldilocks::Element *treeAddress, std::string constTreeFile) { + uint64_t merkleTreeArity = starkInfo.starkStruct.verificationHashType == std::string("BN128") ? starkInfo.starkStruct.merkleTreeArity : 2; + uint64_t merkleTreeCustom = starkInfo.starkStruct.verificationHashType == std::string("BN128") ? starkInfo.starkStruct.merkleTreeCustom : true; uint64_t N = 1 << starkInfo.starkStruct.nBits; - uint64_t extendBits = starkInfo.starkStruct.nBitsExt - starkInfo.starkStruct.nBits; - x = new Goldilocks::Element[N << extendBits]; - x[0] = Goldilocks::shift(); - for (uint64_t k = 1; k < (N << extendBits); k++) - { - x[k] = x[k - 1] * Goldilocks::w(starkInfo.starkStruct.nBits + extendBits); - } - - S = new Goldilocks::Element[starkInfo.qDeg]; - Goldilocks::Element shiftIn = Goldilocks::exp(Goldilocks::inv(Goldilocks::shift()), N); - S[0] = Goldilocks::one(); - for(uint64_t i = 1; i < starkInfo.qDeg; i++) { - S[i] = Goldilocks::mul(S[i - 1], shiftIn); - } - } - - void buildZHInv(StarkInfo& starkInfo) - { - uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; - uint64_t extendBits = starkInfo.starkStruct.nBitsExt - starkInfo.starkStruct.nBits; - uint64_t extend = (1 << extendBits); - - Goldilocks::Element w = Goldilocks::one(); - Goldilocks::Element sn = Goldilocks::shift(); - - for (uint64_t i = 0; i < starkInfo.starkStruct.nBits; i++) Goldilocks::square(sn, sn); - - for (uint64_t i=0; i constPolsUsed(setupCtx.starkInfo.constPolsMap.size(), false); std::vector cmPolsUsed(setupCtx.starkInfo.cmPolsMap.size(), false); @@ -107,12 +107,12 @@ class ExpressionsAvx : public ExpressionsCtx { if(dests[0].params[0].parserParams.expId == int64_t(setupCtx.starkInfo.cExpId)) { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT[j] = setupCtx.constPols.x_2ns[row + j]; + bufferT[j] = setupCtx.proverHelpers.x_2ns[row + j]; } Goldilocks::load_avx(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings]], &bufferT[0]); for(uint64_t d = 0; d < setupCtx.starkInfo.boundaries.size(); ++d) { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT[j] = setupCtx.constPols.zi[row + j + d*domainSize]; + bufferT[j] = setupCtx.proverHelpers.zi[row + j + d*domainSize]; } Goldilocks::load_avx(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + 1 + d], &bufferT[0]); } @@ -127,7 +127,7 @@ class ExpressionsAvx : public ExpressionsCtx { } } else { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT[j] = setupCtx.constPols.x_n[row + j]; + bufferT[j] = setupCtx.proverHelpers.x_n[row + j]; } Goldilocks::load_avx(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings]], &bufferT[0]); } diff --git a/pil2-stark/src/starkpil/expressions_avx512.hpp b/pil2-stark/src/starkpil/expressions_avx512.hpp index fe53391b..2ea3bf7d 100644 --- a/pil2-stark/src/starkpil/expressions_avx512.hpp +++ b/pil2-stark/src/starkpil/expressions_avx512.hpp @@ -52,7 +52,7 @@ class ExpressionsAvx512 : public ExpressionsCtx { nextStrides[i] = opening * extend; } - Goldilocks::Element *constPols = domainExtended ? setupCtx.constPols.pConstPolsAddressExtended : setupCtx.constPols.pConstPolsAddress; + Goldilocks::Element *constPols = domainExtended ? ¶ms.pConstPolsExtendedTreeAddress[2] : params.pConstPolsAddress; std::vector constPolsUsed(setupCtx.starkInfo.constPolsMap.size(), false); std::vector cmPolsUsed(setupCtx.starkInfo.cmPolsMap.size(), false); @@ -107,12 +107,12 @@ class ExpressionsAvx512 : public ExpressionsCtx { if(dests[0].params[0].parserParams.expId == int64_t(setupCtx.starkInfo.cExpId)) { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT[j] = setupCtx.constPols.x_2ns[row + j]; + bufferT[j] = setupCtx.proverHelpers.x_2ns[row + j]; } Goldilocks::load_avx512(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings]], &bufferT[0]); for(uint64_t d = 0; d < setupCtx.starkInfo.boundaries.size(); ++d) { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT[j] = setupCtx.constPols.zi[row + j + d*domainSize]; + bufferT[j] = setupCtx.proverHelpers.zi[row + j + d*domainSize]; } Goldilocks::load_avx512(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + 1 + d], &bufferT[0]); } @@ -127,7 +127,7 @@ class ExpressionsAvx512 : public ExpressionsCtx { } } else { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT[j] = setupCtx.constPols.x_n[row + j]; + bufferT[j] = setupCtx.proverHelpers.x_n[row + j]; } Goldilocks::load_avx512(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings]], &bufferT[0]); } diff --git a/pil2-stark/src/starkpil/expressions_pack.hpp b/pil2-stark/src/starkpil/expressions_pack.hpp index b80c1352..2535b2b7 100644 --- a/pil2-stark/src/starkpil/expressions_pack.hpp +++ b/pil2-stark/src/starkpil/expressions_pack.hpp @@ -50,7 +50,7 @@ class ExpressionsPack : public ExpressionsCtx { nextStrides[i] = opening * extend; } - Goldilocks::Element *constPols = domainExtended ? setupCtx.constPols.pConstPolsAddressExtended : setupCtx.constPols.pConstPolsAddress; + Goldilocks::Element *constPols = domainExtended ? ¶ms.pConstPolsExtendedTreeAddress[2] : params.pConstPolsAddress; std::vector constPolsUsed(setupCtx.starkInfo.constPolsMap.size(), false); std::vector cmPolsUsed(setupCtx.starkInfo.cmPolsMap.size(), false); @@ -102,11 +102,11 @@ class ExpressionsPack : public ExpressionsCtx { if(dests[0].params[0].parserParams.expId == int64_t(setupCtx.starkInfo.cExpId)) { for(uint64_t d = 0; d < setupCtx.starkInfo.boundaries.size(); ++d) { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + d + 1)*nrowsPack + j] = setupCtx.constPols.zi[row + j + d*domainSize]; + bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + d + 1)*nrowsPack + j] = setupCtx.proverHelpers.zi[row + j + d*domainSize]; } } for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings])*nrowsPack + j] = setupCtx.constPols.x_2ns[row + j]; + bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings])*nrowsPack + j] = setupCtx.proverHelpers.x_2ns[row + j]; } } else if(dests[0].params[0].parserParams.expId == int64_t(setupCtx.starkInfo.friExpId)) { for(uint64_t d = 0; d < setupCtx.starkInfo.openingPoints.size(); ++d) { @@ -118,7 +118,7 @@ class ExpressionsPack : public ExpressionsCtx { } } else { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings])*nrowsPack + j] = setupCtx.constPols.x[row + j]; + bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings])*nrowsPack + j] = setupCtx.proverHelpers.x[row + j]; } } } diff --git a/pil2-stark/src/starkpil/gen_recursive_proof.hpp b/pil2-stark/src/starkpil/gen_recursive_proof.hpp index f10a72de..f2915a54 100644 --- a/pil2-stark/src/starkpil/gen_recursive_proof.hpp +++ b/pil2-stark/src/starkpil/gen_recursive_proof.hpp @@ -1,14 +1,14 @@ #include "starks.hpp" template -void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldilocks::Element *publicInputs, std::string proofFile) { +void *genRecursiveProof(SetupCtx& setupCtx, json& globalInfo, uint64_t airgroupId, Goldilocks::Element *pAddress, Goldilocks::Element *pConstPols, Goldilocks::Element *pConstTree, Goldilocks::Element *publicInputs, std::string proofFile) { TimerStart(STARK_PROOF); FRIProof proof(setupCtx.starkInfo); using TranscriptType = std::conditional_t::value, TranscriptGL, TranscriptBN128>; - Starks starks(setupCtx); + Starks starks(setupCtx, pConstTree); #ifdef __AVX512__ ExpressionsAvx512 expressionsCtx(setupCtx); @@ -36,6 +36,8 @@ void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldi airgroupValues : airgroupValues, evals : evals, xDivXSub : nullptr, + pConstPolsAddress: pConstPols, + pConstPolsExtendedTreeAddress: pConstTree, }; for (uint64_t i = 0; i < setupCtx.starkInfo.mapSectionsN["cm1"]; ++i) @@ -51,7 +53,6 @@ void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldi ElementType verkey[nFieldElements]; starks.treesGL[setupCtx.starkInfo.nStages + 1]->getRoot(verkey); starks.addTranscript(transcript, &verkey[0], nFieldElements); - if(setupCtx.starkInfo.nPublics > 0) { if(!setupCtx.starkInfo.starkStruct.hashCommits) { starks.addTranscriptGL(transcript, &publicInputs[0], setupCtx.starkInfo.nPublics); @@ -171,7 +172,7 @@ void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldi } } - starks.calculateQuotientPolynomial(params); + expressionsCtx.calculateExpression(params, ¶ms.pols[setupCtx.starkInfo.mapOffsets[std::make_pair("q", true)]], setupCtx.starkInfo.cExpId); for(uint64_t i = 0; i < setupCtx.starkInfo.cmPolsMap.size(); i++) { if(setupCtx.starkInfo.cmPolsMap[i].stage == setupCtx.starkInfo.nStages + 1) { @@ -199,7 +200,7 @@ void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldi Goldilocks::Element* LEv = &pAddress[setupCtx.starkInfo.mapOffsets[make_pair("LEv", true)]]; starks.computeLEv(xiChallenge, LEv); - starks.computeEvals(pAddress,LEv, evals, proof); + starks.computeEvals(params ,LEv, proof); if(!setupCtx.starkInfo.starkStruct.hashCommits) { starks.addTranscriptGL(transcript, evals, setupCtx.starkInfo.evMap.size() * FIELD_EXTENSION); @@ -290,5 +291,7 @@ void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldi TimerStopAndLog(STARK_PROOF); + zkin = publics2zkin(zkin, publicInputs, globalInfo, airgroupId); + return (void *) new nlohmann::ordered_json(zkin); } diff --git a/pil2-stark/src/starkpil/hints.hpp b/pil2-stark/src/starkpil/hints.hpp index 2dc2016d..957745ff 100644 --- a/pil2-stark/src/starkpil/hints.hpp +++ b/pil2-stark/src/starkpil/hints.hpp @@ -44,8 +44,7 @@ void getPolynomial(SetupCtx& setupCtx, Goldilocks::Element *buffer, Goldilocks:: uint64_t nCols = setupCtx.starkInfo.mapSectionsN[stage]; uint64_t offset = setupCtx.starkInfo.mapOffsets[std::make_pair(stage, domainExtended)]; offset += polInfo.stagePos; - Goldilocks::Element *pols = committed ? buffer : domainExtended ? setupCtx.constPols.pConstPolsAddressExtended : setupCtx.constPols.pConstPolsAddress; - Polinomial pol = Polinomial(&pols[offset], deg, dim, nCols, std::to_string(idPol)); + Polinomial pol = Polinomial(&buffer[offset], deg, dim, nCols, std::to_string(idPol)); #pragma omp parallel for for(uint64_t j = 0; j < deg; ++j) { std::memcpy(&dest[j*dim], pol[j], dim * sizeof(Goldilocks::Element)); @@ -112,11 +111,11 @@ void printRow(SetupCtx& setupCtx, Goldilocks::Element* buffer, uint64_t stage, u cout << "}" << endl; } -void printColById(SetupCtx& setupCtx, Goldilocks::Element* buffer, bool committed, uint64_t polId, uint64_t firstPrintValue = 0, uint64_t lastPrintValue = 0) +void printColById(SetupCtx& setupCtx, StepsParams ¶ms, bool committed, uint64_t polId, uint64_t firstPrintValue = 0, uint64_t lastPrintValue = 0) { uint64_t N = 1 << setupCtx.starkInfo.starkStruct.nBits; PolMap polInfo = committed ? setupCtx.starkInfo.cmPolsMap[polId] : setupCtx.starkInfo.constPolsMap[polId]; - Goldilocks::Element *pols = committed ? buffer : setupCtx.constPols.pConstPolsAddress; + Goldilocks::Element *pols = committed ? params.pols : params.pConstPolsAddress; Polinomial p; setupCtx.starkInfo.getPolynomial(p, pols, committed, polId, false); @@ -161,7 +160,7 @@ HintFieldInfo printByName(SetupCtx& setupCtx, StepsParams& params, string name, if(!lengths_match) continue; } if(cmPol.name == name) { - printColById(setupCtx, params.pols, true, i, firstPrintValue, lastPrintValue); + printColById(setupCtx, params, true, i, firstPrintValue, lastPrintValue); if(returnValues) { hintFieldInfo.size = cmPol.dim * N; hintFieldInfo.values = new Goldilocks::Element[hintFieldInfo.size]; @@ -187,7 +186,7 @@ HintFieldInfo printByName(SetupCtx& setupCtx, StepsParams& params, string name, if(!lengths_match) continue; } if(constPol.name == name) { - printColById(setupCtx, params.pols, false, i, firstPrintValue, lastPrintValue); + printColById(setupCtx, params, false, i, firstPrintValue, lastPrintValue); if(returnValues) { hintFieldInfo.size = N; hintFieldInfo.values = new Goldilocks::Element[hintFieldInfo.size]; diff --git a/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.cpp b/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.cpp index 4451afad..c84c298a 100644 --- a/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.cpp +++ b/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.cpp @@ -246,4 +246,13 @@ void MerkleTreeBN128::merkelize() cursor = cursorNext; cursorNext = &cursor[nextN256 * arity]; } +} + +void MerkleTreeBN128::writeFile(std::string constTreeFile) { + std::ofstream fw(constTreeFile.c_str(), std::fstream::out | std::fstream::binary); + fw.write((const char *)&(width), sizeof(width)); + fw.write((const char *)&(height), sizeof(height)); + fw.write((const char *)source, width * height * sizeof(Goldilocks::Element)); + fw.write((const char *)nodes, numNodes * sizeof(RawFr::Element)); + fw.close(); } \ No newline at end of file diff --git a/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.hpp b/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.hpp index f1dc514a..0edece6d 100644 --- a/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.hpp +++ b/pil2-stark/src/starkpil/merkleTree/merkleTreeBN128.hpp @@ -49,5 +49,7 @@ class MerkleTreeBN128 void getGroupProof(RawFr::Element *proof, uint64_t idx); void merkelize(); + + void writeFile(std::string constTreeFile); }; #endif \ No newline at end of file diff --git a/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.cpp b/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.cpp index 665b8920..a12b8766 100644 --- a/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.cpp +++ b/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.cpp @@ -176,4 +176,20 @@ void MerkleTreeGL::merkelize() #else PoseidonGoldilocks::merkletree_seq(nodes, source, width, height); #endif +} + +void MerkleTreeGL::writeFile(std::string constTreeFile) +{ + ofstream fw(constTreeFile.c_str(), std::fstream::out | std::fstream::binary); + fw.write((const char *)&(width), sizeof(uint64_t)); + fw.write((const char *)&(height), sizeof(uint64_t)); + // fw.write((const char *)source, width * height * sizeof(Goldilocks::Element)); + // fw.write((const char *)nodes, numNodes * sizeof(Goldilocks::Element)); + // fw.close(); + + uint64_t sourceOffset = sizeof(uint64_t) * 2; + uint64_t nodesOffset = sourceOffset + width * height * sizeof(Goldilocks::Element); + fw.close(); + writeFileParallel(constTreeFile, source, width * height * sizeof(Goldilocks::Element), sourceOffset); + writeFileParallel(constTreeFile, nodes, numNodes * sizeof(Goldilocks::Element), nodesOffset); } \ No newline at end of file diff --git a/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.hpp b/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.hpp index f46be681..0b70b054 100644 --- a/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.hpp +++ b/pil2-stark/src/starkpil/merkleTree/merkleTreeGL.hpp @@ -49,6 +49,8 @@ class MerkleTreeGL bool verifyGroupProof(Goldilocks::Element* root, std::vector> &mp, uint64_t idx, std::vector> &v); void merkelize(); + + void writeFile(std::string file); }; #endif \ No newline at end of file diff --git a/pil2-stark/src/starkpil/proof2zkinStark.cpp b/pil2-stark/src/starkpil/proof2zkinStark.cpp index 66995416..385b4875 100644 --- a/pil2-stark/src/starkpil/proof2zkinStark.cpp +++ b/pil2-stark/src/starkpil/proof2zkinStark.cpp @@ -273,7 +273,10 @@ ordered_json challenges2zkin(json& globalInfo, Goldilocks::Element* challenges) return challengesJson; } -void *publics2zkin(ordered_json &zkin, Goldilocks::Element* publics, json& globalInfo, uint64_t airgroupId, bool isAggregated) { +ordered_json publics2zkin(ordered_json &zkin_, Goldilocks::Element* publics, json& globalInfo, uint64_t airgroupId) { + ordered_json zkin = ordered_json::object(); + zkin = zkin_; + uint64_t p = 0; zkin["sv_circuitType"] = Goldilocks::toString(publics[p++]); if(globalInfo["aggTypes"][airgroupId].size() > 0) { @@ -317,47 +320,48 @@ void *publics2zkin(ordered_json &zkin, Goldilocks::Element* publics, json& globa zkin["sv_finalPolHash"][j] = Goldilocks::toString(publics[p++]); } - if(!isAggregated) { - if(uint64_t(globalInfo["nPublics"]) > 0) { - zkin["publics"] = ordered_json::array(); - for(uint64_t i = 0; i < uint64_t(globalInfo["nPublics"]); ++i) { - zkin["publics"][i] = Goldilocks::toString(publics[p++]); - } + if(uint64_t(globalInfo["nPublics"]) > 0) { + zkin["publics"] = ordered_json::array(); + for(uint64_t i = 0; i < uint64_t(globalInfo["nPublics"]); ++i) { + zkin["publics"][i] = Goldilocks::toString(publics[p++]); } + } - zkin["challenges"] = ordered_json::array(); - - uint64_t nChallenges = 0; - for(uint64_t i = 0; i < globalInfo["numChallenges"].size(); ++i) { - nChallenges += uint64_t(globalInfo["numChallenges"][i]); - } - nChallenges += 4; - for(uint64_t i = 0; i < nChallenges; ++i) { - zkin["challenges"][i] = ordered_json::array(); - for(uint64_t k = 0; k < FIELD_EXTENSION; ++k) { - zkin["challenges"][i][k] = Goldilocks::toString(publics[p++]); - } + zkin["challenges"] = ordered_json::array(); + + uint64_t nChallenges = 0; + for(uint64_t i = 0; i < globalInfo["numChallenges"].size(); ++i) { + nChallenges += uint64_t(globalInfo["numChallenges"][i]); + } + nChallenges += 4; + for(uint64_t i = 0; i < nChallenges; ++i) { + zkin["challenges"][i] = ordered_json::array(); + for(uint64_t k = 0; k < FIELD_EXTENSION; ++k) { + zkin["challenges"][i][k] = Goldilocks::toString(publics[p++]); } + } - zkin["challengesFRISteps"] = ordered_json::array(); - for(uint64_t i = 0; i < globalInfo["stepsFRI"].size() + 1; ++i) { - zkin["challengesFRISteps"][i] = ordered_json::array(); - for(uint64_t k = 0; k < FIELD_EXTENSION; ++k) { - zkin["challengesFRISteps"][i][k] = Goldilocks::toString(publics[p++]); - } + zkin["challengesFRISteps"] = ordered_json::array(); + for(uint64_t i = 0; i < globalInfo["stepsFRI"].size() + 1; ++i) { + zkin["challengesFRISteps"][i] = ordered_json::array(); + for(uint64_t k = 0; k < FIELD_EXTENSION; ++k) { + zkin["challengesFRISteps"][i][k] = Goldilocks::toString(publics[p++]); } } + - return (void *)new ordered_json(zkin); + return zkin; } -void *addRecursive2VerKey(ordered_json &zkin, Goldilocks::Element* recursive2VerKey) { - zkin["rootCAgg"] = ordered_json::array(); +ordered_json addRecursive2VerKey(ordered_json &zkin, Goldilocks::Element* recursive2VerKey) { + ordered_json zkinUpdated = ordered_json::object(); + zkinUpdated = zkin; + zkinUpdated["rootCAgg"] = ordered_json::array(); for(uint64_t i = 0; i < 4; ++i) { - zkin["rootCAgg"][i] = Goldilocks::toString(recursive2VerKey[i]); + zkinUpdated["rootCAgg"][i] = Goldilocks::toString(recursive2VerKey[i]); } - return (void *)new ordered_json(zkin); + return zkinUpdated; } ordered_json joinzkinfinal(json& globalInfo, Goldilocks::Element* publics, Goldilocks::Element* proofValues, Goldilocks::Element* challenges, void **zkin_vec, void **starkInfo_vec) { diff --git a/pil2-stark/src/starkpil/proof2zkinStark.hpp b/pil2-stark/src/starkpil/proof2zkinStark.hpp index ba8ac920..06b24a48 100644 --- a/pil2-stark/src/starkpil/proof2zkinStark.hpp +++ b/pil2-stark/src/starkpil/proof2zkinStark.hpp @@ -15,7 +15,7 @@ ordered_json joinzkin(ordered_json &zkin1, ordered_json &zkin2, ordered_json &ve ordered_json joinzkinfinal(json& globalInfo, Goldilocks::Element* publics, Goldilocks::Element *proofValues, Goldilocks::Element* challenges, void **zkin_vec, void **starkInfo_vec); ordered_json joinzkinrecursive2(json& globalInfo, uint64_t airgroupId, Goldilocks::Element* publics, Goldilocks::Element* challenges, ordered_json &zkin1, ordered_json &zkin2, StarkInfo &starkInfo); -void *publics2zkin(ordered_json &zkin, Goldilocks::Element* publics, json& globalInfo, uint64_t airgroupId, bool isAggregated); -void *addRecursive2VerKey(ordered_json &zkin, Goldilocks::Element* recursive2VerKey); +ordered_json publics2zkin(ordered_json &zkin, Goldilocks::Element* publics, json& globalInfo, uint64_t airgroupId); +ordered_json addRecursive2VerKey(ordered_json &zkin, Goldilocks::Element* recursive2VerKey); #endif \ No newline at end of file diff --git a/pil2-stark/src/starkpil/setup_ctx.hpp b/pil2-stark/src/starkpil/setup_ctx.hpp index a646e7da..c0234537 100644 --- a/pil2-stark/src/starkpil/setup_ctx.hpp +++ b/pil2-stark/src/starkpil/setup_ctx.hpp @@ -5,16 +5,170 @@ #include "const_pols.hpp" #include "expressions_bin.hpp" +class ProverHelpers { + public: + Goldilocks::Element *zi = nullptr; + Goldilocks::Element *S = nullptr; + Goldilocks::Element *x = nullptr; + Goldilocks::Element *x_n = nullptr; // Needed for PIL1 compatibility + Goldilocks::Element *x_2ns = nullptr; // Needed for PIL1 compatibility + + ProverHelpers(StarkInfo &starkInfo) { + computeZerofier(starkInfo); + + computeX(starkInfo); + + computeConnectionsX(starkInfo); // Needed for PIL1 compatibility + } + + void computeZerofier(StarkInfo& starkInfo) { + uint64_t N = 1 << starkInfo.starkStruct.nBits; + uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; + zi = new Goldilocks::Element[starkInfo.boundaries.size() * NExtended]; + + for(uint64_t i = 0; i < starkInfo.boundaries.size(); ++i) { + Boundary boundary = starkInfo.boundaries[i]; + if(boundary.name == "everyRow") { + buildZHInv(starkInfo); + } else if(boundary.name == "firstRow") { + buildOneRowZerofierInv(starkInfo, i, 0); + } else if(boundary.name == "lastRow") { + buildOneRowZerofierInv(starkInfo, i, N); + } else if(boundary.name == "everyRow") { + buildFrameZerofierInv(starkInfo, i, boundary.offsetMin, boundary.offsetMax); + } + } + } + + void computeConnectionsX(StarkInfo& starkInfo) { + uint64_t N = 1 << starkInfo.starkStruct.nBits; + uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; + x_n = new Goldilocks::Element[N]; + Goldilocks::Element xx = Goldilocks::one(); + for (uint64_t i = 0; i < N; i++) + { + x_n[i] = xx; + Goldilocks::mul(xx, xx, Goldilocks::w(starkInfo.starkStruct.nBits)); + } + xx = Goldilocks::shift(); + x_2ns = new Goldilocks::Element[NExtended]; + for (uint64_t i = 0; i < NExtended; i++) + { + x_2ns[i] = xx; + Goldilocks::mul(xx, xx, Goldilocks::w(starkInfo.starkStruct.nBitsExt)); + } + } + + void computeX(StarkInfo& starkInfo) { + uint64_t N = 1 << starkInfo.starkStruct.nBits; + uint64_t extendBits = starkInfo.starkStruct.nBitsExt - starkInfo.starkStruct.nBits; + x = new Goldilocks::Element[N << extendBits]; + x[0] = Goldilocks::shift(); + for (uint64_t k = 1; k < (N << extendBits); k++) + { + x[k] = x[k - 1] * Goldilocks::w(starkInfo.starkStruct.nBits + extendBits); + } + + S = new Goldilocks::Element[starkInfo.qDeg]; + Goldilocks::Element shiftIn = Goldilocks::exp(Goldilocks::inv(Goldilocks::shift()), N); + S[0] = Goldilocks::one(); + for(uint64_t i = 1; i < starkInfo.qDeg; i++) { + S[i] = Goldilocks::mul(S[i - 1], shiftIn); + } + } + + void buildZHInv(StarkInfo& starkInfo) + { + uint64_t NExtended = 1 << starkInfo.starkStruct.nBitsExt; + uint64_t extendBits = starkInfo.starkStruct.nBitsExt - starkInfo.starkStruct.nBits; + uint64_t extend = (1 << extendBits); + + Goldilocks::Element w = Goldilocks::one(); + Goldilocks::Element sn = Goldilocks::shift(); + + for (uint64_t i = 0; i < starkInfo.starkStruct.nBits; i++) Goldilocks::square(sn, sn); + + for (uint64_t i=0; i::computeQ(uint64_t step, Goldilocks::Element *buffer, F #pragma omp parallel for for(uint64_t i = 0; i < N; i++) { - Goldilocks3::mul((Goldilocks3::Element &)cmQ[(i * setupCtx.starkInfo.qDeg + p) * FIELD_EXTENSION], (Goldilocks3::Element &)buffer[setupCtx.starkInfo.mapOffsets[std::make_pair("q", true)] + (p * N + i) * FIELD_EXTENSION], setupCtx.constPols.S[p]); + Goldilocks3::mul((Goldilocks3::Element &)cmQ[(i * setupCtx.starkInfo.qDeg + p) * FIELD_EXTENSION], (Goldilocks3::Element &)buffer[setupCtx.starkInfo.mapOffsets[std::make_pair("q", true)] + (p * N + i) * FIELD_EXTENSION], setupCtx.proverHelpers.S[p]); } } @@ -128,10 +128,10 @@ void Starks::computeLEv(Goldilocks::Element *xiChallenge, Goldilock template -void Starks::computeEvals(Goldilocks::Element *buffer, Goldilocks::Element *LEv, Goldilocks::Element *evals, FRIProof &proof) +void Starks::computeEvals(StepsParams ¶ms, Goldilocks::Element *LEv, FRIProof &proof) { - evmap(buffer, evals, LEv); - proof.proof.setEvals(evals); + evmap(params, LEv); + proof.proof.setEvals(params.evals); } template @@ -162,7 +162,7 @@ void Starks::calculateXDivXSub(Goldilocks::Element *xiChallenge, Go #pragma omp parallel for for (uint64_t k = 0; k < NExtended; k++) { - Goldilocks3::sub((Goldilocks3::Element &)(xDivXSub[(k + i * NExtended) * FIELD_EXTENSION]), setupCtx.constPols.x[k], (Goldilocks3::Element &)(xis[i * FIELD_EXTENSION])); + Goldilocks3::sub((Goldilocks3::Element &)(xDivXSub[(k + i * NExtended) * FIELD_EXTENSION]), setupCtx.proverHelpers.x[k], (Goldilocks3::Element &)(xis[i * FIELD_EXTENSION])); } } @@ -174,13 +174,13 @@ void Starks::calculateXDivXSub(Goldilocks::Element *xiChallenge, Go #pragma omp parallel for for (uint64_t k = 0; k < NExtended; k++) { - Goldilocks3::mul((Goldilocks3::Element &)(xDivXSub[(k + i * NExtended) * FIELD_EXTENSION]), (Goldilocks3::Element &)(xDivXSub[(k + i * NExtended) * FIELD_EXTENSION]), setupCtx.constPols.x[k]); + Goldilocks3::mul((Goldilocks3::Element &)(xDivXSub[(k + i * NExtended) * FIELD_EXTENSION]), (Goldilocks3::Element &)(xDivXSub[(k + i * NExtended) * FIELD_EXTENSION]), setupCtx.proverHelpers.x[k]); } } } template -void Starks::evmap(Goldilocks::Element *buffer, Goldilocks::Element *evals, Goldilocks::Element *LEv) +void Starks::evmap(StepsParams& params, Goldilocks::Element *LEv) { uint64_t extendBits = setupCtx.starkInfo.starkStruct.nBitsExt - setupCtx.starkInfo.starkStruct.nBits; u_int64_t size_eval = setupCtx.starkInfo.evMap.size(); @@ -189,7 +189,7 @@ void Starks::evmap(Goldilocks::Element *buffer, Goldilocks::Element int num_threads = omp_get_max_threads(); int size_thread = size_eval * FIELD_EXTENSION; - Goldilocks::Element *evals_acc = &buffer[setupCtx.starkInfo.mapOffsets[std::make_pair("evals", true)]]; + Goldilocks::Element *evals_acc = ¶ms.pols[setupCtx.starkInfo.mapOffsets[std::make_pair("evals", true)]]; memset(&evals_acc[0], 0, num_threads * size_thread * sizeof(Goldilocks::Element)); Polinomial *ordPols = new Polinomial[size_eval]; @@ -198,7 +198,7 @@ void Starks::evmap(Goldilocks::Element *buffer, Goldilocks::Element { EvMap ev = setupCtx.starkInfo.evMap[i]; bool committed = ev.type == EvMap::eType::cm ? true : false; - Goldilocks::Element *pols = committed ? buffer : setupCtx.constPols.pConstPolsAddressExtended; + Goldilocks::Element *pols = committed ? params.pols : ¶ms.pConstPolsExtendedTreeAddress[2]; setupCtx.starkInfo.getPolynomial(ordPols[i], pols, committed, ev.id, true); } @@ -237,19 +237,12 @@ void Starks::evmap(Goldilocks::Element *buffer, Goldilocks::Element { Goldilocks3::add(sum, sum, (Goldilocks3::Element &)(evals_acc[k * size_thread + i * FIELD_EXTENSION])); } - std::memcpy((Goldilocks3::Element &)(evals[i * FIELD_EXTENSION]), sum, FIELD_EXTENSION * sizeof(Goldilocks::Element)); + std::memcpy((Goldilocks3::Element &)(params.evals[i * FIELD_EXTENSION]), sum, FIELD_EXTENSION * sizeof(Goldilocks::Element)); } } delete[] ordPols; } -template -void Starks::setConstTree(ConstPols &constPols) -{ - setupCtx.constPols = constPols; - treesGL[setupCtx.starkInfo.nStages + 1] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, (Goldilocks::Element *)setupCtx.constPols.pConstTreeAddress); -} - template void Starks::getChallenge(TranscriptType &transcript, Goldilocks::Element &challenge) { @@ -315,11 +308,6 @@ void Starks::calculateQuotientPolynomial(StepsParams ¶ms) { #else ExpressionsPack expressionsCtx(setupCtx); #endif - if(setupCtx.constPols.pConstTreeAddress == nullptr) { - zklog.error("Const tree is not set"); - exitProcess(); - exit(-1); - } expressionsCtx.calculateExpression(params, ¶ms.pols[setupCtx.starkInfo.mapOffsets[std::make_pair("q", true)]], setupCtx.starkInfo.cExpId); } diff --git a/pil2-stark/src/starkpil/starks.hpp b/pil2-stark/src/starkpil/starks.hpp index 83738756..6680a694 100644 --- a/pil2-stark/src/starkpil/starks.hpp +++ b/pil2-stark/src/starkpil/starks.hpp @@ -32,10 +32,10 @@ class Starks MerkleTreeType **treesFRI; public: - Starks(SetupCtx& setupCtx_) : setupCtx(setupCtx_) + Starks(SetupCtx& setupCtx_, Goldilocks::Element *pConstPolsExtendedTreeAddress) : setupCtx(setupCtx_) { treesGL = new MerkleTreeType*[setupCtx.starkInfo.nStages + 2]; - if(setupCtx.constPols.pConstTreeAddress != nullptr) treesGL[setupCtx.starkInfo.nStages + 1] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, (Goldilocks::Element *)setupCtx.constPols.pConstTreeAddress); + treesGL[setupCtx.starkInfo.nStages + 1] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, pConstPolsExtendedTreeAddress); for (uint64_t i = 0; i < setupCtx.starkInfo.nStages + 1; i++) { std::string section = "cm" + to_string(i + 1); @@ -77,7 +77,7 @@ class Starks void calculateFRIPolynomial(StepsParams& params); void computeLEv(Goldilocks::Element *xiChallenge, Goldilocks::Element *LEv); - void computeEvals(Goldilocks::Element *buffer, Goldilocks::Element *LEv, Goldilocks::Element *evals, FRIProof &proof); + void computeEvals(StepsParams ¶ms, Goldilocks::Element *LEv, FRIProof &proof); void calculateXDivXSub(Goldilocks::Element *xiChallenge, Goldilocks::Element *xDivXSub); @@ -87,12 +87,10 @@ class Starks void addTranscript(TranscriptType &transcript, ElementType* buffer, uint64_t nElements); void getChallenge(TranscriptType &transcript, Goldilocks::Element& challenge); - void setConstTree(ConstPols &constPols); - // Following function are created to be used by the ffi interface void ffi_treesGL_get_root(uint64_t index, ElementType *dst); - void evmap(Goldilocks::Element *buffer, Goldilocks::Element *evals, Goldilocks::Element *LEv); + void evmap(StepsParams& params, Goldilocks::Element *LEv); }; template class Starks; diff --git a/pil2-stark/src/starkpil/steps.hpp b/pil2-stark/src/starkpil/steps.hpp index a172aaa1..c91a0756 100644 --- a/pil2-stark/src/starkpil/steps.hpp +++ b/pil2-stark/src/starkpil/steps.hpp @@ -12,6 +12,8 @@ struct StepsParams Goldilocks::Element *airValues; Goldilocks::Element *evals; Goldilocks::Element *xDivXSub; + Goldilocks::Element *pConstPolsAddress; + Goldilocks::Element *pConstPolsExtendedTreeAddress; }; #endif \ No newline at end of file diff --git a/pil2-stark/src/utils/utils.cpp b/pil2-stark/src/utils/utils.cpp index f2e9ee93..f572e8f1 100644 --- a/pil2-stark/src/utils/utils.cpp +++ b/pil2-stark/src/utils/utils.cpp @@ -321,6 +321,44 @@ void *copyFile(const string &fileName, uint64_t size) return mapFileInternal(fileName, size, false, false); } +void loadFileParallel(void* buffer, const string &fileName, uint64_t size) { + + // Check file size + struct stat sb; + if (lstat(fileName.c_str(), &sb) == -1) { + zklog.error("loadFileParallel() failed calling lstat() of file " + fileName); + exitProcess(); + } + if ((uint64_t)sb.st_size != size) { + zklog.error("loadFileParallel() found size of file " + fileName + " to be " + to_string(sb.st_size) + " B instead of " + to_string(size) + " B"); + exitProcess(); + } + + // Determine the number of chunks and the size of each chunk + size_t numChunks = 8; //omp_get_max_threads()/2; + if(numChunks == 0 ) numChunks = 1; + size_t chunkSize = size / numChunks; + size_t remainder = size - numChunks*chunkSize; + + #pragma omp parallel for num_threads(numChunks) + for(size_t i=0; i( pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, provers: Vec>>, mut witness_lib: Box>, -) { +) -> Result<(), Box> { const MY_NAME: &str = "CstrVrfy"; log::info!("{}: --> Checking constraints", MY_NAME); @@ -24,7 +24,7 @@ pub fn verify_constraints_proof( let mut constraints = Vec::new(); for prover in provers.iter() { - let constraints_prover_info = prover.verify_constraints(pctx.clone()); + let constraints_prover_info = prover.verify_constraints(sctx.clone(), pctx.clone()); constraints.push(constraints_prover_info); } @@ -129,7 +129,13 @@ pub fn verify_constraints_proof( if valid_constraints && global_constraints_verified { log::info!("{}: ··· {}", MY_NAME, "\u{2713} All constraints were verified".bright_green().bold()); + Ok(()) } else { log::info!("{}: ··· {}", MY_NAME, "\u{2717} Not all constraints were verified.".bright_red().bold()); + Err(Box::new(std::io::Error::new( + // <-- Return a boxed error + std::io::ErrorKind::Other, + format!("{}: Not all constraints were verified.", MY_NAME), + ))) } } diff --git a/proofman/src/global_constraints.rs b/proofman/src/global_constraints.rs index 97f7d918..955eedd7 100644 --- a/proofman/src/global_constraints.rs +++ b/proofman/src/global_constraints.rs @@ -58,7 +58,7 @@ pub fn aggregate_airgroupvals(pctx: Arc>) -> Vec> { airgroupvalues } -pub fn verify_global_constraints_proof(pctx: Arc>, sctx: Arc) -> bool { +pub fn verify_global_constraints_proof(pctx: Arc>, sctx: Arc>) -> bool { const MY_NAME: &str = "GlCstVfy"; log::info!("{}: --> Checking global constraints", MY_NAME); @@ -104,7 +104,7 @@ pub fn verify_global_constraints_proof(pctx: Arc>, sctx: A pub fn get_hint_field_gc( pctx: Arc>, - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -146,7 +146,7 @@ pub fn get_hint_field_gc( pub fn get_hint_field_gc_a( pctx: Arc>, - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -194,7 +194,7 @@ pub fn get_hint_field_gc_a( pub fn get_hint_field_gc_m( pctx: Arc>, - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, print_expression: bool, @@ -247,7 +247,7 @@ pub fn get_hint_field_gc_m( pub fn set_hint_field( pctx: Arc>, - sctx: Arc, + sctx: Arc>, hint_id: u64, hint_field_name: &str, value: HintFieldOutput, diff --git a/proofman/src/proofman.rs b/proofman/src/proofman.rs index 87358048..dd40d875 100644 --- a/proofman/src/proofman.rs +++ b/proofman/src/proofman.rs @@ -18,11 +18,11 @@ use crate::{WitnessLibInitFn, WitnessLibrary}; use crate::verify_constraints_proof; use crate::generate_recursion_proof; -use proofman_common::{ExecutionCtx, ProofCtx, ProofOptions, ProofType, Prover, SetupCtx}; +use proofman_common::{ExecutionCtx, ProofCtx, ProofOptions, ProofType, Prover, SetupCtx, SetupsVadcop}; use std::os::raw::c_void; -use proofman_util::{timer_start_info, timer_start_debug, timer_stop_and_log_info, timer_stop_and_log_debug}; +use proofman_util::{timer_start_debug, timer_start_info, timer_stop_and_log_debug, timer_stop_and_log_info}; pub struct ProofMan { _phantom: std::marker::PhantomData, @@ -53,7 +53,7 @@ impl ProofMan { )?; let buffer_allocator: Arc = Arc::new(StarkBufferAllocator::new(proving_key_path.clone())); let ectx = ExecutionCtx::builder() - .with_rom_path(rom_path) + .with_rom_path(rom_path.clone()) .with_buffer_allocator(buffer_allocator) .with_verbose_mode(options.verbose_mode) .build(); @@ -63,11 +63,12 @@ impl ProofMan { let witness_lib: Symbol> = unsafe { library.get(b"init_library")? }; - let mut witness_lib = witness_lib(ectx.clone(), public_inputs_path)?; + let mut witness_lib = witness_lib(rom_path, public_inputs_path)?; let pctx = Arc::new(ProofCtx::create_ctx(witness_lib.pilout(), proving_key_path.clone())); - let sctx: Arc = Arc::new(SetupCtx::new(&pctx.global_info, &ProofType::Basic)); + let setups = Arc::new(SetupsVadcop::new(&pctx.global_info, options.aggregation)); + let sctx: Arc> = setups.sctx.clone(); Self::initialize_witness(&mut witness_lib, pctx.clone(), ectx.clone(), sctx.clone()); witness_lib.calculate_witness(1, pctx.clone(), ectx.clone(), sctx.clone()); @@ -80,6 +81,8 @@ impl ProofMan { Self::print_summary(pctx.clone()); } + Self::initialize_setup(setups.clone(), pctx.clone(), ectx.clone(), options.aggregation); + let mut provers: Vec>> = Vec::new(); Self::initialize_provers(sctx.clone(), &mut provers, pctx.clone(), ectx.clone()); @@ -97,7 +100,7 @@ impl ProofMan { witness_lib.calculate_witness(stage, pctx.clone(), ectx.clone(), sctx.clone()); } - Self::calculate_stage(stage, &mut provers, pctx.clone()); + Self::calculate_stage(stage, &mut provers, sctx.clone(), pctx.clone()); if !options.verify_constraints { Self::commit_stage(stage, &mut provers, pctx.clone()); @@ -127,13 +130,12 @@ impl ProofMan { } if options.verify_constraints { - verify_constraints_proof(pctx.clone(), ectx.clone(), sctx.clone(), provers, witness_lib); - return Ok(()); + return verify_constraints_proof(pctx.clone(), ectx.clone(), sctx.clone(), provers, witness_lib); } // Compute Quotient polynomial Self::get_challenges(num_commit_stages + 1, &mut provers, pctx.clone(), &transcript); - Self::calculate_stage(num_commit_stages + 1, &mut provers, pctx.clone()); + Self::calculate_stage(num_commit_stages + 1, &mut provers, sctx.clone(), pctx.clone()); Self::commit_stage(num_commit_stages + 1, &mut provers, pctx.clone()); Self::calculate_challenges( num_commit_stages + 1, @@ -145,7 +147,7 @@ impl ProofMan { ); // Compute openings - Self::opening_stages(&mut provers, pctx.clone(), ectx.clone(), &mut transcript); + Self::opening_stages(&mut provers, pctx.clone(), sctx.clone(), ectx.clone(), &mut transcript); //Generate proves_out let proves_out = Self::finalize_proof(&mut provers, pctx.clone(), output_dir_path.to_string_lossy().as_ref()); @@ -163,6 +165,7 @@ impl ProofMan { let comp_proofs = generate_recursion_proof( &pctx, &ectx, + setups.sctx_compressor.as_ref().unwrap().clone(), &proves_out, &ProofType::Compressor, output_dir_path.clone(), @@ -175,6 +178,7 @@ impl ProofMan { let recursive1_proofs = generate_recursion_proof( &pctx, &ectx, + setups.sctx_recursive1.as_ref().unwrap().clone(), &comp_proofs, &ProofType::Recursive1, output_dir_path.clone(), @@ -185,9 +189,11 @@ impl ProofMan { ectx.dctx.read().unwrap().barrier(); timer_start_info!(GENERATING_RECURSIVE2_PROOFS); + let sctx_recursive2 = setups.sctx_recursive2.clone(); let recursive2_proofs = generate_recursion_proof( &pctx, &ectx, + sctx_recursive2.as_ref().unwrap().clone(), &recursive1_proofs, &ProofType::Recursive2, output_dir_path.clone(), @@ -202,6 +208,7 @@ impl ProofMan { let _final_proof = generate_recursion_proof( &pctx, &ectx, + setups.sctx_final.as_ref().unwrap().clone(), &recursive2_proofs, &ProofType::Final, output_dir_path.clone(), @@ -220,8 +227,8 @@ impl ProofMan { fn initialize_witness( witness_lib: &mut Box>, pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, ) { timer_start_debug!(INITIALIZE_WITNESS); witness_lib.start_proof(pctx.clone(), ectx.clone(), sctx.clone()); @@ -261,10 +268,10 @@ impl ProofMan { } fn initialize_provers( - sctx: Arc, + sctx: Arc>, provers: &mut Vec>>, pctx: Arc>, - _ectx: Arc, + _ectx: Arc>, ) { timer_start_debug!(INITIALIZE_PROVERS); info!("{}: Initializing provers", Self::MY_NAME); @@ -286,6 +293,7 @@ impl ProofMan { for prover in provers.iter_mut() { prover.build(pctx.clone()); } + let mut buff_helper_size = 0_usize; for prover in provers.iter_mut() { let buff_helper_prover_size = prover.get_buff_helper_size(); @@ -300,19 +308,104 @@ impl ProofMan { timer_stop_and_log_debug!(INITIALIZE_PROVERS); } - pub fn calculate_stage(stage: u32, provers: &mut [Box>], proof_ctx: Arc>) { + fn initialize_setup( + setups: Arc>, + pctx: Arc>, + ectx: Arc>, + aggregation: bool, + ) { + info!("{}: Initializing setup fixed pols", Self::MY_NAME); + timer_start_debug!(INITIALIZE_SETUP); + timer_start_debug!(INITIALIZE_CONST_POLS); + + let mut const_pols_calculated: HashMap<(usize, usize), bool> = HashMap::new(); + + let dctx = ectx.dctx.read().unwrap(); + + for id in &dctx.my_instances { + let (airgroup_id, air_id) = dctx.instances[*id]; + const_pols_calculated.entry((airgroup_id, air_id)).or_insert_with(|| { + let setup = setups.sctx.get_setup(airgroup_id, air_id); + setup.load_const_pols(&pctx.global_info, &ProofType::Basic); + setup.load_const_pols_tree(&pctx.global_info, &ProofType::Basic, false); + true + }); + } + + timer_stop_and_log_debug!(INITIALIZE_CONST_POLS); + + if aggregation { + timer_start_debug!(INITIALIZE_CONST_POLS_COMPRESSOR); + let mut const_pols_calculated_compressor: HashMap<(usize, usize), bool> = HashMap::new(); + + let sctx_compressor = setups.sctx_compressor.as_ref().unwrap().clone(); + let sctx_recursive1 = setups.sctx_recursive1.as_ref().unwrap().clone(); + let sctx_recursive2 = setups.sctx_recursive2.as_ref().unwrap().clone(); + let sctx_final = setups.sctx_final.as_ref().unwrap().clone(); + + for id in &dctx.my_instances { + let (airgroup_id, air_id) = dctx.instances[*id]; + if pctx.global_info.get_air_has_compressor(airgroup_id, air_id) + && !const_pols_calculated_compressor.contains_key(&(airgroup_id, air_id)) + { + let setup = sctx_compressor.get_setup(airgroup_id, air_id); + setup.load_const_pols(&pctx.global_info, &ProofType::Compressor); + setup.load_const_pols_tree(&pctx.global_info, &ProofType::Compressor, false); + const_pols_calculated_compressor.insert((airgroup_id, air_id), true); + } + } + timer_stop_and_log_debug!(INITIALIZE_CONST_POLS_COMPRESSOR); + + timer_start_debug!(INITIALIZE_CONST_POLS_RECURSIVE1); + let mut const_pols_calculated_recursive1: HashMap<(usize, usize), bool> = HashMap::new(); + for id in &dctx.my_instances { + let (airgroup_id, air_id) = dctx.instances[*id]; + const_pols_calculated_recursive1.entry((airgroup_id, air_id)).or_insert_with(|| { + let setup = sctx_recursive1.get_setup(airgroup_id, air_id); + setup.load_const_pols(&pctx.global_info, &ProofType::Recursive1); + setup.load_const_pols_tree(&pctx.global_info, &ProofType::Recursive1, false); + true + }); + } + timer_stop_and_log_debug!(INITIALIZE_CONST_POLS_RECURSIVE1); + + timer_start_debug!(INITIALIZE_CONST_POLS_RECURSIVE2); + for (idx, group_instances) in dctx.airgroup_instances.iter().enumerate() { + if !group_instances.is_empty() { + let setup = sctx_recursive2.get_setup(idx, 0); + setup.load_const_pols(&pctx.global_info, &ProofType::Recursive2); + setup.load_const_pols_tree(&pctx.global_info, &ProofType::Recursive2, false); + } + } + timer_stop_and_log_debug!(INITIALIZE_CONST_POLS_RECURSIVE2); + + timer_start_debug!(INITIALIZE_CONST_POLS_FINAL); + let setup = sctx_final.get_setup(0, 0); + setup.load_const_pols(&pctx.global_info, &ProofType::Final); + setup.load_const_pols_tree(&pctx.global_info, &ProofType::Final, false); + timer_stop_and_log_debug!(INITIALIZE_CONST_POLS_FINAL); + } + timer_stop_and_log_debug!(INITIALIZE_SETUP); + } + + pub fn calculate_stage( + stage: u32, + provers: &mut [Box>], + setup_ctx: Arc>, + proof_ctx: Arc>, + ) { if stage as usize == proof_ctx.global_info.n_challenges.len() + 1 { info!("{}: Calculating Quotient Polynomials", Self::MY_NAME); timer_start_debug!(CALCULATING_QUOTIENT_POLYNOMIAL); for prover in provers.iter_mut() { - prover.calculate_stage(stage, proof_ctx.clone()); + prover.calculate_stage(stage, setup_ctx.clone(), proof_ctx.clone()); } timer_stop_and_log_debug!(CALCULATING_QUOTIENT_POLYNOMIAL); } else { info!("{}: Calculating stage {}", Self::MY_NAME, stage); timer_start_debug!(CALCULATING_STAGE); for prover in provers.iter_mut() { - prover.calculate_stage(stage, proof_ctx.clone()); + prover.calculate_stage(stage, setup_ctx.clone(), proof_ctx.clone()); } timer_stop_and_log_debug!(CALCULATING_STAGE); } @@ -371,7 +464,7 @@ impl ProofMan { stage: u32, provers: &mut [Box>], pctx: Arc>, - ectx: Arc, + ectx: Arc>, transcript: &mut FFITranscript, verify_constraints: bool, ) { @@ -434,7 +527,8 @@ impl ProofMan { pub fn opening_stages( provers: &mut [Box>], pctx: Arc>, - ectx: Arc, + sctx: Arc>, + ectx: Arc>, transcript: &mut FFITranscript, ) { let num_commit_stages = pctx.global_info.n_challenges.len() as u32; @@ -446,7 +540,7 @@ impl ProofMan { for group_idx in dctx.my_air_groups.iter() { provers[group_idx[0]].calculate_lev(pctx.clone()); for idx in group_idx.iter() { - provers[*idx].opening_stage(1, pctx.clone()); + provers[*idx].opening_stage(1, sctx.clone(), pctx.clone()); } } timer_stop_and_log_debug!(CALCULATING_EVALS); @@ -459,7 +553,7 @@ impl ProofMan { for group_idx in dctx.my_air_groups.iter() { provers[group_idx[0]].calculate_xdivxsub(pctx.clone()); for idx in group_idx.iter() { - provers[*idx].opening_stage(2, pctx.clone()); + provers[*idx].opening_stage(2, sctx.clone(), pctx.clone()); } } timer_stop_and_log_debug!(CALCULATING_FRI_POLINOMIAL); @@ -494,7 +588,7 @@ impl ProofMan { ); } for prover in provers.iter_mut() { - prover.opening_stage(opening_id + 3, pctx.clone()); + prover.opening_stage(opening_id + 3, sctx.clone(), pctx.clone()); } if opening_id < num_opening_stages { Self::calculate_challenges( @@ -520,6 +614,7 @@ impl ProofMan { let mut proves = Vec::new(); for prover in provers.iter_mut() { proves.push(prover.get_zkin_proof(proof_ctx.clone(), output_dir)); + prover.free(); } let public_inputs_guard = proof_ctx.public_inputs.inputs.read().unwrap(); let challenges_guard = proof_ctx.challenges.challenges.read().unwrap(); diff --git a/proofman/src/recursion.rs b/proofman/src/recursion.rs index 3a8d5a2c..747f933a 100644 --- a/proofman/src/recursion.rs +++ b/proofman/src/recursion.rs @@ -1,6 +1,7 @@ use libloading::{Library, Symbol}; use p3_field::Field; use std::ffi::CString; +use std::sync::Arc; use proofman_starks_lib_c::*; use std::mem::MaybeUninit; use std::path::{Path, PathBuf}; @@ -26,7 +27,8 @@ type GenWitnessResult = Result<(Vec>, Vec>), Bo pub fn generate_recursion_proof( pctx: &ProofCtx, - ectx: &ExecutionCtx, + ectx: &ExecutionCtx, + sctx: Arc>, proofs: &[*mut c_void], proof_type: &ProofType, output_dir_path: PathBuf, @@ -40,9 +42,7 @@ pub fn generate_recursion_proof( let global_info_path = pctx.global_info.get_proving_key_path().join("pilout.globalInfo.json"); let global_info_file: &str = global_info_path.to_str().unwrap(); - let sctx = SetupCtx::new(&pctx.global_info, proof_type); - - // Run proofs + // Run proves match *proof_type { ProofType::Compressor | ProofType::Recursive1 => { for (prover_idx, air_instance) in @@ -55,22 +55,21 @@ pub fn generate_recursion_proof( } else { let air_instance_name = &pctx.global_info.airs[air_instance.airgroup_id][air_instance.air_id].name; - timer_start_trace!(GET_SETUP); - let setup = sctx.get_setup(air_instance.airgroup_id, air_instance.air_id).expect("Setup not found"); + let setup = sctx.get_setup(air_instance.airgroup_id, air_instance.air_id); let p_setup: *mut c_void = (&setup.p_setup).into(); let p_stark_info: *mut c_void = setup.p_setup.p_stark_info; - timer_stop_and_log_trace!(GET_SETUP); - let mut zkin = proofs[prover_idx]; - if *proof_type == ProofType::Recursive1 { + let zkin = if *proof_type == ProofType::Recursive1 { let recursive2_verkey = pctx .global_info .get_air_setup_path(air_instance.airgroup_id, air_instance.air_id, &ProofType::Recursive2) .display() .to_string() + ".verkey.json"; - zkin = add_recursive2_verkey_c(zkin, recursive2_verkey.as_str()); - } + add_recursive2_verkey_c(proofs[prover_idx], recursive2_verkey.as_str()) + } else { + proofs[prover_idx] + }; let (buffer, publics) = generate_witness( pctx, @@ -80,6 +79,7 @@ pub fn generate_recursion_proof( zkin, proof_type, )?; + let p_publics = publics.as_ptr() as *mut c_void; let p_address = buffer.as_ptr() as *mut c_void; @@ -110,12 +110,23 @@ pub fn generate_recursion_proof( false => String::from(""), }; - let mut p_prove = gen_recursive_proof_c(p_setup, p_address, p_publics, &proof_file); - p_prove = - publics2zkin_c(p_prove, p_publics, global_info_file, air_instance.airgroup_id as u64, false); + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + + let p_prove = gen_recursive_proof_c( + p_setup, + p_address, + const_pols_ptr, + const_tree_ptr, + p_publics, + &proof_file, + global_info_file, + air_instance.airgroup_id as u64, + ); proofs_out.push(p_prove); drop(buffer); + drop(publics); log::info!("{}: ··· Proof generated.", MY_NAME); timer_stop_and_log_trace!(GENERATE_PROOF); @@ -172,11 +183,9 @@ pub fn generate_recursion_proof( if airgroup_proofs[airgroup][j + 1].is_none() { panic!("Recursive2 proof is missing"); } - timer_start_trace!(GET_RECURSIVE2_SETUP); - let setup = sctx.get_setup(airgroup, 0).expect("Setup not found"); + let setup = sctx.get_setup(airgroup, 0); let p_setup: *mut c_void = (&setup.p_setup).into(); let p_stark_info: *mut c_void = setup.p_setup.p_stark_info; - timer_stop_and_log_trace!(GET_RECURSIVE2_SETUP); let public_inputs_guard = pctx.public_inputs.inputs.read().unwrap(); let challenges_guard = pctx.challenges.challenges.read().unwrap(); @@ -184,7 +193,7 @@ pub fn generate_recursion_proof( let public_inputs = (*public_inputs_guard).as_ptr() as *mut c_void; let challenges = (*challenges_guard).as_ptr() as *mut c_void; - let mut zkin_recursive2 = join_zkin_recursive2_c( + let zkin_recursive2 = join_zkin_recursive2_c( airgroup as u64, public_inputs, challenges, @@ -200,10 +209,17 @@ pub fn generate_recursion_proof( .display() .to_string() + ".verkey.json"; - zkin_recursive2 = add_recursive2_verkey_c(zkin_recursive2, recursive2_verkey.as_str()); + let zkin_recursive2_updated = + add_recursive2_verkey_c(zkin_recursive2, recursive2_verkey.as_str()); - let (buffer, publics) = - generate_witness(pctx, airgroup, 0, p_stark_info, zkin_recursive2, proof_type)?; + let (buffer, publics) = generate_witness( + pctx, + airgroup, + 0, + p_stark_info, + zkin_recursive2_updated, + proof_type, + )?; let p_publics = publics.as_ptr() as *mut c_void; let p_address = buffer.as_ptr() as *mut c_void; @@ -228,17 +244,26 @@ pub fn generate_recursion_proof( MY_NAME, format!("··· Generating recursive2 proof for instances of {}", air_instance_name) ); - - airgroup_proofs[airgroup][j] = - Some(gen_recursive_proof_c(p_setup, p_address, p_publics, &proof_file)); - airgroup_proofs[airgroup][j] = Some(publics2zkin_c( - airgroup_proofs[airgroup][j].unwrap(), + let const_pols_ptr = + (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = + (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + + let zkin = gen_recursive_proof_c( + p_setup, + p_address, + const_pols_ptr, + const_tree_ptr, p_publics, + &proof_file, global_info_file, airgroup as u64, - false, - )); + ); + + airgroup_proofs[airgroup][j] = Some(zkin); + drop(buffer); + drop(publics); timer_stop_and_log_trace!(GENERATE_RECURSIVE2_PROOF); log::info!("{}: ··· Proof generated.", MY_NAME); } @@ -273,7 +298,7 @@ pub fn generate_recursion_proof( let mut stark_infos_recursive2 = Vec::new(); for (idx, _) in pctx.global_info.air_groups.iter().enumerate() { - stark_infos_recursive2.push(sctx.get_setup(idx, 0).unwrap().p_setup.p_stark_info); + stark_infos_recursive2.push(sctx.get_setup(idx, 0).p_setup.p_stark_info); } let proofs_recursive2_ptr = proofs_recursive2.as_mut_ptr(); @@ -293,11 +318,9 @@ pub fn generate_recursion_proof( } } ProofType::Final => { - timer_start_trace!(GET_FINAL_SETUP); - let setup = sctx.get_setup(0, 0).expect("Setup not found"); + let setup = sctx.get_setup(0, 0); let p_setup: *mut c_void = (&setup.p_setup).into(); let p_stark_info: *mut c_void = setup.p_setup.p_stark_info; - timer_stop_and_log_trace!(GET_FINAL_SETUP); let (buffer, publics) = generate_witness(pctx, 0, 0, p_stark_info, proofs[0], proof_type)?; let p_address = buffer.as_ptr() as *mut c_void; @@ -306,11 +329,17 @@ pub fn generate_recursion_proof( log::info!("{}: ··· Generating final proof", MY_NAME); timer_start_trace!(GENERATE_PROOF); // prove + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; let _p_prove = gen_recursive_proof_c( p_setup, p_address, + const_pols_ptr, + const_tree_ptr, p_publics, output_dir_path.join("proofs/final_proof.json").to_string_lossy().as_ref(), + global_info_file, + 0, ); log::info!("{}: ··· Proof generated.", MY_NAME); drop(buffer); diff --git a/proofman/src/witness_component.rs b/proofman/src/witness_component.rs index 97694446..76a5e1c0 100644 --- a/proofman/src/witness_component.rs +++ b/proofman/src/witness_component.rs @@ -3,15 +3,15 @@ use std::sync::Arc; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; pub trait WitnessComponent: Send + Sync { - fn start_proof(&self, _pctx: Arc>, _ectx: Arc, _sctx: Arc) {} + fn start_proof(&self, _pctx: Arc>, _ectx: Arc>, _sctx: Arc>) {} fn calculate_witness( &self, _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, + _ectx: Arc>, + _sctx: Arc>, ) { } diff --git a/proofman/src/witness_executor.rs b/proofman/src/witness_executor.rs index 7ac119dc..8ae748ac 100644 --- a/proofman/src/witness_executor.rs +++ b/proofman/src/witness_executor.rs @@ -1,5 +1,5 @@ use proofman_common::{ExecutionCtx, ProofCtx}; pub trait WitnessExecutor { - fn execute(&self, pctx: &mut ProofCtx, ectx: &mut ExecutionCtx); + fn execute(&self, pctx: &mut ProofCtx, ectx: &mut ExecutionCtx); } diff --git a/proofman/src/witness_library.rs b/proofman/src/witness_library.rs index 0a9e33f0..befb67bb 100644 --- a/proofman/src/witness_library.rs +++ b/proofman/src/witness_library.rs @@ -4,18 +4,24 @@ use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; /// This is the type of the function that is used to load a witness library. pub type WitnessLibInitFn = - fn(ectx: Arc, Option) -> Result>, Box>; + fn(Option, Option) -> Result>, Box>; pub trait WitnessLibrary { - fn start_proof(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc); + fn start_proof(&mut self, pctx: Arc>, ectx: Arc>, sctx: Arc>); fn end_proof(&mut self); - fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc); + fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>); - fn calculate_witness(&mut self, stage: u32, pctx: Arc>, ectx: Arc, sctx: Arc); + fn calculate_witness( + &mut self, + stage: u32, + pctx: Arc>, + ectx: Arc>, + sctx: Arc>, + ); - fn debug(&mut self, _pctx: Arc>, _ectx: Arc, _sctx: Arc) {} + fn debug(&mut self, _pctx: Arc>, _ectx: Arc>, _sctx: Arc>) {} fn pilout(&self) -> WitnessPilout; } diff --git a/proofman/src/witness_manager.rs b/proofman/src/witness_manager.rs index f28ccf81..5f921b04 100644 --- a/proofman/src/witness_manager.rs +++ b/proofman/src/witness_manager.rs @@ -15,14 +15,14 @@ pub struct WitnessManager { airs: RwLock>, // First usize is the air_id, second usize is the index of the component in the components vector pctx: Arc>, - ectx: Arc, - sctx: Arc, + ectx: Arc>, + sctx: Arc>, } impl WitnessManager { const MY_NAME: &'static str = "WCMnager"; - pub fn new(pctx: Arc>, ectx: Arc, sctx: Arc) -> Self { + pub fn new(pctx: Arc>, ectx: Arc>, sctx: Arc>) -> Self { WitnessManager { components: RwLock::new(Vec::new()), airs: RwLock::new(HashMap::new()), pctx, ectx, sctx } } @@ -55,7 +55,7 @@ impl WitnessManager { self.airs.write().unwrap().insert((airgroup_id, air_id), component_idx); } - pub fn start_proof(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn start_proof(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { for component in self.components.read().unwrap().iter() { component.start_proof(pctx.clone(), ectx.clone(), sctx.clone()); } @@ -67,7 +67,13 @@ impl WitnessManager { } } - pub fn calculate_witness(&self, stage: u32, pctx: Arc>, ectx: Arc, sctx: Arc) { + pub fn calculate_witness( + &self, + stage: u32, + pctx: Arc>, + ectx: Arc>, + sctx: Arc>, + ) { log::info!( "{}: Calculating witness for stage {} / {}", Self::MY_NAME, @@ -116,11 +122,11 @@ impl WitnessManager { self.pctx.clone() } - pub fn get_ectx(&self) -> Arc { + pub fn get_ectx(&self) -> Arc> { self.ectx.clone() } - pub fn get_sctx(&self) -> Arc { + pub fn get_sctx(&self) -> Arc> { self.sctx.clone() } } diff --git a/provers/stark/src/stark_prover.rs b/provers/stark/src/stark_prover.rs index 67c8effd..8626448d 100644 --- a/provers/stark/src/stark_prover.rs +++ b/provers/stark/src/stark_prover.rs @@ -7,10 +7,9 @@ use std::path::PathBuf; use std::any::type_name; use std::sync::Arc; -use proofman_common::StepsParams; use proofman_common::{ BufferAllocator, ConstraintInfo, ConstraintsResults, ProofCtx, ProofType, Prover, ProverInfo, ProverStatus, - SetupCtx, + StepsParams, SetupCtx, }; use log::{debug, trace}; use transcript::FFITranscript; @@ -38,14 +37,13 @@ pub struct StarkProver { air_id: usize, airgroup_id: usize, instance_id: usize, - p_setup: *mut c_void, pub p_stark: *mut c_void, p_stark_info: *mut c_void, stark_info: StarkInfo, n_field_elements: usize, merkle_tree_arity: u64, merkle_tree_custom: bool, - p_proof: Option<*mut c_void>, + p_proof: *mut c_void, _marker: PhantomData, // Add PhantomData to track the type F } @@ -56,7 +54,7 @@ impl StarkProver { const FIELD_EXTENSION: usize = 3; pub fn new( - sctx: Arc, + sctx: Arc>, pctx: Arc>, airgroup_id: usize, air_id: usize, @@ -65,9 +63,11 @@ impl StarkProver { ) -> Self { let air_setup_path = pctx.global_info.get_air_setup_path(airgroup_id, air_id, &ProofType::Basic); - let setup = sctx.get_setup(airgroup_id, air_id).expect("REASON"); + let setup = sctx.get_setup(airgroup_id, air_id); - let p_stark = starks_new_c((&setup.p_setup).into()); + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + + let p_stark = starks_new_c((&setup.p_setup).into(), const_tree_ptr); let stark_info_path = air_setup_path.display().to_string() + ".starkinfo.json"; let stark_info_json = std::fs::read_to_string(&stark_info_path) @@ -81,16 +81,19 @@ impl StarkProver { (Self::HASH_SIZE, 2, false) }; + let p_stark_info = setup.p_setup.p_stark_info; + + let p_proof = fri_proof_new_c((&setup.p_setup).into()); + Self { initialized: true, prover_idx, air_id, airgroup_id, instance_id, - p_setup: (&setup.p_setup).into(), - p_stark_info: setup.p_setup.p_stark_info, + p_stark_info, p_stark, - p_proof: None, + p_proof, stark_info, n_field_elements, merkle_tree_arity, @@ -109,8 +112,6 @@ impl Prover for StarkProver { vec![F::zero(); self.stark_info.challenges_map.as_ref().unwrap().len() * Self::FIELD_EXTENSION]; *proof_ctx.challenges.challenges.write().unwrap() = challenges; - self.p_proof = Some(fri_proof_new_c(self.p_setup)); - let number_stage1_commits = *self.stark_info.map_sections_n.get("cm1").unwrap() as usize; for i in 0..number_stage1_commits { air_instance.set_commit_calculated(i); @@ -119,6 +120,11 @@ impl Prover for StarkProver { self.initialized = true; } + fn free(&mut self) { + starks_free_c(self.p_stark); + fri_proof_free_c(self.p_proof); + } + fn new_transcript(&self) -> FFITranscript { let p_stark: *mut std::ffi::c_void = self.p_stark; @@ -131,12 +137,17 @@ impl Prover for StarkProver { self.stark_info.n_stages } - fn verify_constraints(&self, proof_ctx: Arc>) -> Vec { + fn verify_constraints(&self, setup_ctx: Arc>, proof_ctx: Arc>) -> Vec { let air_instance = &mut proof_ctx.air_instance_repo.air_instances.write().unwrap()[self.prover_idx]; + let setup = setup_ctx.get_setup(self.airgroup_id, self.air_id); + let public_inputs_guard = proof_ctx.public_inputs.inputs.read().unwrap(); let challenges_guard = proof_ctx.challenges.challenges.read().unwrap(); + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, public_inputs: (*public_inputs_guard).as_ptr() as *mut c_void, @@ -145,9 +156,11 @@ impl Prover for StarkProver { airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: air_instance.evals.as_ptr() as *mut c_void, xdivxsub: std::ptr::null_mut(), + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; - let raw_ptr = verify_constraints_c(self.p_setup, (&steps_params).into()); + let raw_ptr = verify_constraints_c((&setup.p_setup).into(), (&steps_params).into()); unsafe { let constraints_result = Box::from_raw(raw_ptr as *mut ConstraintsResults); @@ -156,14 +169,19 @@ impl Prover for StarkProver { .to_vec() } - fn calculate_stage(&mut self, stage_id: u32, proof_ctx: Arc>) { + fn calculate_stage(&mut self, stage_id: u32, setup_ctx: Arc>, proof_ctx: Arc>) { let air_instance = &mut proof_ctx.air_instance_repo.air_instances.write().unwrap()[self.prover_idx]; let n_commits = self.stark_info.cm_pols_map.as_ref().expect("REASON").len(); + let setup = setup_ctx.get_setup(self.airgroup_id, self.air_id); + let public_inputs_guard = proof_ctx.public_inputs.inputs.read().unwrap(); let challenges_guard = proof_ctx.challenges.challenges.read().unwrap(); + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, public_inputs: (*public_inputs_guard).as_ptr() as *mut c_void, @@ -172,6 +190,8 @@ impl Prover for StarkProver { airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: air_instance.evals.as_ptr() as *mut c_void, xdivxsub: std::ptr::null_mut(), + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; if stage_id as usize <= proof_ctx.global_info.n_challenges.len() { @@ -200,7 +220,7 @@ impl Prover for StarkProver { } if stage_id as usize == proof_ctx.global_info.n_challenges.len() { - let p_proof = self.p_proof.unwrap(); + let p_proof = self.p_proof; fri_proof_set_airgroup_values_c(p_proof, steps_params.airgroup_values); fri_proof_set_air_values_c(p_proof, steps_params.airvalues); } @@ -267,7 +287,7 @@ impl Prover for StarkProver { timer_start_trace!(STARK_COMMIT_STAGE_, stage_id); - let p_proof = self.p_proof.unwrap(); + let p_proof = self.p_proof; let element_type = if type_name::() == type_name::() { 1 } else { 0 }; let buff_helper_guard = proof_ctx.buff_helper.buff_helper.read().unwrap(); @@ -283,13 +303,18 @@ impl Prover for StarkProver { } } - fn opening_stage(&mut self, opening_id: u32, proof_ctx: Arc>) -> ProverStatus { + fn opening_stage( + &mut self, + opening_id: u32, + setup_ctx: Arc>, + proof_ctx: Arc>, + ) -> ProverStatus { let steps_fri: Vec = proof_ctx.global_info.steps_fri.iter().map(|step| step.n_bits).collect(); let last_stage_id = steps_fri.len() as u32 + 3; if opening_id == 1 { - self.compute_evals(opening_id, proof_ctx); + self.compute_evals(opening_id, setup_ctx, proof_ctx); } else if opening_id == 2 { - self.compute_fri_pol(opening_id, proof_ctx); + self.compute_fri_pol(opening_id, setup_ctx, proof_ctx); } else if opening_id < last_stage_id { let global_step_fri = steps_fri[(opening_id - 3) as usize]; let step_index = @@ -503,14 +528,14 @@ impl Prover for StarkProver { if let Some(step_index) = step_index { let n_steps = steps.len() - 1; if step_index < n_steps { - let p_proof = self.p_proof.unwrap(); + let p_proof = self.p_proof; fri_proof_get_tree_root_c(p_proof, value.as_mut_ptr() as *mut c_void, step_index as u64); } else { let air_instance = &mut proof_ctx.air_instance_repo.air_instances.write().unwrap()[self.prover_idx]; let buffer = air_instance.get_buffer_ptr() as *mut c_void; let n_hash = (1 << (steps[n_steps].n_bits)) * Self::FIELD_EXTENSION as u64; - let fri_pol = get_fri_pol_c(self.p_setup, buffer); + let fri_pol = get_fri_pol_c(self.p_stark_info, buffer); calculate_hash_c(p_stark, value.as_mut_ptr() as *mut c_void, fri_pol, n_hash); } } @@ -563,7 +588,7 @@ impl Prover for StarkProver { } fn get_proof(&self) -> *mut c_void { - self.p_proof.unwrap() + self.p_proof } fn get_zkin_proof(&self, proof_ctx: Arc>, output_dir: &str) -> *mut c_void { @@ -579,7 +604,7 @@ impl Prover for StarkProver { fri_proof_get_zkinproof_c( gidx as u64, - self.p_proof.unwrap(), + self.p_proof, public_inputs, challenges, self.p_stark_info, @@ -623,12 +648,7 @@ impl Prover for StarkProver { } impl StarkProver { - // Return the total number of elements needed to compute the STARK - pub fn get_total_bytes(&self) -> usize { - get_map_totaln_c(self.p_setup) as usize * std::mem::size_of::() - } - - fn compute_evals(&mut self, _opening_id: u32, proof_ctx: Arc>) { + fn compute_evals(&mut self, _opening_id: u32, setup_ctx: Arc>, proof_ctx: Arc>) { let air_name = &proof_ctx.global_info.airs[self.airgroup_id][self.air_id].name; debug!("{}: ··· Calculating evals of instance {} of {}", Self::MY_NAME, self.instance_id, air_name); let air_instance = &mut proof_ctx.air_instance_repo.air_instances.write().unwrap()[self.prover_idx]; @@ -637,24 +657,45 @@ impl StarkProver { let evals = air_instance.evals.as_mut_ptr() as *mut c_void; + let setup = setup_ctx.get_setup(self.airgroup_id, self.air_id); + let p_stark = self.p_stark; - let p_proof = self.p_proof.unwrap(); + let p_proof = self.p_proof; let buff_helper_guard = proof_ctx.buff_helper.buff_helper.read().unwrap(); let buff_helper = (*buff_helper_guard).as_ptr() as *mut c_void; - compute_evals_c(p_stark, buffer, buff_helper, evals, p_proof); + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + + let steps_params = StepsParams { + buffer, + public_inputs: std::ptr::null_mut(), + challenges: std::ptr::null_mut(), + airgroup_values: std::ptr::null_mut(), + airvalues: std::ptr::null_mut(), + evals, + xdivxsub: std::ptr::null_mut(), + p_const_pols: std::ptr::null_mut(), + p_const_tree: const_tree_ptr, + }; + + compute_evals_c(p_stark, (&steps_params).into(), buff_helper, p_proof); } - fn compute_fri_pol(&mut self, _opening_id: u32, proof_ctx: Arc>) { + fn compute_fri_pol(&mut self, _opening_id: u32, setup_ctx: Arc>, proof_ctx: Arc>) { let air_name = &proof_ctx.global_info.airs[self.airgroup_id][self.air_id].name; debug!("{}: ··· Calculating FRI polynomial of instance {} of {}", Self::MY_NAME, self.instance_id, air_name); let air_instance = &mut proof_ctx.air_instance_repo.air_instances.write().unwrap()[self.prover_idx]; + let setup = setup_ctx.get_setup(self.airgroup_id, self.air_id); + let public_inputs_guard = proof_ctx.public_inputs.inputs.read().unwrap(); let challenges_guard = proof_ctx.challenges.challenges.read().unwrap(); let buff_helper_guard = proof_ctx.buff_helper.buff_helper.read().unwrap(); + let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let p_stark = self.p_stark; let steps_params = StepsParams { @@ -665,13 +706,15 @@ impl StarkProver { airvalues: air_instance.airvalues.as_ptr() as *mut c_void, evals: air_instance.evals.as_ptr() as *mut c_void, xdivxsub: (*buff_helper_guard).as_ptr() as *mut c_void, + p_const_pols: const_pols_ptr, + p_const_tree: const_tree_ptr, }; calculate_fri_polynomial_c(p_stark, (&steps_params).into()); } fn compute_fri_folding(&mut self, step_index: u32, proof_ctx: Arc>) { - let p_proof = self.p_proof.unwrap(); + let p_proof = self.p_proof; let air_name = &proof_ctx.global_info.airs[self.airgroup_id][self.air_id].name; @@ -691,7 +734,7 @@ impl StarkProver { let air_instance = &mut proof_ctx.air_instance_repo.air_instances.write().unwrap()[self.prover_idx]; let buffer = air_instance.get_buffer_ptr() as *mut c_void; - let fri_pol = get_fri_pol_c(self.p_setup, buffer); + let fri_pol = get_fri_pol_c(self.p_stark_info, buffer); let challenges_guard = proof_ctx.challenges.challenges.read().unwrap(); let challenge: Vec = challenges_guard.iter().skip(challenges_guard.len() - 3).cloned().collect(); @@ -716,7 +759,7 @@ impl StarkProver { fn compute_fri_queries(&mut self, _opening_id: u32, proof_ctx: Arc>) { let p_stark = self.p_stark; - let p_proof = self.p_proof.unwrap(); + let p_proof = self.p_proof; let n_queries = self.stark_info.stark_struct.n_queries; let steps = &self.stark_info.stark_struct.steps; @@ -751,7 +794,7 @@ impl StarkProver { let air_instance = &mut proof_ctx.air_instance_repo.air_instances.write().unwrap()[self.prover_idx]; let buffer = air_instance.get_buffer_ptr() as *mut c_void; - let fri_pol = get_fri_pol_c(self.p_setup, buffer); + let fri_pol = get_fri_pol_c(self.p_stark_info, buffer); compute_queries_c(p_stark, p_proof, fri_queries.as_mut_ptr(), n_queries, (self.num_stages() + 2) as u64); for (step, _) in steps.iter().enumerate().take(self.stark_info.stark_struct.steps.len()).skip(1) { @@ -783,15 +826,16 @@ impl StarkBufferAllocator { } } -impl BufferAllocator for StarkBufferAllocator { +impl BufferAllocator for StarkBufferAllocator { fn get_buffer_info( &self, - sctx: &SetupCtx, + sctx: &SetupCtx, airgroup_id: usize, air_id: usize, ) -> Result<(u64, Vec), Box> { - let ps = sctx.get_partial_setup(airgroup_id, air_id).expect("REASON"); + let ps = sctx.get_setup(airgroup_id, air_id); - Ok((get_map_totaln_c(ps.p_setup.p_stark_info), vec![get_map_offsets_c(ps.p_setup.p_stark_info, "cm1", false)])) + let p_stark_info = ps.p_setup.p_stark_info; + Ok((get_map_totaln_c(p_stark_info), vec![get_map_offsets_c(p_stark_info, "cm1", false)])) } } diff --git a/provers/starks-lib-c/bindings_starks.rs b/provers/starks-lib-c/bindings_starks.rs index 16ec636f..4aa01c86 100644 --- a/provers/starks-lib-c/bindings_starks.rs +++ b/provers/starks-lib-c/bindings_starks.rs @@ -68,14 +68,6 @@ extern "C" { #[link_name = "\u{1}_Z14fri_proof_freePv"] pub fn fri_proof_free(pFriProof: *mut ::std::os::raw::c_void); } -extern "C" { - #[link_name = "\u{1}_Z13setup_ctx_newPvS_S_"] - pub fn setup_ctx_new( - p_stark_info: *mut ::std::os::raw::c_void, - p_expression_bin: *mut ::std::os::raw::c_void, - p_const_pols: *mut ::std::os::raw::c_void, - ) -> *mut ::std::os::raw::c_void; -} extern "C" { #[link_name = "\u{1}_Z20get_hint_ids_by_namePvPc"] pub fn get_hint_ids_by_name( @@ -83,10 +75,6 @@ extern "C" { hintName: *mut ::std::os::raw::c_char, ) -> *mut ::std::os::raw::c_void; } -extern "C" { - #[link_name = "\u{1}_Z14setup_ctx_freePv"] - pub fn setup_ctx_free(pSetupCtx: *mut ::std::os::raw::c_void); -} extern "C" { #[link_name = "\u{1}_Z14stark_info_newPc"] pub fn stark_info_new(filename: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_void; @@ -142,39 +130,47 @@ extern "C" { pub fn stark_info_free(pStarkInfo: *mut ::std::os::raw::c_void); } extern "C" { - #[link_name = "\u{1}_Z14const_pols_newPcPvb"] - pub fn const_pols_new( - filename: *mut ::std::os::raw::c_char, + #[link_name = "\u{1}_Z18prover_helpers_newPv"] + pub fn prover_helpers_new( pStarkInfo: *mut ::std::os::raw::c_void, - calculate_tree: bool, ) -> *mut ::std::os::raw::c_void; } extern "C" { - #[link_name = "\u{1}_Z24const_pols_with_tree_newPcS_Pv"] - pub fn const_pols_with_tree_new( - filename: *mut ::std::os::raw::c_char, - treeFilename: *mut ::std::os::raw::c_char, - pStarkInfo: *mut ::std::os::raw::c_void, - ) -> *mut ::std::os::raw::c_void; + #[link_name = "\u{1}_Z19prover_helpers_freePv"] + pub fn prover_helpers_free(pProverHelpers: *mut ::std::os::raw::c_void); } extern "C" { - #[link_name = "\u{1}_Z15load_const_treePvS_Pc"] + #[link_name = "\u{1}_Z15load_const_treePvPcm"] pub fn load_const_tree( - pConstPols: *mut ::std::os::raw::c_void, - pStarkInfo: *mut ::std::os::raw::c_void, + pConstTree: *mut ::std::os::raw::c_void, treeFilename: *mut ::std::os::raw::c_char, + constTreeSize: u64, ); } extern "C" { - #[link_name = "\u{1}_Z20calculate_const_treePvS_"] - pub fn calculate_const_tree( + #[link_name = "\u{1}_Z15load_const_polsPvPcm"] + pub fn load_const_pols( pConstPols: *mut ::std::os::raw::c_void, - pStarkInfo: *mut ::std::os::raw::c_void, + constFilename: *mut ::std::os::raw::c_char, + constSize: u64, ); } extern "C" { - #[link_name = "\u{1}_Z15const_pols_freePv"] - pub fn const_pols_free(pConstPols: *mut ::std::os::raw::c_void); + #[link_name = "\u{1}_Z19get_const_tree_sizePv"] + pub fn get_const_tree_size(pStarkInfo: *mut ::std::os::raw::c_void) -> u64; +} +extern "C" { + #[link_name = "\u{1}_Z14get_const_sizePv"] + pub fn get_const_size(pStarkInfo: *mut ::std::os::raw::c_void) -> u64; +} +extern "C" { + #[link_name = "\u{1}_Z20calculate_const_treePvS_S_Pc"] + pub fn calculate_const_tree( + pStarkInfo: *mut ::std::os::raw::c_void, + pConstPolsAddress: *mut ::std::os::raw::c_void, + pConstTree: *mut ::std::os::raw::c_void, + treeFilename: *mut ::std::os::raw::c_char, + ); } extern "C" { #[link_name = "\u{1}_Z19expressions_bin_newPcb"] @@ -246,8 +242,11 @@ extern "C" { ) -> u64; } extern "C" { - #[link_name = "\u{1}_Z10starks_newPv"] - pub fn starks_new(pSetupCtx: *mut ::std::os::raw::c_void) -> *mut ::std::os::raw::c_void; + #[link_name = "\u{1}_Z10starks_newPvS_"] + pub fn starks_new( + pSetupCtx: *mut ::std::os::raw::c_void, + pConstTree: *mut ::std::os::raw::c_void, + ) -> *mut ::std::os::raw::c_void; } extern "C" { #[link_name = "\u{1}_Z11starks_freePv"] @@ -272,7 +271,7 @@ extern "C" { extern "C" { #[link_name = "\u{1}_Z11get_fri_polPvS_"] pub fn get_fri_pol( - pSetupCtx: *mut ::std::os::raw::c_void, + pStarkInfo: *mut ::std::os::raw::c_void, buffer: *mut ::std::os::raw::c_void, ) -> *mut ::std::os::raw::c_void; } @@ -318,12 +317,11 @@ extern "C" { ); } extern "C" { - #[link_name = "\u{1}_Z13compute_evalsPvS_S_S_S_"] + #[link_name = "\u{1}_Z13compute_evalsPvS_S_S_"] pub fn compute_evals( pStarks: *mut ::std::os::raw::c_void, - buffer: *mut ::std::os::raw::c_void, + params: *mut ::std::os::raw::c_void, LEv: *mut ::std::os::raw::c_void, - evals: *mut ::std::os::raw::c_void, pProof: *mut ::std::os::raw::c_void, ); } @@ -336,13 +334,6 @@ extern "C" { nElements: u64, ); } -extern "C" { - #[link_name = "\u{1}_Z14set_const_treePvS_"] - pub fn set_const_tree( - pStarks: *mut ::std::os::raw::c_void, - pConstPols: *mut ::std::os::raw::c_void, - ); -} extern "C" { #[link_name = "\u{1}_Z15merkle_tree_newmmmb"] pub fn merkle_tree_new( @@ -523,10 +514,14 @@ extern "C" { ); } extern "C" { - #[link_name = "\u{1}_Z19gen_recursive_proofPvS_S_Pc"] + #[link_name = "\u{1}_Z19gen_recursive_proofPvPcmS_S_S_S_S0_"] pub fn gen_recursive_proof( pSetupCtx: *mut ::std::os::raw::c_void, + globalInfoFile: *mut ::std::os::raw::c_char, + airgroupId: u64, pAddress: *mut ::std::os::raw::c_void, + pConstPols: *mut ::std::os::raw::c_void, + pConstTree: *mut ::std::os::raw::c_void, pPublicInputs: *mut ::std::os::raw::c_void, proof_file: *mut ::std::os::raw::c_char, ) -> *mut ::std::os::raw::c_void; @@ -535,16 +530,6 @@ extern "C" { #[link_name = "\u{1}_Z12get_zkin_ptrPc"] pub fn get_zkin_ptr(zkin_file: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_void; } -extern "C" { - #[link_name = "\u{1}_Z11public2zkinPvS_Pcmb"] - pub fn public2zkin( - pZkin: *mut ::std::os::raw::c_void, - pPublics: *mut ::std::os::raw::c_void, - globalInfoFile: *mut ::std::os::raw::c_char, - airgroupId: u64, - isAggregated: bool, - ) -> *mut ::std::os::raw::c_void; -} extern "C" { #[link_name = "\u{1}_Z21add_recursive2_verkeyPvPc"] pub fn add_recursive2_verkey( @@ -577,12 +562,16 @@ extern "C" { } extern "C" { #[link_name = "\u{1}_Z20get_serialized_proofPvPm"] - pub fn get_serialized_proof(zkin: *mut ::std::os::raw::c_void, size: *mut u64) -> *mut ::std::os::raw::c_char; + pub fn get_serialized_proof( + zkin: *mut ::std::os::raw::c_void, + size: *mut u64, + ) -> *mut ::std::os::raw::c_char; } - extern "C" { #[link_name = "\u{1}_Z22deserialize_zkin_proofPc"] - pub fn deserialize_zkin_proof(serialized_proof: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_void; + pub fn deserialize_zkin_proof( + serialized_proof: *mut ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_void; } extern "C" { #[link_name = "\u{1}_Z14get_zkin_proofPc"] @@ -599,4 +588,4 @@ extern "C" { extern "C" { #[link_name = "\u{1}_Z11setLogLevelm"] pub fn setLogLevel(level: u64); -} +} \ No newline at end of file diff --git a/provers/starks-lib-c/src/ffi_starks.rs b/provers/starks-lib-c/src/ffi_starks.rs index 5ba46ac8..0300be2f 100644 --- a/provers/starks-lib-c/src/ffi_starks.rs +++ b/provers/starks-lib-c/src/ffi_starks.rs @@ -107,22 +107,6 @@ pub fn fri_proof_free_c(p_fri_proof: *mut c_void) { } } -#[cfg(not(feature = "no_lib_link"))] -pub fn setup_ctx_new_c( - p_stark_info: *mut c_void, - p_expressions_bin: *mut c_void, - p_const_pols: *mut c_void, -) -> *mut c_void { - unsafe { setup_ctx_new(p_stark_info, p_expressions_bin, p_const_pols) } -} - -#[cfg(not(feature = "no_lib_link"))] -pub fn setup_ctx_free_c(p_setup_ctx: *mut c_void) { - unsafe { - setup_ctx_free(p_setup_ctx); - } -} - #[cfg(not(feature = "no_lib_link"))] pub fn stark_info_new_c(filename: &str) -> *mut c_void { unsafe { @@ -188,48 +172,61 @@ pub fn stark_info_free_c(p_stark_info: *mut c_void) { } #[cfg(not(feature = "no_lib_link"))] -pub fn const_pols_new_c(filename: &str, p_stark_info: *mut c_void, calculate_tree: bool) -> *mut c_void { - unsafe { - let filename = CString::new(filename).unwrap(); +pub fn prover_helpers_new_c(p_stark_info: *mut c_void) -> *mut c_void { + unsafe { prover_helpers_new(p_stark_info) } +} - const_pols_new(filename.as_ptr() as *mut std::os::raw::c_char, p_stark_info, calculate_tree) +#[cfg(not(feature = "no_lib_link"))] +pub fn prover_helpers_free_c(p_prover_helpers: *mut c_void) { + unsafe { + prover_helpers_free(p_prover_helpers); } } #[cfg(not(feature = "no_lib_link"))] -pub fn const_pols_with_tree_new_c(filename: &str, tree_filename: &str, p_stark_info: *mut c_void) -> *mut c_void { +pub fn load_const_pols_c(pConstPolsAddress: *mut c_void, const_filename: &str, const_size: u64) { unsafe { - let filename = CString::new(filename).unwrap(); - let tree_filename = CString::new(tree_filename).unwrap(); + let const_filename: CString = CString::new(const_filename).unwrap(); - const_pols_with_tree_new( - filename.as_ptr() as *mut std::os::raw::c_char, - tree_filename.as_ptr() as *mut std::os::raw::c_char, - p_stark_info, - ) + load_const_pols(pConstPolsAddress, const_filename.as_ptr() as *mut std::os::raw::c_char, const_size); } } #[cfg(not(feature = "no_lib_link"))] -pub fn load_const_tree_c(pConstPols: *mut c_void, pStarkInfo: *mut c_void, tree_filename: &str) { - unsafe { - let tree_filename = CString::new(tree_filename).unwrap(); +pub fn get_const_size_c(pStarkInfo: *mut c_void) -> u64 { + unsafe { get_const_size(pStarkInfo) } +} - load_const_tree(pConstPols, pStarkInfo, tree_filename.as_ptr() as *mut std::os::raw::c_char); - } +#[cfg(not(feature = "no_lib_link"))] +pub fn get_const_tree_size_c(pStarkInfo: *mut c_void) -> u64 { + unsafe { get_const_tree_size(pStarkInfo) } } #[cfg(not(feature = "no_lib_link"))] -pub fn calculate_const_tree_c(pConstPols: *mut c_void, pStarkInfo: *mut c_void) { +pub fn load_const_tree_c(pConstPolsTreeAddress: *mut c_void, tree_filename: &str, const_tree_size: u64) { unsafe { - calculate_const_tree(pConstPols, pStarkInfo); + let tree_filename: CString = CString::new(tree_filename).unwrap(); + + load_const_tree(pConstPolsTreeAddress, tree_filename.as_ptr() as *mut std::os::raw::c_char, const_tree_size); } } #[cfg(not(feature = "no_lib_link"))] -pub fn const_pols_free_c(p_const_pols: *mut c_void) { +pub fn calculate_const_tree_c( + pStarkInfo: *mut c_void, + pConstPols: *mut c_void, + pConstPolsTreeAddress: *mut c_void, + tree_filename: &str, +) { unsafe { - const_pols_free(p_const_pols); + let tree_filename: CString = CString::new(tree_filename).unwrap(); + + calculate_const_tree( + pStarkInfo, + pConstPols, + pConstPolsTreeAddress, + tree_filename.as_ptr() as *mut std::os::raw::c_char, + ); } } @@ -378,8 +375,8 @@ pub fn set_hint_field_c( } #[cfg(not(feature = "no_lib_link"))] -pub fn starks_new_c(p_setup_ctx: *mut c_void) -> *mut c_void { - unsafe { starks_new(p_setup_ctx) } +pub fn starks_new_c(p_setup_ctx: *mut c_void, p_const_tree: *mut c_void) -> *mut c_void { + unsafe { starks_new(p_setup_ctx, p_const_tree) } } #[cfg(not(feature = "no_lib_link"))] @@ -416,8 +413,8 @@ pub fn calculate_xdivxsub_c(p_stark: *mut c_void, xi_challenge: *mut c_void, xdi } #[cfg(not(feature = "no_lib_link"))] -pub fn get_fri_pol_c(p_setup_ctx: *mut c_void, buffer: *mut c_void) -> *mut c_void { - unsafe { get_fri_pol(p_setup_ctx, buffer) } +pub fn get_fri_pol_c(p_stark_info: *mut c_void, buffer: *mut c_void) -> *mut c_void { + unsafe { get_fri_pol(p_stark_info, buffer) } } #[cfg(not(feature = "no_lib_link"))] @@ -463,15 +460,9 @@ pub fn compute_lev_c(p_stark: *mut c_void, xi_challenge: *mut c_void, lev: *mut } #[cfg(not(feature = "no_lib_link"))] -pub fn compute_evals_c( - p_stark: *mut c_void, - buffer: *mut c_void, - lev: *mut c_void, - evals: *mut c_void, - pProof: *mut c_void, -) { +pub fn compute_evals_c(p_stark: *mut c_void, params: *mut c_void, lev: *mut c_void, pProof: *mut c_void) { unsafe { - compute_evals(p_stark, buffer, lev, evals, pProof); + compute_evals(p_stark, params, lev, pProof); } } @@ -544,13 +535,6 @@ pub fn calculate_hash_c(pStarks: *mut c_void, pHhash: *mut c_void, pBuffer: *mut } } -#[cfg(not(feature = "no_lib_link"))] -pub fn set_const_tree_c(pStarks: *mut c_void, pConstPols: *mut c_void) { - unsafe { - set_const_tree(pStarks, pConstPols); - } -} - #[cfg(not(feature = "no_lib_link"))] pub fn transcript_new_c(element_type: u32, arity: u64, custom: bool) -> *mut c_void { unsafe { transcript_new(element_type, arity, custom) } @@ -700,16 +684,35 @@ pub fn print_row_c(p_setup_ctx: *mut c_void, buffer: *mut c_void, stage: u64, ro } #[cfg(not(feature = "no_lib_link"))] +#[allow(clippy::too_many_arguments)] pub fn gen_recursive_proof_c( p_setup_ctx: *mut c_void, p_address: *mut c_void, + p_const_pols: *mut c_void, + p_const_tree: *mut c_void, p_public_inputs: *mut c_void, proof_file: &str, + global_info_file: &str, + airgroup_id: u64, ) -> *mut c_void { let proof_file_name = CString::new(proof_file).unwrap(); let proof_file_ptr = proof_file_name.as_ptr() as *mut std::os::raw::c_char; - unsafe { gen_recursive_proof(p_setup_ctx, p_address, p_public_inputs, proof_file_ptr) } + let global_info_file_name = CString::new(global_info_file).unwrap(); + let global_info_file_ptr = global_info_file_name.as_ptr() as *mut std::os::raw::c_char; + + unsafe { + gen_recursive_proof( + p_setup_ctx, + global_info_file_ptr, + airgroup_id, + p_address, + p_const_pols, + p_const_tree, + p_public_inputs, + proof_file_ptr, + ) + } } #[cfg(not(feature = "no_lib_link"))] @@ -720,20 +723,6 @@ pub fn get_zkin_ptr_c(zkin_file: &str) -> *mut c_void { unsafe { get_zkin_ptr(zkin_file_ptr) } } -#[cfg(not(feature = "no_lib_link"))] -pub fn publics2zkin_c( - p_zkin: *mut c_void, - p_publics: *mut c_void, - global_info_file: &str, - airgroup_id: u64, - is_aggregated: bool, -) -> *mut c_void { - let global_info_file_name = CString::new(global_info_file).unwrap(); - let global_info_file_ptr = global_info_file_name.as_ptr() as *mut std::os::raw::c_char; - - unsafe { public2zkin(p_zkin, p_publics, global_info_file_ptr, airgroup_id, is_aggregated) } -} - #[cfg(not(feature = "no_lib_link"))] pub fn add_recursive2_verkey_c(p_zkin: *mut c_void, recursive2_verkey: &str) -> *mut c_void { let recursive2_verkey_name = CString::new(recursive2_verkey).unwrap(); @@ -910,32 +899,6 @@ pub fn fri_proof_free_c(_p_fri_proof: *mut c_void) { trace!("{}: ··· {}", "ffi ", "fri_proof_free: This is a mock call because there is no linked library"); } -#[cfg(feature = "no_lib_link")] -pub fn setup_ctx_new_c( - _p_stark_info: *mut c_void, - _p_expressions_bin: *mut c_void, - _p_const_pols: *mut c_void, -) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "setup_ctx_new: This is a mock call because there is no linked library"); - std::ptr::null_mut() -} - -#[cfg(feature = "no_lib_link")] -pub fn setup_ctx_new1_c( - _stark_info_file: &str, - _expressions_bin_file: &str, - _const_pols_file: &str, - _const_tree_file: &str, -) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "setup_ctx_new1: This is a mock call because there is no linked library"); - std::ptr::null_mut() -} - -#[cfg(feature = "no_lib_link")] -pub fn setup_ctx_free_c(_p_setup_ctx: *mut c_void) { - trace!("{}: ··· {}", "ffi ", "starkinfo_free: This is a mock call because there is no linked library"); -} - #[cfg(feature = "no_lib_link")] pub fn stark_info_new_c(_filename: &str) -> *mut c_void { trace!("{}: ··· {}", "ffi ", "starkinfo_new: This is a mock call because there is no linked library"); @@ -1006,27 +969,45 @@ pub fn stark_info_free_c(_p_stark_info: *mut c_void) { } #[cfg(feature = "no_lib_link")] -pub fn const_pols_new_c(_filename: &str, _p_stark_info: *mut c_void, _calculate_tree: bool) -> *mut c_void { +pub fn prover_helpers_new_c(_p_stark_info: *mut c_void) -> *mut c_void { + trace!("{}: ··· {}", "ffi ", "prover_helpers_new_c: This is a mock call because there is no linked library"); std::ptr::null_mut() } #[cfg(feature = "no_lib_link")] -pub fn const_pols_with_tree_new_c(_filename: &str, _tree_filename: &str, _p_stark_info: *mut c_void) -> *mut c_void { - std::ptr::null_mut() +pub fn prover_helpers_free_c(_p_prover_helpers: *mut c_void) {} + +#[cfg(feature = "no_lib_link")] +pub fn load_const_pols_c(_pConstPolsAddress: *mut c_void, _const_filename: &str, _const_size: u64) { + trace!("{}: ··· {}", "ffi ", "load_const_pols_c: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] -pub fn load_const_tree_c(_pConstPols: *mut c_void, _pStarkInfo: *mut c_void, _tree_filename: &str) { - trace!("{}: ··· {}", "ffi ", "load_const_tree: This is a mock call because there is no linked library"); +pub fn get_const_tree_size_c(_pStarkInfo: *mut c_void) -> u64 { + trace!("{}: ··· {}", "ffi ", "get_const_tree_size_c: This is a mock call because there is no linked library"); + 1000000 } #[cfg(feature = "no_lib_link")] -pub fn calculate_const_tree_c(_pConstPols: *mut c_void, _pStarkInfo: *mut c_void) { - trace!("{}: ··· {}", "ffi ", "calculate_const_tree: This is a mock call because there is no linked library"); +pub fn get_const_size_c(_pStarkInfo: *mut c_void) -> u64 { + trace!("{}: ··· {}", "ffi ", "get_const_size_c: This is a mock call because there is no linked library"); + 1000000 } #[cfg(feature = "no_lib_link")] -pub fn const_pols_free_c(_p_const_pols: *mut c_void) {} +pub fn load_const_tree_c(_pConstPolsTreeAddress: *mut c_void, _tree_filename: &str, _const_tree_size: u64) { + trace!("{}: ··· {}", "ffi ", "load_const_tree_c: This is a mock call because there is no linked library"); +} + +#[cfg(feature = "no_lib_link")] +pub fn calculate_const_tree_c( + _pStarkInfo: *mut c_void, + _pConstPols: *mut c_void, + _pConstPolsTreeAddress: *mut c_void, + _tree_filename: &str, +) { + trace!("{}: ··· {}", "ffi ", "calculate_const_tree_c: This is a mock call because there is no linked library"); +} #[cfg(feature = "no_lib_link")] pub fn expressions_bin_new_c(_filename: &str, _global: bool) -> *mut c_void { @@ -1113,7 +1094,7 @@ pub fn set_hint_field_c( } #[cfg(feature = "no_lib_link")] -pub fn starks_new_c(_p_config: *mut c_void) -> *mut c_void { +pub fn starks_new_c(_p_config: *mut c_void, _p_const_tree: *mut c_void) -> *mut c_void { trace!("{}: ··· {}", "ffi ", "starks_new: This is a mock call because there is no linked library"); std::ptr::null_mut() } @@ -1176,13 +1157,7 @@ pub fn compute_lev_c(_p_stark: *mut c_void, _xi_challenge: *mut c_void, _lev: *m } #[cfg(feature = "no_lib_link")] -pub fn compute_evals_c( - _p_stark: *mut c_void, - _buffer: *mut c_void, - _lev: *mut c_void, - _evals: *mut c_void, - _pProof: *mut c_void, -) { +pub fn compute_evals_c(_p_stark: *mut c_void, _params: *mut c_void, _lev: *mut c_void, _pProof: *mut c_void) { trace!("{}: ··· {}", "ffi ", "compute_evals: This is a mock call because there is no linked library"); } @@ -1192,7 +1167,7 @@ pub fn calculate_xdivxsub_c(_p_stark: *mut c_void, _xi_challenge: *mut c_void, _ } #[cfg(feature = "no_lib_link")] -pub fn get_fri_pol_c(_p_setup_ctx: *mut c_void, _buffer: *mut c_void) -> *mut c_void { +pub fn get_fri_pol_c(_p_stark_info: *mut c_void, _buffer: *mut c_void) -> *mut c_void { trace!("ffi : ··· {}", "get_fri_pol: This is a mock call because there is no linked library"); std::ptr::null_mut() } @@ -1254,11 +1229,6 @@ pub fn calculate_hash_c(_pStarks: *mut c_void, _pHhash: *mut c_void, _pBuffer: * trace!("{}: ··· {}", "ffi ", "calculate_hash: This is a mock call because there is no linked library"); } -#[cfg(feature = "no_lib_link")] -pub fn set_const_tree_c(_pStarks: *mut c_void, _pConstPols: *mut c_void) { - trace!("{}: ··· {}", "ffi ", "set_const_tree: This is a mock call because there is no linked library"); -} - #[cfg(feature = "no_lib_link")] pub fn transcript_new_c(_element_type: u32, _arity: u64, _custom: bool) -> *mut c_void { trace!("{}: ··· {}", "ffi ", "transcript_new: This is a mock call because there is no linked library"); @@ -1383,11 +1353,16 @@ pub fn print_row_c(_p_setup_ctx: *mut c_void, _buffer: *mut c_void, _stage: u64, } #[cfg(feature = "no_lib_link")] +#[allow(clippy::too_many_arguments)] pub fn gen_recursive_proof_c( _p_setup_ctx: *mut c_void, _p_address: *mut c_void, + _p_const_pols: *mut c_void, + _p_const_tree: *mut c_void, _p_public_inputs: *mut c_void, _proof_file: &str, + _global_info_file: &str, + _airgroup_id: u64, ) -> *mut c_void { trace!("{}: ··· {}", "ffi ", "gen_recursive_proof_c: This is a mock call because there is no linked library"); std::ptr::null_mut() @@ -1399,18 +1374,6 @@ pub fn get_zkin_ptr_c(_zkin_file: &str) -> *mut c_void { std::ptr::null_mut() } -#[cfg(feature = "no_lib_link")] -pub fn publics2zkin_c( - _p_zkin: *mut c_void, - _p_publics: *mut c_void, - _global_info_file: &str, - _airgroup_id: u64, - _is_aggregated: bool, -) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "publics2zkin_c: This is a mock call because there is no linked library"); - std::ptr::null_mut() -} - #[cfg(feature = "no_lib_link")] pub fn add_recursive2_verkey_c(_p_zkin: *mut c_void, _recursive2_verkey: &str) -> *mut c_void { trace!("{}: ··· {}", "ffi ", "add_recursive2_verkey_c: This is a mock call because there is no linked library");