From ebaee9e2b1e07570ad94a278717ebcf629c006ad Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Sat, 26 Oct 2024 08:00:27 -0400 Subject: [PATCH] feat: lookupless min/max ops (#854) --- src/circuit/ops/hybrid.rs | 8 ++++ src/circuit/ops/layouts.rs | 42 +++++++++++++++++ src/circuit/ops/lookup.rs | 18 ------- src/graph/utilities.rs | 89 +++++++++-------------------------- tests/py_integration_tests.rs | 69 +++++++++++++-------------- 5 files changed, 107 insertions(+), 119 deletions(-) diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index ae47a5583..4a081426d 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -45,6 +45,8 @@ pub enum HybridOp { ReduceArgMin { dim: usize, }, + Max, + Min, Softmax { input_scale: utils::F32, output_scale: utils::F32, @@ -79,6 +81,8 @@ impl Op for Hybrid | HybridOp::Less { .. } | HybridOp::Equals { .. } | HybridOp::GreaterEqual { .. } + | HybridOp::Max + | HybridOp::Min | HybridOp::LessEqual { .. } => { vec![0, 1] } @@ -93,6 +97,8 @@ impl Op for Hybrid fn as_string(&self) -> String { match self { + HybridOp::Max => format!("MAX"), + HybridOp::Min => format!("MIN"), HybridOp::Recip { input_scale, output_scale, @@ -162,6 +168,8 @@ impl Op for Hybrid values: &[ValTensor], ) -> Result>, CircuitError> { Ok(Some(match self { + HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?, + HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?, HybridOp::SumPool { padding, stride, diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 12966f4f8..725396688 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -4155,6 +4155,48 @@ pub(crate) fn argmin( Ok(assigned_argmin) } +/// max layout +pub(crate) fn max_comp( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, CircuitError> { + let is_greater = greater(config, region, values)?; + let is_less = not(config, region, &[is_greater.clone()])?; + + let max_val_p1 = pairwise( + config, + region, + &[values[0].clone(), is_greater], + BaseOp::Mult, + )?; + + let max_val_p2 = pairwise(config, region, &[values[1].clone(), is_less], BaseOp::Mult)?; + + pairwise(config, region, &[max_val_p1, max_val_p2], BaseOp::Add) +} + +/// min comp layout +pub(crate) fn min_comp( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, CircuitError> { + let is_greater = greater(config, region, values)?; + let is_less = not(config, region, &[is_greater.clone()])?; + + let min_val_p1 = pairwise(config, region, &[values[0].clone(), is_less], BaseOp::Mult)?; + + let min_val_p2 = pairwise( + config, + region, + &[values[1].clone(), is_greater], + BaseOp::Mult, + )?; + + pairwise(config, region, &[min_val_p1, min_val_p2], BaseOp::Add) +} + /// max layout pub(crate) fn max( config: &BaseConfig, diff --git a/src/circuit/ops/lookup.rs b/src/circuit/ops/lookup.rs index 1a4c545e9..f6c30d1b3 100644 --- a/src/circuit/ops/lookup.rs +++ b/src/circuit/ops/lookup.rs @@ -21,14 +21,6 @@ pub enum LookupOp { Cast { scale: utils::F32, }, - Max { - scale: utils::F32, - a: utils::F32, - }, - Min { - scale: utils::F32, - a: utils::F32, - }, Ceil { scale: utils::F32, }, @@ -129,8 +121,6 @@ impl LookupOp { LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale), LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a), LookupOp::KroneckerDelta => "kronecker_delta".into(), - LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a), - LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a), LookupOp::Div { denom } => format!("div_{}", denom), LookupOp::Cast { scale } => format!("cast_{}", scale), LookupOp::Recip { @@ -186,12 +176,6 @@ impl LookupOp { LookupOp::KroneckerDelta => { Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x)) } - LookupOp::Max { scale, a } => Ok::<_, TensorError>( - tensor::ops::nonlinearities::max(&x, scale.0.into(), a.0.into()), - ), - LookupOp::Min { scale, a } => Ok::<_, TensorError>( - tensor::ops::nonlinearities::min(&x, scale.0.into(), a.0.into()), - ), LookupOp::Div { denom } => Ok::<_, TensorError>( tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()), ), @@ -289,8 +273,6 @@ impl Op for Lookup LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale), LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a), LookupOp::KroneckerDelta => "K_DELTA".into(), - LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a), - LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a), LookupOp::Recip { input_scale, output_scale, diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 8c77269f7..94b08de1c 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -763,81 +763,38 @@ pub fn new_op_from_onnx( .map(|(i, _)| i) .collect::>(); - if const_inputs.len() != 1 { - return Err(GraphError::OpMismatch(idx, "Max".to_string())); - } - - let const_idx = const_inputs[0]; - let boxed_op = inputs[const_idx].opkind(); - let unit = if let Some(c) = extract_const_raw_values(boxed_op) { - if c.len() == 1 { - c[0] - } else { - return Err(GraphError::InvalidDims(idx, "max".to_string())); - } - } else { - return Err(GraphError::OpMismatch(idx, "Max".to_string())); - }; - if inputs.len() == 2 { - if let Some(node) = inputs.get_mut(const_idx) { - node.decrement_use(); - deleted_indices.push(const_idx); - } - if unit == 0. { - SupportedOp::Linear(PolyOp::ReLU) + if const_inputs.len() > 0 { + let const_idx = const_inputs[0]; + let boxed_op = inputs[const_idx].opkind(); + let unit = if let Some(c) = extract_const_raw_values(boxed_op) { + if c.len() == 1 { + c[0] + } else { + return Err(GraphError::InvalidDims(idx, "max".to_string())); + } + } else { + return Err(GraphError::OpMismatch(idx, "Max".to_string())); + }; + if unit == 0. { + if let Some(node) = inputs.get_mut(const_idx) { + node.decrement_use(); + deleted_indices.push(const_idx); + } + SupportedOp::Linear(PolyOp::ReLU) + } else { + SupportedOp::Hybrid(HybridOp::Max) + } } else { - // get the non-constant index - let non_const_idx = if const_idx == 0 { 1 } else { 0 }; - SupportedOp::Nonlinear(LookupOp::Max { - scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(), - a: crate::circuit::utils::F32(unit), - }) + SupportedOp::Hybrid(HybridOp::Max) } } else { return Err(GraphError::InvalidDims(idx, "max".to_string())); } } "Min" => { - // Extract the min value - // first find the input that is a constant - // and then extract the value - let const_inputs = inputs - .iter() - .enumerate() - .filter(|(_, n)| n.is_constant()) - .map(|(i, _)| i) - .collect::>(); - - if const_inputs.len() != 1 { - return Err(GraphError::OpMismatch(idx, "Min".to_string())); - } - - let const_idx = const_inputs[0]; - let boxed_op = inputs[const_idx].opkind(); - let unit = if let Some(c) = extract_const_raw_values(boxed_op) { - if c.len() == 1 { - c[0] - } else { - return Err(GraphError::InvalidDims(idx, "min".to_string())); - } - } else { - return Err(GraphError::OpMismatch(idx, "Min".to_string())); - }; - if inputs.len() == 2 { - if let Some(node) = inputs.get_mut(const_idx) { - node.decrement_use(); - deleted_indices.push(const_idx); - } - - // get the non-constant index - let non_const_idx = if const_idx == 0 { 1 } else { 0 }; - - SupportedOp::Nonlinear(LookupOp::Min { - scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(), - a: crate::circuit::utils::F32(unit), - }) + SupportedOp::Hybrid(HybridOp::Min) } else { return Err(GraphError::InvalidDims(idx, "min".to_string())); } diff --git a/tests/py_integration_tests.rs b/tests/py_integration_tests.rs index 3d3dcb2e6..8c5829b28 100644 --- a/tests/py_integration_tests.rs +++ b/tests/py_integration_tests.rs @@ -124,41 +124,40 @@ mod py_tests { } const TESTS: [&str; 34] = [ - "ezkl_demo_batch.ipynb", - "proof_splitting.ipynb", // 0 - "variance.ipynb", - "mnist_gan.ipynb", - // "mnist_vae.ipynb", - "keras_simple_demo.ipynb", - "mnist_gan_proof_splitting.ipynb", // 4 - "hashed_vis.ipynb", // 5 - "simple_demo_all_public.ipynb", - "data_attest.ipynb", - "little_transformer.ipynb", - "simple_demo_aggregated_proofs.ipynb", - "ezkl_demo.ipynb", // 10 - "lstm.ipynb", - "set_membership.ipynb", // 12 - "decision_tree.ipynb", - "random_forest.ipynb", - "gradient_boosted_trees.ipynb", // 15 - "xgboost.ipynb", - "lightgbm.ipynb", - "svm.ipynb", - "simple_demo_public_input_output.ipynb", - "simple_demo_public_network_output.ipynb", // 20 - "gcn.ipynb", - "linear_regression.ipynb", - "stacked_regression.ipynb", - "data_attest_hashed.ipynb", - "kzg_vis.ipynb", // 25 - "kmeans.ipynb", - "solvency.ipynb", - "sklearn_mlp.ipynb", - "generalized_inverse.ipynb", - "mnist_classifier.ipynb", // 30 - "world_rotation.ipynb", - "logistic_regression.ipynb", + "ezkl_demo_batch.ipynb", // 0 + "proof_splitting.ipynb", // 1 + "variance.ipynb", // 2 + "mnist_gan.ipynb", // 3 + "keras_simple_demo.ipynb", // 4 + "mnist_gan_proof_splitting.ipynb", // 5 + "hashed_vis.ipynb", // 6 + "simple_demo_all_public.ipynb", // 7 + "data_attest.ipynb", // 8 + "little_transformer.ipynb", // 9 + "simple_demo_aggregated_proofs.ipynb", // 10 + "ezkl_demo.ipynb", // 11 + "lstm.ipynb", // 12 + "set_membership.ipynb", // 13 + "decision_tree.ipynb", // 14 + "random_forest.ipynb", // 15 + "gradient_boosted_trees.ipynb", // 16 + "xgboost.ipynb", // 17 + "lightgbm.ipynb", // 18 + "svm.ipynb", // 19 + "simple_demo_public_input_output.ipynb", // 20 + "simple_demo_public_network_output.ipynb", // 21 + "gcn.ipynb", // 22 + "linear_regression.ipynb", // 23 + "stacked_regression.ipynb", // 24 + "data_attest_hashed.ipynb", // 25 + "kzg_vis.ipynb", // 26 + "kmeans.ipynb", // 27 + "solvency.ipynb", // 28 + "sklearn_mlp.ipynb", // 29 + "generalized_inverse.ipynb", // 30 + "mnist_classifier.ipynb", // 31 + "world_rotation.ipynb", // 32 + "logistic_regression.ipynb", // 33 ]; macro_rules! test_func {