Skip to content

Commit

Permalink
refactor: set defaults as exportable consts (zkonduit#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Dec 12, 2023
1 parent 78c0aff commit 865532b
Show file tree
Hide file tree
Showing 16 changed files with 316 additions and 271 deletions.
2 changes: 1 addition & 1 deletion src/circuit/modules/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl<'a, F: Field, CS: Assignment<F> + 'a + SyncDeps> Layouter<F> for ModuleLayo
};
let index = index.parse::<usize>().map_err(|_| {
log::error!("Invalid module name");
return Error::Synthesis;
Error::Synthesis
})?;
if !self.regions.contains_key(&index) {
warn!("spawning module {}", index)
Expand Down
4 changes: 2 additions & 2 deletions src/circuit/ops/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
let mut nonaccum_selectors = BTreeMap::new();
let mut accum_selectors = BTreeMap::new();

if !(inputs[0].num_cols() == inputs[1].num_cols()) {
if inputs[0].num_cols() != inputs[1].num_cols() {
log::warn!("input shapes do not match");
}
if !(inputs[0].num_cols() == output.num_cols()) {
if inputs[0].num_cols() != output.num_cols() {
log::warn!("input and output shapes do not match");
}

Expand Down
2 changes: 1 addition & 1 deletion src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
input.flatten();

// assert we have a single index
if !(index.dims().iter().product::<usize>() == 1) {
if index.dims().iter().product::<usize>() != 1 {
return Err("index must be a single element".into());
}

Expand Down
2 changes: 1 addition & 1 deletion src/circuit/ops/region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let diff = self.num_inner_cols - remainder;
self.increment(diff);
}
if !(self.linear_coord % self.num_inner_cols == 0) {
if self.linear_coord % self.num_inner_cols != 0 {
return Err("flush: linear coord is not aligned with the next row".into());
}
Ok(())
Expand Down
243 changes: 142 additions & 101 deletions src/commands.rs

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,10 @@ pub(crate) async fn get_srs_cmd(
logrows: Option<u32>,
check_mode: CheckMode,
) -> Result<(), Box<dyn Error>> {
let k = if let Some(settings_p) = settings_path {
// logrows overrides settings
let k = if let Some(k) = logrows {
k
} else if let Some(settings_p) = settings_path {
if settings_p.exists() {
let settings = GraphSettings::load(&settings_p)?;
settings.run_args.logrows
Expand All @@ -438,8 +441,6 @@ pub(crate) async fn get_srs_cmd(
);
return Err(err_string.into());
}
} else if let Some(k) = logrows {
k
} else {
let err_string = format!(
"You will need to provide a settings file or set the logrows. You should run gen-settings to generate a settings file (and calibrate-settings to pick optimal logrows)."
Expand Down
46 changes: 20 additions & 26 deletions src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,50 +254,49 @@ impl ToPyObject for GraphWitness {
.map(|x| x.iter().map(field_to_vecu64_montgomery).collect())
.collect();

dict.set_item("inputs", &inputs).unwrap();
dict.set_item("outputs", &outputs).unwrap();
dict.set_item("max_lookup_inputs", &self.max_lookup_inputs)
dict.set_item("inputs", inputs).unwrap();
dict.set_item("outputs", outputs).unwrap();
dict.set_item("max_lookup_inputs", self.max_lookup_inputs)
.unwrap();

if let Some(processed_inputs) = &self.processed_inputs {
//poseidon_hash
if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash {
insert_poseidon_hash_pydict(&dict_inputs, &processed_inputs_poseidon_hash).unwrap();
insert_poseidon_hash_pydict(dict_inputs, processed_inputs_poseidon_hash).unwrap();
}
if let Some(processed_inputs_elgamal) = &processed_inputs.elgamal {
insert_elgamal_results_pydict(py, dict_inputs, processed_inputs_elgamal).unwrap();
}
if let Some(processed_inputs_kzg_commit) = &processed_inputs.kzg_commit {
insert_kzg_commit_pydict(&dict_inputs, &processed_inputs_kzg_commit).unwrap();
insert_kzg_commit_pydict(dict_inputs, processed_inputs_kzg_commit).unwrap();
}

dict.set_item("processed_inputs", dict_inputs).unwrap();
}

if let Some(processed_params) = &self.processed_params {
if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash {
insert_poseidon_hash_pydict(dict_params, &processed_params_poseidon_hash).unwrap();
insert_poseidon_hash_pydict(dict_params, processed_params_poseidon_hash).unwrap();
}
if let Some(processed_params_elgamal) = &processed_params.elgamal {
insert_elgamal_results_pydict(py, dict_params, processed_params_elgamal).unwrap();
}
if let Some(processed_params_kzg_commit) = &processed_params.kzg_commit {
insert_kzg_commit_pydict(&dict_inputs, &processed_params_kzg_commit).unwrap();
insert_kzg_commit_pydict(dict_inputs, processed_params_kzg_commit).unwrap();
}

dict.set_item("processed_params", dict_params).unwrap();
}

if let Some(processed_outputs) = &self.processed_outputs {
if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash {
insert_poseidon_hash_pydict(dict_outputs, &processed_outputs_poseidon_hash)
.unwrap();
insert_poseidon_hash_pydict(dict_outputs, processed_outputs_poseidon_hash).unwrap();
}
if let Some(processed_outputs_elgamal) = &processed_outputs.elgamal {
insert_elgamal_results_pydict(py, dict_outputs, processed_outputs_elgamal).unwrap();
}
if let Some(processed_outputs_kzg_commit) = &processed_outputs.kzg_commit {
insert_kzg_commit_pydict(&dict_inputs, &processed_outputs_kzg_commit).unwrap();
insert_kzg_commit_pydict(dict_inputs, processed_outputs_kzg_commit).unwrap();
}

dict.set_item("processed_outputs", dict_outputs).unwrap();
Expand Down Expand Up @@ -470,29 +469,19 @@ impl GraphSettings {

/// if any visibility is encrypted or hashed
pub fn module_requires_fixed(&self) -> bool {
if self.run_args.input_visibility.is_encrypted()
self.run_args.input_visibility.is_encrypted()
|| self.run_args.input_visibility.is_hashed()
|| self.run_args.output_visibility.is_encrypted()
|| self.run_args.output_visibility.is_hashed()
|| self.run_args.param_visibility.is_encrypted()
|| self.run_args.param_visibility.is_hashed()
{
true
} else {
false
}
}

/// any kzg visibility
pub fn module_requires_kzg(&self) -> bool {
if self.run_args.input_visibility.is_kzgcommit()
self.run_args.input_visibility.is_kzgcommit()
|| self.run_args.output_visibility.is_kzgcommit()
|| self.run_args.param_visibility.is_kzgcommit()
{
true
} else {
false
}
}
}

Expand Down Expand Up @@ -1188,6 +1177,11 @@ impl GraphCircuit {
) -> Result<(), Box<dyn std::error::Error>> {
// Set up local anvil instance for reading on-chain data

let input_scales = self.model().graph.get_input_scales();
let output_scales = self.model().graph.get_output_scales()?;
let input_shapes = self.model().graph.input_shapes()?;
let output_shapes = self.model().graph.output_shapes()?;

if matches!(
test_on_chain_data.data_sources.input,
TestDataSource::OnChain
Expand All @@ -1210,8 +1204,8 @@ impl GraphCircuit {

let datam: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
input_data,
self.model().graph.get_input_scales(),
self.model().graph.input_shapes()?,
input_scales,
input_shapes,
test_on_chain_data.rpc.as_deref(),
)
.await?;
Expand Down Expand Up @@ -1239,8 +1233,8 @@ impl GraphCircuit {
};
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
output_data,
self.model().graph.get_output_scales()?,
self.model().graph.output_shapes()?,
output_scales,
output_shapes,
test_on_chain_data.rpc.as_deref(),
)
.await?;
Expand Down
3 changes: 1 addition & 2 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,14 +395,13 @@ impl ParsedNodes {
pub fn get_input_scales(&self) -> Vec<crate::Scale> {
let input_nodes = self.inputs.iter();
input_nodes
.map(|idx| {
.flat_map(|idx| {
self.nodes
.get(idx)
.ok_or(GraphError::MissingNode(*idx))
.map(|n| n.out_scales())
.unwrap_or_default()
})
.flatten()
.collect()
}

Expand Down
5 changes: 4 additions & 1 deletion src/graph/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,10 @@ impl Node {

input_ids
.iter()
.map(|(i, _)| Ok(inputs.push(other_nodes.get(i).ok_or("input not found")?.clone())))
.map(|(i, _)| {
inputs.push(other_nodes.get(i).ok_or("input not found")?.clone());
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;

let (mut opkind, deleted_indices) = new_op_from_onnx(
Expand Down
2 changes: 1 addition & 1 deletion src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ pub fn new_op_from_onnx(
.map(|(i, _)| i)
.collect::<Vec<_>>();

if !(const_idx.len() <= 1) {
if const_idx.len() > 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "mul".to_string())));
}

Expand Down
20 changes: 19 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub mod wasm;
pub type Scale = i32;

/// Parameters specific to a proving run
#[derive(Debug, Args, Deserialize, Serialize, Clone, Default, PartialEq, PartialOrd)]
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
pub struct RunArgs {
/// The tolerance for error on model outputs
#[arg(short = 'T', long, default_value = "0")]
Expand Down Expand Up @@ -112,6 +112,24 @@ pub struct RunArgs {
pub param_visibility: Visibility,
}

impl Default for RunArgs {
fn default() -> Self {
Self {
tolerance: Tolerance::default(),
input_scale: 7,
param_scale: 7,
scale_rebase_multiplier: 1,
lookup_range: (-32768, 32768),
logrows: 17,
num_inner_cols: 2,
variables: vec![("batch_size".to_string(), 1)],
input_visibility: Visibility::Private,
output_visibility: Visibility::Public,
param_visibility: Visibility::Private,
}
}
}

impl RunArgs {
///
pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
Expand Down
16 changes: 11 additions & 5 deletions src/pfsys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ use thiserror::Error as thisError;
use halo2curves::bn256::{Bn256, Fr, G1Affine};

#[allow(missing_docs)]
#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
#[derive(
ValueEnum, Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd,
)]
pub enum ProofType {
#[default]
Single,
ForAggr,
}
Expand Down Expand Up @@ -142,9 +145,12 @@ pub enum PfSysError {
}

#[allow(missing_docs)]
#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
#[derive(
ValueEnum, Default, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd,
)]
pub enum TranscriptType {
Poseidon,
#[default]
EVM,
}

Expand Down Expand Up @@ -230,10 +236,10 @@ where
.iter()
.map(|x| x.iter().map(|fp| field_to_vecu64_montgomery(fp)).collect())
.collect::<Vec<_>>();
dict.set_item("instances", &field_elems).unwrap();
dict.set_item("instances", field_elems).unwrap();
let hex_proof = hex::encode(&self.proof);
dict.set_item("proof", &hex_proof).unwrap();
dict.set_item("transcript_type", &self.transcript_type)
dict.set_item("proof", hex_proof).unwrap();
dict.set_item("transcript_type", self.transcript_type)
.unwrap();
dict.to_object(py)
}
Expand Down
Loading

0 comments on commit 865532b

Please sign in to comment.