Skip to content

Commit

Permalink
chore: unify leakyrelu and relu (#858)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Oct 29, 2024
1 parent ebaee9e commit 17f1d42
Show file tree
Hide file tree
Showing 15 changed files with 148 additions and 97 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,20 @@ harness = false


[[bench]]
name = "relu"
name = "sigmoid"
harness = false

[[bench]]
name = "relu_lookupless"
harness = false

[[bench]]
name = "accum_matmul_relu"
name = "accum_matmul_sigmoid"
harness = false


[[bench]]
name = "accum_matmul_relu_overflow"
name = "accum_matmul_sigmoid_overflow"
harness = false

[[bin]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
&a,
BITS,
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();

Expand Down Expand Up @@ -93,7 +93,7 @@ impl Circuit<Fr> for MyCircuit {
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Circuit<Fr> for MyCircuit {
&a,
BITS,
k,
&LookupOp::LeakyReLU { slope: 0.0.into() },
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();

Expand Down Expand Up @@ -94,7 +94,7 @@ impl Circuit<Fr> for MyCircuit {
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
Expand Down
9 changes: 8 additions & 1 deletion benches/relu_lookupless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ impl Circuit<Fr> for NLCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
.layout(
&mut region,
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();
Ok(())
},
Expand Down
4 changes: 2 additions & 2 deletions benches/relu.rs → benches/sigmoid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();

let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
let nl = LookupOp::Sigmoid { scale: 1.0.into() };

let mut config = Config::default();

Expand All @@ -68,7 +68,7 @@ impl Circuit<Fr> for NLCircuit {
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
Expand Down
26 changes: 16 additions & 10 deletions examples/conv2d_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ where
let params = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);

let _constant = VarTensor::constant_cols(cs, K, LEN, false);

println!("INPUT COL {:#?}", input);

let mut layer_config = PolyConfig::configure(
Expand All @@ -156,15 +158,11 @@ where
);

layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();

layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.unwrap();

layer_config
Expand Down Expand Up @@ -195,6 +193,11 @@ where
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();

config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();

let x = layouter
.assign_region(
|| "mlp_4d",
Expand Down Expand Up @@ -224,7 +227,10 @@ where
.layout(
&mut region,
&[x.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();

Expand Down
34 changes: 22 additions & 12 deletions examples/mlp_4d_einsum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,23 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
let output = VarTensor::new_advice(cs, K, 1, LEN);
// tells the config layer to add an affine op to the circuit gate

let _constant = VarTensor::constant_cols(cs, K, LEN, false);

println!("INPUT COL {:#?}", input);

let mut layer_config = PolyConfig::<F>::configure(
cs,
&[input.clone(), params.clone()],
&output,
CheckMode::SAFE,
);

// sets up a new ReLU table and resuses it for l1 and l3 non linearities
layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();

layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.unwrap();

// sets up a new ReLU table and resuses it for l1 and l3 non linearities
Expand Down Expand Up @@ -104,6 +103,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();

config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();

let x = layouter
.assign_region(
|| "mlp_4d",
Expand Down Expand Up @@ -144,7 +148,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap()
.unwrap();
Expand Down Expand Up @@ -184,7 +191,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap();
println!("6");
Expand Down
3 changes: 3 additions & 0 deletions src/circuit/ops/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,7 @@ pub enum CircuitError {
#[error("[io] {0}")]
/// IO error
IoError(#[from] std::io::Error),
/// Invalid scale
#[error("negative scale for an op that requires positive inputs {0}")]
NegativeScale(String),
}
44 changes: 39 additions & 5 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4305,7 +4305,6 @@ pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
) -> Result<ValTensor<F>, CircuitError> {
let mut decomp = decompose(config, region, values, &region.base(), &region.legs())?;
// get every n elements now, which correspond to the sign bit

decomp.get_every_n(region.legs() + 1)?;
decomp.reshape(values[0].dims())?;

Expand All @@ -4322,10 +4321,12 @@ pub(crate) fn abs<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
pairwise(config, region, &[values[0].clone(), sign], BaseOp::Mult)
}

pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
pub(crate) fn leaky_relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
alpha: &utils::F32,
input_scale: &i32,
) -> Result<ValTensor<F>, CircuitError> {
let sign = sign(config, region, values)?;

Expand All @@ -4334,12 +4335,45 @@ pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(

let relu_mask = equals(config, region, &[sign, unit])?;

pairwise(
let positive = pairwise(
config,
region,
&[values[0].clone(), relu_mask],
&[values[0].clone(), relu_mask.clone()],
BaseOp::Mult,
)
)?;

if alpha.0 == 0. {
return Ok(positive);
}

if input_scale < &0 {
return Err(CircuitError::NegativeScale("leaky_relu".to_string()));
}

let scale_constant = create_constant_tensor(F::from(2_i32.pow(*input_scale as u32) as u64), 1);

let rescaled_positive = pairwise(config, region, &[positive, scale_constant], BaseOp::Mult)?;

let neg_mask = not(config, region, &[relu_mask])?;

let quantized_alpha = quantize_tensor(
Tensor::from([alpha.0; 1].into_iter()),
*input_scale,
&crate::graph::Visibility::Fixed,
)?;

let alpha_tensor = create_constant_tensor(quantized_alpha[0], 1);

let scaled_neg_mask = pairwise(config, region, &[neg_mask, alpha_tensor], BaseOp::Mult)?;

let neg_part = pairwise(
config,
region,
&[values[0].clone(), scaled_neg_mask],
BaseOp::Mult,
)?;

pairwise(config, region, &[rescaled_positive, neg_part], BaseOp::Add)
}

fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
Expand Down
15 changes: 0 additions & 15 deletions src/circuit/ops/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ pub enum LookupOp {
input_scale: utils::F32,
output_scale: utils::F32,
},
LeakyReLU {
slope: utils::F32,
},
Sigmoid {
scale: utils::F32,
},
Expand Down Expand Up @@ -94,7 +91,6 @@ pub enum LookupOp {
Erf {
scale: utils::F32,
},
KroneckerDelta,
Pow {
scale: utils::F32,
a: utils::F32,
Expand All @@ -120,14 +116,12 @@ impl LookupOp {
LookupOp::Round { scale } => format!("round_{}", scale),
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
LookupOp::KroneckerDelta => "kronecker_delta".into(),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Cast { scale } => format!("cast_{}", scale),
LookupOp::Recip {
input_scale,
output_scale,
} => format!("recip_{}_{}", input_scale, output_scale),
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
Expand Down Expand Up @@ -173,9 +167,6 @@ impl LookupOp {
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
),
LookupOp::KroneckerDelta => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x))
}
LookupOp::Div { denom } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
),
Expand All @@ -190,9 +181,6 @@ impl LookupOp {
input_scale.into(),
output_scale.into(),
)),
LookupOp::LeakyReLU { slope: a } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
}
LookupOp::Sigmoid { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
Expand Down Expand Up @@ -272,7 +260,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
LookupOp::KroneckerDelta => "K_DELTA".into(),
LookupOp::Recip {
input_scale,
output_scale,
Expand All @@ -283,7 +270,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
Expand Down Expand Up @@ -327,7 +313,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
LookupOp::KroneckerDelta => 0,
_ => inputs_scale[0],
};
Ok(scale)
Expand Down
Loading

0 comments on commit 17f1d42

Please sign in to comment.