Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: bounded lookup round half to even #863

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 148 additions & 1 deletion examples/onnx/rounding_ops/input.json
Original file line number Diff line number Diff line change
@@ -1 +1,148 @@
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.6261028051376343, 0.49872446060180664, -0.04514765739440918, 0.5936200618743896, 0.9271858930587769, 0.6688600778579712, -0.20331168174743652, -0.7016235589981079, 0.025863051414489746, -0.19426143169403076, 0.9827852249145508, 0.4897397756576538, 0.2992602586746216, 0.7011144161224365, 0.9278832674026489, 0.5943725109100342, -0.573331356048584, 0.3675816059112549], [0.7803324460983276, -0.9616303443908691, 0.6070173978805542, -0.028337717056274414, -0.5080242156982422, -0.9280107021331787, 0.6150380373001099, 0.3865993022918701, -0.43668973445892334, 0.17152702808380127, 0.5144252777099609, -0.28881049156188965, 0.8932310342788696, 0.059034109115600586, 0.6865451335906982, 0.009820222854614258, 0.23011493682861328, -0.9492779970169067], [-0.21352827548980713, -0.16015326976776123, -0.38964390754699707, 0.13464701175689697, -0.8814496994018555, 0.5037975311279297, -0.804405927658081, 0.9858957529067993, 0.19567716121673584, 0.9777265787124634, 0.6151977777481079, 0.568595290184021, 0.10584986209869385, -0.8975653648376465, 0.6235959529876709, -0.547879695892334, 0.9289869070053101, 0.7567293643951416]], "output_data": [[1.0, 0.0, -0.0, 1.0, 1.0, 1.0, -0.0, -1.0, 0.0, -0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, -1.0, 0.0], [0.0, -1.0, 0.0, -1.0, -1.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0], [-0.0, -0.0, -0.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0]]}
{
"input_shapes": [
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
]
],
"input_data": [
[
0.5,
1.5,
-0.04514765739440918,
0.5936200618743896,
0.9271858930587769,
0.6688600778579712,
-0.20331168174743652,
-0.7016235589981079,
0.025863051414489746,
-0.19426143169403076,
0.9827852249145508,
0.4897397756576538,
-1.5,
-0.5,
0.9278832674026489,
0.5943725109100342,
-0.573331356048584,
0.3675816059112549
],
[
0.7803324460983276,
-0.9616303443908691,
0.6070173978805542,
-0.028337717056274414,
-0.5080242156982422,
-0.9280107021331787,
0.6150380373001099,
0.3865993022918701,
-0.43668973445892334,
0.17152702808380127,
0.5144252777099609,
-0.28881049156188965,
0.8932310342788696,
0.059034109115600586,
0.6865451335906982,
0.009820222854614258,
0.23011493682861328,
-0.9492779970169067
],
[
-0.21352827548980713,
-0.16015326976776123,
-0.38964390754699707,
0.13464701175689697,
-0.8814496994018555,
0.5037975311279297,
-0.804405927658081,
0.9858957529067993,
0.19567716121673584,
0.9777265787124634,
0.6151977777481079,
0.568595290184021,
0.10584986209869385,
-0.8975653648376465,
0.6235959529876709,
-0.547879695892334,
0.9289869070053101,
0.7567293643951416
]
],
"output_data": [
[
1.0,
0.0,
-0.0,
1.0,
1.0,
1.0,
-0.0,
-1.0,
0.0,
-0.0,
1.0,
0.0,
0.0,
1.0,
1.0,
1.0,
-1.0,
0.0
],
[
0.0,
-1.0,
0.0,
-1.0,
-1.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
-1.0
],
[
-0.0,
-0.0,
-0.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0
]
]
}
11 changes: 11 additions & 0 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
RoundHalfToEven {
scale: utils::F32,
legs: usize,
},
Ceil {
scale: utils::F32,
legs: usize,
Expand Down Expand Up @@ -108,9 +112,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid

fn as_string(&self) -> String {
match self {
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),

HybridOp::Max => format!("MAX"),
HybridOp::Min => format!("MIN"),
HybridOp::Recip {
Expand Down Expand Up @@ -181,6 +189,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Ceil { scale, legs } => {
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
}
Expand Down
149 changes: 149 additions & 0 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4654,6 +4654,155 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
)
}

/// round half to even layout
/// # Arguments
/// * `config` - BaseConfig
/// * `region` - RegionCtx
/// * `values` - &[ValTensor<F>; 1]
/// * `scale` - utils::F32
/// * `legs` - usize
/// # Returns
/// * ValTensor<F>
/// # Example
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::circuit::ops::layouts::round;
/// use ezkl::tensor::val::ValTensor;
/// use halo2curves::bn256::Fr as Fp;
/// use ezkl::circuit::region::RegionCtx;
/// use ezkl::circuit::region::RegionSettings;
/// use ezkl::circuit::BaseConfig;
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[3, -2, 3, 1]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = round::<Fp>(&dummy_config, &mut dummy_region, &[x], 4.0.into(), 2).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, -4, 4, 0]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
/// ```
///
pub fn round_half_to_even<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: utils::F32,
legs: usize,
) -> Result<ValTensor<F>, CircuitError> {
// decompose with base scale and then set the last element to zero
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
// set the last element to zero and then recompose, we don't actually need to assign here
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
let zero = ValType::Constant(F::ZERO);

// if scale is not exactly divisible by 2 we warn
if scale.0 % 2.0 != 0.0 {
log::warn!("Scale is not exactly divisible by 2.0, rounding may not be accurate");
}

let midway_point: ValTensor<F> = create_constant_tensor(
integer_rep_to_felt((scale.0 / 2.0).round() as IntegerRep),
1,
);
let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?;

region.increment(1);

let dims = decomposition.dims().to_vec();
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();

let mut incremented_tensor = Tensor::new(None, &first_dims)?;

let cartesian_coord = first_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();

let inner_loop_function =
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
let coord = cartesian_coord[i].clone();
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut sliced_input = decomposition.get_slice(&slice)?;
sliced_input.flatten();
let last_elem = sliced_input.last()?;

let penultimate_elem =
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?;

let is_equal_to_midway = equals(
config,
region,
&[last_elem.clone(), assigned_midway_point.clone()],
)?;
// penultimate_elem is equal to midway point and even, do nothing
let is_odd = nonlinearity(
config,
region,
&[penultimate_elem.clone()],
&LookupOp::IsOdd,
)?;

let is_odd_and_equal_to_midway = and(
config,
region,
&[is_odd.clone(), is_equal_to_midway.clone()],
)?;

let is_greater_than_midway = greater(
config,
region,
&[last_elem.clone(), assigned_midway_point.clone()],
)?;

// if the number is equal to midway point and odd increment, or if it is is_greater_than_midway
let is_odd_and_equal_to_midway_or_greater_than_midway = or(
config,
region,
&[
is_odd_and_equal_to_midway.clone(),
is_greater_than_midway.clone(),
],
)?;

// increment the penultimate element
let incremented_elem = pairwise(
config,
region,
&[
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
is_odd_and_equal_to_midway_or_greater_than_midway.clone(),
],
BaseOp::Add,
)?;

let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
inner_tensor[sliced_input.len() - 2] =
incremented_elem.get_inner_tensor()?.clone()[0].clone();

// set the last elem to zero
inner_tensor[sliced_input.len() - 1] = zero.clone();

Ok(inner_tensor.clone())
};

region.update_max_min_lookup_inputs_force(0, scale.0 as IntegerRep)?;

region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;

let mut incremented_tensor = incremented_tensor.combine()?;
incremented_tensor.reshape(&dims)?;

recompose(
config,
region,
&[incremented_tensor.into()],
&(scale.0 as usize),
)
}

pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
Expand Down
Loading
Loading