From 9799ed9d54d8e5756f47269cdc98bc2a8bf2d5af Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Wed, 12 Jan 2022 10:39:41 +0100 Subject: [PATCH 01/18] UIntGadget trait + macro implementation --- r1cs/gadgets/std/src/bits/macros.rs | 1804 +++++++++++++++++++++++++++ r1cs/gadgets/std/src/bits/mod.rs | 180 ++- r1cs/gadgets/std/src/eq.rs | 21 - r1cs/gadgets/std/src/lib.rs | 2 +- 4 files changed, 1980 insertions(+), 27 deletions(-) create mode 100644 r1cs/gadgets/std/src/bits/macros.rs diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs new file mode 100644 index 000000000..402c68244 --- /dev/null +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -0,0 +1,1804 @@ +macro_rules! impl_uint_gadget { + ($type_name: ident, $bit_size: expr, $native_type: ident, $mod_name: ident) => { + pub mod $mod_name { + + use crate::{boolean::{Boolean, AllocatedBit}, fields::{fp::FpGadget, FieldGadget}, eq::{EqGadget, MultiEq}, ToBitsGadget, FromBitsGadget, ToBytesGadget, RotateUInt, UIntGadget, select::CondSelectGadget, bits::UInt8, Assignment}; + + use r1cs_core::{ConstraintSystemAbstract, SynthesisError, LinearCombination}; + use crate::alloc::{AllocGadget, ConstantGadget}; + + use algebra::{fields::{PrimeField, FpParameters}, ToConstraintField}; + + use std::{borrow::Borrow, ops::{Shl, Shr}, convert::TryInto}; + + + //ToDo: remove public use of fields + #[derive(Clone, Debug)] + pub struct $type_name { + // Least significant bit_gadget first + pub(crate) bits: Vec, + pub(crate) value: Option<$native_type>, + } + + impl $type_name { + pub fn get_value(&self) -> Option<$native_type> { + self.value + } + + pub fn constant(value: $native_type) -> Self { + let mut bits = Vec::with_capacity($bit_size); + + for i in 0..$bit_size { + let bit = (value >> i) & 1; + bits.push(Boolean::constant(bit == 1)); + } + + Self { + bits, + value: Some(value), + } + } + + pub fn alloc_vec( + mut cs: CS, + values: &[T], + ) -> Result, SynthesisError> + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + T: Into> + Copy, + { + let mut output_vec = Vec::with_capacity(values.len()); + for (i, value) in values.iter().enumerate() { + let val: Option<$native_type> = Into::into(*value); + let alloc_val = Self::alloc(&mut cs.ns(|| format!("el_{}", i)), || val.get())?; + output_vec.push(alloc_val); + } + Ok(output_vec) + } + + /// Allocates a vector of `u8`'s by first converting (chunks of) them to + /// `ConstraintF` elements, (thus reducing the number of input allocations), + /// and then converts this list of `ConstraintF` gadgets back into + /// bits and then packs chunks of such into `Self`. + pub fn alloc_input_vec( + mut cs: CS, + values: &[u8], + ) -> Result, SynthesisError> + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + { + let field_elements: Vec = + ToConstraintField::::to_field_elements(values).unwrap(); + + let max_size = (::Params::CAPACITY / 8) as usize; + + let mut allocated_bits = Vec::new(); + for (i, (field_element, byte_chunk)) in field_elements + .into_iter() + .zip(values.chunks(max_size)) + .enumerate() + { + let fe = FpGadget::alloc_input(&mut cs.ns(|| format!("Field element {}", i)), || { + Ok(field_element) + })?; + + // Let's use the length-restricted variant of the ToBitsGadget to remove the + // padding: the padding bits are not constrained to be zero, so any field element + // passed as input (as long as it has the last bits set to the proper value) can + // satisfy the constraints. This kind of freedom might not be desiderable in + // recursive SNARK circuits, where the public inputs of the inner circuit are + // usually involved in other kind of constraints inside the wrap circuit. + let to_skip: usize = + ::Params::MODULUS_BITS as usize - (byte_chunk.len() * 8); + let mut fe_bits = fe.to_bits_with_length_restriction( + cs.ns(|| format!("Convert fe to bits {}", i)), + to_skip, + )?; + + // FpGadget::to_bits outputs a big-endian binary representation of + // fe_gadget's value, so we have to reverse it to get the little-endian + // form. + fe_bits.reverse(); + + allocated_bits.extend_from_slice(fe_bits.as_slice()); + } + + // pad with additional zero bits to have a number of bits which is multiple of $bit_size + while allocated_bits.len() % $bit_size != 0 { + allocated_bits.push(Boolean::constant(false)); + } + + // Chunk up slices of $bit_size bits into bytes. + Ok(allocated_bits[..] + .chunks($bit_size) + .enumerate() + .map(|(i, chunk)| Self::from_bits_le(cs.ns(|| format!("pack input chunk {}", i)), chunk)) + .collect::>()?) + } + + /// Construct a constant vector of `Self` from a vector of `u8` + pub fn constant_vec(values: &[u8]) -> Vec { + const BYTES_PER_ELEMENT: usize = $bit_size/8; + let mut result = Vec::new(); + for bytes in values.chunks(BYTES_PER_ELEMENT) { + let mut value: $native_type = 0; + for (i, byte) in bytes.iter().enumerate() { + let byte: $native_type = (*byte).into(); + value |= byte << (i*8); + } + result.push(Self::constant(value)); + } + result + } + + // Return little endian representation of self. Will be removed when to_bits_le and + // from_bits_le will be merged. + pub fn into_bits_le(&self) -> Vec { + self.bits.to_vec() + } + + // Construct self from its little endian bit representation. Will be removed when + // to_bits_le and from_bits_le will be merged. + pub fn from_bits_le(cs: CS, bits: &[Boolean]) -> Result + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + { + let be_bits = bits.iter().rev().map(|el| *el).collect::>(); + Self::from_bits(cs, &be_bits) + } + + // Construct Self from a little endian byte representation, provided in the form of + // byte gadgets + pub fn from_bytes(bytes: &[UInt8]) -> Result { + assert!(bytes.len()*8 <= $bit_size); + let mut bits = Vec::with_capacity($bit_size); + let mut value: Option<$native_type> = Some(0); + for (i, byte) in bytes.iter().enumerate() { + value = match byte.get_value() { + Some(val) => value.as_mut().map(|v| { + let val_native_type: $native_type = val.into(); + *v |= val_native_type << (i*8); + *v}), + None => None, + }; + bits.append(&mut byte.into_bits_le()); + } + + // pad with 0 bits to get to $bit_size + while bits.len() != $bit_size { + bits.push(Boolean::constant(false)); + } + + Ok(Self{ + bits, + value, + }) + } + } + + impl PartialEq for $type_name { + fn eq(&self, other: &Self) -> bool { + self.value.is_some() && other.value.is_some() && self.value == other.value + } +} + + impl Eq for $type_name {} + + impl EqGadget for $type_name { + fn is_eq>( + &self, + cs: CS, + other: &Self, + ) -> Result { + self.bits.is_eq(cs, &other.bits) + } + + fn conditional_enforce_equal>( + &self, + cs: CS, + other: &Self, + should_enforce: &Boolean, + ) -> Result<(), SynthesisError> { + self.bits.conditional_enforce_equal(cs, &other.bits, should_enforce) + } + + //ToDO: check if the default implementation is better than the probably buggy one for [Boolean] + fn conditional_enforce_not_equal>( + &self, + cs: CS, + other: &Self, + should_enforce: &Boolean, + ) -> Result<(), SynthesisError> { + self.bits.conditional_enforce_not_equal(cs, &other.bits, should_enforce) + } + } + + impl AllocGadget<$native_type, ConstraintF> for $type_name { + fn alloc(mut cs: CS, value_gen: F) -> Result + where + CS: ConstraintSystemAbstract, + F: FnOnce() -> Result, + T: Borrow<$native_type> + { + let value = value_gen().map(|val| *val.borrow()); + let bit_values = match value { + Ok(val) => { + let mut bits = Vec::with_capacity($bit_size); + for i in 0..$bit_size { + let bit = (val >> i) & 1; + bits.push(Some(bit == 1)); + } + bits + }, + _ => vec![None; $bit_size] + }; + + let bits = bit_values.into_iter().enumerate().map(|(i, val)| { + Ok(Boolean::from(AllocatedBit::alloc( + &mut cs.ns(|| format!("allocated bit_gadget {}", i)), + || val.ok_or(SynthesisError::AssignmentMissing) + )?)) + }).collect::, SynthesisError>>()?; + + Ok(Self{ + bits, + value: value.ok(), + }) + } + + fn alloc_input(mut cs: CS, value_gen: F) -> Result + where + CS: ConstraintSystemAbstract, + F: FnOnce() -> Result, + T: Borrow<$native_type> + { + let mut value = None; + //ToDo: verify if ConstraintF must be a PrimeField + let field_element = FpGadget::::alloc_input(cs.ns(|| "alloc_input as field element"), || { + let val = value_gen().map(|val| *val.borrow())?; + value = Some(val); + Ok(ConstraintF::from(val)) + })?; + + let to_skip_bits: usize = ConstraintF::Params::MODULUS_BITS as usize - $bit_size; + + let mut bits = field_element.to_bits_with_length_restriction( + &mut cs.ns(|| "field element to bits"), to_skip_bits + )?; + + // need to reverse bits since to_bits_with_length_restriction generates a + // big-endian representation, while Self requires bits in little-endian order + bits.reverse(); + + Ok(Self{ + bits, + value, + }) + } + } + + impl ToBitsGadget for $type_name { + fn to_bits>( + &self, + _cs: CS, + ) -> Result, SynthesisError> { + //Need to reverse bits since to_bits must return a big-endian representation + let le_bits = self.bits.iter().rev().map(|el| *el).collect::>(); + Ok(le_bits) + } + + fn to_bits_strict>( + &self, + cs: CS, + ) -> Result, SynthesisError> { + self.to_bits(cs) + } + } + + impl FromBitsGadget for $type_name { + fn from_bits>( + _cs: CS, + bits: &[Boolean], + ) -> Result { + if bits.len() != $bit_size { + let mut error_msg = String::from(concat!("error: building ", stringify!($type_name))); + error_msg.push_str(format!("from slice of {} bits", bits.len()).as_str()); + return Err(SynthesisError::Other(error_msg)) + } + let mut le_bits = Vec::with_capacity($bit_size); + let mut value: Option<$native_type> = Some(0); + for (i, el) in bits.iter().rev().enumerate() { + le_bits.push(*el); + value = match el.get_value() { + Some(bit) => value.as_mut().map(|v| {*v |= + if bit { + let mask: $native_type = 1; + mask << i + } else { + 0 + }; *v}), + None => None, + }; + } + + Ok(Self { + bits: le_bits, + value, + }) + } + } + + + impl ToBytesGadget for $type_name { + fn to_bytes>( + &self, + _cs: CS, + ) -> Result, SynthesisError> { + const NUM_BYTES: usize = $bit_size/8 + if $bit_size % 8 == 0 {0} else {1}; + let byte_values = match self.value { + Some(val) => { + let mut values = [None; NUM_BYTES]; + for i in 0..NUM_BYTES { + let byte_value: u8 = ((val >> i*8) & 255).try_into().unwrap(); + values[i] = Some(byte_value); + } + values + }, + None => [None; NUM_BYTES] + }; + Ok(self.bits.as_slice().chunks(8).zip(byte_values.iter()).map(|(el, val)| UInt8{ + bits: el.to_vec(), + value: *val, + }).collect::>()) + } + + fn to_bytes_strict>( + &self, + cs: CS, + ) -> Result, SynthesisError> { + self.to_bytes(cs) + } + + } + + impl ConstantGadget<$native_type, ConstraintF> for $type_name { + fn from_value>(_cs: CS, value: &$native_type) -> Self { + $type_name::constant(*value) + } + + fn get_constant(&self) -> $native_type { + self.get_value().unwrap() + } + } + + impl CondSelectGadget for $type_name { + fn conditionally_select>( + mut cs: CS, + cond: &Boolean, + first: &Self, + second: &Self, + ) -> Result { + let bits = first.bits.iter().zip(second.bits.iter()).enumerate().map(|(i, (t, f))| Boolean::conditionally_select(&mut cs.ns(|| format!("cond select bit {}", i)), cond, t, f)).collect::, SynthesisError>>()?; + + assert_eq!(bits.len(), $bit_size); // this assert should always be verified if first and second are built only with public methods + + let value = match cond.get_value() { + Some(cond_bit) => if cond_bit {first.get_value()} else {second.get_value()}, + None => None, + }; + + Ok(Self{ + bits, + value, + }) + } + + fn cost() -> usize { + $bit_size * >::cost() + } + } + + impl Shl for $type_name { + type Output = Self; + + fn shl(self, rhs: usize) -> Self::Output { + let by = if rhs >= $bit_size { + $bit_size-1 + } else { + rhs + }; + + let bits = vec![Boolean::constant(false); by] + .iter() // append rhs zeros as least significant bits + .chain(self.bits.iter()) // Chain existing bits as most significant bits starting from least significant ones + .take($bit_size) // Crop after $bit_size bits + .map(|el| *el) + .collect(); + + Self { + bits, + value: self.value.map(|v| v << by as $native_type), + } + } + } + + impl Shr for $type_name { + type Output = Self; + + fn shr(self, rhs: usize) -> Self::Output { + let by = if rhs >= $bit_size { + $bit_size-1 + } else { + rhs + }; + + let bits = self + .bits + .iter() + .skip(by) // skip least significant bits which are removed by the shift + .chain(vec![Boolean::constant(false); by].iter()) // append zeros as most significant bits + .map(|el| *el) + .collect(); + + Self { + bits, + value: self.value.map(|v| v >> by as $native_type), + } + } + } + + + impl RotateUInt for $type_name { + fn rotl(&self, by: usize) -> Self { + let by = by % $bit_size; + + let bits = self + .bits + .iter() + .skip($bit_size - by) + .chain(self.bits.iter()) + .take($bit_size) + .map(|el| *el) + .collect(); + + Self { + bits, + value: self.value.map(|v| v.rotate_left(by as u32)), + } + } + + fn rotr(&self, by: usize) -> Self { + let by = by % $bit_size; + + let bits = self + .bits + .iter() + .skip(by) + .chain(self.bits.iter()) + .take($bit_size) + .map(|el| *el) + .collect(); + + Self { + bits, + value: self.value.map(|v| v.rotate_right(by as u32)), + } + } + } + + //this macro allows to implement the binary bitwise operations already available for Booleans (i.e., XOR, OR, AND) + macro_rules! impl_binary_bitwise_operation { + ($func_name: ident, $op: tt, $boolean_func: tt) => { + fn $func_name>(&self, mut cs: CS, other: &Self) + -> Result { + let bits = self.bits.iter() + .zip(other.bits.iter()) + .enumerate() + .map(|(i , (b1, b2))| Boolean::$boolean_func(cs.ns(|| format!("xor bit {}", i)), &b1, &b2)) + .collect::, SynthesisError>>()?; + + let value = match other.value { + Some(val) => self.value.map(|v| v $op val), + None => None, + }; + + Ok(Self { + bits, + value, + }) + } + } + } + + // this macro generates the code to handle the case when too many operands are provided + // to addmany/addmany_nocarry/mulmany/mulmany_nocarry functions. + // The operands are split in batches of $max_num_operands elements, + // and each batch is processed independently, aggregating the intermediate result to + // obtain the final outcome of the operation applied to all the operands + macro_rules! handle_numoperands_opmany { + ($opmany_func: tt, $cs: tt, $operands: tt, $max_num_operands: tt) => { + let num_operands = $operands.len(); + // compute the aggregate result over batches of max_num_operands + let mut result = $type_name::$opmany_func($cs.ns(|| "first batch of operands"), &$operands[..$max_num_operands])?; + let mut operands_processed = $max_num_operands; + while operands_processed < num_operands { + let last_op_to_process = if operands_processed + $max_num_operands - 1 > num_operands { + num_operands + } else { + operands_processed + $max_num_operands - 1 + }; + let mut next_operands = $operands[operands_processed..last_op_to_process].iter().cloned().collect::>(); + next_operands.push(result); + result = $type_name::$opmany_func($cs.ns(|| format!("operands from {} to {}", operands_processed, last_op_to_process)), &next_operands[..])?; + operands_processed += $max_num_operands - 1; + } + return Ok(result); + } + } + + impl UIntGadget for $type_name { + + impl_binary_bitwise_operation!(xor, ^, xor); + impl_binary_bitwise_operation!(or, |, or); + impl_binary_bitwise_operation!(and, &, and); + + fn not>(&self, _cs: CS) -> Self { + let bits = self.bits.iter().map(|el| el.not()).collect::>(); + + Self { + bits, + value: self.value.map(|el| !el), + } + } + + fn addmany(mut cs: M, operands: &[Self]) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> + { + let num_operands = operands.len(); + let field_bits = (ConstraintF::Params::MODULUS_BITS - 1) as usize; + // in this case it is not possible to enforce the correctness of the addition + // of at least 2 elements for the field ConstraintF + assert!(field_bits > $bit_size); + assert!(num_operands >= 2); // Weird trivial cases that should never happen + + let overflow_bits = (num_operands as f64).log2().ceil() as usize; + if field_bits < $bit_size + overflow_bits { + // in this case addition of num_operands elements over field would overflow, + // thus it would not be possible to ensure the correctness of the result. + // Therefore, the operands are split in smaller slices, and the sum is + // enforced by multiple calls to addmany over these smaller slices + + // given the field ConstraintF and the $bit_size, compute the maximum number + // of operands for which we can enforce correctness of the result + let max_overflow_bits = field_bits - $bit_size; + let max_num_operands = 1usize << max_overflow_bits; + handle_numoperands_opmany!(addmany, cs, operands, max_num_operands); + } + + + // result_value is the sum of all operands in the ConstraintF field, + // which is employed in the constraint + let mut result_value: Option = Some(ConstraintF::zero()); + // modular_result_value is the sum of all operands mod 2^$bit_size, + // which represents the actual result of the operation + let mut modular_result_value: Option<$native_type> = Some(0); + + + let mut lc = LinearCombination::zero(); + + let mut all_constants = true; + + for op in operands { + // compute value of the result + match op.value { + Some(val) => { + modular_result_value = modular_result_value.as_mut().map(|v| { + let (updated_val, _overflow) = v.overflowing_add($native_type::from(val)); //don't care if addition overflows + updated_val + }); + result_value = result_value.as_mut().map(|v| { + let field_val = ConstraintF::from(val); + *v = *v + field_val; + *v}); + }, + // if at least one of the operands is unset, then the result cannot be computed + None => { modular_result_value = None; + result_value = None + }, + }; + + let mut coeff = ConstraintF::one(); + for bit in &op.bits { + lc = lc + &bit.lc(CS::one(), coeff); + + all_constants &= bit.is_constant(); + + coeff.double_in_place(); + } + } + + if all_constants && result_value.is_some() { + return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &modular_result_value.unwrap())); + } + + let result_bits = match result_value { + Some(f) => f.write_bits().iter().rev().map(|b| Some(*b)).collect::>(), + None => vec![None; ConstraintF::Params::MODULUS_BITS as usize], + }; + // create linear combination for result bits + let mut coeff = ConstraintF::one(); + let mut result_lc = LinearCombination::zero(); + let mut result_bits_gadgets = Vec::with_capacity($bit_size); + for i in 0..$bit_size+overflow_bits { + let alloc_bit = Boolean::alloc(cs.ns(|| format!("alloc result bit {}", i)), || result_bits[i].ok_or(SynthesisError::AssignmentMissing))?; + + result_lc = result_lc + &alloc_bit.lc(CS::one(), coeff); + + coeff.double_in_place(); + + if i < $bit_size { + // only the first $bit_size variables are useful for further operations on the result + result_bits_gadgets.push(alloc_bit); + } + } + + cs.get_root().enforce_equal($bit_size+overflow_bits, &lc, &result_lc); + + Ok(Self { + bits: result_bits_gadgets, + value: modular_result_value, + }) + } + + fn addmany_nocarry(mut cs: M, operands: &[Self]) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> { + let num_operands = operands.len(); + let field_bits = (ConstraintF::Params::MODULUS_BITS - 1) as usize; + // in this case it is not possible to enforce the correctness of the addition + // of at least 2 elements for the field ConstraintF + assert!(field_bits > $bit_size); + assert!(num_operands >= 2); // Weird trivial cases that should never happen + + let overflow_bits = (num_operands as f64).log2().ceil() as usize; + if field_bits < $bit_size + overflow_bits { + // in this case addition of num_operands elements over field would overflow, + // thus it would not be possible to ensure the correctness of the result. + // Therefore, the operands are split in smaller slices, and the sum is + // enforced by multiple calls to addmany_nocarry over these smaller slices + + // given the field ConstraintF and the $bit_size, compute the maximum number + // of operands for which we can enforce correctness of the result + let max_overflow_bits = field_bits - $bit_size; + let max_num_operands = 1usize << max_overflow_bits; + handle_numoperands_opmany!(addmany_nocarry, cs, operands, max_num_operands); + } + + let mut result_value: Option<$native_type> = Some(0); + // this flag allows to verify if the addition of operands overflows, which allows + // to return an error in case a set of constants whose sum is overflowing is provided + let mut is_overflowing = false; + + let mut lc = LinearCombination::zero(); + + let mut all_constants = true; + + for op in operands { + // compute value of the result + result_value = match op.value { + Some(val) => result_value.as_mut().map(|v| { + let (updated_val, overflow) = v.overflowing_add($native_type::from(val)); + is_overflowing |= overflow; + updated_val + }), + // if at least one of the operands is unset, then the result cannot be computed + None => None, + }; + + let mut coeff = ConstraintF::one(); + for bit in &op.bits { + lc = lc + &bit.lc(CS::one(), coeff); + + all_constants &= bit.is_constant(); + + coeff.double_in_place(); + } + } + + if all_constants && result_value.is_some() { + if is_overflowing { + return Err(SynthesisError::Unsatisfiable); + } + return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &result_value.unwrap())); + } + + let result_var = $type_name::alloc(cs.ns(|| "alloc result"), || result_value.ok_or(SynthesisError::AssignmentMissing))?; + + let mut coeff = ConstraintF::one(); + let mut result_lc = LinearCombination::zero(); + + for bit in result_var.bits.iter() { + result_lc = result_lc + &bit.lc(CS::one(), coeff); + + coeff.double_in_place(); + } + + cs.get_root().enforce_equal($bit_size, &lc, &result_lc); + + Ok(result_var) + } + + fn mulmany(mut cs: CS, operands: &[Self]) -> Result + where CS: ConstraintSystemAbstract { + let num_operands = operands.len(); + let field_bits = (ConstraintF::Params::MODULUS_BITS - 1) as usize; + assert!(num_operands >= 2); + assert!(field_bits >= 2*$bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field + + if field_bits < num_operands*$bit_size { + let max_num_operands = field_bits/$bit_size; + handle_numoperands_opmany!(mulmany, cs, operands, max_num_operands); + } + + // corner case: check if all operands are constants before allocating any variable + let mut all_constants = true; + let mut result_value: Option<$native_type> = Some(1); + for op in operands { + for bit in &op.bits { + all_constants &= bit.is_constant(); + } + + result_value = match op.value { + Some(val) => result_value.as_mut().map(|v| v.overflowing_mul(val).0), + None => None, + } + } + + if all_constants && result_value.is_some() { + return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &result_value.unwrap())); + } + + let op0_bits = operands[0].to_bits(cs.ns(|| "unpack first operand"))?; + let op1_bits = operands[1].to_bits(cs.ns(|| "unpack second operand"))?; + let field_op0 = FpGadget::::from_bits(cs.ns(|| "alloc operand 0 in field"), &op0_bits[..])?; + let field_op1 = FpGadget::::from_bits(cs.ns(|| "alloc operand 1 in field"), &op1_bits[..])?; + let mut result = field_op0.mul(cs.ns(|| "mul op0 and op1"), &field_op1)?; + for (i, op) in operands.iter().enumerate().skip(2) { + let op_bits = op.to_bits(cs.ns(|| format!("unpack operand {}", i)))?; + let field_op = FpGadget::::from_bits(cs.ns(|| format!("alloc operand {} in field", i)), &op_bits[..])?; + result = result.mul(cs.ns(|| format!("mul op {}", i)), &field_op)?; + } + + let skip_leading_bits = field_bits + 1 - num_operands*$bit_size; + let result_bits = result.to_bits_with_length_restriction(cs.ns(|| "unpack result field element"), skip_leading_bits)?; + let result_lsbs = result_bits + .iter() + .skip((num_operands-1)*$bit_size) + .map(|el| *el) + .collect::>(); + + $type_name::from_bits(cs.ns(|| "packing result"), &result_lsbs[..]) + } + + fn mulmany_nocarry(mut cs: CS, operands: &[Self]) -> Result + where CS: ConstraintSystemAbstract { + let num_operands = operands.len(); + let field_bits = (ConstraintF::Params::MODULUS_BITS - 1) as usize; + assert!(num_operands >= 2); + assert!(field_bits >= 2*$bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field + + if field_bits < num_operands*$bit_size { + let max_num_operands = field_bits/$bit_size; + handle_numoperands_opmany!(mulmany_nocarry, cs, operands, max_num_operands); + } + + // corner case: check if all operands are constants before allocating any variable + let mut all_constants = true; + let mut result_value: Option<$native_type> = Some(1); + let mut is_overflowing = false; + for op in operands { + for bit in &op.bits { + all_constants &= bit.is_constant(); + } + + result_value = match op.value { + Some(val) => result_value.as_mut().map(|v| { + let (updated_val, overflow) = v.overflowing_mul(val); + is_overflowing |= overflow; + updated_val + }), + None => None, + } + } + + if all_constants && result_value.is_some() { + if is_overflowing{ + return Err(SynthesisError::Unsatisfiable); + } else { + return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &result_value.unwrap())); + } + } + + let op0_bits = operands[0].to_bits(cs.ns(|| "unpack first operand"))?; + let op1_bits = operands[1].to_bits(cs.ns(|| "unpack second operand"))?; + let field_op0 = FpGadget::::from_bits(cs.ns(|| "alloc operand 0 in field"), &op0_bits[..])?; + let field_op1 = FpGadget::::from_bits(cs.ns(|| "alloc operand 1 in field"), &op1_bits[..])?; + let mut result = field_op0.mul(cs.ns(|| "mul op0 and op1"), &field_op1)?; + for (i, op) in operands.iter().enumerate().skip(2) { + let op_bits = op.to_bits(cs.ns(|| format!("unpack operand {}", i)))?; + let field_op = FpGadget::::from_bits(cs.ns(|| format!("alloc operand {} in field", i)), &op_bits[..])?; + result = result.mul(cs.ns(|| format!("mul op {}", i)), &field_op)?; + } + + let skip_leading_bits = field_bits + 1 - $bit_size; // we want to verify that the field element for the product of operands can be represented with $bit_size bits to ensure that there is no overflow + let result_bits = result.to_bits_with_length_restriction(cs.ns(|| "unpack result field element"), skip_leading_bits)?; + assert_eq!(result_bits.len(), $bit_size); + $type_name::from_bits(cs.ns(|| "packing result"), &result_bits[..]) + } + + } + + + + #[cfg(test)] + mod test { + use super::$type_name; + use rand::{Rng, thread_rng}; + use algebra::{fields::tweedle::Fr, Group, Field, FpParameters, PrimeField}; + use r1cs_core::{ + ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode, SynthesisError, + }; + + use std::ops::{Shl, Shr}; + + use crate::{alloc::{AllocGadget, ConstantGadget}, eq::{EqGadget, MultiEq}, boolean::Boolean, ToBitsGadget, FromBitsGadget, ToBytesGadget, RotateUInt, UIntGadget, select::CondSelectGadget, bits::UInt8}; + + + fn test_uint_gadget_value(val: $native_type, alloc_val: &$type_name, check_name: &str) { + assert_eq!(alloc_val.get_value().unwrap(), val, "assertion on value fails for check: {}", check_name); + for i in 0..$bit_size { + assert_eq!(alloc_val.bits[i].get_value().unwrap(), (val >> i) & 1 == 1, "assertion on {} bit fails for check: {}", i, check_name); + } + } + + #[derive(Copy, Clone, Debug)] + enum OperandType { + True, + False, + AllocatedTrue, + AllocatedFalse, + NegatedAllocatedTrue, + NegatedAllocatedFalse, + } + #[derive(Copy, Clone, Debug)] + enum VariableType { + Constant, + Allocated, + PublicInput, + } + + static VARIABLE_TYPES: [VariableType; 3] = [ + VariableType::Constant, + VariableType::Allocated, + VariableType::PublicInput, + ]; + + static BOOLEAN_TYPES: [OperandType; 6] = [ + OperandType::True, + OperandType::False, + OperandType::AllocatedTrue, + OperandType::AllocatedFalse, + OperandType::NegatedAllocatedTrue, + OperandType::NegatedAllocatedFalse, + ]; + + // utility function employed to allocate either a variable, a public input or a constant + fn alloc_fn(cs: &mut ConstraintSystem::, name: &str, alloc_type: &VariableType, value: $native_type) -> $type_name { + match *alloc_type { + VariableType::Allocated => $type_name::alloc(cs.ns(|| name), || Ok(value)).unwrap(), + VariableType::PublicInput => $type_name::alloc_input(cs.ns(|| name), || Ok(value)).unwrap(), + VariableType::Constant => $type_name::from_value(cs.ns(|| name), &value), + } + } + + // utility function employed to allocate a Boolean gadget for all possible types + fn alloc_boolean_cond(cs: &mut ConstraintSystem::, name: &str, alloc_type: &OperandType) -> Boolean { + let cs = cs.ns(|| name); + + match alloc_type { + OperandType::True => Boolean::constant(true), + OperandType::False => Boolean::constant(false), + OperandType::AllocatedTrue => { + Boolean::alloc(cs, || Ok(true)).unwrap() + } + OperandType::AllocatedFalse => { + Boolean::alloc(cs, || Ok(false)).unwrap() + } + OperandType::NegatedAllocatedTrue => { + Boolean::alloc(cs, || Ok(true)).unwrap().not() + } + OperandType::NegatedAllocatedFalse => { + Boolean::alloc(cs, || Ok(false)).unwrap().not() + } + } + } + + #[test] + fn test_eq_gadget() { + let rng = &mut thread_rng(); + + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let val: $native_type = rng.gen(); + + let witness = $type_name::alloc(cs.ns(|| "alloc value"), || Ok(val)).unwrap(); + let public_input = $type_name::alloc(cs.ns(|| "alloc input value"), || Ok(val)).unwrap(); + + let cmp = witness.is_eq(cs.ns(|| "witness == public_input"), &public_input).unwrap(); + assert!(cmp.get_value().unwrap()); + + witness.enforce_equal(cs.ns(|| "enforce witness == public_input"), &public_input).unwrap(); + assert!(cs.is_satisfied()); + + witness.conditional_enforce_not_equal(cs.ns(|| "fake enforce witness != public_input"), &public_input, &Boolean::constant(false)).unwrap(); + assert!(cs.is_satisfied()); //cs should still be satisfied as the previous inequality should not be enforced + + + let witness_ne = $type_name::alloc(cs.ns(|| "alloc value+1"), || Ok(val+1)).unwrap(); + + let cmp = witness_ne.is_neq(cs.ns(|| "val+1 != val"), &public_input).unwrap(); + assert!(cmp.get_value().unwrap()); + + witness_ne.enforce_not_equal(cs.ns(|| "enforce val != val+1"), &public_input).unwrap(); + assert!(cs.is_satisfied()); + + let cmp = witness.is_eq(cs.ns(|| "val == val+1"), &witness_ne).unwrap(); + assert!(!cmp.get_value().unwrap()); + + witness.conditional_enforce_equal(cs.ns(|| "fake enforce val == val+1"), &witness_ne, &Boolean::constant(false)).unwrap(); + assert!(cs.is_satisfied()); //cs should be satisfied since the previous equality should not be enforced + + witness.enforce_equal(cs.ns(|| "enforce val == val+1"), &witness_ne).unwrap(); + assert!(!cs.is_satisfied()); + } + + + #[test] + fn test_alloc() { + let rng = &mut thread_rng(); + + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let val: $native_type = rng.gen(); + + let value_gen = |val: Option<$native_type>| { + val.ok_or(SynthesisError::Other("no value".to_string())) + }; + + let alloc_var = $type_name::alloc(cs.ns(|| "alloc val"), || value_gen(Some(val))).unwrap(); + let alloc_input_var = $type_name::alloc_input(cs.ns(|| "alloc input val"), || value_gen(Some(val))).unwrap(); + + test_uint_gadget_value(val, &alloc_var, "alloc variable"); + test_uint_gadget_value(val, &alloc_input_var, "alloc public input"); + + //try allocating no value + let alloc_err = $type_name::alloc(cs.ns (|| "alloc empty val"), || value_gen(None)).unwrap_err(); + let alloc_input_err = $type_name::alloc_input(cs.ns (|| "alloc empty input val"), || value_gen(None)).unwrap_err(); + + assert!( + match alloc_err { + SynthesisError::AssignmentMissing => true, + _ => false, + } + ); + assert!( + match alloc_input_err { + SynthesisError::Other(_) => true, + _ => false, + } + ); + + //allocating no value in cs in setup mode should yield no error -> unwrap should not panic + let mut cs = ConstraintSystem::::new(SynthesisMode::Setup); + let _ = $type_name::alloc(cs.ns (|| "alloc empty val"), || value_gen(None)).unwrap(); + let _ = $type_name::alloc_input(cs.ns (|| "alloc empty input val"), || value_gen(None)).unwrap(); + + // test constant generation + let const_val: $native_type = rng.gen(); + + let uint_const = $type_name::from_value(cs.ns(|| "alloc const val"), &const_val); + + test_uint_gadget_value(const_val, &uint_const, "alloc constant"); + assert_eq!(const_val, ConstantGadget::<$native_type, Fr>::get_constant(&uint_const)); + } + + #[test] + fn test_alloc_vec() { + let rng = &mut thread_rng(); + + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let vec_len: usize = rng.gen_range(15..30); + // compute values + let values = (0..vec_len).map(|_| rng.gen()).collect::>(); + + let alloc_vec = $type_name::alloc_vec(cs.ns(|| "alloc vec"), &values).unwrap(); + + for (i, (alloc_val, val)) in alloc_vec.iter().zip(values.iter()).enumerate() { + test_uint_gadget_value(*val, alloc_val, format!("test vec element {}", i).as_str()); + } + + // try allocating no values + let empty_values = (0..vec_len).map(|_| None).collect::>>(); + + let alloc_err = $type_name::alloc_vec(cs.ns(|| "alloc empty vec"), &empty_values).unwrap_err(); + + assert!( + match alloc_err { + SynthesisError::AssignmentMissing => true, + _ => false, + } + ); + + //allocating no value in cs in setup mode should yield no error -> unwrap should not panic + let mut cs = ConstraintSystem::::new(SynthesisMode::Setup); + let _ = $type_name::alloc_vec(cs.ns (|| "alloc empty vec"), &empty_values).unwrap(); + } + + #[test] + fn test_alloc_input_vec() { + let rng = &mut thread_rng(); + + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let vec_len: usize = rng.gen_range($bit_size..$bit_size*2); + println!("vec len: {}", vec_len); + // allocate input vector of VEC_LEN random bytes + let input_vec = (0..vec_len).map(|_| rng.gen()).collect::>(); + + let alloc_vec = $type_name::alloc_input_vec(cs.ns(|| "alloc input vec"), &input_vec).unwrap(); + + for (i, (input_bytes, alloc_el)) in input_vec.chunks_exact($bit_size/8).zip(alloc_vec.iter()).enumerate() { + let input_bytes_gadgets = UInt8::constant_vec(&input_bytes); + let input_el = $type_name::from_bytes(&input_bytes_gadgets).unwrap(); + input_el.enforce_equal(cs.ns(|| format!("eq for chunk {}", i)), &alloc_el).unwrap(); + assert_eq!(input_el.get_value().unwrap(), alloc_el.get_value().unwrap()); + } + + assert!(cs.is_satisfied()); + + // test allocation of vector of constants from vector of bytes + let constant_vec = $type_name::constant_vec(&input_vec); + + for (i, (input_bytes, alloc_el)) in input_vec.chunks($bit_size/8).zip(constant_vec.iter()).enumerate() { + let input_bytes_gadgets = input_bytes.iter().enumerate() + .map(|(j, byte)| UInt8::from_value(cs.ns(|| format!("alloc byte {} in chunk {}", j, i)), byte)) + .collect::>(); + let input_el = $type_name::from_bytes(&input_bytes_gadgets).unwrap(); + input_el.enforce_equal(cs.ns(|| format!("eq for chunk {} of constant vec", i)), &alloc_el).unwrap(); + assert_eq!(input_el.get_value().unwrap(), alloc_el.get_value().unwrap()); + } + + assert!(cs.is_satisfied()); + + } + + #[test] + fn test_bit_serialization() { + let rng = &mut thread_rng(); + + for var_type in VARIABLE_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let val: $native_type = rng.gen(); + + let alloc_var = alloc_fn(&mut cs, "alloc var", var_type, val); + + let bits = alloc_var.to_bits(cs.ns(|| "unpack variable")).unwrap(); + assert_eq!(bits.len(), $bit_size, "unpacking value"); + + let reconstructed_var = $type_name::from_bits(cs.ns(|| "pack bits"), &bits).unwrap(); + test_uint_gadget_value(val, &reconstructed_var, "packing bits"); + } + } + + #[test] + fn test_byte_serialization() { + let rng = &mut thread_rng(); + + for var_type in VARIABLE_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let val: $native_type = rng.gen(); + + let alloc_var = alloc_fn(&mut cs, "alloc var", var_type, val); + + let bytes = alloc_var.to_bytes(cs.ns(|| "unpack variable")).unwrap(); + assert_eq!(bytes.len(), $bit_size/8); + + let reconstructed_var = $type_name::from_bytes(&bytes).unwrap(); + test_uint_gadget_value(val, &reconstructed_var, "packing bytes"); + + let bits = alloc_var.to_bits(cs.ns(|| "unpack to bits")).unwrap(); + for (i, (bit_chunk, byte)) in bits.chunks(8).zip(bytes.iter().rev()).enumerate() { + let reconstructed_byte = UInt8::from_bits(cs.ns(|| format!("pack byte {} from bits", i)),&bit_chunk).unwrap(); + reconstructed_byte.enforce_equal(cs.ns(|| format!("check equality for byte {}", i)), byte).unwrap(); + } + assert!(cs.is_satisfied()); + } + } + + #[test] + fn test_from_bits() { + let rng = &mut thread_rng(); + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let mut bits = Vec::with_capacity($bit_size); // vector of Booleans + let mut bit_values = Vec::with_capacity($bit_size); // vector of the actual values wrapped by Booleans found in bits vector + for i in 0..$bit_size { + let bit_value: bool = rng.gen(); + // we test all types of Booleans + match i % 3 { + 0 => { + bit_values.push(bit_value); + bits.push(Boolean::Constant(bit_value)) + }, + 1 => { + bit_values.push(bit_value); + let bit = Boolean::alloc(cs.ns(|| format!("alloc bit {}", i)), || Ok(bit_value)).unwrap(); + bits.push(bit) + }, + 2 => { + bit_values.push(!bit_value); + let bit = Boolean::alloc(cs.ns(|| format!("alloc bit {}", i)), || Ok(bit_value)).unwrap(); + bits.push(bit.not()) + }, + _ => {}, + } + } + + + let uint_gadget = $type_name::from_bits(cs.ns(|| "pack random bits"), &bits).unwrap(); + let value = uint_gadget.get_value().unwrap(); + + for (i, el) in uint_gadget.bits.iter().enumerate() { + let bit = el.get_value().unwrap(); + assert_eq!(bit, bits[$bit_size-1-i].get_value().unwrap()); + assert_eq!(bit, bit_values[$bit_size-1-i]); + assert_eq!(bit, (value >> i) & 1 == 1); + } + + // check that to_bits(from_bits(bits)) == bits + let unpacked_bits = uint_gadget.to_bits(cs.ns(|| "unpack bits")).unwrap(); + + for (bit1, bit2) in bits.iter().zip(unpacked_bits.iter()) { + assert_eq!(bit1, bit2); + } + + //check that an error is returned if more than $bit_size bits are unpacked + let mut bits = Vec::with_capacity($bit_size+1); + for _ in 0..$bit_size+1 { + bits.push(Boolean::constant(false)); + } + + let _ = $type_name::from_bits(cs.ns(|| "unpacking too many bits"), &bits).unwrap_err(); + } + + #[test] + fn test_shifts() { + let rng = &mut thread_rng(); + + for var_type in VARIABLE_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let value: $native_type = rng.gen(); + + // test sequence of shifts + let mut alloc_var = alloc_fn(&mut cs, "alloc var", var_type, value); + for by in 0..$bit_size { + let shr_var = alloc_var.shr(by); + test_uint_gadget_value(value >> by, &shr_var, format!("right shift by {} bits", by).as_str()); + alloc_var = shr_var.shl(by); + test_uint_gadget_value((value >> by) << by, &alloc_var, format!("left shift by {} bits", by).as_str()); + } + + + + // check that shl(var, by) == shl(var, $bit_size-1) for by > $bit_size + let alloc_var = alloc_fn(&mut cs, "alloc var for invalid shl", var_type, value); + let by = $bit_size*2; + let shl_var = alloc_var.shl(by); + test_uint_gadget_value(value << $bit_size-1, &shl_var, "invalid left shift"); + + // check that shr(var, by) == shr(var, $bit_size) for by > $bit_size + let alloc_var = alloc_fn(&mut cs, "alloc var for invalid shr", var_type, value); + let by = $bit_size*2; + let shr_var = alloc_var.shr(by); + test_uint_gadget_value(value >> $bit_size-1, &shr_var, "invalid right shift"); + + assert!(cs.is_satisfied()); + } + } + + #[test] + fn test_rotations() { + let rng = &mut thread_rng(); + + for var_type in VARIABLE_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let value: $native_type = rng.gen(); + + let alloc_var = alloc_fn(&mut cs, "alloc var", var_type, value); + for i in 0..$bit_size { + let rotl_var = alloc_var.rotl(i); + test_uint_gadget_value(value.rotate_left(i as u32), &rotl_var, format!("left rotation by {}", i).as_str()); + let rotr_var = rotl_var.rotr(i); + test_uint_gadget_value(value, &rotr_var, format!("right rotation by {}", i).as_str()); + } + + //check rotations are ok even if by > $bit_size + let by = $bit_size*2; + let rotl_var = alloc_var.rotl(by); + test_uint_gadget_value(value.rotate_left(by as u32), &rotl_var, format!("left rotation by {}", by).as_str()); + + let rotr_var = alloc_var.rotl(by); + test_uint_gadget_value(value.rotate_right(by as u32), &rotr_var, format!("right rotation by {}", by).as_str()); + } + } + + #[test] + fn test_bitwise_operations() { + let rng = &mut thread_rng(); + for var_type_a in VARIABLE_TYPES.iter() { + for var_type_b in VARIABLE_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let a: $native_type = rng.gen(); + let b: $native_type = rng.gen(); + let res_xor = a ^ b; + let res_or = a | b; + let res_and = a & b; + let res_nand = !res_and; + + let alloc_a = alloc_fn(&mut cs, "alloc first value", var_type_a, a); + let alloc_b = alloc_fn(&mut cs, "alloc second value", var_type_b, b); + + let xor_var = alloc_a.xor(cs.ns(|| "a xor b"), &alloc_b).unwrap(); + let or_var = alloc_a.or(cs.ns(|| "a or b"), &alloc_b).unwrap(); + let and_var = alloc_a.and(cs.ns(|| "a and b"), &alloc_b).unwrap(); + let nand_var = and_var.not(cs.ns(|| "a nand b")); + + test_uint_gadget_value(res_xor, &xor_var, format!("xor between {:?} {:?}", var_type_a, var_type_b).as_str()); + test_uint_gadget_value(res_or, &or_var, format!("or between {:?} {:?}", var_type_a, var_type_b).as_str()); + test_uint_gadget_value(res_and, &and_var, format!("and between {:?} {:?}", var_type_a, var_type_b).as_str()); + test_uint_gadget_value(res_nand, &nand_var, format!("nand between {:?} {:?}", var_type_a, var_type_b).as_str()); + + + let alloc_xor = alloc_fn(&mut cs, "alloc xor result", var_type_a, res_xor); + let alloc_or = alloc_fn(&mut cs, "alloc or result", var_type_b, res_or); + let alloc_and = alloc_fn(&mut cs, "alloc and result", var_type_a, res_and); + let alloc_nand = alloc_fn(&mut cs, "alloc nand result", var_type_b, res_nand); + + alloc_xor.enforce_equal(cs.ns(|| "check xor result"), &xor_var).unwrap(); + alloc_or.enforce_equal(cs.ns(|| "check or result"), &or_var).unwrap(); + alloc_and.enforce_equal(cs.ns(|| "check and result"), &and_var).unwrap(); + alloc_nand.enforce_equal(cs.ns(|| "check nand result"), &nand_var).unwrap(); + + assert!(cs.is_satisfied()); + } + } + } + + #[test] + fn test_cond_select() { + let rng = &mut thread_rng(); + + //random generates a and b numbers and check all the conditions for each couple + for condition in BOOLEAN_TYPES.iter() { + for var_a_type in VARIABLE_TYPES.iter() { + for var_b_type in VARIABLE_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let cond; + let a; + let b; + let a_val: $native_type = rng.gen(); + let b_val: $native_type = rng.gen(); + { + cond = alloc_boolean_cond(&mut cs, "cond", condition); + } + { + a = alloc_fn(&mut cs, "var_a",var_a_type,a_val); + b = alloc_fn(&mut cs, "var_b",var_b_type,b_val); + } + let before = cs.num_constraints(); + let c = $type_name::conditionally_select(&mut cs, &cond, &a, &b).unwrap(); + let after = cs.num_constraints(); + + assert!( + cs.is_satisfied(), + "failed with operands: cond: {:?}, a: {:?}, b: {:?}", + condition, + a, + b, + ); + test_uint_gadget_value(if cond.get_value().unwrap() { + a_val + } else { + b_val + }, &c, "conditional select"); + + assert!(<$type_name as CondSelectGadget>::cost() >= after - before); + } + } + } + } + + #[test] + fn test_addmany() { + const NUM_OPERANDS: usize = 10; + let rng = &mut thread_rng(); + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let operand_values = (0..NUM_OPERANDS).map(|_| rng.gen()).collect::>(); + + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc operand {}", i).as_str(), &VARIABLE_TYPES[i % 3], *val) + }).collect::>(); + + let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a,b| a.overflowing_add(b).0).unwrap(); + + let result_var = { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + $type_name::addmany(multi_eq.ns(|| "add operands"), &operands).unwrap() + }; + + test_uint_gadget_value(result_value, &result_var, "result correctness"); + assert!(cs.is_satisfied()); + + // negative test + let bit_gadget_path = "add operands/alloc result bit 0/boolean"; + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); + + // test with all constants + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc constant operand {}", i).as_str(), &VARIABLE_TYPES[0], *val) + }).collect::>(); + let num_constraints = cs.num_constraints(); + + let result_var = { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + $type_name::addmany(multi_eq.ns(|| "add constant operands"), &operands).unwrap() + }; + + test_uint_gadget_value(result_value, &result_var, "sum of constants result correctness"); + assert!(cs.is_satisfied()); + assert_eq!(cs.num_constraints(), num_constraints); + } + + #[test] + fn test_mulmany() { + const MAX_NUM_OPERANDS: usize = (::Params::MODULUS_BITS-1) as usize/$bit_size ; + const NUM_OPERANDS: usize = MAX_NUM_OPERANDS*2+5; + // we want to test a case when the operands must be split in multiple chunks + assert!(NUM_OPERANDS > MAX_NUM_OPERANDS); + + + let rng = &mut thread_rng(); + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let operand_values = (0..NUM_OPERANDS).map(|_| rng.gen()).collect::>(); + + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc operand {}", i).as_str(), &VARIABLE_TYPES[i % 3], *val) + }).collect::>(); + + let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a,b| a.overflowing_mul(b).0).unwrap(); + + let result_var = $type_name::mulmany(cs.ns(|| "mul operands"), &operands).unwrap(); + + test_uint_gadget_value(result_value, &result_var, "result correctness"); + assert!(cs.is_satisfied()); + + + + // negative test on first batch + let bit_gadget_path = "mul operands/first batch of operands/unpack result field element/bit 0/boolean"; + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); + + // set bit value back + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(cs.is_satisfied()); + + // negative test on allocated field element + let mut last_batch_start_operand = MAX_NUM_OPERANDS + (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1)*(MAX_NUM_OPERANDS-1); + if last_batch_start_operand == NUM_OPERANDS { + last_batch_start_operand -= MAX_NUM_OPERANDS-1; + } + let bit_gadget_path = format!("mul operands/operands from {} to {}/unpack result field element/bit 0/boolean", last_batch_start_operand, NUM_OPERANDS); + if cs.get(&bit_gadget_path).is_zero() { + cs.set(&bit_gadget_path, Fr::one()); + } else { + cs.set(&bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); + + // set bit value back + if cs.get(&bit_gadget_path).is_zero() { + cs.set(&bit_gadget_path, Fr::one()); + } else { + cs.set(&bit_gadget_path, Fr::zero()); + } + assert!(cs.is_satisfied()); + + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc constant operand {}", i).as_str(), &VARIABLE_TYPES[0], *val) + }).collect::>(); + let num_constraints = cs.num_constraints(); + let result_var = $type_name::mulmany(cs.ns(|| "mul constant operands"), &operands).unwrap(); + + test_uint_gadget_value(result_value, &result_var, "mul of constants result correctness"); + assert!(cs.is_satisfied()); + // check that no additional constraints are introduced if the operands are all constant values + assert_eq!(cs.num_constraints(), num_constraints) + + } + + #[test] + fn test_modular_arithmetic_operations() { + let rng = &mut thread_rng(); + for condition in BOOLEAN_TYPES.iter() { + for var_type_op1 in VARIABLE_TYPES.iter() { + for var_type_op2 in VARIABLE_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let op1: $native_type = rng.gen(); + let op2: $native_type = rng.gen(); + let add_result_val = op1.overflowing_add(op2).0; + let mul_result_val = op1.overflowing_mul(op2).0; + + let op1_var = alloc_fn(&mut cs, "alloc op1", &var_type_op1, op1); + let op2_var = alloc_fn(&mut cs, "alloc op2", &var_type_op2, op2); + let cond_var = alloc_boolean_cond(&mut cs, "alloc condition", condition); + + let add_result_var = { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + op1_var.conditionally_add(&mut multi_eq, &cond_var, &op2_var).unwrap() + }; + let mul_result_var = op1_var.conditionally_mul(&mut cs, &cond_var, &op2_var).unwrap(); + + test_uint_gadget_value(if cond_var.get_value().unwrap() { + add_result_val + } else { + op1 + }, &add_result_var, "addition correctness"); + test_uint_gadget_value(if cond_var.get_value().unwrap() { + mul_result_val + } else { + op1 + }, &mul_result_var, "addition correctness"); + assert!(cs.is_satisfied()); + } + } + } + } + + #[test] + fn test_addmany_nocarry() { + const NUM_OPERANDS: $native_type = 10; + let rng = &mut thread_rng(); + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let max_value = $native_type::MAX/NUM_OPERANDS; // generate operands in this range to ensure no overflow occurs when summing them + let operand_values = (0..NUM_OPERANDS).map(|_| rng.gen_range(0..max_value)).collect::>(); + + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc operand {}", i).as_str(), &VARIABLE_TYPES[i % 3], *val) + }).collect::>(); + + // computation of result_value will panic in case of addition overflows, but it + // should never happen given how we generate operand_values + let result_value: $native_type = operand_values.iter().sum(); + + let result_var = { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + $type_name::addmany_nocarry(multi_eq.ns(|| "add operands"), &operands).unwrap() + }; + + test_uint_gadget_value(result_value, &result_var, "result correctness"); + assert!(cs.is_satisfied()); + + // negative test + let bit_gadget_path = "add operands/alloc result/allocated bit_gadget 0/boolean"; + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); + + // set bit value back + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(cs.is_satisfied()); + + // test with all constants + let num_constraints = cs.num_constraints(); + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc constant operand {}", i).as_str(), &VARIABLE_TYPES[0], *val) + }).collect::>(); + + let result_var = { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + $type_name::addmany_nocarry(multi_eq.ns(|| "add constant operands"), &operands).unwrap() + }; + + test_uint_gadget_value(result_value, &result_var, "sum of constants result correctness"); + assert!(cs.is_satisfied()); + assert_eq!(num_constraints, cs.num_constraints()); // check that no constraints are added when operands are all constants + + // check that constraints are not satisfied in case of overflow + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let operand_values = (0..NUM_OPERANDS).map(|_| rng.gen_range(max_value..=$native_type::MAX)).collect::>(); + + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc operand {}", i).as_str(), &VARIABLE_TYPES[i % 3], *val) + }).collect::>(); + + let mut is_overflowing = false; + let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a,b| { + let (updated_sum, overflow) = a.overflowing_add(b); + is_overflowing |= overflow; + updated_sum + }).unwrap(); + //check that the addition actually overflows, which should always happen given how we generate operand values + assert!(is_overflowing); + + let result_var = { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + $type_name::addmany_nocarry(multi_eq.ns(|| "add overflowing operands"), &operands).unwrap() + }; + + // result should still be corrected, but constraints should not be verified + test_uint_gadget_value(result_value, &result_var, "result of overflowing add correctness"); + assert!(!cs.is_satisfied(), "checking overflow constraint"); + } + + #[test] + fn test_mulmany_nocarry() { + const MAX_NUM_OPERANDS: usize = (::Params::MODULUS_BITS-1) as usize/$bit_size ; + const NUM_OPERANDS: usize = MAX_NUM_OPERANDS*2+5; + // we want to test a case when the operands must be split in multiple chunks + assert!(NUM_OPERANDS > MAX_NUM_OPERANDS); + + let rng = &mut thread_rng(); + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let max_value: $native_type = 1 << ($bit_size/NUM_OPERANDS); // generate operands in this range to ensure no overflow occurs when multiplying them + let operand_values = (0..NUM_OPERANDS).map(|_| rng.gen_range(0..max_value)).collect::>(); + + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc operand {}", i).as_str(), &VARIABLE_TYPES[i % 3], *val) + }).collect::>(); + + // computation of result_value will panic in case of addition overflows, but it + // should never happen given how we generate operand_values + let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a, b| a*b).unwrap(); + + let result_var = $type_name::mulmany_nocarry(cs.ns(|| "mul operands"), &operands).unwrap(); + + test_uint_gadget_value(result_value, &result_var, "result correctness"); + assert!(cs.is_satisfied()); + + // negative test on first batch + let bit_gadget_path = "mul operands/first batch of operands/unpack result field element/bit 0/boolean"; + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); + + // set bit value back + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(cs.is_satisfied()); + + // negative test on allocated field element + let mut last_batch_start_operand = MAX_NUM_OPERANDS + (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1)*(MAX_NUM_OPERANDS-1); + if last_batch_start_operand == NUM_OPERANDS { + last_batch_start_operand -= MAX_NUM_OPERANDS-1; + } + let bit_gadget_path = format!("mul operands/operands from {} to {}/unpack result field element/bit 0/boolean", last_batch_start_operand, NUM_OPERANDS); + if cs.get(&bit_gadget_path).is_zero() { + cs.set(&bit_gadget_path, Fr::one()); + } else { + cs.set(&bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); + + // set bit value back + if cs.get(&bit_gadget_path).is_zero() { + cs.set(&bit_gadget_path, Fr::one()); + } else { + cs.set(&bit_gadget_path, Fr::zero()); + } + assert!(cs.is_satisfied()); + + // test with all constants + let num_constraints = cs.num_constraints(); + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc constant operand {}", i).as_str(), &VARIABLE_TYPES[0], *val) + }).collect::>(); + + let result_var = $type_name::mulmany_nocarry(cs.ns(|| "mul constant operands"), &operands).unwrap(); + + test_uint_gadget_value(result_value, &result_var, "sum of constants result correctness"); + assert!(cs.is_satisfied()); + assert_eq!(num_constraints, cs.num_constraints()); // check that no constraints are added when operands are all constants + + // check that constraints are not satisfied in case of overflow + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let operand_values = (0..NUM_OPERANDS).map(|_| rng.gen_range(max_value..=$native_type::MAX)).collect::>(); + + let operands = operand_values.iter().enumerate().map(|(i, val)| { + alloc_fn(&mut cs, format!("alloc operand {}", i).as_str(), &VARIABLE_TYPES[i % 3], *val) + }).collect::>(); + + + let mut is_overflowing = false; + let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a,b| { + let (updated_sum, overflow) = a.overflowing_mul(b); + is_overflowing |= overflow; + updated_sum + }).unwrap(); + //check that the multiplication actually overflows, which should always happen given how we generate operand values + assert!(is_overflowing); + + let result_var = $type_name::mulmany_nocarry(cs.ns(|| "mul overflowing operands"), &operands).unwrap(); + + test_uint_gadget_value(result_value, &result_var, "result of overflowing mul correctness"); + assert!(!cs.is_satisfied()); + } + + #[test] + fn test_no_carry_arithmetic_operations() { + const OPERATIONS: [&str; 2] = ["add", "mul"]; + let rng = &mut thread_rng(); + for condition in BOOLEAN_TYPES.iter() { + for var_type_op1 in VARIABLE_TYPES.iter() { + for var_type_op2 in VARIABLE_TYPES.iter() { + for op in &OPERATIONS { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let is_add = *op == "add"; + + let max_value: $native_type = if is_add { + $native_type::MAX/2 + } else { + 1 << ($bit_size/2) + }; + + let op1: $native_type = rng.gen_range(0..max_value); + let op2: $native_type = rng.gen_range(0..max_value); + let (result_val, overflow) = if is_add { + op1.overflowing_add(op2) + } else { + op1.overflowing_mul(op2) + }; + // check that performing op on operands do not overflow, which should never happen given how we generate the operands + assert!(!overflow); + + + let op1_var = alloc_fn(&mut cs, "alloc op1", &var_type_op1, op1); + let op2_var = alloc_fn(&mut cs, "alloc op2", &var_type_op2, op2); + let cond_var = alloc_boolean_cond(&mut cs, "alloc conditional", condition); + + let result_var = if is_add { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + op1_var.conditionally_add_nocarry(&mut multi_eq, &cond_var, &op2_var).unwrap() + } else { + op1_var.conditionally_mul_nocarry(&mut cs, &cond_var, &op2_var).unwrap() + }; + + test_uint_gadget_value(if cond_var.get_value().unwrap() { + result_val + } else { + op1 + }, &result_var, format!("{} correctness", op).as_str()); + assert!(cs.is_satisfied()); + + // check that addition with overflow fails + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let op1: $native_type = rng.gen_range(max_value..=$native_type::MAX); + let op2: $native_type = rng.gen_range(max_value..=$native_type::MAX); + + let (result_val, overflow) = if is_add { + op1.overflowing_add(op2) + } else { + op1.overflowing_mul(op2) + }; + // check that addition of operands overflows, which should always happen given how we generate the operands + assert!(overflow); + + let op1_var = alloc_fn(&mut cs, "alloc op1", &var_type_op1, op1); + let op2_var = alloc_fn(&mut cs, "alloc op2", &var_type_op2, op2); + + let result = if is_add { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + op1_var.conditionally_add_nocarry(&mut multi_eq, &cond_var, &op2_var) + } else { + op1_var.conditionally_mul_nocarry(&mut cs, &cond_var, &op2_var) + }; + // Need to distinguish between operands being both constant or not, + // as in the former case the operation should return an error + // rather than unsatisfied constraints + match (var_type_op1, var_type_op2) { + (VariableType::Constant, VariableType::Constant) => { + match result.unwrap_err() { + SynthesisError::Unsatisfiable => (), + _ => assert!(false, "invalid error returned by {}", if is_add {"conditionally_add_nocarry"} else {"conditionally_mul_nocarry"}) + }; + return; + }, + (_, _) => (), + }; + let result_var = result.unwrap(); + + // result should still be correct, but constraints should not be satisfied + test_uint_gadget_value(if cond_var.get_value().unwrap() { + result_val + } else { + op1 + }, &result_var, format!("{} correctness", op).as_str()); + assert!(!cs.is_satisfied(), "checking overflow constraint for {:?} {:?}", var_type_op1, var_type_op2); + + } + } + } + } + } + + } + + } + } +} + +pub mod test_mod {} \ No newline at end of file diff --git a/r1cs/gadgets/std/src/bits/mod.rs b/r1cs/gadgets/std/src/bits/mod.rs index 2137cf17a..b78464b40 100644 --- a/r1cs/gadgets/std/src/bits/mod.rs +++ b/r1cs/gadgets/std/src/bits/mod.rs @@ -1,11 +1,23 @@ -use crate::bits::{boolean::Boolean, uint8::UInt8}; -use algebra::Field; +use crate::bits::boolean::Boolean; +use algebra::{Field, PrimeField}; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; +use crate::alloc::{AllocGadget, ConstantGadget}; +use crate::eq::{EqGadget, MultiEq}; +use crate::select::CondSelectGadget; +use std::fmt::Debug; +use std::ops::{Shl, Shr}; pub mod boolean; -pub mod uint32; -pub mod uint64; -pub mod uint8; +//pub mod uint32; +//pub mod uint64; + +#[macro_use] +pub mod macros; +impl_uint_gadget!(U8, 8, u8, uint8); +impl_uint_gadget!(UInt64, 64, u64, uint64); +impl_uint_gadget!(UInt32, 32, u32, uint32); + +pub type UInt8 = uint8::U8; pub trait ToBitsGadget { fn to_bits>( @@ -63,6 +75,164 @@ where } } +// this trait allows to move out rotl and rotr from UIntGadget, in turn allowing to avoid specifying +// for the compiler a field ConstraintF every time these methods are called, which requires a +// verbose syntax (e.g., UIntGadget::::rotl(&gadget_variable, i) +pub trait RotateUInt { + /// Rotate left `self` by `by` bits. + fn rotl(&self, by: usize) -> Self; + + /// Rotate right `self` by `by` bits. + fn rotr(&self, by: usize) -> Self; +} + +pub trait UIntGadget: +Sized ++ Clone ++ Debug ++ Eq ++ PartialEq ++ EqGadget ++ ToBitsGadget ++ FromBitsGadget ++ ToBytesGadget ++ CondSelectGadget ++ AllocGadget ++ ConstantGadget ++ Shr ++ Shl ++ RotateUInt +{ + /// XOR `self` with `other` + fn xor(&self, cs: CS, other: &Self) -> Result + where + CS: ConstraintSystemAbstract; + + /// OR `self` with `other` + fn or(&self, cs: CS, other: &Self) -> Result + where + CS: ConstraintSystemAbstract; + + /// AND `self` with `other` + fn and(&self, cs: CS, other: &Self) -> Result + where + CS: ConstraintSystemAbstract; + + /// Bitwise NOT of `self` + fn not(&self, cs: CS) -> Self + where + CS: ConstraintSystemAbstract; + + + + /// Perform modular addition of several `Self` objects. + fn addmany(cs: M, operands: &[Self]) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract>; + + /// Perform modular addition of `self` and `other`. The default implementation just invokes + /// `addmany`, it may be overridden in case addition of 2 values may be performed more + /// efficiently than addition of n >= 3 values + fn add(&self, cs: M, other: &Self) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> + { + Self::addmany(cs, &[self.clone(), other.clone()]) + } + + /// Add `self` to `other` if `cond` is True, otherwise do nothing. + fn conditionally_add( + &self, + mut cs: M, + cond: &Boolean, + other: &Self + ) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> + { + let sum = self.add(cs.ns(|| "compute sum"), other)?; + Self::conditionally_select(cs.ns(|| "conditionally select values"), cond, &sum, self) + } + + + /// Perform addition of several `Self` objects, checking that no overflows occur. + fn addmany_nocarry(cs: M, operands: &[Self]) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract>; + + + /// Perform addition of `self` and `other`, checking that no overflows occur. + /// The default implementation just invokes `addmany`, it may be overridden in case addition + /// of 2 values may be performed more efficiently than addition of n >= 3 values + fn add_nocarry(&self, cs: M, other: &Self) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> + { + Self::addmany_nocarry(cs, &[self.clone(), other.clone()]) + } + + /// Add `self` to `other` if `cond` is True, checking that no overflows occur, otherwise do nothing. + fn conditionally_add_nocarry( + &self, + mut cs: M, + cond: &Boolean, + other: &Self + ) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> + { + let sum = self.add_nocarry(cs.ns(|| "compute sum"), other)?; + Self::conditionally_select(cs.ns(|| "conditionally select values"), cond, &sum, self) + } + + /// Perform modular multiplication of several `Self` objects. + fn mulmany(cs: CS, operands: &[Self]) -> Result + where + CS: ConstraintSystemAbstract; + + /// Perform modular multiplication of `self` and `other` + fn mul(&self, cs: CS, other: &Self) -> Result + where + CS: ConstraintSystemAbstract { + Self::mulmany(cs, &[self.clone(), other.clone()]) + } + + /// Multiply `self` to `other` if `cond` is true, do nothing otherwise + fn conditionally_mul(&self, mut cs: CS, cond: &Boolean, other: &Self) -> Result + where + CS: ConstraintSystemAbstract { + let product = self.mul(cs.ns(|| "mul values"), other)?; + Self::conditionally_select(cs.ns(|| "cond select mul result"), cond, &product, self) + } + + /// Perform multiplication of several `Self` objects, checking that no overflows occur + fn mulmany_nocarry(cs: CS, operands: &[Self]) -> Result + where + CS: ConstraintSystemAbstract; + + /// Multiply `self` to `other`, checking that no overflows occur + fn mul_nocarry(&self, cs: CS, other: &Self) -> Result + where + CS: ConstraintSystemAbstract { + Self::mulmany_nocarry(cs, &[self.clone(), other.clone()]) + } + + /// Multiply `self` to `other` if `cond` is true, do nothing otherwise + fn conditionally_mul_nocarry(&self, mut cs: CS, cond: &Boolean, other: &Self) -> Result + where + CS: ConstraintSystemAbstract { + let product = self.mul_nocarry(cs.ns(|| "mul values"), other)?; + Self::conditionally_select(cs.ns(|| "cond select mul result"), cond, &product, self) + } + +} + impl ToBitsGadget for Boolean { fn to_bits>( &self, diff --git a/r1cs/gadgets/std/src/eq.rs b/r1cs/gadgets/std/src/eq.rs index 2ace7a2ba..0e6386d27 100644 --- a/r1cs/gadgets/std/src/eq.rs +++ b/r1cs/gadgets/std/src/eq.rs @@ -128,27 +128,6 @@ impl, ConstraintF: Field> EqGadget for [T] } Ok(()) } - - fn conditional_enforce_not_equal>( - &self, - mut cs: CS, - other: &Self, - should_enforce: &Boolean, - ) -> Result<(), SynthesisError> { - assert_eq!(self.len(), other.len()); - let some_are_different = self.is_neq(cs.ns(|| "is_neq"), other)?; - if some_are_different.get_value().is_some() && should_enforce.get_value().is_some() { - assert!(some_are_different.get_value().unwrap()); - Ok(()) - } else { - some_are_different.conditional_enforce_equal( - cs.ns(|| "conditional_enforce_equal"), - should_enforce, - should_enforce, - )?; - Ok(()) - } - } } /// A struct for collecting identities of linear combinations of Booleans to serve diff --git a/r1cs/gadgets/std/src/lib.rs b/r1cs/gadgets/std/src/lib.rs index 00580dc66..1b2685065 100644 --- a/r1cs/gadgets/std/src/lib.rs +++ b/r1cs/gadgets/std/src/lib.rs @@ -72,7 +72,7 @@ pub mod prelude { pub use crate::{ alloc::*, bits::{ - boolean::Boolean, uint32::UInt32, uint8::UInt8, FromBitsGadget, ToBitsGadget, + boolean::Boolean, uint32::UInt32, UInt8, FromBitsGadget, ToBitsGadget, ToBytesGadget, }, eq::*, From 997bd8f398430e81d0d79a4c77fde7c6639af22d Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Thu, 13 Jan 2022 11:12:01 +0100 Subject: [PATCH 02/18] Define cmp gadget + implementation for FpGadget --- r1cs/gadgets/std/src/bits/boolean.rs | 2 +- r1cs/gadgets/std/src/cmp.rs | 60 +++++ r1cs/gadgets/std/src/fields/cmp.rs | 367 ++++++++++++--------------- r1cs/gadgets/std/src/lib.rs | 1 + 4 files changed, 226 insertions(+), 204 deletions(-) create mode 100644 r1cs/gadgets/std/src/cmp.rs diff --git a/r1cs/gadgets/std/src/bits/boolean.rs b/r1cs/gadgets/std/src/bits/boolean.rs index 6b4f8e139..d02836124 100644 --- a/r1cs/gadgets/std/src/bits/boolean.rs +++ b/r1cs/gadgets/std/src/bits/boolean.rs @@ -715,7 +715,7 @@ impl Boolean { assert!(bits_iter.next().is_none()); Ok(current_run) - } + } } impl PartialEq for Boolean { diff --git a/r1cs/gadgets/std/src/cmp.rs b/r1cs/gadgets/std/src/cmp.rs new file mode 100644 index 000000000..3a52a669c --- /dev/null +++ b/r1cs/gadgets/std/src/cmp.rs @@ -0,0 +1,60 @@ +use std::cmp::Ordering; +use algebra::Field; +use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; +use crate::boolean::Boolean; +use crate::eq::EqGadget; + +pub trait ComparisonGadget: Sized + EqGadget +{ + /// Output a `Boolean` gadget which is equal to `self < other` + fn is_smaller_than>(&self, cs: CS, other: &Self) -> Result; + + /// Enforce in the constraint system `cs` that `self < other` + fn enforce_smaller_than>(&self, cs: CS, other: &Self) -> Result<(), SynthesisError>; + + /// Output a `Boolean` gadget which is true iff the given order relationship between `self` + /// and `other` holds. If `should_also_check_equality` is true, then the order relationship + /// is not strict (e.g., `self <= other` must hold rather than `self < other`). + // The ordering relationship with equality is verified by exploiting the following identities: + // - x <= y iff !(y < x) + // - x >= y iff !(x < y) + fn is_cmp>( + &self, + mut cs: CS, + other: &Self, + ordering: Ordering, + should_also_check_equality: bool, + ) -> Result { + let (left, right) = match (ordering, should_also_check_equality) { + (Ordering::Less, false) | (Ordering::Greater, true) => (self, other), + (Ordering::Greater, false) | (Ordering::Less, true) => (other, self), + (Ordering::Equal, _) => return self.is_eq(cs, other), + }; + + + let is_smaller = left.is_smaller_than(cs.ns(|| "is smaller"), right)?; + + if should_also_check_equality { + return Ok(is_smaller.not()) + } + + Ok(is_smaller) + } + + /// Enforce the given order relationship between `self` and `other`. + /// If `should_also_check_equality` is true, then the order relationship is not strict + /// (e.g., `self <= other` is enforced rather than `self < other`). + // Default implementation calls `is_cmp` to get a Boolean which is true iff the order + // relationship holds, and then enforce this Boolean to be true + fn enforce_cmp>( + &self, + mut cs: CS, + other: &Self, + ordering: Ordering, + should_also_check_equality: bool, + ) -> Result<(), SynthesisError> { + let is_cmp = self.is_cmp(cs.ns(|| "cmp outcome"), other, ordering, should_also_check_equality)?; + + is_cmp.enforce_equal(cs.ns(|| "enforce cmp"), &Boolean::constant(true)) + } +} \ No newline at end of file diff --git a/r1cs/gadgets/std/src/fields/cmp.rs b/r1cs/gadgets/std/src/fields/cmp.rs index 38a424161..aae76128c 100644 --- a/r1cs/gadgets/std/src/fields/cmp.rs +++ b/r1cs/gadgets/std/src/fields/cmp.rs @@ -1,103 +1,12 @@ -use crate::{ - boolean::Boolean, - fields::fp::FpGadget, - prelude::*, - ToBitsGadget, -}; +use std::cmp::Ordering; use algebra::PrimeField; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; -use core::cmp::Ordering; +use crate::{boolean::Boolean, bits::ToBitsGadget, eq::EqGadget}; +use crate::cmp::ComparisonGadget; +use crate::fields::{fp::FpGadget, FieldGadget}; +// implement functions for FpGadget that are useful to implement the ComparisonGadget impl FpGadget { - /// This function enforces the ordering between `self` and `other`. The - /// constraint system will not be satisfied otherwise. If `self` should - /// also be checked for equality, e.g. `self <= other` instead of `self < - /// other`, set `should_also_check_quality` to `true`. This variant - /// verifies `self` and `other` are `<= (p-1)/2`. - pub fn enforce_cmp>( - &self, - mut cs: CS, - other: &FpGadget, - ordering: Ordering, - should_also_check_equality: bool, - ) -> Result<(), SynthesisError> { - let (left, right) = self.process_cmp_inputs(cs.ns(|| "process cmp inputs"), other, ordering, should_also_check_equality)?; - left.enforce_smaller_than(cs.ns(|| "enforce smaller"), &right) - } - - /// This function enforces the ordering between `self` and `other`. The - /// constraint system will not be satisfied otherwise. If `self` should - /// also be checked for equality, e.g. `self <= other` instead of `self < - /// other`, set `should_also_check_quality` to `true`. This variant - /// assumes `self` and `other` are `<= (p-1)/2` and does not generate - /// constraints to verify that. - pub fn enforce_cmp_unchecked>( - &self, - mut cs: CS, - other: &FpGadget, - ordering: Ordering, - should_also_check_equality: bool, - ) -> Result<(), SynthesisError> { - let (left, right) = self.process_cmp_inputs(cs.ns(|| "process cmp inputs"), other, ordering, should_also_check_equality)?; - left.enforce_smaller_than_unchecked(cs.ns(|| "enforce smaller"), &right) - } - - /// This function checks the ordering between `self` and `other`. It outputs - /// self `Boolean` that contains the result - `1` if true, `0` - /// otherwise. The constraint system will be satisfied in any case. If - /// `self` should also be checked for equality, e.g. `self <= other` - /// instead of `self < other`, set `should_also_check_quality` to - /// `true`. This variant verifies `self` and `other` are `<= (p-1)/2`. - pub fn is_cmp>( - &self, - mut cs: CS, - other: &FpGadget, - ordering: Ordering, - should_also_check_equality: bool, - ) -> Result { - let (left, right) = self.process_cmp_inputs(cs.ns(|| "process cmp inputs"), other, ordering, should_also_check_equality)?; - left.is_smaller_than(cs.ns(|| "is smaller"), &right) - } - - /// This function checks the ordering between `self` and `other`. It outputs - /// a `Boolean` that contains the result - `1` if true, `0` otherwise. - /// The constraint system will be satisfied in any case. If `self` - /// should also be checked for equality, e.g. `self <= other` instead of - /// `self < other`, set `should_also_check_quality` to `true`. This - /// variant assumes `self` and `other` are `<= (p-1)/2` and does not - /// generate constraints to verify that. - pub fn is_cmp_unchecked>( - &self, - mut cs: CS, - other: &FpGadget, - ordering: Ordering, - should_also_check_equality: bool, - ) -> Result { - let (left, right) = self.process_cmp_inputs(cs.ns(|| "process cmp inputs"), other, ordering, should_also_check_equality)?; - left.is_smaller_than_unchecked(cs.ns(|| "is smaller"), &right) - } - - fn process_cmp_inputs>( - &self, - mut cs: CS, - other: &Self, - ordering: Ordering, - should_also_check_equality: bool, - ) -> Result<(Self, Self), SynthesisError> { - let (left, right) = match ordering { - Ordering::Less => (self, other), - Ordering::Greater => (other, self), - Ordering::Equal => return Err(SynthesisError::Unsatisfiable), - }; - let one = FpGadget::::from_value(cs.ns(|| "from value"), &F::one()); - let right_for_check = if should_also_check_equality { - right.add(cs.ns(|| "add"),&one)? - } else { - right.clone() - }; - - Ok((left.clone(), right_for_check)) - } /// Helper function to enforce that `self <= (p-1)/2`. pub fn enforce_smaller_or_equal_than_mod_minus_one_div_two>( @@ -109,27 +18,19 @@ impl FpGadget { let bits_be = self.to_bits(cs.ns(|| "to bits"))?; let bits_le = bits_be.into_iter().rev().collect::>(); let _ = Boolean::enforce_smaller_or_equal_than_le( - cs.ns(|| "enforce smaller or equal"), + cs.ns(|| "enforce smaller or equal"), &bits_le, &F::modulus_minus_one_div_two(), )?; Ok(()) } - /// Helper function to check `self < other` and output a result bit. This - /// function verifies `self` and `other` are `<= (p-1)/2`. - fn is_smaller_than>(&self, mut cs: CS, other: &FpGadget) -> Result { - self.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "self smaller or equal mod"))?; - other.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "other smaller or equal mod"))?; - self.is_smaller_than_unchecked(cs.ns(|| "is smaller unchecked"), other) - } - /// Helper function to check `self < other` and output a result bit. This /// function assumes `self` and `other` are `<= (p-1)/2` and does not /// generate constraints to verify that. // Note that `len((p-1)/2) = len(p) - 1 = CAPACITY`. - fn is_smaller_than_unchecked>(&self, mut cs: CS, other: &FpGadget) -> Result { - // Since `a = self` and `b = other` are from `[0, (p-1)/2]`, we know that + pub fn is_smaller_than_unchecked>(&self, mut cs: CS, other: &Self) -> Result { + // Since `a = self` and `b = other` are from `[0, (p-1)/2]`, we know that // `` // self - other // `` @@ -139,46 +40,81 @@ impl FpGadget { // `` // 0 <= 2 * (self - other) <= (p-1), // `` - // and the least significant bit of `2 * (self - other) mod p` is zero. - // Otherwise, if `self < other`, then + // and the least significant bit of `2 * (self - other) mod p` is zero. + // Otherwise, if `self < other`, then // `` // 2 * (self - other) mod p = 2 * (self - other) + p // `` // which is a positive odd number, having least significant bit equal to `1`. // To assure the right decision we need to return the least significant - // bit of the NATIVE bit representation of `2 * (self - other)`. Hence we + // bit of the NATIVE bit representation of `2 * (self - other)`. Hence we // need to use `to_bits_strict()`. Ok(self.sub(cs.ns(|| "sub"), other)? .double(cs.ns(|| "double"))? .to_bits_strict(cs.ns(|| "to bits"))? // returns big endian - .into_iter().rev().collect::>() - .first() + .into_iter().rev().collect::>() + .first() .unwrap() .clone()) } - /// Helper function to enforce `self < other`. This function verifies `self` - /// and `other` are `<= (p-1)/2`. - fn enforce_smaller_than>(&self, mut cs: CS, other: &FpGadget) -> Result<(), SynthesisError> { + pub fn enforce_smaller_than_unchecked>(&self, mut cs: CS, other: &Self) -> Result<(), SynthesisError> { + let is_smaller = self.is_smaller_than_unchecked(cs.ns(|| "is smaller unchecked"), other)?; + is_smaller.enforce_equal(cs.ns(|| "enforce smaller than"), &Boolean::constant(true)) + } + + /// Variant of `enforce_cmp` that assumes `self` and `other` are `<= (p-1)/2` and + /// does not generate constraints to verify that. + fn enforce_cmp_unchecked>( + &self, + mut cs: CS, + other: &Self, + ordering: Ordering, + should_also_check_equality: bool, + ) -> Result<(), SynthesisError> { + let is_cmp = self.is_cmp_unchecked(cs.ns(|| "is cmp unchecked"), other, ordering, should_also_check_equality)?; + is_cmp.enforce_equal(cs.ns(|| "enforce cmp"), &Boolean::constant(true)) + } + + /// Variant of `is_cmp` that assumes `self` and `other` are `<= (p-1)/2` and does not generate + /// constraints to verify that. + // It differs from the default implementation of `is_cmp` only by + // calling `is_smaller_than_unchecked` in place of `is_smaller_than` for efficiency given that + // there is no need to verify that `self` and `other` are `<= (p-1)/2` + fn is_cmp_unchecked>( + &self, + mut cs: CS, + other: &Self, + ordering: Ordering, + should_also_check_equality: bool, + ) -> Result { + let (left, right) = match (ordering, should_also_check_equality) { + (Ordering::Less, false) | (Ordering::Greater, true) => (self, other), + (Ordering::Greater, false) | (Ordering::Less, true) => (other, self), + (Ordering::Equal, _) => return self.is_eq(cs, other), + }; + + let is_smaller = left.is_smaller_than_unchecked(cs.ns(|| "is smaller"), right)?; + + if should_also_check_equality { + return Boolean::xor(cs.ns(|| "negating cmp outcome"), &is_smaller, &Boolean::constant(true)) + } + + Ok(is_smaller) + } +} + +impl ComparisonGadget for FpGadget { + fn is_smaller_than>(&self, mut cs: CS, other: &Self) -> Result { self.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "self smaller or equal mod"))?; other.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "other smaller or equal mod"))?; - self.enforce_smaller_than_unchecked(cs.ns(|| "enforce smaller unchecked"), other) + self.is_smaller_than_unchecked(cs.ns(|| "is smaller unchecked"), other) } - /// Helper function to enforce `self < other`. This function assumes `self` - /// and `other` are `<= (p-1)/2` and does not generate constraints to - /// verify that. - fn enforce_smaller_than_unchecked>(&self, mut cs: CS, other: &FpGadget) -> Result<(), SynthesisError> { - let is_smaller_than = self.is_smaller_than_unchecked(cs.ns(|| "is smaller"), other)?; - //println!("{} Is smaller then {}: {}", self.get_value().unwrap(), other.get_value().unwrap(), is_smaller_than.get_value().unwrap()); - let lc_one = CS::one(); - cs.enforce( - || "Enforce smaller then", - |lc| lc + is_smaller_than.lc(CS::one(), F::one()), - |lc| lc + lc_one.clone(), - |lc| lc + lc_one - ); - Ok(()) + fn enforce_smaller_than>(&self, mut cs: CS, other: &Self) -> Result<(), SynthesisError> { + self.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "self smaller or equal mod"))?; + other.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "other smaller or equal mod"))?; + self.enforce_smaller_than_unchecked(cs.ns(|| "enforce smaller than unchecked"), other) } } @@ -186,96 +122,121 @@ impl FpGadget { mod test { use std::cmp::Ordering; use rand::{Rng, thread_rng}; - - use r1cs_core::{ConstraintSystemAbstract, ConstraintSystem, SynthesisMode, ConstraintSystemDebugger}; - use crate::{algebra::{UniformRand, PrimeField, - fields::bls12_381::Fr, + use r1cs_core::{ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode}; + use crate::{algebra::{UniformRand, PrimeField, + fields::tweedle::Fr, Group, }, fields::fp::FpGadget}; - use crate::alloc::AllocGadget; - - #[test] - fn test_cmp() { - let mut rng = &mut thread_rng(); - fn rand_in_range(rng: &mut R) -> Fr { - let pminusonedivtwo: Fr = Fr::modulus_minus_one_div_two().into(); - let mut r; - loop { - r = Fr::rand(rng); - if r <= pminusonedivtwo { - break; + use crate::{alloc::{AllocGadget, ConstantGadget}, cmp::ComparisonGadget}; + + macro_rules! test_cmp_function { + ($cmp_func: tt) => { + let mut rng = &mut thread_rng(); + fn rand_in_range(rng: &mut R) -> Fr { + let pminusonedivtwo: Fr = Fr::modulus_minus_one_div_two().into(); + let mut r; + loop { + r = Fr::rand(rng); + if r <= pminusonedivtwo { + break; + } } + r } - r - } - for i in 0..10 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let a = rand_in_range(&mut rng); - let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); - let b = rand_in_range(&mut rng); - let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); + for i in 0..10 { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let a = rand_in_range(&mut rng); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + let b = rand_in_range(&mut rng); + let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); + + match a.cmp(&b) { + Ordering::Less => { + a_var.$cmp_func(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less equal"), &b_var, Ordering::Less, true).unwrap(); + } + Ordering::Greater => { + a_var.$cmp_func(cs.ns(|| "enforce greater"), &b_var, Ordering::Greater, false).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce greater equal"), &b_var, Ordering::Greater, true).unwrap(); + } + _ => {} + } - match a.cmp(&b) { - Ordering::Less => { - a_var.enforce_cmp(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false).unwrap(); - a_var.enforce_cmp(cs.ns(|| "enforce less equal"), &b_var, Ordering::Less, true).unwrap(); + if i == 0 { + println!("number of constraints: {}", cs.num_constraints()); } - Ordering::Greater => { - a_var.enforce_cmp(cs.ns(|| "enforce greater"), &b_var, Ordering::Greater, false).unwrap(); - a_var.enforce_cmp(cs.ns(|| "enforce greater equal"), &b_var, Ordering::Greater, true).unwrap(); + if !cs.is_satisfied(){ + println!("{:?}", cs.which_is_unsatisfied()); } - _ => {} + assert!(cs.is_satisfied()); } - - if i == 0 { - println!("number of constraints: {}", cs.num_constraints()); - } - if !cs.is_satisfied(){ - println!("{:?}", cs.which_is_unsatisfied()); + println!("Finished with satisfaction tests"); + + for _i in 0..10 { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a = rand_in_range(&mut rng); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + let b = rand_in_range(&mut rng); + let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); + + match b.cmp(&a) { + Ordering::Less => { + a_var.$cmp_func(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less equal"),&b_var, Ordering::Less, true).unwrap(); + } + Ordering::Greater => { + a_var.$cmp_func(cs.ns(|| "enforce greater"),&b_var, Ordering::Greater, false).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce greater equal"),&b_var, Ordering::Greater, true).unwrap(); + } + _ => {} + } + assert!(!cs.is_satisfied()); } - assert!(cs.is_satisfied()); - } - println!("Finished with satisfaction tests"); - for _i in 0..10 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let a = rand_in_range(&mut rng); - let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); - let b = rand_in_range(&mut rng); - let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); + for _i in 0..10 { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a = rand_in_range(&mut rng); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less"),&a_var, Ordering::Less, false).unwrap(); - match b.cmp(&a) { - Ordering::Less => { - a_var.enforce_cmp(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false).unwrap(); - a_var.enforce_cmp(cs.ns(|| "enforce less equal"),&b_var, Ordering::Less, true).unwrap(); - } - Ordering::Greater => { - a_var.enforce_cmp(cs.ns(|| "enforce greater"),&b_var, Ordering::Greater, false).unwrap(); - a_var.enforce_cmp(cs.ns(|| "enforce greater equal"),&b_var, Ordering::Greater, true).unwrap(); + assert!(!cs.is_satisfied()); + } + + for _i in 0..10 { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a = rand_in_range(&mut rng); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less"),&a_var, Ordering::Less, true).unwrap(); + if !cs.is_satisfied(){ + println!("{:?}", cs.which_is_unsatisfied()); } - _ => {} + assert!(cs.is_satisfied()); } - assert!(!cs.is_satisfied()); - } - for _i in 0..10 { + // test corner case when operands are extreme values of range [0, (p-1)/2] of + // admissible values let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let a = rand_in_range(&mut rng); - let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); - a_var.enforce_cmp(cs.ns(|| "enforce less"),&a_var, Ordering::Less, false).unwrap(); + let max_val: Fr = Fr::modulus_minus_one_div_two().into(); + let max_var = FpGadget::::alloc(&mut cs.ns(|| "generate_max"), || Ok(max_val)).unwrap(); + let zero_var = FpGadget::::from_value(cs.ns(|| "alloc zero"), &Fr::zero()); + zero_var.$cmp_func(cs.ns(|| "enforce 0 <= (p-1) div 2"), &max_var, Ordering::Less, true).unwrap(); + + assert!(cs.is_satisfied()); + // test when one of the operands is beyond (p-1)/2 + let out_range_var = FpGadget::::alloc(&mut cs.ns(|| "generate_out_range"), || Ok(max_val.double())).unwrap(); + zero_var.$cmp_func(cs.ns(|| "enforce 0 <= p-1"), &out_range_var, Ordering::Less, true).unwrap(); assert!(!cs.is_satisfied()); } + } - for _i in 0..10 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let a = rand_in_range(&mut rng); - let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); - a_var.enforce_cmp(cs.ns(|| "enforce less"),&a_var, Ordering::Less, true).unwrap(); - if !cs.is_satisfied(){ - println!("{:?}", cs.which_is_unsatisfied()); - } - assert!(cs.is_satisfied()); - } + #[test] + fn test_cmp() { + test_cmp_function!(enforce_cmp); + } + + #[test] + fn test_cmp_unchecked() { + test_cmp_function!(enforce_cmp_unchecked); } } \ No newline at end of file diff --git a/r1cs/gadgets/std/src/lib.rs b/r1cs/gadgets/std/src/lib.rs index 1b2685065..e35c5645b 100644 --- a/r1cs/gadgets/std/src/lib.rs +++ b/r1cs/gadgets/std/src/lib.rs @@ -64,6 +64,7 @@ pub mod instantiated; pub use instantiated::*; pub mod alloc; +pub mod cmp; pub mod eq; pub mod select; pub mod to_field_gadget_vec; From 64756f4c10388a69443711bf4fefc5692a414f8c Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 14 Jan 2022 09:50:04 +0100 Subject: [PATCH 03/18] functions to compare arbitrary field elements --- r1cs/gadgets/std/src/fields/cmp.rs | 234 +++++++++++++++++++++++++++-- 1 file changed, 221 insertions(+), 13 deletions(-) diff --git a/r1cs/gadgets/std/src/fields/cmp.rs b/r1cs/gadgets/std/src/fields/cmp.rs index aae76128c..0073f8403 100644 --- a/r1cs/gadgets/std/src/fields/cmp.rs +++ b/r1cs/gadgets/std/src/fields/cmp.rs @@ -1,13 +1,78 @@ use std::cmp::Ordering; use algebra::PrimeField; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; -use crate::{boolean::Boolean, bits::ToBitsGadget, eq::EqGadget}; +use crate::{boolean::Boolean, bits::{ToBitsGadget, FromBitsGadget}, eq::EqGadget, select::CondSelectGadget}; use crate::cmp::ComparisonGadget; use crate::fields::{fp::FpGadget, FieldGadget}; // implement functions for FpGadget that are useful to implement the ComparisonGadget impl FpGadget { + /// Helper function that allows to compare 2 slices of 2 bits, outputting 2 Booleans: + /// the former (resp. the latter) one is true iff the big-endian integer represented by the + /// first slice is smaller (resp. is equal) than the big-endian integer represented by the second slice + fn compare_msbs>(mut cs: CS, first: &[Boolean], second: &[Boolean]) + -> Result<(Boolean, Boolean), SynthesisError> { + assert_eq!(first.len(), 2); + assert_eq!(second.len(), 2); + + let a = first[0]; + let b = first[1]; + let c = second[0]; + let d = second[1]; + + // is_less corresponds to the Boolean function: !a*(c+!b*d)+(!b*c*d), + // which is true iff first < second, where + is Boolean OR and * is Boolean AND + let bd = Boolean::and(cs.ns(|| "!bd"), &b.not(), &d)?; + let first_tmp = Boolean::or(cs.ns(|| "!a + !bd"), &a.not(), &bd)?; + let second_tmp = Boolean::and(cs.ns(|| "!a!bd"), &a.not(), &bd)?; + let is_less = Boolean::conditionally_select(cs.ns(|| "is less"), &c, &first_tmp, &second_tmp)?; + + // is_eq corresponds to the Boolean function: !((a xor c) + (b xor d)), + // which is true iff first == second + let first_tmp = Boolean::xor(cs.ns(|| "a xor c"), &a, &c)?; + let second_tmp = Boolean::xor(cs.ns(|| "b xor d"), &b, &d)?; + let is_eq = Boolean::or(cs.ns(|| "is eq"), &first_tmp, &second_tmp)?.not(); + + Ok((is_less, is_eq)) + } + + /// Output a Boolean that is true iff `self` < `other`. Here `self` and `other` + /// can be arbitrary field elements, they are not constrained to be at most (p-1)/2 + pub fn is_smaller_than_unrestricted>( + &self, + mut cs: CS, + other: &Self, + ) -> Result { + let self_bits = self.to_bits_strict(cs.ns(|| "first op to bits"))?; + let other_bits = other.to_bits_strict(cs.ns(|| "second op to bits"))?; + // extract the least significant MODULUS_BITS-2 bits and convert them to a field element, + // which is necessarily lower than (p-1)/2 + let fp_for_self_lsbs = FpGadget::::from_bits(cs.ns(|| "pack second op MSBs"), &self_bits[2..])?; + let fp_for_other_lsbs = FpGadget::::from_bits(cs.ns(|| "pack second op LSBs"), &other_bits[2..])?; + + // since the field elements are lower than (p-1)/2, we can compare it with the efficient approach + let is_less_lsbs = fp_for_self_lsbs.is_smaller_than_unchecked(cs.ns(|| "compare LSBs"), &fp_for_other_lsbs)?; + + + // obtain two Booleans: the former (resp. the latter) one is true iff the integer + // represented by the 2 MSBs of self is smaller (resp. is equal) than the integer + // represented by the 2 MSBs of other + let (is_less_msbs, is_eq_msbs) = Self::compare_msbs(cs.ns(|| "compare MSBs"), &self_bits[..2], &other_bits[..2])?; + + // Equivalent to is_less_msbs OR is_eq_msbs AND is_less_msbs, given that is_less_msbs and + // is_eq_msbs cannot be true at the same time + Boolean::conditionally_select(cs, &is_eq_msbs, &is_less_lsbs, &is_less_msbs) + } + + /// Enforce than `self` < `other`. Here `self` and `other` they are arbitrary field elements, + /// they are not constrained to be at most (p-1)/2 + pub fn enforce_smaller_than_unrestricted>(&self, mut cs: CS, other: &Self) -> Result<(), SynthesisError> { + let is_smaller = self.is_smaller_than_unrestricted(cs.ns(|| "is smaller unchecked"), other)?; + is_smaller.enforce_equal(cs.ns(|| "enforce smaller than"), &Boolean::constant(true)) + } + + /// Helper function to enforce that `self <= (p-1)/2`. pub fn enforce_smaller_or_equal_than_mod_minus_one_div_two>( &self, @@ -97,7 +162,7 @@ impl FpGadget { let is_smaller = left.is_smaller_than_unchecked(cs.ns(|| "is smaller"), right)?; if should_also_check_equality { - return Boolean::xor(cs.ns(|| "negating cmp outcome"), &is_smaller, &Boolean::constant(true)) + return Ok(is_smaller.not()); } Ok(is_smaller) @@ -128,20 +193,21 @@ mod test { }, fields::fp::FpGadget}; use crate::{alloc::{AllocGadget, ConstantGadget}, cmp::ComparisonGadget}; + fn rand_in_range(rng: &mut R) -> Fr { + let pminusonedivtwo: Fr = Fr::modulus_minus_one_div_two().into(); + let mut r; + loop { + r = Fr::rand(rng); + if r <= pminusonedivtwo { + break; + } + } + r + } + macro_rules! test_cmp_function { ($cmp_func: tt) => { let mut rng = &mut thread_rng(); - fn rand_in_range(rng: &mut R) -> Fr { - let pminusonedivtwo: Fr = Fr::modulus_minus_one_div_two().into(); - let mut r; - loop { - r = Fr::rand(rng); - if r <= pminusonedivtwo { - break; - } - } - r - } for i in 0..10 { let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); @@ -239,4 +305,146 @@ mod test { fn test_cmp_unchecked() { test_cmp_function!(enforce_cmp_unchecked); } + + macro_rules! test_smaller_than_func { + ($is_smaller_func: tt, $enforce_smaller_func: tt) => { + let mut rng = &mut thread_rng(); + for _ in 0..10 { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let a = rand_in_range(&mut rng); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + let b = rand_in_range(&mut rng); + let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); + + let is_smaller = a_var.$is_smaller_func(cs.ns(|| "is smaller"), &b_var).unwrap(); + + a_var.$enforce_smaller_func(cs.ns(|| "enforce smaller"), &b_var).unwrap(); + + match a.cmp(&b) { + Ordering::Less => { + assert!(is_smaller.get_value().unwrap()); + assert!(cs.is_satisfied()); + } + Ordering::Greater | Ordering::Equal => { + assert!(!is_smaller.get_value().unwrap()); + assert!(!cs.is_satisfied()) + } + } + } + + for _ in 0..10 { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a = rand_in_range(&mut rng); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + let is_smaller = a_var.$is_smaller_func(cs.ns(|| "is smaller"),&a_var).unwrap(); + // check that a.is_smaller(a) == false + assert!(!is_smaller.get_value().unwrap()); + a_var.$enforce_smaller_func(cs.ns(|| "enforce is smaller"), &a_var).unwrap(); + assert!(!cs.is_satisfied()); + } + + // test corner case when operands are extreme values of range [0, (p-1)/2] of + // admissible values + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let max_val: Fr = Fr::modulus_minus_one_div_two().into(); + let max_var = FpGadget::::alloc(&mut cs.ns(|| "generate_max"), || Ok(max_val)).unwrap(); + let zero_var = FpGadget::::from_value(cs.ns(|| "alloc zero"), &Fr::zero()); + let is_smaller = zero_var.$is_smaller_func(cs.ns(|| "0 is smaller than (p-1) div 2"), &max_var).unwrap(); + assert!(is_smaller.get_value().unwrap()); + zero_var.$enforce_smaller_func(cs.ns(|| "enforce 0 <= (p-1) div 2"), &max_var).unwrap(); + assert!(cs.is_satisfied()); + + // test when one of the operands is beyond (p-1)/2 + let out_range_var = FpGadget::::alloc(&mut cs.ns(|| "generate_out_range"), || Ok(max_val.double())).unwrap(); + zero_var.$enforce_smaller_func(cs.ns(|| "enforce 0 <= p-1"), &out_range_var).unwrap(); + assert!(!cs.is_satisfied()); + } + } + + #[test] + fn test_smaller_than() { + test_smaller_than_func!(is_smaller_than, enforce_smaller_than); + } + + #[test] + fn test_smaller_than_unchecked() { + test_smaller_than_func!(is_smaller_than_unchecked, enforce_smaller_than_unchecked); + } + + macro_rules! test_smaller_than_unrestricted { + ($rand_func: tt) => { + let mut rng = &mut thread_rng(); + + for _ in 0..10 { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let a = $rand_func(&mut rng); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + let b = $rand_func(&mut rng); + let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); + let is_smaller = a_var.is_smaller_than_unrestricted(cs.ns(|| "is smaller"), &b_var).unwrap(); + a_var.enforce_smaller_than_unrestricted(cs.ns(|| "enforce is smaller"), &b_var).unwrap(); + + match a.cmp(&b) { + Ordering::Less => { + assert!(is_smaller.get_value().unwrap()); + assert!(cs.is_satisfied()); + } + Ordering::Greater | Ordering::Equal => { + assert!(!is_smaller.get_value().unwrap()); + assert!(!cs.is_satisfied()) + } + } + } + + for _ in 0..10 { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a = $rand_func(&mut rng); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + let is_smaller = a_var.is_smaller_than_unrestricted(cs.ns(|| "is smaller"),&a_var).unwrap(); + // check that a.is_smaller(a) == false + assert!(!is_smaller.get_value().unwrap()); + a_var.enforce_smaller_than_unrestricted(cs.ns(|| "enforce is smaller"), &a_var).unwrap(); + assert!(!cs.is_satisfied()); + } + + // test corner case where the operands are extreme values of range [0, p-1] of + // admissible values + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let max_val: Fr = Fr::modulus_minus_one_div_two().into(); + let max_val = max_val.double(); + let max_var = FpGadget::::alloc(&mut cs.ns(|| "generate_max"), || Ok(max_val)).unwrap(); + let zero_var = FpGadget::::from_value(cs.ns(|| "alloc zero"), &Fr::zero()); + let is_smaller = zero_var.is_smaller_than_unrestricted(cs.ns(|| "0 is smaller than p-1"), &max_var).unwrap(); + assert!(is_smaller.get_value().unwrap()); + zero_var.enforce_smaller_than_unrestricted(cs.ns(|| "enforce 0 <= (p-1) div 2"), &max_var).unwrap(); + assert!(cs.is_satisfied()); + } + } + + #[test] + fn test_smaller_than_unrestricted() { + fn rand_higher(rng: &mut R) -> Fr { + let pminusonedivtwo: Fr = Fr::modulus_minus_one_div_two().into(); + let mut r; + loop { + r = Fr::rand(rng); + if r > pminusonedivtwo { + break; + } + } + r + } + + fn field_uniform_rand(rng: &mut R) -> Fr { + Fr::rand(rng) + } + // test with random field elements >(p-1)/2 + test_smaller_than_unrestricted!(rand_higher); + // test with random field elements <=(p-1)/2 + test_smaller_than_unrestricted!(rand_in_range); + // test with arbitrary field elements + test_smaller_than_unrestricted!(field_uniform_rand); + } } \ No newline at end of file From f78a28a0b407e7ff10adf159741e818647d1e36a Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Mon, 17 Jan 2022 16:56:34 +0100 Subject: [PATCH 04/18] Implement ComparisonGadget and sub for UIntGadet --- r1cs/gadgets/std/src/bits/boolean.rs | 50 +++ r1cs/gadgets/std/src/bits/macros.rs | 488 ++++++++++++++++++++++++++- r1cs/gadgets/std/src/bits/mod.rs | 42 +++ 3 files changed, 570 insertions(+), 10 deletions(-) diff --git a/r1cs/gadgets/std/src/bits/boolean.rs b/r1cs/gadgets/std/src/bits/boolean.rs index d02836124..ce6fa4291 100644 --- a/r1cs/gadgets/std/src/bits/boolean.rs +++ b/r1cs/gadgets/std/src/bits/boolean.rs @@ -631,6 +631,56 @@ impl Boolean { } } + /// Enforce that at least one operand is true, given that bits.len() is less than the size of + /// the field + pub fn enforce_or(mut cs: CS, bits: &[Self]) -> Result<(), SynthesisError> + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + { + // this is done with a single constraint as follows: + // - Compute a linear combination sum_lc which is the sum of all the bits + // - enforce that the sum != 0 with a single constraint: sum*v = 1, where v can only be + // chosen as the inverse of sum (which exists iff sum != 0) + let mut sum_lc = LinearCombination::zero(); + let mut sum_of_bits = Some(ConstraintF::zero()); + let mut all_constants = true; + for bit in bits { + sum_lc = sum_lc + &bit.lc(CS::one(), ConstraintF::one()); + + all_constants &= bit.is_constant(); + + sum_of_bits = match bit.get_value() { + Some(bitval) => sum_of_bits.as_mut().map(|sum| { + if bitval { + *sum += ConstraintF::one(); + }; + *sum + }), + None => None, + } + } + + if all_constants { + if sum_of_bits.unwrap().is_zero() { + return Err(SynthesisError::Unsatisfiable); + } + return Ok(()); + } + + let inv = sum_of_bits.map(|sum| + match sum.inverse() { + Some(val) => val, + None => ConstraintF::one(), // if sum == 0, then inverse can be any value, the constraint should never be verified + }); + + let inv_var = FpGadget::::alloc(cs.ns(|| "alloc inv"), || inv.ok_or(SynthesisError::AssignmentMissing))?; + + cs.enforce(|| "enforce self != other", |_| sum_lc, |lc| &inv_var.get_variable() + lc, |_| (ConstraintF::one(), CS::one()).into()); + + Ok(()) + } + /// Asserts that this bit_gadget representation is "in /// the field" when interpreted in big endian. pub fn enforce_in_field( diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index 402c68244..1401c2cb7 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -2,14 +2,14 @@ macro_rules! impl_uint_gadget { ($type_name: ident, $bit_size: expr, $native_type: ident, $mod_name: ident) => { pub mod $mod_name { - use crate::{boolean::{Boolean, AllocatedBit}, fields::{fp::FpGadget, FieldGadget}, eq::{EqGadget, MultiEq}, ToBitsGadget, FromBitsGadget, ToBytesGadget, RotateUInt, UIntGadget, select::CondSelectGadget, bits::UInt8, Assignment}; + use crate::{boolean::{Boolean, AllocatedBit}, fields::{fp::FpGadget, FieldGadget}, eq::{EqGadget, MultiEq}, ToBitsGadget, FromBitsGadget, ToBytesGadget, RotateUInt, UIntGadget, select::CondSelectGadget, bits::UInt8, Assignment, cmp::ComparisonGadget}; use r1cs_core::{ConstraintSystemAbstract, SynthesisError, LinearCombination}; use crate::alloc::{AllocGadget, ConstantGadget}; use algebra::{fields::{PrimeField, FpParameters}, ToConstraintField}; - use std::{borrow::Borrow, ops::{Shl, Shr}, convert::TryInto}; + use std::{borrow::Borrow, ops::{Shl, Shr}, convert::TryInto, cmp::Ordering}; //ToDo: remove public use of fields @@ -133,6 +133,22 @@ macro_rules! impl_uint_gadget { result } + /// enfroces that self >= other. This function is provided as it is much efficient + /// in terms of constraints with respect to the default implementation of + /// ComparisonGadget, which relies on the smaller_than functions + pub fn enforce_greater_or_equal_than(&self, mut cs: CS, other: &Self) + -> Result<(), SynthesisError> + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + { + // self >= other iff self - other does not underflow + let mut multi_eq = MultiEq::new(&mut cs); + let _ = self.sub_noborrow(&mut multi_eq, other)?; + + Ok(()) + } + // Return little endian representation of self. Will be removed when to_bits_le and // from_bits_le will be merged. pub fn into_bits_le(&self) -> Vec { @@ -205,7 +221,6 @@ macro_rules! impl_uint_gadget { self.bits.conditional_enforce_equal(cs, &other.bits, should_enforce) } - //ToDO: check if the default implementation is better than the probably buggy one for [Boolean] fn conditional_enforce_not_equal>( &self, cs: CS, @@ -842,6 +857,239 @@ macro_rules! impl_uint_gadget { $type_name::from_bits(cs.ns(|| "packing result"), &result_bits[..]) } + fn sub(&self, mut cs: M, other: &Self) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> + { + // this assertion checks that the field is big enough: that is, the field must + // be able to represent integers up to 2^($bit_size+1) + assert!(ConstraintF::Params::MODULUS_BITS - 1 > $bit_size); + + // Overall idea: allocate $bit_size+1 bits representing a field element diff + // and enforce that diff == self - other + 2^$bit_size. + // The addition of 2^$bit_size is useful in case other >= self to avoid + // field underflows, which would require to allocate as many bits as the field + // modulus. Only the first $bit_size bits are returned as the result of the + // subtraction, hence the addition of 2^$bit_size has no impact on the final + // result + + // max_value is a field element equal to 2^$bit_size + let max_value = ConstraintF::from($native_type::MAX) + ConstraintF::one(); + let mut lc = (max_value, CS::one()).into(); + let mut coeff = ConstraintF::one(); + let mut all_constants = true; + for (self_bit, other_bit) in self.bits.iter().zip(other.bits.iter()) { + lc = lc + &self_bit.lc(CS::one(), coeff); + lc = lc - &other_bit.lc(CS::one(), coeff); + + all_constants &= self_bit.is_constant() && other_bit.is_constant(); + + coeff.double_in_place(); + } + + let (diff, diff_in_field) = match (self.value, other.value) { + (Some(val1), Some(val2)) => { + let (diff, _) = val1.overflowing_sub(val2); // don't care if there is an underflow + let fe1 = ConstraintF::from(val1); + let fe2 = ConstraintF::from(val2); + (Some(diff), Some(fe1 - fe2 + max_value)) + }, + _ => (None, None), + }; + + if all_constants && diff.is_some() { + return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &diff.unwrap())); + } + + let diff_bits = match diff_in_field { + Some(diff) => diff.write_bits().iter().rev().map(|b| Some(*b)).collect::>(), + None => vec![None; $bit_size+1], + }; + + let mut result_bits = Vec::with_capacity($bit_size); + let mut result_lc = LinearCombination::zero(); + let mut coeff = ConstraintF::one(); + for i in 0..$bit_size+1 { + let diff_bit = Boolean::alloc(cs.ns(|| format!("alloc diff bit {}", i)), || diff_bits[i].ok_or(SynthesisError::AssignmentMissing))?; + + result_lc = result_lc + &diff_bit.lc(CS::one(), coeff); + + coeff.double_in_place(); + + if i < $bit_size { + result_bits.push(diff_bit); + } + } + + cs.get_root().enforce_equal($bit_size+1, &lc, &result_lc); + + Ok(Self{ + bits: result_bits, + value: diff, + }) + + } + + fn sub_noborrow(&self, mut cs: M, other: &Self) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> + { + // this assertion checks that the field is big enough: subtraction of any 2 + // values in $native_type must be a field element that cannot be represented as + // $native_type + assert!(ConstraintF::Params::MODULUS_BITS - 1 > $bit_size); + + let mut lc = LinearCombination::zero(); + let mut coeff = ConstraintF::one(); + let mut all_constants = true; + for (self_bit, other_bit) in self.bits.iter().zip(other.bits.iter()) { + lc = lc + &self_bit.lc(CS::one(), coeff); + lc = lc - &other_bit.lc(CS::one(), coeff); + + all_constants &= self_bit.is_constant() && other_bit.is_constant(); + + coeff.double_in_place(); + } + + let (diff, is_underflowing) = match (self.value, other.value) { + (Some(val1), Some(val2)) => { + let (diff, underflow) = val1.overflowing_sub(val2); + (Some(diff), underflow) + }, + _ => (None, false), + }; + + + if all_constants && diff.is_some() { + if is_underflowing { + // in this case self < other + return Err(SynthesisError::Unsatisfiable) + } else { + return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &diff.unwrap())) + } + } + + let diff_var = Self::alloc(cs.ns(|| "alloc diff"), || diff.ok_or(SynthesisError::AssignmentMissing))?; + + let mut diff_lc = LinearCombination::zero(); + let mut coeff = ConstraintF::one(); + for diff_bit in diff_var.bits.iter() { + diff_lc = diff_lc + &diff_bit.lc(CS::one(), coeff); + + coeff.double_in_place(); + } + + cs.get_root().enforce_equal($bit_size, &lc, &diff_lc); + + Ok(diff_var) + } + } + + impl ComparisonGadget for $type_name { + fn is_smaller_than>(&self, mut cs: CS, other: &Self) + -> Result + { + // this assertion checks that the field is big enough: subtraction of any 2 + // values in $native_type must be a field element that cannot be represented as + // $native_type + assert!(ConstraintF::Params::MODULUS_BITS - 1 > $bit_size); + + let mut delta_lc = LinearCombination::zero(); + let mut coeff = ConstraintF::one(); + let mut all_constants = true; + for (self_bit, other_bit) in self.bits.iter().zip(other.bits.iter()) { + delta_lc = delta_lc + &self_bit.lc(CS::one(), coeff); + delta_lc = delta_lc - &other_bit.lc(CS::one(), coeff); + + all_constants &= self_bit.is_constant() && other_bit.is_constant(); + + coeff.double_in_place(); + } + + let mut is_underflowing = None; + // delta = self - other - diff in the field, where diff = self - other over the uint type + let mut delta = None; + let diff = match (self.get_value(), other.get_value()) { + (Some(value1), Some(value2)) => { + let (diff, underflow) = value1.overflowing_sub(value2); + is_underflowing = Some(underflow); + // compute self - other - diff over the field + let self_in_field = ConstraintF::from(value1); + let other_in_field = ConstraintF::from(value2); + let diff_in_field = ConstraintF::from(diff); + delta = Some(self_in_field - other_in_field - diff_in_field); + Some(diff) + }, + _ => None, + }; + + if all_constants && diff.is_some() { + return Ok(Boolean::constant(is_underflowing.unwrap())) + } + + let diff_var = Self::alloc(cs.ns(|| "alloc diff"), || diff.ok_or(SynthesisError::AssignmentMissing))?; + let mut coeff = ConstraintF::one(); + for diff_bit in diff_var.bits.iter() { + delta_lc = delta_lc - &diff_bit.lc(CS::one(), coeff); + coeff.double_in_place(); + } + // ToDo: It should not be necessary to allocate it as a Boolean gadget + let is_smaller = Boolean::alloc(cs.ns(|| "alloc result"), || is_underflowing.ok_or(SynthesisError::AssignmentMissing))?; + + let inv = delta.map(|delta| { + match delta.inverse() { + Some(inv) => inv, + None => ConstraintF::one(), // delta is 0, so we can set any value + } + }); + + let inv_var = FpGadget::::alloc(cs.ns(|| "alloc inv"), || inv.ok_or(SynthesisError::AssignmentMissing))?; + + // enforce constraints: + // (1 - is_smaller) * delta_lc = 0 enforces that is_smaller == 1 when delta != 0, i.e., when a < b + // inv * delta_lc = is_smaller enforces that is_smaller == 0 when delta == 0, i.e., when b >= a + cs.enforce(|| "enforce is smaller == true", |_| is_smaller.not().lc(CS::one(), ConstraintF::one()), |lc| lc + &delta_lc, |lc| lc); + cs.enforce(|| "enforce is smaller == false", |lc| &inv_var.get_variable() + lc, |lc| lc + &delta_lc, |_| is_smaller.lc(CS::one(), ConstraintF::one())); + + Ok(is_smaller) + } + + fn enforce_smaller_than> + (&self, mut cs: CS, other: &Self) -> Result<(), SynthesisError> { + // first enforce that self <= other, which holds iff other - self does not underflow + let diff = { + let mut multi_eq = MultiEq::new(&mut cs); + other.sub_noborrow(&mut multi_eq, self)? + }; + // then, enforce that other - self is non zero, which holds iff the difference + // has at least a non zero bit + Boolean::enforce_or(cs.ns(|| "enforce self != other"), &diff.bits) + } + + // override the default implementation to exploit the fact that enforcing constraint + // is cheaper than computing a Boolean gadget with the comparison outcome + fn enforce_cmp>( + &self, + mut cs: CS, + other: &Self, + ordering: Ordering, + should_also_check_equality: bool, + ) -> Result<(), SynthesisError> { + let (left, right) = match (ordering, should_also_check_equality) { + (Ordering::Less, false) | (Ordering::Greater, true) => (self, other), + (Ordering::Greater, false) | (Ordering::Less, true) => (other, self), + (Ordering::Equal, _) => return self.enforce_equal(cs, other), + }; + + if should_also_check_equality { + left.enforce_greater_or_equal_than(cs.ns(|| "enforce greater equal"), right) + } else { + left.enforce_smaller_than(cs.ns(|| "enforce smaller than"), right) + } + } + } @@ -855,9 +1103,9 @@ macro_rules! impl_uint_gadget { ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode, SynthesisError, }; - use std::ops::{Shl, Shr}; + use std::{ops::{Shl, Shr}, cmp::Ordering}; - use crate::{alloc::{AllocGadget, ConstantGadget}, eq::{EqGadget, MultiEq}, boolean::Boolean, ToBitsGadget, FromBitsGadget, ToBytesGadget, RotateUInt, UIntGadget, select::CondSelectGadget, bits::UInt8}; + use crate::{alloc::{AllocGadget, ConstantGadget}, eq::{EqGadget, MultiEq}, boolean::Boolean, ToBitsGadget, FromBitsGadget, ToBytesGadget, RotateUInt, UIntGadget, select::CondSelectGadget, bits::UInt8, cmp::ComparisonGadget}; fn test_uint_gadget_value(val: $native_type, alloc_val: &$type_name, check_name: &str) { @@ -1479,15 +1727,18 @@ macro_rules! impl_uint_gadget { let op2: $native_type = rng.gen(); let add_result_val = op1.overflowing_add(op2).0; let mul_result_val = op1.overflowing_mul(op2).0; + let sub_result_val = op1.overflowing_sub(op2).0; let op1_var = alloc_fn(&mut cs, "alloc op1", &var_type_op1, op1); let op2_var = alloc_fn(&mut cs, "alloc op2", &var_type_op2, op2); let cond_var = alloc_boolean_cond(&mut cs, "alloc condition", condition); - let add_result_var = { + let (add_result_var, sub_result_var) = { // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped let mut multi_eq = MultiEq::new(&mut cs); - op1_var.conditionally_add(&mut multi_eq, &cond_var, &op2_var).unwrap() + let add_result = op1_var.conditionally_add(multi_eq.ns(|| "conditionally add"), &cond_var, &op2_var).unwrap(); + let sub_result = op1_var.conditionally_sub(multi_eq.ns(|| "conditionally sub"), &cond_var, &op2_var).unwrap(); + (add_result, sub_result) }; let mul_result_var = op1_var.conditionally_mul(&mut cs, &cond_var, &op2_var).unwrap(); @@ -1500,7 +1751,13 @@ macro_rules! impl_uint_gadget { mul_result_val } else { op1 - }, &mul_result_var, "addition correctness"); + }, &mul_result_var, "multiplication correctness"); + test_uint_gadget_value(if cond_var.get_value().unwrap() { + sub_result_val + } else { + op1 + }, &sub_result_var, "subtraction correctness"); + assert!(cs.is_satisfied()); } } @@ -1795,10 +2052,221 @@ macro_rules! impl_uint_gadget { } } + #[test] + fn test_subtraction() { + let rng = &mut thread_rng(); + for condition in BOOLEAN_TYPES.iter() { + for var_type_op1 in VARIABLE_TYPES.iter() { + for var_type_op2 in VARIABLE_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let op1: $native_type = rng.gen(); + let op2: $native_type = rng.gen(); + + let (left, right, var_type_left, var_type_right) = match op1.cmp(&op2) { + Ordering::Less => (op2, op1, &var_type_op2, &var_type_op1), + Ordering::Greater | Ordering::Equal => (op1, op2, &var_type_op1, &var_type_op2), + }; + + // compute subtraction and check that no underflow occurs + let (diff, underflow) = left.overflowing_sub(right); + assert!(!underflow); + + let left_op = alloc_fn(&mut cs, "alloc left op", var_type_left, left); + let right_op = alloc_fn(&mut cs, "alloc right op", var_type_right, right); + let cond_var = alloc_boolean_cond(&mut cs, "alloc conditional", condition); + + let result_var = { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + left_op.conditionally_sub_noborrow(&mut multi_eq, &cond_var, &right_op).unwrap() + }; + test_uint_gadget_value(if cond_var.get_value().unwrap() { + diff + } else { + left + }, &result_var, "sub without underflow correctness"); + assert!(cs.is_satisfied()); + + // check that subtraction with underflows fails + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let op1: $native_type = rng.gen(); + let op2: $native_type = rng.gen(); + + let (left, right, var_type_left, var_type_right) = match op1.cmp(&op2) { + Ordering::Less => (op1, op2, &var_type_op1, &var_type_op2), + Ordering::Greater => (op2, op1, &var_type_op2, &var_type_op1), + Ordering::Equal => { + // make left < right by choosing left=op1 and right=op2+1 + let (right, overflow) = op2.overflowing_add(1); + if overflow { + // if op2+1 overflows, then it is zero, hence swap with op1 + (right, op1, &var_type_op2, &var_type_op1) + } else { + (op1, right, &var_type_op1, &var_type_op2) + } + }, + }; + + // compute subtraction and check that underflow occurs + let (diff, underflow) = left.overflowing_sub(right); + assert!(underflow); + + let left_op = alloc_fn(&mut cs, "alloc left op", var_type_left, left); + let right_op = alloc_fn(&mut cs, "alloc right op", var_type_right, right); + let result = { + // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped + let mut multi_eq = MultiEq::new(&mut cs); + left_op.conditionally_sub_noborrow(&mut multi_eq, &cond_var, &right_op) + }; + + // Need to distinguish between operands being both constant or not, + // as in the former case the operation should return an error + // rather than unsatisfied constraints + let result_var = match (var_type_op1, var_type_op2) { + (VariableType::Constant, VariableType::Constant) => { + match result.unwrap_err() { + SynthesisError::Unsatisfiable => (), + err => assert!(false, "invalid error returned by sub_noborrow: {}", err) + }; + return; + }, + (_, _) => result.unwrap(), + }; + + test_uint_gadget_value(if cond_var.get_value().unwrap() { + diff + } else { + left + }, &result_var, "sub with underflow correctness"); + assert!(!cs.is_satisfied()); + } + } + } + } + + #[test] + fn test_cmp_gadget() { + let rng = &mut thread_rng(); + const NUM_RUNS: usize = 10; + + // helper closure which is useful to deal with the error returned by enforce cmp + // function if both the operands are constant and the comparison is + // unsatisfiable on such constants + let handle_constant_operands = |cs: &ConstraintSystem::, must_be_satisfied: bool, cmp_result: Result<(), SynthesisError>, var_type_op1: &VariableType, var_type_op2: &VariableType, assertion_label| { + match (*var_type_op1, *var_type_op2) { + (VariableType::Constant, VariableType::Constant) => { + if must_be_satisfied { + cmp_result.unwrap() + } else { + match cmp_result.unwrap_err() { + SynthesisError::Unsatisfiable | SynthesisError::AssignmentMissing => assert!(true), + err => assert!(false, "wrong error returned with constant operands in {}: {}", assertion_label, err), + } + } + }, + _ => { + cmp_result.unwrap(); + assert!(!(cs.is_satisfied() ^ must_be_satisfied), "{} for {:?} {:?}", assertion_label, var_type_op1, var_type_op2); + } + } + }; + + for var_type_op1 in VARIABLE_TYPES.iter() { + for var_type_op2 in VARIABLE_TYPES.iter() { + for _ in 0..NUM_RUNS { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let a: $native_type = rng.gen(); + let b: $native_type = rng.gen(); + + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let b_var = alloc_fn(&mut cs, "alloc b", var_type_op2, b); + + let is_smaller_var = a_var.is_smaller_than(cs.ns(|| "a < b"), &b_var).unwrap(); + let is_smaller = match a.cmp(&b) { + Ordering::Less => { + assert!(is_smaller_var.get_value().unwrap()); + assert!(cs.is_satisfied(), "is smaller"); + true + } + Ordering::Greater | Ordering::Equal => { + assert!(!is_smaller_var.get_value().unwrap()); + assert!(cs.is_satisfied(), "is not smaller"); + false + } + }; + + // test when operands are equal + let is_smaller_var = a_var.is_smaller_than(cs.ns(|| "a < a"), &a_var).unwrap(); + assert!(!is_smaller_var.get_value().unwrap()); + assert!(cs.is_satisfied()); + + // test enforce_smaller_than + let enforce_ret = a_var.enforce_smaller_than(cs.ns(|| "enforce a < b"), &b_var); + handle_constant_operands(&cs, is_smaller, enforce_ret, var_type_op1, var_type_op2, "enforce_smaller_than test"); + + // test equality + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let enforce_ret = a_var.enforce_smaller_than(cs.ns(|| "enforce a < a"), &a_var); + handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, "enforce a < a test"); + + + // test all comparisons + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let b_var = alloc_fn(&mut cs, "alloc b", var_type_op2, b); + match a.cmp(&b) { + Ordering::Less => { + let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, "enforce less test"); + let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less equal"), &b_var, Ordering::Less, true); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, "enforce less equal test"); + } + Ordering::Greater => { + let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater"), &b_var, Ordering::Greater, false); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, "enforce greater test"); + let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater equal"), &b_var, Ordering::Greater, true); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, "enforce greater equal test"); + } + _ => {} + } + + + // negative test + match b.cmp(&a) { + Ordering::Less => { + let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, "enforce less negative test"); + let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less equal"),&b_var, Ordering::Less, true); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, "enforce less equal negative test"); + + } + Ordering::Greater => { + let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater"),&b_var, Ordering::Greater, false); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, "enforce greater negative test"); + let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater equal"),&b_var, Ordering::Greater, true); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, "enforce greater equal negative test"); + } + _ => {} + } + + // test equality with enforce_cmp + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + + let enforce_ret = a_var.enforce_cmp(cs.ns(|| "enforce a <= a"), &a_var, Ordering::Less, true); + handle_constant_operands(&cs, true, enforce_ret, var_type_op1, &VariableType::Constant, "enforce less equal on same variable test"); + let enforce_ret = a_var.enforce_cmp(cs.ns(|| "enforce a < a"), &a_var, Ordering::Less, false); + handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, "enforce less on same variable test"); + } + } + } } } + + } } } - -pub mod test_mod {} \ No newline at end of file diff --git a/r1cs/gadgets/std/src/bits/mod.rs b/r1cs/gadgets/std/src/bits/mod.rs index b78464b40..2654afe52 100644 --- a/r1cs/gadgets/std/src/bits/mod.rs +++ b/r1cs/gadgets/std/src/bits/mod.rs @@ -191,6 +191,48 @@ Sized Self::conditionally_select(cs.ns(|| "conditionally select values"), cond, &sum, self) } + /// Perform modular subtraction of `other` from `self` + fn sub(&self, cs: M, other: &Self) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract>; + + /// Perform modular subtraction of `other` from `self` if `cond` is True, otherwise do nothing + fn conditionally_sub( + &self, + mut cs: M, + cond: &Boolean, + other: &Self + ) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> + { + let diff = self.sub(cs.ns(|| "sub"), other)?; + Self::conditionally_select(cs.ns(|| "conditionally select result"), cond, &diff, self) + } + + /// Subtract `other` from `self`, checking that no borrows occur (i.e., that self - other >= 0) + fn sub_noborrow(&self, cs: M, other: &Self) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract>; + + /// Subtract `other` from `self` if `cond` is True, checking that no borrows occur, otherwise do nothing + fn conditionally_sub_noborrow( + &self, + mut cs: M, + cond: &Boolean, + other: &Self + ) -> Result + where + CS: ConstraintSystemAbstract, + M: ConstraintSystemAbstract> + { + let diff = self.sub_noborrow(cs.ns(|| "sub"), other)?; + Self::conditionally_select(cs.ns(|| "conditionally select result"), cond, &diff, self) + } + /// Perform modular multiplication of several `Self` objects. fn mulmany(cs: CS, operands: &[Self]) -> Result where From 856f5bed74966f032b8ad91733df1c3232870ccd Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Tue, 18 Jan 2022 15:20:39 +0100 Subject: [PATCH 05/18] Fix security issue on is_smaller_than + add comments --- r1cs/gadgets/std/src/bits/macros.rs | 180 ++++++++++++++++++++++------ 1 file changed, 142 insertions(+), 38 deletions(-) diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index 1401c2cb7..21e649f13 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -271,7 +271,6 @@ macro_rules! impl_uint_gadget { T: Borrow<$native_type> { let mut value = None; - //ToDo: verify if ConstraintF must be a PrimeField let field_element = FpGadget::::alloc_input(cs.ns(|| "alloc_input as field element"), || { let val = value_gen().map(|val| *val.borrow())?; value = Some(val); @@ -575,7 +574,7 @@ macro_rules! impl_uint_gadget { M: ConstraintSystemAbstract> { let num_operands = operands.len(); - let field_bits = (ConstraintF::Params::MODULUS_BITS - 1) as usize; + let field_bits = (ConstraintF::Params::CAPACITY) as usize; // in this case it is not possible to enforce the correctness of the addition // of at least 2 elements for the field ConstraintF assert!(field_bits > $bit_size); @@ -595,6 +594,19 @@ macro_rules! impl_uint_gadget { handle_numoperands_opmany!(addmany, cs, operands, max_num_operands); } + /* + Result is computed as follows: + Without loss of generality, consider 2 operands a, b, with their little-endian + bit representations, and n = $bit_size. + The addition is computed as ADD(a,b)=2^0a_0 + 2^1a_1 + ... + 2^(n-1)a_{n-1} + + 2^0b_0 + 2^1b_1 + ... + 2^(n-1)b_{n-1} in the ConstraintF field. + Then, m = $bit_size + overflow_bits Booleans res_0,...,res_{m-1} are allocated as + witness, and it is enforced that ADD(a,b) == 2^0res_0 + 2^1res_1 + ... + 2^(m-1)res_{m-1}. + Then, the Booleans res_0,...,res_{n-1} represent the result modulo 2^n (assuming + that no field overflow occurs in computing ADD(a,b), + which is checked with the initial assertions), which is returned. + */ + // result_value is the sum of all operands in the ConstraintF field, // which is employed in the constraint @@ -675,7 +687,7 @@ macro_rules! impl_uint_gadget { CS: ConstraintSystemAbstract, M: ConstraintSystemAbstract> { let num_operands = operands.len(); - let field_bits = (ConstraintF::Params::MODULUS_BITS - 1) as usize; + let field_bits = (ConstraintF::Params::CAPACITY) as usize; // in this case it is not possible to enforce the correctness of the addition // of at least 2 elements for the field ConstraintF assert!(field_bits > $bit_size); @@ -695,6 +707,19 @@ macro_rules! impl_uint_gadget { handle_numoperands_opmany!(addmany_nocarry, cs, operands, max_num_operands); } + /* + Result is computed as follows. + Without loss of generality, consider 2 operands a, b, with their little-endian + bit representations, and n = $bit_size. + The addition is computed as ADD(a,b)=2^0a_0 + 2^1a_1 + ... + 2^(n-1)a_{n-1} + + 2^0b_0 + 2^1b_1 + ... + 2^(n-1)b_{n-1} in the ConstraintF field. + Then, n Booleans res_0,...,res_{n-1} are allocated as witness, and it is + enforced that ADD(a,b) == 2^0res_0 + 2^1res_1 + ... + 2^(n-1)res_{n-1}. + Such constraint is verified iff ADD(a,b) can be represented with at most n bits, + that is iff the addition does not overflow (assuming that ADD(a,b) does not + overflow in the field, which is checked with the initial assertions) + */ + let mut result_value: Option<$native_type> = Some(0); // this flag allows to verify if the addition of operands overflows, which allows // to return an error in case a set of constants whose sum is overflowing is provided @@ -752,7 +777,7 @@ macro_rules! impl_uint_gadget { fn mulmany(mut cs: CS, operands: &[Self]) -> Result where CS: ConstraintSystemAbstract { let num_operands = operands.len(); - let field_bits = (ConstraintF::Params::MODULUS_BITS - 1) as usize; + let field_bits = (ConstraintF::Params::CAPACITY) as usize; assert!(num_operands >= 2); assert!(field_bits >= 2*$bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field @@ -761,6 +786,22 @@ macro_rules! impl_uint_gadget { handle_numoperands_opmany!(mulmany, cs, operands, max_num_operands); } + + /* + Result is computed as follows. + Without loss of generality, consider 3 operands a, b, c and n = $bit_size. + The operands are converted to field gadgets fa, fb, fc employing their big-endian + bit representations. + Then, the product of all this elements over the field is computed with + num_operands-1 (i.e., 2 in this case) constraints: + - a*b=tmp + - tmp*c=res + Field gadget res is then converted to its big-endian bit representation, and only + the n least significant bits are returned as the result, which thus corresponds + to the a*b*c mod 2^n (assuming that no field overflow occurs in computing a*b*c, + which is checked with the initial assertions) + */ + // corner case: check if all operands are constants before allocating any variable let mut all_constants = true; let mut result_value: Option<$native_type> = Some(1); @@ -804,7 +845,7 @@ macro_rules! impl_uint_gadget { fn mulmany_nocarry(mut cs: CS, operands: &[Self]) -> Result where CS: ConstraintSystemAbstract { let num_operands = operands.len(); - let field_bits = (ConstraintF::Params::MODULUS_BITS - 1) as usize; + let field_bits = (ConstraintF::Params::CAPACITY) as usize; assert!(num_operands >= 2); assert!(field_bits >= 2*$bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field @@ -813,6 +854,22 @@ macro_rules! impl_uint_gadget { handle_numoperands_opmany!(mulmany_nocarry, cs, operands, max_num_operands); } + /* + Result is computed as follows. + Without loss of generality, consider 3 operands a, b, c and n = $bit_size. + The operands are converted to field gadgets fa, fb, fc employing their big-endian + bit representations. + Then, the product of all this elements over the field is computed with + num_operands-1 (i.e., 2 in this case) constraints: + - a*b=tmp + - tmp*c=res + Field gadget res is then converted to a big-endian bit representation employing + only n bits. If this conversion succeeds, then it means that a*b*c does not + overflow 2^n (assuming that no field overflow occurs in computing a*b*c, which + is checked by the initial assertions), and such n bits represent the final + product + */ + // corner case: check if all operands are constants before allocating any variable let mut all_constants = true; let mut result_value: Option<$native_type> = Some(1); @@ -864,18 +921,25 @@ macro_rules! impl_uint_gadget { { // this assertion checks that the field is big enough: that is, the field must // be able to represent integers up to 2^($bit_size+1) - assert!(ConstraintF::Params::MODULUS_BITS - 1 > $bit_size); - - // Overall idea: allocate $bit_size+1 bits representing a field element diff - // and enforce that diff == self - other + 2^$bit_size. - // The addition of 2^$bit_size is useful in case other >= self to avoid - // field underflows, which would require to allocate as many bits as the field - // modulus. Only the first $bit_size bits are returned as the result of the - // subtraction, hence the addition of 2^$bit_size has no impact on the final - // result + assert!(ConstraintF::Params::CAPACITY > $bit_size); + + /* Result is computed as follows. + Consider 2 operands a,b with their little-endian bit representations, + and n=$bit_size. + The subtraction is computed as SUB(a,b)=2^0a_0 + 2^1a_1 + ... + 2^(n-1)a_{n-1} - + 2^0b_0 - 2^1b_1 - ... - 2^(n-1)b_{n-1} in the ConstraintF field. + Then, allocate n+1 bits res_0,...,res_{n} and enforce that + SUB(a,b) + 2^n == 2^0res_0 + 2^1res_1 + ... + 2^nres_n. + The addition of 2^n is useful in case other >= self to avoid + field underflows, which would require to allocate as many bits as the field + modulus. Only the first n bits are returned as the result of the + subtraction (since the result must be computed modulo 2^n), hence the addition + of 2^n has no impact on the final result + */ // max_value is a field element equal to 2^$bit_size let max_value = ConstraintF::from($native_type::MAX) + ConstraintF::one(); + // lc will be constructed as SUB(self,other)+2^$bit_size let mut lc = (max_value, CS::one()).into(); let mut coeff = ConstraintF::one(); let mut all_constants = true; @@ -888,6 +952,8 @@ macro_rules! impl_uint_gadget { coeff.double_in_place(); } + // diff = self - other mod 2^$bit_size, + // while diff_in_field = self - other + 2^$bit_size over the ConstraintF field let (diff, diff_in_field) = match (self.value, other.value) { (Some(val1), Some(val2)) => { let (diff, _) = val1.overflowing_sub(val2); // don't care if there is an underflow @@ -902,12 +968,14 @@ macro_rules! impl_uint_gadget { return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &diff.unwrap())); } + // convert diff_in_field to little-endian bit representation let diff_bits = match diff_in_field { Some(diff) => diff.write_bits().iter().rev().map(|b| Some(*b)).collect::>(), None => vec![None; $bit_size+1], }; let mut result_bits = Vec::with_capacity($bit_size); + // result_lc is constructed as 2^0diff_bits[0]+2^1diff_bits[1]+...+2^($bit_size)diff_bits[$bit_size] let mut result_lc = LinearCombination::zero(); let mut coeff = ConstraintF::one(); for i in 0..$bit_size+1 { @@ -918,6 +986,8 @@ macro_rules! impl_uint_gadget { coeff.double_in_place(); if i < $bit_size { + // only $bit_size bit are useful for the result, + // as the result must be modulo 2^$bit_size result_bits.push(diff_bit); } } @@ -939,8 +1009,21 @@ macro_rules! impl_uint_gadget { // this assertion checks that the field is big enough: subtraction of any 2 // values in $native_type must be a field element that cannot be represented as // $native_type - assert!(ConstraintF::Params::MODULUS_BITS - 1 > $bit_size); - + assert!(ConstraintF::Params::CAPACITY > $bit_size); + + /* Result is computed as follows. + Consider 2 operands a,b with their little-endian bit representations, + and n=$bit_size. + The subtraction is computed as SUB(a,b)=2^0a_0 + 2^1a_1 + ... + 2^(n-1)a_{n-1} - + 2^0b_0 - 2^1b_1 - ... - 2^(n-1)b_{n-1} in the ConstraintF field. + Then, allocate n bits res_0,...,res_{n-1} and enforce that + SUB(a,b) == 2^0res_0 + 2^1res_1 + ... + 2^(n-1)res_{n-1}. + Such constraint is satisfied iff SUB(a,b) can be represented with n bits, + that is iff no field underflow occurs in the computation of SUB(a,b), which holds + iff a - b does not underflow + */ + + // lc is constructed as SUB(self, other) let mut lc = LinearCombination::zero(); let mut coeff = ConstraintF::one(); let mut all_constants = true; @@ -964,7 +1047,6 @@ macro_rules! impl_uint_gadget { if all_constants && diff.is_some() { if is_underflowing { - // in this case self < other return Err(SynthesisError::Unsatisfiable) } else { return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &diff.unwrap())) @@ -994,48 +1076,70 @@ macro_rules! impl_uint_gadget { // this assertion checks that the field is big enough: subtraction of any 2 // values in $native_type must be a field element that cannot be represented as // $native_type - assert!(ConstraintF::Params::MODULUS_BITS - 1 > $bit_size); + assert!(ConstraintF::Params::CAPACITY > $bit_size); + + + /* Result is computed as follows. + Consider 2 operands a,b with their little-endian bit representations, + and n=$bit_size. + Compute n bits diff_0,...,diff_{n-1} which represents a - b mod 2^n employing + sub function, and then compute the term + DELTA(a,b,diff) = 2^0a_0 + 2^1a_1 + ... + 2^(n-1)a_{n-1} - + 2^0b_0 - 2^1b_1 - ... - 2^(n-1)b_{n-1} - 2^0diff_0 - 2^1diff_1 - ... + - 2^(n-1)diff_{n-1} in the ConstraintF field. + Since diff is a field element equal to a-b mod 2^n, then it holds that + DELTA(a,b,diff) == 0 iff a >= b. + To compute the result, allocate a Boolean res and a field element v, + and enforces these 2 constraints: + - (1-res)*DELTA(a,b,diff) = 0 + - v*DELTA(a,b,diff) == res + The first constraint ensures that res=1 when DELTA(a,b,diff)!=0, which holds iff + a < b; the second constraint ensures that res=0 when DELTA(a,b,diff)=0, which holds + iff a >= b. Note that to satisfy the second constraint when res=1, the prover + must set v to the multiplicative inverse of DELTA(a,b,diff), which necessarily + exists when res=1 as DELTA(a,b,diff) != 0 + */ + + let diff_var = { + // add a scope for multi_eq as constraints are enforced when variable is + // dropped + let mut multi_eq = MultiEq::new(&mut cs); + self.sub(multi_eq.ns(|| "a - b mod 2^n"), other)? + }; let mut delta_lc = LinearCombination::zero(); let mut coeff = ConstraintF::one(); let mut all_constants = true; - for (self_bit, other_bit) in self.bits.iter().zip(other.bits.iter()) { + for ((self_bit, other_bit), diff_bit) in self.bits.iter().zip(other.bits.iter()).zip(diff_var.bits.iter()) { delta_lc = delta_lc + &self_bit.lc(CS::one(), coeff); delta_lc = delta_lc - &other_bit.lc(CS::one(), coeff); + delta_lc = delta_lc - &diff_bit.lc(CS::one(), coeff); all_constants &= self_bit.is_constant() && other_bit.is_constant(); coeff.double_in_place(); } - let mut is_underflowing = None; - // delta = self - other - diff in the field, where diff = self - other over the uint type - let mut delta = None; - let diff = match (self.get_value(), other.get_value()) { + let (diff_val, is_underflowing, delta) = match (self.get_value(), other.get_value()) { (Some(value1), Some(value2)) => { let (diff, underflow) = value1.overflowing_sub(value2); - is_underflowing = Some(underflow); - // compute self - other - diff over the field + // compute delta = self - other - diff in the field, + // where diff = self - other mod 2^$bit_size let self_in_field = ConstraintF::from(value1); let other_in_field = ConstraintF::from(value2); let diff_in_field = ConstraintF::from(diff); - delta = Some(self_in_field - other_in_field - diff_in_field); - Some(diff) + let delta = self_in_field - other_in_field - diff_in_field; + (Some(diff), Some(underflow), Some(delta)) }, - _ => None, + _ => (None, None, None), }; - if all_constants && diff.is_some() { + if all_constants && diff_val.is_some() { return Ok(Boolean::constant(is_underflowing.unwrap())) } - let diff_var = Self::alloc(cs.ns(|| "alloc diff"), || diff.ok_or(SynthesisError::AssignmentMissing))?; - let mut coeff = ConstraintF::one(); - for diff_bit in diff_var.bits.iter() { - delta_lc = delta_lc - &diff_bit.lc(CS::one(), coeff); - coeff.double_in_place(); - } - // ToDo: It should not be necessary to allocate it as a Boolean gadget + // ToDo: It should not be necessary to allocate it as a Boolean gadget, + // can be done when a Boolean::from(FieldGadget) will be implemented let is_smaller = Boolean::alloc(cs.ns(|| "alloc result"), || is_underflowing.ok_or(SynthesisError::AssignmentMissing))?; let inv = delta.map(|delta| { @@ -1640,7 +1744,7 @@ macro_rules! impl_uint_gadget { #[test] fn test_mulmany() { - const MAX_NUM_OPERANDS: usize = (::Params::MODULUS_BITS-1) as usize/$bit_size ; + const MAX_NUM_OPERANDS: usize = (::Params::CAPACITY) as usize/$bit_size ; const NUM_OPERANDS: usize = MAX_NUM_OPERANDS*2+5; // we want to test a case when the operands must be split in multiple chunks assert!(NUM_OPERANDS > MAX_NUM_OPERANDS); @@ -1853,7 +1957,7 @@ macro_rules! impl_uint_gadget { #[test] fn test_mulmany_nocarry() { - const MAX_NUM_OPERANDS: usize = (::Params::MODULUS_BITS-1) as usize/$bit_size ; + const MAX_NUM_OPERANDS: usize = (::Params::CAPACITY) as usize/$bit_size ; const NUM_OPERANDS: usize = MAX_NUM_OPERANDS*2+5; // we want to test a case when the operands must be split in multiple chunks assert!(NUM_OPERANDS > MAX_NUM_OPERANDS); From 2e126d5b0b594bd1962f91fb9bfec06f3a38b8d7 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Wed, 19 Jan 2022 10:45:47 +0100 Subject: [PATCH 06/18] Implement uint16/128 + deal with field too small for mul --- r1cs/gadgets/std/src/bits/macros.rs | 320 ++++++++++++++++++---------- r1cs/gadgets/std/src/bits/mod.rs | 3 + 2 files changed, 212 insertions(+), 111 deletions(-) diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index 21e649f13..d0c3016a5 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -133,7 +133,7 @@ macro_rules! impl_uint_gadget { result } - /// enfroces that self >= other. This function is provided as it is much efficient + /// enforces that self >= other. This function is provided as it is much efficient /// in terms of constraints with respect to the default implementation of /// ComparisonGadget, which relies on the smaller_than functions pub fn enforce_greater_or_equal_than(&self, mut cs: CS, other: &Self) @@ -193,6 +193,86 @@ macro_rules! impl_uint_gadget { value, }) } + + /// This function allows to multiply `self` and `other` with a variant of + /// double & add algorithm. + /// It is useful when the field ConstraintF is too small to employ the much more + /// efficient algorithm employed in multiplication functions of the UIntGadget. + /// If `no_carry` is true, then the function checks that the multiplication + /// does not overflow, otherwise modular multiplication with no overflow checking + /// is performed + fn mul_with_double_and_add(&self, mut cs: CS, other: &Self, + no_carry: bool) -> Result + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + + { + let field_bits = ConstraintF::Params::CAPACITY as usize; + // max_overflow_bits are the maximum number of non-zero bits a `Self` element + // can have to be multiplied to another `Self` without overflowing the field + let max_overflow_bits = field_bits - $bit_size; + // given a base b = 2^m, where m=2^max_overflow_bits, the `other` operand is + // represented in base b with digits of m bits. Then, the product self*other + // is computed by the following summation: + // sum_{i from 0 to h-1} ((self*b^i) % 2^$bit_size * digit_i), where h is the + // number of digits employed to represent other in base b + let mut coeff = self.clone(); // coeff will hold self*b^i mod 2^$bit_size for digit i + let mut operands = Vec::new(); // operands will accumulate all the operands of the summation + for (i, digit) in other.bits.chunks(max_overflow_bits).enumerate() { + // multiply digit to coeff over the field, since digit < b, + // then we are sure no field overflow will occur + let be_bits = digit.iter().rev().map(|bit| *bit).collect::>(); + let digit_in_field = FpGadget::::from_bits(cs.ns(|| format!("digit {} to field", i)), &be_bits[..])?; + let coeff_bits = coeff.to_bits(cs.ns(|| format!("unpack coeff for digit {}", i)))?; + let coeff_in_field = FpGadget::::from_bits(cs.ns(|| format!("coeff for digit {} to field", i)), &coeff_bits[..])?; + let tmp_result = coeff_in_field.mul(cs.ns(|| format!("tmp result for digit {}", i)), &digit_in_field)?; + let result_bits = if no_carry { + // ensure that tmp_result can be represented with $bit_size bits to + // ensure that no native type overflow has happened in the multiplication + tmp_result.to_bits_with_length_restriction(cs.ns(|| format!("to bits for digit {}", i)), field_bits + 1 - $bit_size)? + } else { + let result_bits = tmp_result.to_bits_with_length_restriction(cs.ns(|| format!("to bits for digit {}", i)), 1)?; + result_bits + .iter() + .skip(max_overflow_bits) + .map(|el| *el) + .collect::>() + }; + // addend is equal to coeff*digit mod 2^$bit_size + let addend = $type_name::from_bits(cs.ns(|| format!("packing addend for digit {}", i)), &result_bits[..])?; + operands.push(addend); + // move coeff from self*b^i mod 2^$bit_size to self*b^(i+1) mod 2^$bit_size + coeff = coeff.shl(max_overflow_bits); + } + let mut multi_eq = MultiEq::new(&mut cs); + if no_carry { + return $type_name::addmany_nocarry(multi_eq.ns(|| "add operands"), &operands) + } else { + return $type_name::addmany(multi_eq.ns(|| "add operands"), &operands) + } + } + + /// This function allows to multiply a set of operands with a variant of + /// double & add algorithm. + /// It is useful when the field ConstraintF is too small to employ the much more + /// efficient algorithm employed in multiplication functions of the UIntGadget. + /// If `no_carry` is true, then the function checks that the multiplication + /// does not overflow, otherwise modular multiplication with no overflow checking + /// is performed + fn mulmany_with_double_and_add(mut cs: CS, operands: &[Self], + no_carry: bool) -> Result + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + { + let mut result = operands[0].mul_with_double_and_add(cs.ns(|| "double and add first operands"), &operands[1], no_carry)?; + for (i, op) in operands.iter().skip(2).enumerate() { + result = result.mul_with_double_and_add(cs.ns(|| format!("double and add operand {}", i)),op, no_carry)?; + } + Ok(result) + } + } impl PartialEq for $type_name { @@ -779,7 +859,30 @@ macro_rules! impl_uint_gadget { let num_operands = operands.len(); let field_bits = (ConstraintF::Params::CAPACITY) as usize; assert!(num_operands >= 2); - assert!(field_bits >= 2*$bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field + + // corner case: check if all operands are constants before allocating any variable + let mut all_constants = true; + let mut result_value: Option<$native_type> = Some(1); + for op in operands { + for bit in &op.bits { + all_constants &= bit.is_constant(); + } + + result_value = match op.value { + Some(val) => result_value.as_mut().map(|v| v.overflowing_mul(val).0), + None => None, + } + } + + if all_constants && result_value.is_some() { + return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &result_value.unwrap())); + } + + assert!(field_bits > $bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field + + if field_bits < 2*$bit_size { + return $type_name::mulmany_with_double_and_add(cs.ns(|| "double and add"), operands, false); + } if field_bits < num_operands*$bit_size { let max_num_operands = field_bits/$bit_size; @@ -802,24 +905,6 @@ macro_rules! impl_uint_gadget { which is checked with the initial assertions) */ - // corner case: check if all operands are constants before allocating any variable - let mut all_constants = true; - let mut result_value: Option<$native_type> = Some(1); - for op in operands { - for bit in &op.bits { - all_constants &= bit.is_constant(); - } - - result_value = match op.value { - Some(val) => result_value.as_mut().map(|v| v.overflowing_mul(val).0), - None => None, - } - } - - if all_constants && result_value.is_some() { - return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &result_value.unwrap())); - } - let op0_bits = operands[0].to_bits(cs.ns(|| "unpack first operand"))?; let op1_bits = operands[1].to_bits(cs.ns(|| "unpack second operand"))?; let field_op0 = FpGadget::::from_bits(cs.ns(|| "alloc operand 0 in field"), &op0_bits[..])?; @@ -847,28 +932,6 @@ macro_rules! impl_uint_gadget { let num_operands = operands.len(); let field_bits = (ConstraintF::Params::CAPACITY) as usize; assert!(num_operands >= 2); - assert!(field_bits >= 2*$bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field - - if field_bits < num_operands*$bit_size { - let max_num_operands = field_bits/$bit_size; - handle_numoperands_opmany!(mulmany_nocarry, cs, operands, max_num_operands); - } - - /* - Result is computed as follows. - Without loss of generality, consider 3 operands a, b, c and n = $bit_size. - The operands are converted to field gadgets fa, fb, fc employing their big-endian - bit representations. - Then, the product of all this elements over the field is computed with - num_operands-1 (i.e., 2 in this case) constraints: - - a*b=tmp - - tmp*c=res - Field gadget res is then converted to a big-endian bit representation employing - only n bits. If this conversion succeeds, then it means that a*b*c does not - overflow 2^n (assuming that no field overflow occurs in computing a*b*c, which - is checked by the initial assertions), and such n bits represent the final - product - */ // corner case: check if all operands are constants before allocating any variable let mut all_constants = true; @@ -897,6 +960,33 @@ macro_rules! impl_uint_gadget { } } + assert!(field_bits > $bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field + + if field_bits < 2*$bit_size { + return $type_name::mulmany_with_double_and_add(cs.ns(|| "double and add"), operands, true); + } + + if field_bits < num_operands*$bit_size { + let max_num_operands = field_bits/$bit_size; + handle_numoperands_opmany!(mulmany_nocarry, cs, operands, max_num_operands); + } + + /* + Result is computed as follows. + Without loss of generality, consider 3 operands a, b, c and n = $bit_size. + The operands are converted to field gadgets fa, fb, fc employing their big-endian + bit representations. + Then, the product of all this elements over the field is computed with + num_operands-1 (i.e., 2 in this case) constraints: + - a*b=tmp + - tmp*c=res + Field gadget res is then converted to a big-endian bit representation employing + only n bits. If this conversion succeeds, then it means that a*b*c does not + overflow 2^n (assuming that no field overflow occurs in computing a*b*c, which + is checked by the initial assertions), and such n bits represent the final + product + */ + let op0_bits = operands[0].to_bits(cs.ns(|| "unpack first operand"))?; let op1_bits = operands[1].to_bits(cs.ns(|| "unpack second operand"))?; let field_op0 = FpGadget::::from_bits(cs.ns(|| "alloc operand 0 in field"), &op0_bits[..])?; @@ -1743,6 +1833,7 @@ macro_rules! impl_uint_gadget { } #[test] + #[allow(unconditional_panic)] // otherwise test will not compile for uint128, as field is too small fn test_mulmany() { const MAX_NUM_OPERANDS: usize = (::Params::CAPACITY) as usize/$bit_size ; const NUM_OPERANDS: usize = MAX_NUM_OPERANDS*2+5; @@ -1767,44 +1858,45 @@ macro_rules! impl_uint_gadget { assert!(cs.is_satisfied()); + if MAX_NUM_OPERANDS >= 2 { // negative tests are skipped if if double and add must be used because the field is too small + // negative test on first batch + let bit_gadget_path = "mul operands/first batch of operands/unpack result field element/bit 0/boolean"; + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); - // negative test on first batch - let bit_gadget_path = "mul operands/first batch of operands/unpack result field element/bit 0/boolean"; - if cs.get(bit_gadget_path).is_zero() { - cs.set(bit_gadget_path, Fr::one()); - } else { - cs.set(bit_gadget_path, Fr::zero()); - } - assert!(!cs.is_satisfied()); - - // set bit value back - if cs.get(bit_gadget_path).is_zero() { - cs.set(bit_gadget_path, Fr::one()); - } else { - cs.set(bit_gadget_path, Fr::zero()); - } - assert!(cs.is_satisfied()); + // set bit value back + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(cs.is_satisfied()); - // negative test on allocated field element - let mut last_batch_start_operand = MAX_NUM_OPERANDS + (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1)*(MAX_NUM_OPERANDS-1); - if last_batch_start_operand == NUM_OPERANDS { - last_batch_start_operand -= MAX_NUM_OPERANDS-1; - } - let bit_gadget_path = format!("mul operands/operands from {} to {}/unpack result field element/bit 0/boolean", last_batch_start_operand, NUM_OPERANDS); - if cs.get(&bit_gadget_path).is_zero() { - cs.set(&bit_gadget_path, Fr::one()); - } else { - cs.set(&bit_gadget_path, Fr::zero()); - } - assert!(!cs.is_satisfied()); + // negative test on allocated field element: skip if double and add must be used because the field is too small + let mut last_batch_start_operand = MAX_NUM_OPERANDS + (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1)*(MAX_NUM_OPERANDS-1); + if last_batch_start_operand == NUM_OPERANDS { + last_batch_start_operand -= MAX_NUM_OPERANDS-1; + } + let bit_gadget_path = format!("mul operands/operands from {} to {}/unpack result field element/bit 0/boolean", last_batch_start_operand, NUM_OPERANDS); + if cs.get(&bit_gadget_path).is_zero() { + cs.set(&bit_gadget_path, Fr::one()); + } else { + cs.set(&bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); - // set bit value back - if cs.get(&bit_gadget_path).is_zero() { - cs.set(&bit_gadget_path, Fr::one()); - } else { - cs.set(&bit_gadget_path, Fr::zero()); + // set bit value back + if cs.get(&bit_gadget_path).is_zero() { + cs.set(&bit_gadget_path, Fr::one()); + } else { + cs.set(&bit_gadget_path, Fr::zero()); + } + assert!(cs.is_satisfied()); } - assert!(cs.is_satisfied()); let operands = operand_values.iter().enumerate().map(|(i, val)| { alloc_fn(&mut cs, format!("alloc constant operand {}", i).as_str(), &VARIABLE_TYPES[0], *val) @@ -1956,9 +2048,11 @@ macro_rules! impl_uint_gadget { } #[test] + #[allow(unconditional_panic)] // otherwise test will not compile for uint128, as field is too small fn test_mulmany_nocarry() { const MAX_NUM_OPERANDS: usize = (::Params::CAPACITY) as usize/$bit_size ; const NUM_OPERANDS: usize = MAX_NUM_OPERANDS*2+5; + // we want to test a case when the operands must be split in multiple chunks assert!(NUM_OPERANDS > MAX_NUM_OPERANDS); @@ -1981,43 +2075,47 @@ macro_rules! impl_uint_gadget { test_uint_gadget_value(result_value, &result_var, "result correctness"); assert!(cs.is_satisfied()); - // negative test on first batch - let bit_gadget_path = "mul operands/first batch of operands/unpack result field element/bit 0/boolean"; - if cs.get(bit_gadget_path).is_zero() { - cs.set(bit_gadget_path, Fr::one()); - } else { - cs.set(bit_gadget_path, Fr::zero()); - } - assert!(!cs.is_satisfied()); + if MAX_NUM_OPERANDS >= 2 { // negative tests are skipped if if double and add must be used because the field is too small + // negative test on first batch + let bit_gadget_path = "mul operands/first batch of operands/unpack result field element/bit 0/boolean"; + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); - // set bit value back - if cs.get(bit_gadget_path).is_zero() { - cs.set(bit_gadget_path, Fr::one()); - } else { - cs.set(bit_gadget_path, Fr::zero()); - } - assert!(cs.is_satisfied()); + // set bit value back + if cs.get(bit_gadget_path).is_zero() { + cs.set(bit_gadget_path, Fr::one()); + } else { + cs.set(bit_gadget_path, Fr::zero()); + } + assert!(cs.is_satisfied()); - // negative test on allocated field element - let mut last_batch_start_operand = MAX_NUM_OPERANDS + (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1)*(MAX_NUM_OPERANDS-1); - if last_batch_start_operand == NUM_OPERANDS { - last_batch_start_operand -= MAX_NUM_OPERANDS-1; - } - let bit_gadget_path = format!("mul operands/operands from {} to {}/unpack result field element/bit 0/boolean", last_batch_start_operand, NUM_OPERANDS); - if cs.get(&bit_gadget_path).is_zero() { - cs.set(&bit_gadget_path, Fr::one()); - } else { - cs.set(&bit_gadget_path, Fr::zero()); - } - assert!(!cs.is_satisfied()); + // negative test on allocated field element - // set bit value back - if cs.get(&bit_gadget_path).is_zero() { - cs.set(&bit_gadget_path, Fr::one()); - } else { - cs.set(&bit_gadget_path, Fr::zero()); + let mut last_batch_start_operand = MAX_NUM_OPERANDS + (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1)*(MAX_NUM_OPERANDS-1); + if last_batch_start_operand == NUM_OPERANDS { + last_batch_start_operand -= MAX_NUM_OPERANDS-1; + } + let bit_gadget_path = format!("mul operands/operands from {} to {}/unpack result field element/bit 0/boolean", last_batch_start_operand, NUM_OPERANDS); + if cs.get(&bit_gadget_path).is_zero() { + cs.set(&bit_gadget_path, Fr::one()); + } else { + cs.set(&bit_gadget_path, Fr::zero()); + } + assert!(!cs.is_satisfied()); + + + // set bit value back + if cs.get(&bit_gadget_path).is_zero() { + cs.set(&bit_gadget_path, Fr::one()); + } else { + cs.set(&bit_gadget_path, Fr::zero()); + } + assert!(cs.is_satisfied()); } - assert!(cs.is_satisfied()); // test with all constants let num_constraints = cs.num_constraints(); @@ -2136,7 +2234,7 @@ macro_rules! impl_uint_gadget { SynthesisError::Unsatisfiable => (), _ => assert!(false, "invalid error returned by {}", if is_add {"conditionally_add_nocarry"} else {"conditionally_mul_nocarry"}) }; - return; + continue; }, (_, _) => (), }; @@ -2148,7 +2246,7 @@ macro_rules! impl_uint_gadget { } else { op1 }, &result_var, format!("{} correctness", op).as_str()); - assert!(!cs.is_satisfied(), "checking overflow constraint for {:?} {:?}", var_type_op1, var_type_op2); + assert!(!cs.is_satisfied(), "checking overflow constraint for {:?} {:?} {}", var_type_op1, var_type_op2, is_add); } } diff --git a/r1cs/gadgets/std/src/bits/mod.rs b/r1cs/gadgets/std/src/bits/mod.rs index 2654afe52..1343e1f07 100644 --- a/r1cs/gadgets/std/src/bits/mod.rs +++ b/r1cs/gadgets/std/src/bits/mod.rs @@ -16,6 +16,9 @@ pub mod macros; impl_uint_gadget!(U8, 8, u8, uint8); impl_uint_gadget!(UInt64, 64, u64, uint64); impl_uint_gadget!(UInt32, 32, u32, uint32); +impl_uint_gadget!(UInt16, 16, u16, uint16); +impl_uint_gadget!(UInt128, 128, u128, uint128); + pub type UInt8 = uint8::U8; From 4fc7f89d221eae9ae2de0a25387a5ad1c29ca5b3 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Wed, 19 Jan 2022 15:57:01 +0100 Subject: [PATCH 07/18] Optimize EqGadget for vector of bits --- .../crypto/src/crh/bowe_hopwood/mod.rs | 2 +- r1cs/gadgets/std/src/eq.rs | 316 ++++++++++++++++++ 2 files changed, 317 insertions(+), 1 deletion(-) diff --git a/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs b/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs index 183bfb1ac..6e98029b0 100644 --- a/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs +++ b/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs @@ -8,7 +8,7 @@ use primitives::{ crh::pedersen::PedersenWindow, }; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; -use r1cs_std::{alloc::AllocGadget, groups::GroupGadget, uint8::UInt8}; +use r1cs_std::{alloc::AllocGadget, groups::GroupGadget, UInt8}; use r1cs_std::bits::boolean::Boolean; use std::{borrow::Borrow, marker::PhantomData}; diff --git a/r1cs/gadgets/std/src/eq.rs b/r1cs/gadgets/std/src/eq.rs index 0e6386d27..e0a0704c4 100644 --- a/r1cs/gadgets/std/src/eq.rs +++ b/r1cs/gadgets/std/src/eq.rs @@ -1,6 +1,7 @@ use crate::prelude::*; use algebra::{Field, FpParameters, PrimeField}; use r1cs_core::{ConstraintSystemAbstract, LinearCombination, SynthesisError, Variable}; +use crate::fields::fp::FpGadget; /// Specifies how to generate constraints that check for equality for two variables of type `Self`. pub trait EqGadget: Eq { @@ -130,6 +131,222 @@ impl, ConstraintF: Field> EqGadget for [T] } } +// wrapper type employed to implement helper functions for the implementation of EqGadget for +// Vec +struct BooleanVec<'a>(&'a [Boolean]); +impl BooleanVec<'_> { + #[inline] + // helper function that computes a linear combination of the bits of `self` and `other` which + // corresponds to the difference between two field elements a,b, where a (resp. b) is the field + // element whose little-endian bit representation is `self` (resp. other). + // The function returns also a-b over the field (wrapped in an Option) and a flag + // that specifies if all the bits in both `self` and `other` are constants. + fn compute_diff(&self, _cs: CS, other: &Self) -> (LinearCombination, Option, bool) + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + { + let self_bits = self.0; + let other_bits = other.0; + let field_bits = ConstraintF::Params::CAPACITY as usize; + assert!(self_bits.len() <= field_bits); + assert!(other_bits.len() <= field_bits); + + let mut self_lc = LinearCombination::zero(); + let mut other_lc = LinearCombination::zero(); + let mut coeff = ConstraintF::one(); + let mut diff_in_field = Some(ConstraintF::zero()); + let mut all_constants = true; + for (self_bit, other_bit) in self_bits.iter().zip(other_bits.iter()) { + self_lc = self_lc + &self_bit.lc(CS::one(), coeff); + other_lc = other_lc + &other_bit.lc(CS::one(), coeff); + + all_constants &= self_bit.is_constant() && other_bit.is_constant(); + + diff_in_field = match (self_bit.get_value(), other_bit.get_value()) { + (Some(bit1), Some(bit2)) => diff_in_field.as_mut().map(|diff| { + let self_term = if bit1 { + coeff + } else { + ConstraintF::zero() + }; + let other_term = if bit2 { + coeff + } else { + ConstraintF::zero() + }; + *diff += self_term - other_term; + *diff + }), + _ => None, + }; + + coeff.double_in_place(); + } + + (self_lc - other_lc, diff_in_field, all_constants) + } + + // is_eq computes a Boolean which is true iff `self` == `other`. This function requires that + // `self` and `other` are bit sequences with length at most the capacity of the field ConstraintF. + fn is_eq(&self, mut cs: CS, other: &Self) -> Result + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + { + + let (diff_lc, diff_in_field, all_constants) = self.compute_diff(&mut cs, other); + + if all_constants && diff_in_field.is_some() { + return Ok(Boolean::constant(diff_in_field.unwrap().is_zero())); + } + + let is_eq = Boolean::alloc(cs.ns(|| "alloc result"), || { + let diff = diff_in_field.ok_or(SynthesisError::AssignmentMissing)?; + Ok(diff.is_zero()) + })?; + + let inv = diff_in_field.map(|diff| { + match diff.inverse() { + Some(inv) => inv, + None => ConstraintF::one(), // in this case the value of inv does not matter for the constraint + } + }); + + let inv_var = FpGadget::::alloc(cs.ns(|| "alloc inv"), || {inv.ok_or(SynthesisError::AssignmentMissing)})?; + + // enforce constraints: + // is_eq * diff_lc = 0 enforces that is_eq == 0 when diff_lc != 0, i.e., when self != other + // inv * diff_lc = 1 - is_eq enforces that is_eq == 1 when diff_lc == 0, i.e., when self == other + cs.enforce(|| "enforce is not eq", |_| is_eq.lc(CS::one(), ConstraintF::one()), |lc| lc + &diff_lc, |lc| lc); + cs.enforce(|| "enforce is eq", |lc| &inv_var.get_variable() + lc, |lc| lc + &diff_lc, |_| is_eq.not().lc(CS::one(), ConstraintF::one())); + + Ok(is_eq) + } + + // conditional_enforce_equal enforces that `self` == `other` if `should_enforce` is true, + // enforce nothing otherwise. This function requires that `self` and `other` are bit sequences + // with length at most the capacity of the field ConstraintF. + fn conditional_enforce_equal(&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract + { + let (diff_lc, diff_in_field, all_constants) = self.compute_diff(&mut cs, other); + + if all_constants && diff_in_field.is_some() && should_enforce.is_constant() { + if should_enforce.get_value().unwrap() && !diff_in_field.unwrap().is_zero() { + return Err(SynthesisError::Unsatisfiable) + } + return Ok(()) + } + + // enforce that diff_lc*should_enforce = 0, which enforces that diff_lc = 0 if should_enforce=1, while it enforces nothing if should_enforce=0 + cs.enforce(|| "conditionally enforce equal", |lc| lc + &diff_lc, |_| should_enforce.lc(CS::one(), ConstraintF::one()), |lc| lc); + + Ok(()) + } + + // conditional_enforce_not_equal enforces that `self` != `other` if `should_enforce` is true, + // enforce nothing otherwise. This function requires that `self` and `other` are bit sequences + // with length at most the capacity of the field ConstraintF. + fn conditional_enforce_not_equal(&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract + { + let (diff_lc, diff_in_field, all_constants) = self.compute_diff(&mut cs, other); + + if all_constants && diff_in_field.is_some() && should_enforce.is_constant() { + if should_enforce.get_value().unwrap() && diff_in_field.unwrap().is_zero() { + return Err(SynthesisError::Unsatisfiable); + } + return Ok(()) + } + + let inv = diff_in_field.map(|diff| { + match diff.inverse() { + Some(inv) => inv, + None => ConstraintF::one(), //in this case the value of inv does not matter for the constraint + } + }); + + let inv_var = FpGadget::::alloc(cs.ns(|| "alloc inv"), || { + let cond = should_enforce.get_value().ok_or(SynthesisError::AssignmentMissing)?; + if cond { + return inv.ok_or(SynthesisError::AssignmentMissing) + } + // should not enforce anything, so set inv_var to 0 to trivially satisfy the constraint + Ok(ConstraintF::zero()) + })?; + + // enforce that diff_lc*inv_var = should_enforce, which enforces that diff_lc != 0 if + // should_enforce=1, while it enforces no constraint on diff_lc when should_enforce = 0, + // since inv can be trivially set by the prover to 0 to satisfy the constraint + cs.enforce(|| "conditionally enforce not equal", |lc| lc + &diff_lc, |lc| &inv_var.get_variable() + lc, |_| should_enforce.lc(CS::one(), ConstraintF::one())); + + Ok(()) + } +} + +impl EqGadget for Vec { + fn is_eq>( + &self, + mut cs: CS, + other: &Self, + ) -> Result { + assert_eq!(self.len(), other.len()); + let len = self.len(); + let field_bits = ConstraintF::Params::CAPACITY as usize; + if field_bits < len { + // if `self` and `other` cannot be packed in a single field element, + // then we split them in chunks of size field_bits and then leverage + // `self` == `other` iff each pair of chunks are equal + let mut chunk_eq_gadgets = Vec::new(); + for (i, (self_chunk, other_chunk)) in self.chunks(field_bits).zip(other.chunks(field_bits)).enumerate() { + let is_eq = BooleanVec(self_chunk).is_eq(cs.ns(|| format!("equality for chunk {}", i)), &BooleanVec(other_chunk))?; + chunk_eq_gadgets.push(is_eq); + } + return Boolean::kary_and(cs.ns(|| "is eq"), chunk_eq_gadgets.as_slice()) + } + + BooleanVec(self).is_eq(cs, &BooleanVec(other)) + } + + fn conditional_enforce_equal> + (&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> { + assert_eq!(self.len(), other.len()); + // split `self` and `other` in chunks of size field_bits and enforce equality between each + // pair of chunks + let field_bits = ConstraintF::Params::CAPACITY as usize; + for (i, (self_chunk, other_chunk)) in self.chunks(field_bits).zip(other.chunks(field_bits)).enumerate() { + BooleanVec(self_chunk).conditional_enforce_equal(cs.ns(|| format!("enforce equal for chunk {}", i)), &BooleanVec(other_chunk), should_enforce)?; + } + + Ok(()) + } + + fn conditional_enforce_not_equal> + (&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> { + assert_eq!(self.len(), other.len()); + let field_bits = ConstraintF::Params::CAPACITY as usize; + let len = self.len(); + if field_bits < len { + // in this case, it is not useful to split `self` and `other` in chunks, + // as `self` != `other` iff at least one pair of chunks are different, but we do not + // know on which pair we should enforce inequality. Therefore, we + // compute a Boolean which is true iff `self != `other` and we conditionally + // enforce it to be true + let is_neq = self.is_neq(cs.ns(|| "is not equal"), other)?; + return is_neq.conditional_enforce_equal(cs, &Boolean::constant(true), should_enforce) + } + // instead, if `self` and `other` can be packed in a single field element, we can + // conditionally enforce their inequality, which is more efficient that calling is_neq + BooleanVec(self).conditional_enforce_not_equal(cs, &BooleanVec(other), should_enforce) + } + +} + /// A struct for collecting identities of linear combinations of Booleans to serve /// a more efficient equality enforcement (by packing them in parallel into constraint /// field elements). @@ -269,3 +486,102 @@ impl> self.cs.num_constraints() } } + +#[cfg(test)] +mod test { + use rand::{thread_rng, Rng}; + use r1cs_core::{ConstraintSystem, SynthesisMode, ConstraintSystemAbstract, ConstraintSystemDebugger}; + use algebra::fields::{tweedle::Fr, PrimeField, FpParameters}; + use crate::{boolean::Boolean, alloc::AllocGadget, eq::EqGadget}; + + #[test] + fn test_eq_for_boolean_vec() { + let rng = &mut thread_rng(); + const FIELD_BITS: usize = ::Params::CAPACITY as usize; + // test with vectors whose length is either smaller or greater than the field capacity + const VEC_LENGTHS: [usize; 3] = [FIELD_BITS /2, FIELD_BITS, FIELD_BITS *2]; + for len in &VEC_LENGTHS { + let num_chunks = *len/ FIELD_BITS + if *len % FIELD_BITS != 0 {1} else {0}; + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let vec1 = (0..*len).map(|_| rng.gen()).collect::>(); + let vec2 = (0..*len).map(|_| rng.gen()).collect::>(); + + let vec1_var = vec1.iter().enumerate().map(|(i, bit)| { + match i % 3 { + 0 => Boolean::constant(*bit), + 1 => Boolean::alloc(cs.ns(|| format!("alloc vec1 bit {}", i)), || Ok(*bit)).unwrap(), + 2 => Boolean::alloc(cs.ns(|| format!("alloc vec1 bit {}", i)), || Ok(*bit)).unwrap().not(), + _ => Boolean::Constant(false), + } + }).collect::>(); + let vec2_var = vec2.iter().enumerate().map(|(i, bit)| { + match i % 3 { + 0 => Boolean::alloc(cs.ns(|| format!("alloc vec2 bit {}", i)), || Ok(*bit)).unwrap().not(), + 1 => Boolean::constant(*bit), + 2 => Boolean::alloc(cs.ns(|| format!("alloc vec2 bit {}", i)), || Ok(*bit)).unwrap(), + _ => Boolean::Constant(false), + } + }).collect::>(); + + // test functions on vectors which are distinct + let is_eq = vec1_var.is_eq(cs.ns(|| "vec1 == vec2"), &vec2_var).unwrap(); + assert!(!is_eq.get_value().unwrap()); + assert!(cs.is_satisfied()); + + let is_neq = vec1_var.is_neq(cs.ns(|| "vec1 != vec2"), &vec2_var).unwrap(); + assert!(is_neq.get_value().unwrap()); + assert!(cs.is_satisfied()); + + vec1_var.enforce_not_equal(cs.ns(|| "enforce vec1 != vec2"), &vec2_var).unwrap(); + assert!(cs.is_satisfied()); + vec1_var.conditional_enforce_equal(cs.ns(|| "fake enforce vec1 == vec2"), &vec2_var, &Boolean::constant(false)).unwrap(); + assert!(cs.is_satisfied()); + vec1_var.conditional_enforce_equal(cs.ns(|| "enforce vec1 == vec2"), &vec2_var, &Boolean::constant(true)).unwrap(); + assert!(!cs.is_satisfied()); + + // test functions on vectors which are equal + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let vec1_var = vec1.iter().enumerate().map(|(i, bit)| { + match i % 3 { + 0 => Boolean::constant(*bit), + 1 => Boolean::alloc(cs.ns(|| format!("alloc vec1 bit {}", i)), || Ok(*bit)).unwrap(), + 2 => Boolean::alloc(cs.ns(|| format!("alloc vec1 bit {}", i)), || Ok(*bit)).unwrap().not(), + _ => Boolean::Constant(false), + } + }).collect::>(); + let vec1_var_copy = vec1.iter().enumerate().map(|(i, bit)| { + match i % 3 { + 0 => Boolean::alloc(cs.ns(|| format!("alloc vec1 copy bit {}", i)), || Ok(*bit)).unwrap(), + 1 => Boolean::constant(*bit), + 2 => Boolean::alloc(cs.ns(|| format!("alloc vec1 copy bit {}", i)), || Ok(*bit)).unwrap().not(), + _ => Boolean::Constant(false), + } + }).collect::>(); + let num_constraints = cs.num_constraints(); + let is_eq = vec1_var.is_eq(cs.ns(|| "vec1 == vec1"), &vec1_var_copy).unwrap(); + assert!(is_eq.get_value().unwrap()); + assert!(cs.is_satisfied()); + assert_eq!(num_constraints + 3 + 4*(num_chunks-1), cs.num_constraints()); + + let num_constraints = cs.num_constraints(); + let is_neq = vec1_var.is_neq(cs.ns(|| "vec1 != vec1"), &vec1_var_copy).unwrap(); + assert!(!is_neq.get_value().unwrap()); + assert!(cs.is_satisfied()); + assert_eq!(num_constraints + 3 + 4*(num_chunks-1), cs.num_constraints()); + + let num_constraints = cs.num_constraints(); + vec1_var.enforce_equal(cs.ns(|| "enforce vec1 == vec1"), &vec1_var_copy).unwrap(); + assert!(cs.is_satisfied()); + assert_eq!(num_constraints + num_chunks, cs.num_constraints()); + vec1_var.conditional_enforce_not_equal(cs.ns(|| "fake enforce vec1!=vec1"), &vec1_var_copy, &Boolean::constant(false)).unwrap(); + assert!(cs.is_satisfied()); + let num_constraints = cs.num_constraints(); + vec1_var.conditional_enforce_not_equal(cs.ns(|| "enforce vec1!=vec1"), &vec1_var_copy, &Boolean::constant(true)).unwrap(); + assert!(!cs.is_satisfied()); + assert_eq!(num_constraints + 4*(num_chunks-1) + if num_chunks == 1 {1} else {4}, cs.num_constraints()); + } + + + } +} From 3b0858f91e3a56e2681d6d1209d8ea20e4b72778 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Wed, 19 Jan 2022 16:46:24 +0100 Subject: [PATCH 08/18] Fix non-native tests --- .../src/fields/nonnative/nonnative_field_gadget.rs | 12 ++++++------ .../short_weierstrass/short_weierstrass_jacobian.rs | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/r1cs/gadgets/std/src/fields/nonnative/nonnative_field_gadget.rs b/r1cs/gadgets/std/src/fields/nonnative/nonnative_field_gadget.rs index df0c0a47f..fbae720a6 100644 --- a/r1cs/gadgets/std/src/fields/nonnative/nonnative_field_gadget.rs +++ b/r1cs/gadgets/std/src/fields/nonnative/nonnative_field_gadget.rs @@ -1657,14 +1657,14 @@ impl ToBytesGadget = bits_per_byte.to_vec(); if bits_per_byte.len() < 8 { bits_per_byte.resize_with(8, || Boolean::constant(false)); } - bytes.push(UInt8::from_bits_le(&bits_per_byte)); - }); + bytes.push(UInt8::from_bits_le(cs.ns(|| format!("from bits of chunk {} to byte", i)), &bits_per_byte)?); + } Ok(bytes) } @@ -1678,14 +1678,14 @@ impl ToBytesGadget::new(); - bits.chunks(8).for_each(|bits_per_byte| { + for (i, bits_per_byte) in bits.chunks(8).enumerate() { let mut bits_per_byte: Vec = bits_per_byte.to_vec(); if bits_per_byte.len() < 8 { bits_per_byte.resize_with(8, || Boolean::constant(false)); } - bytes.push(UInt8::from_bits_le(&bits_per_byte)); - }); + bytes.push(UInt8::from_bits_le(cs.ns(|| format!("from bits of chunk {} to byte", i)), &bits_per_byte)?); + } Ok(bytes) } diff --git a/r1cs/gadgets/std/src/groups/nonnative/short_weierstrass/short_weierstrass_jacobian.rs b/r1cs/gadgets/std/src/groups/nonnative/short_weierstrass/short_weierstrass_jacobian.rs index 55d015f80..683fb25e8 100644 --- a/r1cs/gadgets/std/src/groups/nonnative/short_weierstrass/short_weierstrass_jacobian.rs +++ b/r1cs/gadgets/std/src/groups/nonnative/short_weierstrass/short_weierstrass_jacobian.rs @@ -22,7 +22,7 @@ use crate::{ }, prelude::EqGadget, select::{CondSelectGadget, TwoBitLookupGadget}, - uint8::UInt8, + UInt8, Assignment, ToBitsGadget, ToBytesGadget, }; use std::{ From 150cf1fc77865bab68bc0cd5260c5b8ed48679b2 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Thu, 20 Jan 2022 10:01:16 +0100 Subject: [PATCH 09/18] deleting files with old implementations of uint8. uint32 and uint64 --- r1cs/gadgets/std/src/bits/uint32.rs | 660 ------------------------- r1cs/gadgets/std/src/bits/uint64.rs | 724 ---------------------------- r1cs/gadgets/std/src/bits/uint8.rs | 623 ------------------------ 3 files changed, 2007 deletions(-) delete mode 100644 r1cs/gadgets/std/src/bits/uint32.rs delete mode 100644 r1cs/gadgets/std/src/bits/uint64.rs delete mode 100644 r1cs/gadgets/std/src/bits/uint8.rs diff --git a/r1cs/gadgets/std/src/bits/uint32.rs b/r1cs/gadgets/std/src/bits/uint32.rs deleted file mode 100644 index 53b98f547..000000000 --- a/r1cs/gadgets/std/src/bits/uint32.rs +++ /dev/null @@ -1,660 +0,0 @@ -//! A module for representing 32 bit unsigned integers over a prime constraint field. -//! Besides elementary gadgets (such as toBits, toBytes, etc.) implements the bitwise -//! operations -//! - rotl, rotr, shr, xor, -//! as well as -//! - add_many, which performs the addition modulo 2^32 of a slice of operands, -//! the result of which does exceed in length the capacity bound of the constraint -//! field. -use algebra::{Field, FpParameters, PrimeField}; - -use r1cs_core::{ConstraintSystemAbstract, LinearCombination, SynthesisError}; - -use crate::{ - boolean::{AllocatedBit, Boolean}, - eq::MultiEq, - prelude::*, - Assignment, -}; - -/// Represents an interpretation of 32 `Boolean` objects as an -/// unsigned integer. -#[derive(Clone, Debug)] -pub struct UInt32 { - // Least significant bit_gadget first - pub bits: Vec, - pub value: Option, -} - -impl UInt32 { - /// Construct a constant `UInt32` from a `u32` - pub fn constant(value: u32) -> Self { - let mut bits = Vec::with_capacity(32); - - let mut tmp = value; - for _ in 0..32 { - if tmp & 1 == 1 { - bits.push(Boolean::constant(true)) - } else { - bits.push(Boolean::constant(false)) - } - - tmp >>= 1; - } - - UInt32 { - bits, - value: Some(value), - } - } - - /// Allocate a `UInt32` in the constraint system - pub fn alloc(mut cs: CS, value: Option) -> Result - where - ConstraintF: Field, - CS: ConstraintSystemAbstract, - { - let values = match value { - Some(mut val) => { - let mut v = Vec::with_capacity(32); - - for _ in 0..32 { - v.push(Some(val & 1 == 1)); - val >>= 1; - } - - v - } - None => vec![None; 32], - }; - - let bits = values - .into_iter() - .enumerate() - .map(|(i, v)| { - Ok(Boolean::from(AllocatedBit::alloc( - cs.ns(|| format!("allocated bit_gadget {}", i)), - || v.get(), - )?)) - }) - .collect::, SynthesisError>>()?; - - Ok(UInt32 { bits, value }) - } - - /// Turns this `UInt32` into its little-endian byte order representation. - pub fn to_bits_le(&self) -> Vec { - self.bits.clone() - } - - /// Converts a little-endian byte order representation of bits into a - /// `UInt32`. - pub fn from_bits_le(bits: &[Boolean]) -> Self { - assert_eq!(bits.len(), 32); - - let bits = bits.to_vec(); - - let mut value = Some(0u32); - for b in bits.iter().rev() { - value.as_mut().map(|v| *v <<= 1); - - match b { - &Boolean::Constant(b) => { - if b { - value.as_mut().map(|v| *v |= 1); - } - } - &Boolean::Is(ref b) => match b.get_value() { - Some(true) => { - value.as_mut().map(|v| *v |= 1); - } - Some(false) => {} - None => value = None, - }, - &Boolean::Not(ref b) => match b.get_value() { - Some(false) => { - value.as_mut().map(|v| *v |= 1); - } - Some(true) => {} - None => value = None, - }, - } - } - - Self { bits, value } - } - - pub fn into_bits_be(self) -> Vec { - let mut ret = self.bits; - ret.reverse(); - ret - } - - pub fn from_bits_be(bits: &[Boolean]) -> Self { - assert_eq!(bits.len(), 32); - - let mut value = Some(0u32); - for b in bits { - value.as_mut().map(|v| *v <<= 1); - - match b.get_value() { - Some(true) => { - value.as_mut().map(|v| *v |= 1); - } - Some(false) => {} - None => { - value = None; - } - } - } - - UInt32 { - value, - bits: bits.iter().rev().cloned().collect(), - } - } - - pub fn rotl(&self, by: usize) -> Self { - let by = by % 32; - - let new_bits = self - .bits - .iter() - .skip(32 - by) - .chain(self.bits.iter()) - .take(32) - .cloned() - .collect(); - - UInt32 { - bits: new_bits, - value: self.value.map(|v| v.rotate_left(by as u32)), - } - } - - pub fn rotr(&self, by: usize) -> Self { - let by = by % 32; - - let new_bits = self - .bits - .iter() - .skip(by) - .chain(self.bits.iter()) - .take(32) - .cloned() - .collect(); - - UInt32 { - bits: new_bits, - value: self.value.map(|v| v.rotate_right(by as u32)), - } - } - - pub fn shr(&self, by: usize) -> Self { - let by = by % 32; - - let fill = Boolean::constant(false); - - let new_bits = self - .bits - .iter() // The bits are least significant first - .skip(by) // Skip the bits that will be lost during the shift - .chain(Some(&fill).into_iter().cycle()) // Rest will be zeros - .take(32) // Only 32 bits needed! - .cloned() - .collect(); - - UInt32 { - bits: new_bits, - value: self.value.map(|v| v >> by as u32), - } - } - - /// XOR this `UInt32` with another `UInt32` - pub fn xor(&self, mut cs: CS, other: &Self) -> Result - where - ConstraintF: Field, - CS: ConstraintSystemAbstract, - { - let new_value = match (self.value, other.value) { - (Some(a), Some(b)) => Some(a ^ b), - _ => None, - }; - - let bits = self - .bits - .iter() - .zip(other.bits.iter()) - .enumerate() - .map(|(i, (a, b))| Boolean::xor(cs.ns(|| format!("xor of bit_gadget {}", i)), a, b)) - .collect::>()?; - - Ok(UInt32 { - bits, - value: new_value, - }) - } - - /// Perform addition modulo 2^32 of several `UInt32` objects. - pub fn addmany(mut cs: M, operands: &[Self]) -> Result - where - ConstraintF: PrimeField, - CS: ConstraintSystemAbstract, - M: ConstraintSystemAbstract>, - { - // Make some arbitrary bounds for ourselves to avoid overflows - // in the scalar field - - assert!(ConstraintF::Params::MODULUS_BITS >= 64); - assert!(operands.len() >= 2); // Weird trivial cases that should never happen - // TODO: Check this bound. Is it really needed ? - assert!(operands.len() <= 10); - - // Compute the maximum value of the sum so we allocate enough bits for - // the result - let mut max_value = (operands.len() as u64) * (u64::from(u32::max_value())); - - // Keep track of the resulting value - let mut result_value = Some(0u64); - - // This is a linear combination that we will enforce to equal the - // output - let mut lc = LinearCombination::zero(); - - let mut all_constants = true; - - // Iterate over the operands - for op in operands { - // Accumulate the value - match op.value { - Some(val) => { - result_value.as_mut().map(|v| *v += u64::from(val)); - } - None => { - // If any of our operands have unknown value, we won't - // know the value of the result - result_value = None; - } - } - - // Cumulate the terms that correspond to the bits in op to the - // overall LC - let mut coeff = ConstraintF::one(); - for bit in &op.bits { - // adds 2^i * bit[i] to the lc - lc = lc + &bit.lc(CS::one(), coeff); - - // all_constants = all_constants & bit.is_constant() - all_constants &= bit.is_constant(); - - coeff = coeff.double(); - } - } - - // The value of the actual result is modulo 2^32 - let modular_value = result_value.map(|v| v as u32); - - // In case that all operants are constant UInt32 it is enough to return a constant. - if all_constants && modular_value.is_some() { - // We can just return a constant, rather than - // unpacking the result into allocated bits. - - return Ok(UInt32::constant(modular_value.unwrap())); - } - - // Storage area for the resulting bits - let mut result_bits = vec![]; - - // Linear combination representing the output, - // for comparison with the sum of the operands - let mut result_lc = LinearCombination::zero(); - - // Allocate each bit of the result from result_val - let mut coeff = ConstraintF::one(); - let mut i = 0; - while max_value != 0 { - // Allocate the bit using result_value - let b = AllocatedBit::alloc(cs.ns(|| format!("result bit {}", i)), || { - result_value.map(|v| (v >> i) & 1 == 1).get() - })?; - - // Add this bit to the result combination - result_lc += (coeff, b.get_variable()); - - result_bits.push(b.into()); - - max_value >>= 1; - i += 1; - coeff = coeff.double(); - } - - // Enforce equality between the sum and result by aggregating it - // in the MultiEq - cs.get_root().enforce_equal(i, &lc, &result_lc); - - // Discard carry bits that we don't care about - result_bits.truncate(32); - - Ok(UInt32 { - bits: result_bits, - value: modular_value, - }) - } -} - -impl ToBytesGadget for UInt32 { - #[inline] - fn to_bytes>( - &self, - _cs: CS, - ) -> Result, SynthesisError> { - let value_chunks = match self.value.map(|val| { - use algebra::bytes::ToBytes; - let mut bytes = [0u8; 4]; - val.write(bytes.as_mut()).unwrap(); - bytes - }) { - Some(chunks) => [ - Some(chunks[0]), - Some(chunks[1]), - Some(chunks[2]), - Some(chunks[3]), - ], - None => [None, None, None, None], - }; - let mut bytes = Vec::new(); - for (i, chunk8) in self.to_bits_le().chunks(8).into_iter().enumerate() { - let byte = UInt8 { - bits: chunk8.to_vec(), - value: value_chunks[i], - }; - bytes.push(byte); - } - - Ok(bytes) - } - - fn to_bytes_strict>( - &self, - cs: CS, - ) -> Result, SynthesisError> { - self.to_bytes(cs) - } -} - -impl PartialEq for UInt32 { - fn eq(&self, other: &Self) -> bool { - self.value.is_some() && other.value.is_some() && self.value == other.value - } -} - -impl Eq for UInt32 {} - -impl EqGadget for UInt32 { - fn is_eq>( - &self, - cs: CS, - other: &Self, - ) -> Result { - self.bits.as_slice().is_eq(cs, &other.bits) - } - - fn conditional_enforce_equal>( - &self, - cs: CS, - other: &Self, - should_enforce: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_equal(cs, &other.bits, should_enforce) - } - - fn conditional_enforce_not_equal>( - &self, - cs: CS, - other: &Self, - should_enforce: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_not_equal(cs, &other.bits, should_enforce) - } -} - -#[cfg(all(test, feature = "bls12_381"))] -mod test { - use super::UInt32; - use crate::{bits::boolean::Boolean, eq::MultiEq}; - use algebra::fields::{bls12_381::Fr, Field}; - use r1cs_core::{ - ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode, - }; - use rand::{Rng, SeedableRng}; - use rand_xorshift::XorShiftRng; - - #[test] - fn test_uint32_from_bits() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let v = (0..32) - .map(|_| Boolean::constant(rng.gen())) - .collect::>(); - - let b = UInt32::from_bits_le(&v); - - for (i, bit_gadget) in b.bits.iter().enumerate() { - match bit_gadget { - &Boolean::Constant(bit_gadget) => { - assert!(bit_gadget == ((b.value.unwrap() >> i) & 1 == 1)); - } - _ => unreachable!(), - } - } - - let expected_to_be_same = b.to_bits_le(); - - for x in v.iter().zip(expected_to_be_same.iter()) { - match x { - (&Boolean::Constant(true), &Boolean::Constant(true)) => {} - (&Boolean::Constant(false), &Boolean::Constant(false)) => {} - _ => unreachable!(), - } - } - } - } - - #[test] - fn test_uint32_xor() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let a: u32 = rng.gen(); - let b: u32 = rng.gen(); - let c: u32 = rng.gen(); - - let mut expected = a ^ b ^ c; - - let a_bit = UInt32::alloc(cs.ns(|| "a_bit"), Some(a)).unwrap(); - let b_bit = UInt32::constant(b); - let c_bit = UInt32::alloc(cs.ns(|| "c_bit"), Some(c)).unwrap(); - - let r = a_bit.xor(cs.ns(|| "first xor"), &b_bit).unwrap(); - let r = r.xor(cs.ns(|| "second xor"), &c_bit).unwrap(); - - assert!(cs.is_satisfied()); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - &Boolean::Is(ref b) => { - assert!(b.get_value().unwrap() == (expected & 1 == 1)); - } - &Boolean::Not(ref b) => { - assert!(b.get_value().unwrap() != (expected & 1 == 1)); - } - &Boolean::Constant(b) => { - assert!(b == (expected & 1 == 1)); - } - } - - expected >>= 1; - } - } - } - - #[test] - fn test_uint32_addmany_constants() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let num_operands = 10; - - let operands_val = (0..num_operands).map(|_| rng.gen()).collect::>(); - let mut expected = operands_val.iter().fold(0u32, |acc, x| x.wrapping_add(acc)); - - let operands_gadget = operands_val - .into_iter() - .map(UInt32::constant) - .collect::>(); - - let r = { - let mut cs = MultiEq::new(&mut cs); - let r = UInt32::addmany(cs.ns(|| "addition"), operands_gadget.as_slice()).unwrap(); - r - }; - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - &Boolean::Is(_) => panic!(), - &Boolean::Not(_) => panic!(), - &Boolean::Constant(b) => { - assert!(b == (expected & 1 == 1)); - } - } - - expected >>= 1; - } - - assert!(cs.is_satisfied()); - } - } - - #[test] - fn test_uint32_addmany() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let num_operands = 10; - - let operands_val = (0..num_operands).map(|_| rng.gen()).collect::>(); - let mut expected = operands_val.iter().fold(0u32, |acc, x| x.wrapping_add(acc)); - - let operands_gadget = operands_val - .into_iter() - .enumerate() - .map(|(i, val)| { - UInt32::alloc(cs.ns(|| format!("alloc u32 {}", i)), Some(val)).unwrap() - }) - .collect::>(); - - let r = { - let mut cs = MultiEq::new(&mut cs); - let r = UInt32::addmany(cs.ns(|| "addition"), operands_gadget.as_slice()).unwrap(); - r - }; - - assert!(cs.is_satisfied()); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - &Boolean::Is(ref b) => { - assert!(b.get_value().unwrap() == (expected & 1 == 1)); - } - &Boolean::Not(ref b) => { - assert!(b.get_value().unwrap() != (expected & 1 == 1)); - } - &Boolean::Constant(_) => unreachable!(), - } - - expected >>= 1; - } - - // Flip a bit_gadget and see if the addition constraint still works - if cs.get("addition/result bit 0/boolean").is_zero() { - cs.set("addition/result bit 0/boolean", Field::one()); - } else { - cs.set("addition/result bit 0/boolean", Field::zero()); - } - - assert!(!cs.is_satisfied()); - } - } - - #[test] - fn test_uint32_rotr() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - let mut num = rng.gen(); - - let a = UInt32::constant(num); - - for i in 0..32 { - let b = a.rotr(i); - - assert!(b.value.unwrap() == num); - - let mut tmp = num; - for b in &b.bits { - match b { - &Boolean::Constant(b) => { - assert_eq!(b, tmp & 1 == 1); - } - _ => unreachable!(), - } - - tmp >>= 1; - } - - num = num.rotate_right(1); - } - } - - #[test] - fn test_uint32_rotl() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - let mut num = rng.gen(); - - let a = UInt32::constant(num); - - for i in 0..32 { - let b = a.rotl(i); - - assert!(b.value.unwrap() == num); - - let mut tmp = num; - for b in &b.bits { - match b { - &Boolean::Constant(b) => { - assert_eq!(b, tmp & 1 == 1); - } - _ => unreachable!(), - } - - tmp >>= 1; - } - - num = num.rotate_left(1); - } - } -} diff --git a/r1cs/gadgets/std/src/bits/uint64.rs b/r1cs/gadgets/std/src/bits/uint64.rs deleted file mode 100644 index de17674e4..000000000 --- a/r1cs/gadgets/std/src/bits/uint64.rs +++ /dev/null @@ -1,724 +0,0 @@ -use algebra::{Field, FpParameters, PrimeField}; - -use r1cs_core::{ConstraintSystemAbstract, LinearCombination, SynthesisError}; - -use crate::{ - boolean::{AllocatedBit, Boolean}, - prelude::*, - Assignment, -}; - -/// Represents an interpretation of 64 `Boolean` objects as an -/// unsigned integer. -#[derive(Clone, Debug)] -pub struct UInt64 { - // Least significant bit_gadget first - bits: Vec, - value: Option, -} - -impl UInt64 { - pub fn get_value(&self) -> Option { - self.value - } - - /// Construct a constant `UInt64` from a `u64` - pub fn constant(value: u64) -> Self { - let mut bits = Vec::with_capacity(64); - - let mut tmp = value; - for _ in 0..64 { - if tmp & 1 == 1 { - bits.push(Boolean::constant(true)) - } else { - bits.push(Boolean::constant(false)) - } - - tmp >>= 1; - } - - UInt64 { - bits, - value: Some(value), - } - } - - /// Allocate a `UInt64` in the constraint system - pub fn alloc(mut cs: CS, value: Option) -> Result - where - ConstraintF: Field, - CS: ConstraintSystemAbstract, - { - let values = match value { - Some(mut val) => { - let mut v = Vec::with_capacity(64); - - for _ in 0..64 { - v.push(Some(val & 1 == 1)); - val >>= 1; - } - - v - } - None => vec![None; 64], - }; - - let bits = values - .into_iter() - .enumerate() - .map(|(i, v)| { - Ok(Boolean::from(AllocatedBit::alloc( - cs.ns(|| format!("allocated bit_gadget {}", i)), - || v.get(), - )?)) - }) - .collect::, SynthesisError>>()?; - - Ok(UInt64 { bits, value }) - } - - /// Turns this `UInt64` into its little-endian byte order representation. - pub fn to_bits_le(&self) -> Vec { - self.bits.clone() - } - - /// Converts a little-endian byte order representation of bits into a - /// `UInt64`. - pub fn from_bits_le(bits: &[Boolean]) -> Self { - assert_eq!(bits.len(), 64); - - let bits = bits.to_vec(); - - let mut value = Some(0u64); - for b in bits.iter().rev() { - value.as_mut().map(|v| *v <<= 1); - - match b { - &Boolean::Constant(b) => { - if b { - value.as_mut().map(|v| *v |= 1); - } - } - &Boolean::Is(ref b) => match b.get_value() { - Some(true) => { - value.as_mut().map(|v| *v |= 1); - } - Some(false) => {} - None => value = None, - }, - &Boolean::Not(ref b) => match b.get_value() { - Some(false) => { - value.as_mut().map(|v| *v |= 1); - } - Some(true) => {} - None => value = None, - }, - } - } - - Self { bits, value } - } - - pub fn rotr(&self, by: usize) -> Self { - let by = by % 64; - - let new_bits = self - .bits - .iter() - .skip(by) - .chain(self.bits.iter()) - .take(64) - .cloned() - .collect(); - - UInt64 { - bits: new_bits, - value: self.value.map(|v| v.rotate_right(by as u32)), - } - } - - /// XOR this `UInt64` with another `UInt64` - pub fn xor(&self, mut cs: CS, other: &Self) -> Result - where - ConstraintF: Field, - CS: ConstraintSystemAbstract, - { - let new_value = match (self.value, other.value) { - (Some(a), Some(b)) => Some(a ^ b), - _ => None, - }; - - let bits = self - .bits - .iter() - .zip(other.bits.iter()) - .enumerate() - .map(|(i, (a, b))| Boolean::xor(cs.ns(|| format!("xor of bit_gadget {}", i)), a, b)) - .collect::>()?; - - Ok(UInt64 { - bits, - value: new_value, - }) - } - - /// Perform modular addition of several `UInt64` objects. - pub fn addmany(mut cs: CS, operands: &[Self]) -> Result - where - ConstraintF: PrimeField, - CS: ConstraintSystemAbstract, - { - // Make some arbitrary bounds for ourselves to avoid overflows - // in the scalar field - assert!(ConstraintF::Params::MODULUS_BITS >= 128); - - assert!(!operands.is_empty()); - assert!(operands.len() <= 10); - - if operands.len() == 1 { - return Ok(operands[0].clone()); - } - - // Compute the maximum value of the sum so we allocate enough bits for - // the result - let mut max_value = (operands.len() as u128) * u128::from(u64::max_value()); - - // Keep track of the resulting value - let mut result_value = Some(0u64 as u128); - - // This is a linear combination that we will enforce to be "zero" - let mut lc = LinearCombination::zero(); - - let mut all_constants = true; - - // Iterate over the operands - for op in operands { - // Accumulate the value - match op.value { - Some(val) => { - result_value.as_mut().map(|v| *v += u128::from(val)); - } - None => { - // If any of our operands have unknown value, we won't - // know the value of the result - result_value = None; - } - } - - // Iterate over each bit_gadget of the operand and add the operand to - // the linear combination - let mut coeff = ConstraintF::one(); - for bit in &op.bits { - match *bit { - Boolean::Is(ref bit) => { - all_constants = false; - - // Add coeff * bit_gadget - lc += (coeff, bit.get_variable()); - } - Boolean::Not(ref bit) => { - all_constants = false; - - // Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * bit_gadget - lc = lc + (coeff, CS::one()) - (coeff, bit.get_variable()); - } - Boolean::Constant(bit) => { - if bit { - lc += (coeff, CS::one()); - } - } - } - - coeff.double_in_place(); - } - } - - // The value of the actual result is modulo 2^64 - let modular_value = result_value.map(|v| v as u64); - - if all_constants && modular_value.is_some() { - // We can just return a constant, rather than - // unpacking the result into allocated bits. - - return Ok(UInt64::constant(modular_value.unwrap())); - } - - // Storage area for the resulting bits - let mut result_bits = vec![]; - - // Allocate each bit_gadget of the result - let mut coeff = ConstraintF::one(); - let mut i = 0; - while max_value != 0 { - // Allocate the bit_gadget - let b = AllocatedBit::alloc(cs.ns(|| format!("result bit_gadget {}", i)), || { - result_value.map(|v| (v >> i) & 1 == 1).get() - })?; - - // Subtract this bit_gadget from the linear combination to ensure the sums - // balance out - lc = lc - (coeff, b.get_variable()); - - result_bits.push(b.into()); - - max_value >>= 1; - i += 1; - coeff.double_in_place(); - } - - // Enforce that the linear combination equals zero - cs.enforce(|| "modular addition", |lc| lc, |lc| lc, |_| lc); - - // Discard carry bits that we don't care about - result_bits.truncate(64); - - Ok(UInt64 { - bits: result_bits, - value: modular_value, - }) - } - - pub fn conditionally_add( - mut cs: CS, - bit: &Boolean, - first: Self, - second: Self - ) -> Result - where - ConstraintF: PrimeField, - CS: ConstraintSystemAbstract, - { - let added_values_g = UInt64::addmany(cs.ns(|| "added values"),&[first.clone(),second])?; - Self::conditionally_select( - cs.ns(|| "select added_values or original value"), - bit, - &added_values_g, - &first - ) - } -} - -impl ToBytesGadget for UInt64 { - #[inline] - fn to_bytes>( - &self, - _cs: CS, - ) -> Result, SynthesisError> { - let value_chunks = match self.value.map(|val| { - use algebra::bytes::ToBytes; - let mut bytes = [0u8; 8]; - val.write(bytes.as_mut()).unwrap(); - bytes - }) { - Some(chunks) => [ - Some(chunks[0]), - Some(chunks[1]), - Some(chunks[2]), - Some(chunks[3]), - Some(chunks[4]), - Some(chunks[5]), - Some(chunks[6]), - Some(chunks[7]), - ], - None => [None, None, None, None, None, None, None, None], - }; - let mut bytes = Vec::new(); - for (i, chunk8) in self.to_bits_le().chunks(8).enumerate() { - let byte = UInt8 { - bits: chunk8.to_vec(), - value: value_chunks[i], - }; - bytes.push(byte); - } - - Ok(bytes) - } - - fn to_bytes_strict>( - &self, - cs: CS, - ) -> Result, SynthesisError> { - self.to_bytes(cs) - } -} - -impl PartialEq for UInt64 { - fn eq(&self, other: &Self) -> bool { - self.value.is_some() && other.value.is_some() && self.value == other.value - } -} - -impl Eq for UInt64 {} - -impl EqGadget for UInt64 { - fn is_eq>( - &self, - cs: CS, - other: &Self, - ) -> Result { - self.bits.as_slice().is_eq(cs, &other.bits) - } - - fn conditional_enforce_equal>( - &self, - cs: CS, - other: &Self, - should_enforce: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_equal(cs, &other.bits, should_enforce) - } - - fn conditional_enforce_not_equal>( - &self, - cs: CS, - other: &Self, - should_enforce: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_not_equal(cs, &other.bits, should_enforce) - } -} - -impl CondSelectGadget for UInt64 { - fn conditionally_select>( - mut cs: CS, - cond: &Boolean, - true_value: &Self, - false_value: &Self, - ) -> Result { - let selected_bits = true_value - .bits - .iter() - .zip(&false_value.bits) - .enumerate() - .map(|(i, (t, f))| { - Boolean::conditionally_select(&mut cs.ns(|| format!("bit {}", i)), cond, t, f) - }); - let mut bits = [Boolean::Constant(false); 64]; - for (result, new) in bits.iter_mut().zip(selected_bits) { - *result = new?; - } - - let value = cond.get_value().and_then(|cond| { - if cond { - true_value.get_value() - } else { - false_value.get_value() - } - }); - Ok(Self { - bits: bits.to_vec(), - value, - }) - } - - fn cost() -> usize { - 64 * >::cost() - } -} - -#[cfg(test)] -mod test { - use super::UInt64; - use crate::{alloc::AllocGadget, bits::boolean::Boolean, boolean::AllocatedBit, select::CondSelectGadget}; - use algebra::fields::{bls12_381::Fr, Field}; - use r1cs_core::{ - ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode, - }; - use rand::{Rng, SeedableRng}; - use rand_xorshift::XorShiftRng; - - #[test] - fn test_uint64_from_bits() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let v = (0..64) - .map(|_| Boolean::constant(rng.gen())) - .collect::>(); - - let b = UInt64::from_bits_le(&v); - - for (i, bit_gadget) in b.bits.iter().enumerate() { - match bit_gadget { - &Boolean::Constant(bit_gadget) => { - assert!(bit_gadget == ((b.value.unwrap() >> i) & 1 == 1)); - } - _ => unreachable!(), - } - } - - let expected_to_be_same = b.to_bits_le(); - - for x in v.iter().zip(expected_to_be_same.iter()) { - match x { - (&Boolean::Constant(true), &Boolean::Constant(true)) => {} - (&Boolean::Constant(false), &Boolean::Constant(false)) => {} - _ => unreachable!(), - } - } - } - } - - #[test] - fn test_uint64_xor() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let a: u64 = rng.gen(); - let b: u64 = rng.gen(); - let c: u64 = rng.gen(); - - let mut expected = a ^ b ^ c; - - let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), Some(a)).unwrap(); - let b_bit = UInt64::constant(b); - let c_bit = UInt64::alloc(cs.ns(|| "c_bit"), Some(c)).unwrap(); - - let r = a_bit.xor(cs.ns(|| "first xor"), &b_bit).unwrap(); - let r = r.xor(cs.ns(|| "second xor"), &c_bit).unwrap(); - - assert!(cs.is_satisfied()); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - &Boolean::Is(ref b) => { - assert!(b.get_value().unwrap() == (expected & 1 == 1)); - } - &Boolean::Not(ref b) => { - assert!(b.get_value().unwrap() != (expected & 1 == 1)); - } - &Boolean::Constant(b) => { - assert!(b == (expected & 1 == 1)); - } - } - - expected >>= 1; - } - } - } - - #[test] - fn test_uint64_addmany_constants() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let a: u64 = rng.gen(); - let b: u64 = rng.gen(); - let c: u64 = rng.gen(); - - let a_bit = UInt64::constant(a); - let b_bit = UInt64::constant(b); - let c_bit = UInt64::constant(c); - - let mut expected = a.wrapping_add(b).wrapping_add(c); - - let r = UInt64::addmany(cs.ns(|| "addition"), &[a_bit, b_bit, c_bit]).unwrap(); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - &Boolean::Is(_) => panic!(), - &Boolean::Not(_) => panic!(), - &Boolean::Constant(b) => { - assert!(b == (expected & 1 == 1)); - } - } - - expected >>= 1; - } - } - } - - #[test] - fn test_uint64_addmany() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let a: u64 = rng.gen(); - let b: u64 = rng.gen(); - let c: u64 = rng.gen(); - let d: u64 = rng.gen(); - - let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); - - let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), Some(a)).unwrap(); - let b_bit = UInt64::constant(b); - let c_bit = UInt64::constant(c); - let d_bit = UInt64::alloc(cs.ns(|| "d_bit"), Some(d)).unwrap(); - - let r = a_bit.xor(cs.ns(|| "xor"), &b_bit).unwrap(); - let r = UInt64::addmany(cs.ns(|| "addition"), &[r, c_bit, d_bit]).unwrap(); - - assert!(cs.is_satisfied()); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - &Boolean::Is(ref b) => { - assert!(b.get_value().unwrap() == (expected & 1 == 1)); - } - &Boolean::Not(ref b) => { - assert!(b.get_value().unwrap() != (expected & 1 == 1)); - } - &Boolean::Constant(_) => unreachable!(), - } - - expected >>= 1; - } - - // Flip a bit_gadget and see if the addition constraint still works - if cs.get("addition/result bit_gadget 0/boolean").is_zero() { - cs.set("addition/result bit_gadget 0/boolean", Fr::one()); - } else { - cs.set("addition/result bit_gadget 0/boolean", Fr::zero()); - } - - assert!(!cs.is_satisfied()); - } - } - - #[test] - fn test_uint64_rotr() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - let mut num = rng.gen(); - - let a = UInt64::constant(num); - - for i in 0..64 { - let b = a.rotr(i); - - assert!(b.value.unwrap() == num); - - let mut tmp = num; - for b in &b.bits { - match b { - &Boolean::Constant(b) => { - assert_eq!(b, tmp & 1 == 1); - } - _ => unreachable!(), - } - - tmp >>= 1; - } - - num = num.rotate_right(1); - } - } - - #[derive(Copy, Clone, Debug)] - enum OperandType { - True, - False, - AllocatedTrue, - AllocatedFalse, - NegatedAllocatedTrue, - NegatedAllocatedFalse, - } - #[derive(Copy, Clone, Debug)] - enum VariableType { - Constant, - Allocated, - } - - #[test] - fn test_uint64_cond_select() { - let variants = [ - OperandType::True, - OperandType::False, - OperandType::AllocatedTrue, - OperandType::AllocatedFalse, - OperandType::NegatedAllocatedTrue, - OperandType::NegatedAllocatedFalse, - ]; - let var_type = [ - VariableType::Constant, - VariableType::Allocated, - ]; - - use rand::thread_rng; - let rng = &mut thread_rng(); - - //random generates a and b numbers and check all the conditions for each couple - for _ in 0..1000 { - for condition in variants.iter().cloned() { - for var_a_type in var_type.iter().cloned() { - for var_b_type in var_type.iter().cloned() { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let cond; - let a; - let b; - - { - let mut dyn_construct = |operand, name| { - let cs = cs.ns(|| name); - - match operand { - OperandType::True => Boolean::constant(true), - OperandType::False => Boolean::constant(false), - OperandType::AllocatedTrue => { - Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap()) - } - OperandType::AllocatedFalse => { - Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap()) - } - OperandType::NegatedAllocatedTrue => { - Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap()).not() - } - OperandType::NegatedAllocatedFalse => { - Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap()).not() - } - } - }; - cond = dyn_construct(condition, "cond"); - } - { - let mut dyn_construct_var = |var_type, name, value| { - let cs = cs.ns(|| name); - match var_type { - VariableType::Constant => UInt64::constant(value), - VariableType::Allocated => UInt64::alloc(cs, Some(value)).unwrap(), - } - }; - - a = dyn_construct_var(var_a_type,"var_a",rng.gen()); - b = dyn_construct_var(var_b_type,"var_b",rng.gen()); - } - let before = cs.num_constraints(); - let c = UInt64::conditionally_select(&mut cs, &cond, &a, &b).unwrap(); - let after = cs.num_constraints(); - - assert!( - cs.is_satisfied(), - "failed with operands: cond: {:?}, a: {:?}, b: {:?}", - condition, - a, - b, - ); - assert_eq!( - c.get_value(), - if cond.get_value().unwrap() { - a.get_value() - } else { - b.get_value() - } - ); - assert!(>::cost() >= after - before); - } - } - } - } - } -} diff --git a/r1cs/gadgets/std/src/bits/uint8.rs b/r1cs/gadgets/std/src/bits/uint8.rs deleted file mode 100644 index af22f9d9d..000000000 --- a/r1cs/gadgets/std/src/bits/uint8.rs +++ /dev/null @@ -1,623 +0,0 @@ -use algebra::{Field, FpParameters, PrimeField, ToConstraintField}; - -use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; - -use crate::{boolean::AllocatedBit, fields::fp::FpGadget, prelude::*, Assignment}; -use std::borrow::Borrow; - -/// Represents an interpretation of 8 `Boolean` objects as an -/// unsigned integer. -#[derive(Clone, Debug)] -pub struct UInt8 { - // Least significant bit_gadget first - pub(crate) bits: Vec, - pub(crate) value: Option, -} - -impl UInt8 { - pub fn get_value(&self) -> Option { - self.value - } - - /// Construct a constant vector of `UInt8` from a vector of `u8` - pub fn constant_vec(values: &[u8]) -> Vec { - let mut result = Vec::new(); - for value in values { - result.push(UInt8::constant(*value)); - } - result - } - - /// Construct a constant `UInt8` from a `u8` - pub fn constant(value: u8) -> Self { - let mut bits = Vec::with_capacity(8); - - let mut tmp = value; - for _ in 0..8 { - // If last bit is one, push one. - if tmp & 1 == 1 { - bits.push(Boolean::constant(true)) - } else { - bits.push(Boolean::constant(false)) - } - - tmp >>= 1; - } - - Self { - bits, - value: Some(value), - } - } - - pub fn alloc_vec( - mut cs: CS, - values: &[T], - ) -> Result, SynthesisError> - where - ConstraintF: Field, - CS: ConstraintSystemAbstract, - T: Into> + Copy, - { - let mut output_vec = Vec::with_capacity(values.len()); - for (i, value) in values.iter().enumerate() { - let byte: Option = Into::into(*value); - let alloc_byte = Self::alloc(&mut cs.ns(|| format!("byte_{}", i)), || byte.get())?; - output_vec.push(alloc_byte); - } - Ok(output_vec) - } - - /// Allocates a vector of `u8`'s by first converting (chunks of) them to - /// `ConstraintF` elements, (thus reducing the number of input allocations), - /// and then converts this list of `ConstraintF` gadgets back into - /// bytes. - pub fn alloc_input_vec( - mut cs: CS, - values: &[u8], - ) -> Result, SynthesisError> - where - ConstraintF: PrimeField, - CS: ConstraintSystemAbstract, - { - let values_len = values.len(); - let field_elements: Vec = - ToConstraintField::::to_field_elements(values).unwrap(); - - let max_size = (::Params::CAPACITY / 8) as usize; - - let mut allocated_bits = Vec::new(); - for (i, (field_element, byte_chunk)) in field_elements - .into_iter() - .zip(values.chunks(max_size)) - .enumerate() - { - let fe = FpGadget::alloc_input(&mut cs.ns(|| format!("Field element {}", i)), || { - Ok(field_element) - })?; - - // Let's use the length-restricted variant of the ToBitsGadget to remove the - // padding: the padding bits are not constrained to be zero, so any field element - // passed as input (as long as it has the last bits set to the proper value) can - // satisfy the constraints. This kind of freedom might not be desiderable in - // recursive SNARK circuits, where the public inputs of the inner circuit are - // usually involved in other kind of constraints inside the wrap circuit. - let to_skip: usize = - ::Params::MODULUS_BITS as usize - (byte_chunk.len() * 8); - let mut fe_bits = fe.to_bits_with_length_restriction( - cs.ns(|| format!("Convert fe to bits {}", i)), - to_skip, - )?; - - // FpGadget::to_bits outputs a big-endian binary representation of - // fe_gadget's value, so we have to reverse it to get the little-endian - // form. - fe_bits.reverse(); - - allocated_bits.extend_from_slice(fe_bits.as_slice()); - } - - // Chunk up slices of 8 bit into bytes. - Ok(allocated_bits[0..8 * values_len] - .chunks(8) - .map(Self::from_bits_le) - .collect()) - } - - /// Turns this `UInt8` into its big-endian byte order representation. - pub fn into_bits_be(&self) -> Vec { - self.bits.iter().rev().cloned().collect() - } - - /// Turns this `UInt8` into its little-endian byte order representation. - /// LSB-first means that we can easily get the corresponding field element - /// via double and add. - pub fn into_bits_le(&self) -> Vec { - self.bits.to_vec() - } - - /// Converts a little-endian byte order representation of bits into a - /// `UInt8`. - pub fn from_bits_le(bits: &[Boolean]) -> Self { - assert_eq!(bits.len(), 8); - - let bits = bits.to_vec(); - - let mut value = Some(0u8); - for b in bits.iter().rev() { - value.as_mut().map(|v| *v <<= 1); - - match *b { - Boolean::Constant(b) => { - if b { - value.as_mut().map(|v| *v |= 1); - } - } - Boolean::Is(ref b) => match b.get_value() { - Some(true) => { - value.as_mut().map(|v| *v |= 1); - } - Some(false) => {} - None => value = None, - }, - Boolean::Not(ref b) => match b.get_value() { - Some(false) => { - value.as_mut().map(|v| *v |= 1); - } - Some(true) => {} - None => value = None, - }, - } - } - - Self { bits, value } - } - - /// XOR this `UInt8` with another `UInt8` - pub fn xor(&self, mut cs: CS, other: &Self) -> Result - where - ConstraintF: Field, - CS: ConstraintSystemAbstract, - { - let new_value = match (self.value, other.value) { - (Some(a), Some(b)) => Some(a ^ b), - _ => None, - }; - - let bits = self - .bits - .iter() - .zip(other.bits.iter()) - .enumerate() - .map(|(i, (a, b))| Boolean::xor(cs.ns(|| format!("xor of bit_gadget {}", i)), a, b)) - .collect::>()?; - - Ok(Self { - bits, - value: new_value, - }) - } - - /// OR this `UInt8` with another `UInt8` - pub fn or(&self, mut cs: CS, other: &Self) -> Result - where - ConstraintF: Field, - CS: ConstraintSystemAbstract, - { - let new_value = match (self.value, other.value) { - (Some(a), Some(b)) => Some(a | b), - _ => None, - }; - - let bits = self - .bits - .iter() - .zip(other.bits.iter()) - .enumerate() - .map(|(i, (a, b))| Boolean::or(cs.ns(|| format!("or of bit_gadget {}", i)), a, b)) - .collect::>()?; - - Ok(Self { - bits, - value: new_value, - }) - } -} - -impl PartialEq for UInt8 { - fn eq(&self, other: &Self) -> bool { - self.value.is_some() && other.value.is_some() && self.value == other.value - } -} - -impl Eq for UInt8 {} - -impl EqGadget for UInt8 { - fn is_eq>( - &self, - cs: CS, - other: &Self, - ) -> Result { - self.bits.as_slice().is_eq(cs, &other.bits) - } - - fn conditional_enforce_equal>( - &self, - cs: CS, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_equal(cs, &other.bits, condition) - } - - fn conditional_enforce_not_equal>( - &self, - cs: CS, - other: &Self, - condition: &Boolean, - ) -> Result<(), SynthesisError> { - self.bits - .conditional_enforce_not_equal(cs, &other.bits, condition) - } -} - -impl AllocGadget for UInt8 { - fn alloc>( - mut cs: CS, - value_gen: F, - ) -> Result - where - F: FnOnce() -> Result, - T: Borrow, - { - let value = value_gen().map(|val| *val.borrow()); - let values = match value { - Ok(mut val) => { - let mut v = Vec::with_capacity(8); - - for _ in 0..8 { - v.push(Some(val & 1 == 1)); - val >>= 1; - } - - v - } - _ => vec![None; 8], - }; - - let bits = values - .into_iter() - .enumerate() - .map(|(i, v)| { - Ok(Boolean::from(AllocatedBit::alloc( - &mut cs.ns(|| format!("allocated bit_gadget {}", i)), - || v.ok_or(SynthesisError::AssignmentMissing), - )?)) - }) - .collect::, SynthesisError>>()?; - - Ok(Self { - bits, - value: value.ok(), - }) - } - - fn alloc_input>( - mut cs: CS, - value_gen: F, - ) -> Result - where - F: FnOnce() -> Result, - T: Borrow, - { - let value = value_gen().map(|val| *val.borrow()); - let values = match value { - Ok(mut val) => { - let mut v = Vec::with_capacity(8); - for _ in 0..8 { - v.push(Some(val & 1 == 1)); - val >>= 1; - } - - v - } - _ => vec![None; 8], - }; - - let bits = values - .into_iter() - .enumerate() - .map(|(i, v)| { - Ok(Boolean::from(AllocatedBit::alloc_input( - &mut cs.ns(|| format!("allocated bit_gadget {}", i)), - || v.ok_or(SynthesisError::AssignmentMissing), - )?)) - }) - .collect::, SynthesisError>>()?; - - Ok(Self { - bits, - value: value.ok(), - }) - } -} - -impl CondSelectGadget for UInt8 { - fn conditionally_select>( - mut cs: CS, - cond: &Boolean, - true_value: &Self, - false_value: &Self, - ) -> Result { - let selected_bits = true_value - .bits - .iter() - .zip(&false_value.bits) - .enumerate() - .map(|(i, (t, f))| { - Boolean::conditionally_select(&mut cs.ns(|| format!("bit {}", i)), cond, t, f) - }); - let mut bits = [Boolean::Constant(false); 8]; - for (result, new) in bits.iter_mut().zip(selected_bits) { - *result = new?; - } - - let value = cond.get_value().and_then(|cond| { - if cond { - true_value.get_value() - } else { - false_value.get_value() - } - }); - Ok(Self { - bits: bits.to_vec(), - value, - }) - } - - fn cost() -> usize { - 8 * >::cost() - } -} - -#[cfg(all(test, feature = "bls12_381"))] -mod test { - use super::UInt8; - use crate::{boolean::AllocatedBit, prelude::*}; - use algebra::fields::bls12_381::Fr; - use r1cs_core::{ - ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode, - }; - use rand::{Rng, RngCore, SeedableRng}; - use rand_xorshift::XorShiftRng; - - #[test] - fn test_uint8_from_bits_to_bits() { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let byte_val = 0b01110001; - let byte = UInt8::alloc(cs.ns(|| "alloc value"), || Ok(byte_val)).unwrap(); - let bits = byte.into_bits_le(); - for (i, bit) in bits.iter().enumerate() { - assert_eq!(bit.get_value().unwrap(), (byte_val >> i) & 1 == 1) - } - } - - #[test] - fn test_uint8_alloc_input_vec() { - use algebra::{to_bytes, Field, FpParameters, PrimeField, ToBytes, UniformRand}; - use rand::thread_rng; - - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let rng = &mut thread_rng(); - - //Random test - let samples = 100; - for i in 0..samples { - // Test with random field - let byte_vals = to_bytes!(Fr::rand(rng)).unwrap(); - let bytes = - UInt8::alloc_input_vec(cs.ns(|| format!("alloc value {}", i)), &byte_vals).unwrap(); - assert_eq!(byte_vals.len(), bytes.len()); - for (native_byte, gadget_byte) in byte_vals.into_iter().zip(bytes) { - assert_eq!(gadget_byte.get_value().unwrap(), native_byte); - } - - // Test with random bytes - let mut byte_vals = vec![0u8; rng.gen_range(1..200)]; - rng.fill_bytes(byte_vals.as_mut_slice()); - let bytes = UInt8::alloc_input_vec(cs.ns(|| format!("alloc random {}", i)), &byte_vals) - .unwrap(); - assert_eq!(byte_vals.len(), bytes.len()); - for (native_byte, gadget_byte) in byte_vals.into_iter().zip(bytes) { - assert_eq!(gadget_byte.get_value().unwrap(), native_byte); - } - } - - //Test one - let byte_vals = to_bytes!(Fr::one()).unwrap(); - let bytes = UInt8::alloc_input_vec(cs.ns(|| "alloc one bytes"), &byte_vals).unwrap(); - assert_eq!(byte_vals.len(), bytes.len()); - for (native_byte, gadget_byte) in byte_vals.into_iter().zip(bytes) { - assert_eq!(gadget_byte.get_value().unwrap(), native_byte); - } - - //Test zero - let byte_vals = to_bytes!(Fr::zero()).unwrap(); - let bytes = UInt8::alloc_input_vec(cs.ns(|| "alloc zero bytes"), &byte_vals).unwrap(); - assert_eq!(byte_vals.len(), bytes.len()); - for (native_byte, gadget_byte) in byte_vals.into_iter().zip(bytes) { - assert_eq!(gadget_byte.get_value().unwrap(), native_byte); - } - - //Test over the modulus byte vec - let byte_vals = vec![ - std::u8::MAX; - ((::Params::MODULUS_BITS - + ::Params::REPR_SHAVE_BITS) - / 8) as usize - ]; - let bytes = UInt8::alloc_input_vec(cs.ns(|| "alloc all 1s byte vec"), &byte_vals).unwrap(); - assert_eq!(byte_vals.len(), bytes.len()); - for (native_byte, gadget_byte) in byte_vals.into_iter().zip(bytes) { - assert_eq!(gadget_byte.get_value().unwrap(), native_byte); - } - } - - #[test] - fn test_uint8_from_bits() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let v = (0..8) - .map(|_| Boolean::constant(rng.gen())) - .collect::>(); - - let b = UInt8::from_bits_le(&v); - - for (i, bit_gadget) in b.bits.iter().enumerate() { - match bit_gadget { - &Boolean::Constant(bit_gadget) => { - assert!(bit_gadget == ((b.value.unwrap() >> i) & 1 == 1)); - } - _ => unreachable!(), - } - } - - let expected_to_be_same = b.into_bits_le(); - - for x in v.iter().zip(expected_to_be_same.iter()) { - match x { - (&Boolean::Constant(true), &Boolean::Constant(true)) => {} - (&Boolean::Constant(false), &Boolean::Constant(false)) => {} - _ => unreachable!(), - } - } - } - } - - #[test] - fn test_uint8_xor() { - let mut rng = XorShiftRng::seed_from_u64(1231275789u64); - - for _ in 0..1000 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let a: u8 = rng.gen(); - let b: u8 = rng.gen(); - let c: u8 = rng.gen(); - - let mut expected = a ^ b ^ c; - - let a_bit = UInt8::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap(); - let b_bit = UInt8::constant(b); - let c_bit = UInt8::alloc(cs.ns(|| "c_bit"), || Ok(c)).unwrap(); - - let r = a_bit.xor(cs.ns(|| "first xor"), &b_bit).unwrap(); - let r = r.xor(cs.ns(|| "second xor"), &c_bit).unwrap(); - - assert!(cs.is_satisfied()); - - assert!(r.value == Some(expected)); - - for b in r.bits.iter() { - match b { - &Boolean::Is(ref b) => { - assert!(b.get_value().unwrap() == (expected & 1 == 1)); - } - &Boolean::Not(ref b) => { - assert!(b.get_value().unwrap() != (expected & 1 == 1)); - } - &Boolean::Constant(b) => { - assert!(b == (expected & 1 == 1)); - } - } - - expected >>= 1; - } - } - } - - #[derive(Copy, Clone, Debug)] - enum OperandType { - True, - False, - AllocatedTrue, - AllocatedFalse, - NegatedAllocatedTrue, - NegatedAllocatedFalse, - } - - #[test] - fn test_uint8_cond_select() { - let variants = [ - OperandType::True, - OperandType::False, - OperandType::AllocatedTrue, - OperandType::AllocatedFalse, - OperandType::NegatedAllocatedTrue, - OperandType::NegatedAllocatedFalse, - ]; - - use rand::thread_rng; - let rng = &mut thread_rng(); - - //random generates a and b numbers and check all the conditions for each couple - for _ in 0..1000 { - for condition in variants.iter().cloned() { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let cond; - let a; - let b; - - { - let mut dyn_construct = |operand, name| { - let cs = cs.ns(|| name); - - match operand { - OperandType::True => Boolean::constant(true), - OperandType::False => Boolean::constant(false), - OperandType::AllocatedTrue => { - Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap()) - } - OperandType::AllocatedFalse => { - Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap()) - } - OperandType::NegatedAllocatedTrue => { - Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap()).not() - } - OperandType::NegatedAllocatedFalse => { - Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap()).not() - } - } - }; - - cond = dyn_construct(condition, "cond"); - a = UInt8::constant(rng.gen()); - b = UInt8::constant(rng.gen()); - } - - let before = cs.num_constraints(); - let c = UInt8::conditionally_select(&mut cs, &cond, &a, &b).unwrap(); - let after = cs.num_constraints(); - - assert!( - cs.is_satisfied(), - "failed with operands: cond: {:?}, a: {:?}, b: {:?}", - condition, - a, - b, - ); - assert_eq!( - c.get_value(), - if cond.get_value().unwrap() { - a.get_value() - } else { - b.get_value() - } - ); - assert!(>::cost() >= after - before); - } - } - } -} From 5dec26f30be8416db7d904978995b2603f3ace38 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 28 Jan 2022 12:37:16 +0100 Subject: [PATCH 10/18] Add conditional_enforce_cmp to comparison gadget + fix conditional_enforce_eq for Boolean --- r1cs/gadgets/std/src/bits/boolean.rs | 11 +++- r1cs/gadgets/std/src/bits/macros.rs | 83 ++++++++++++++++++++++------ r1cs/gadgets/std/src/cmp.rs | 21 ++++++- r1cs/gadgets/std/src/fields/cmp.rs | 68 +++++++++++++---------- 4 files changed, 135 insertions(+), 48 deletions(-) diff --git a/r1cs/gadgets/std/src/bits/boolean.rs b/r1cs/gadgets/std/src/bits/boolean.rs index ce6fa4291..d69de5c0c 100644 --- a/r1cs/gadgets/std/src/bits/boolean.rs +++ b/r1cs/gadgets/std/src/bits/boolean.rs @@ -841,7 +841,16 @@ impl EqGadget for Boolean { // 1 - 1 = 0 - 0 = 0 (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => return Ok(()), // false != true - (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), + (Constant(_), Constant(_)) => { + if should_enforce.is_constant() { + return if should_enforce.get_value().unwrap() { + Err(SynthesisError::AssignmentMissing) + } else { + Ok(()) + } + } + LinearCombination::zero() + CS::one() // set difference != 0 + }, // 1 - a (Constant(true), Is(a)) | (Is(a), Constant(true)) => { LinearCombination::zero() + one - a.get_variable() diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index d0c3016a5..efe1347d1 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -1500,7 +1500,7 @@ macro_rules! impl_uint_gadget { let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); let vec_len: usize = rng.gen_range($bit_size..$bit_size*2); - println!("vec len: {}", vec_len); + // allocate input vector of VEC_LEN random bytes let input_vec = (0..vec_len).map(|_| rng.gen()).collect::>(); @@ -2356,9 +2356,9 @@ macro_rules! impl_uint_gadget { // helper closure which is useful to deal with the error returned by enforce cmp // function if both the operands are constant and the comparison is // unsatisfiable on such constants - let handle_constant_operands = |cs: &ConstraintSystem::, must_be_satisfied: bool, cmp_result: Result<(), SynthesisError>, var_type_op1: &VariableType, var_type_op2: &VariableType, assertion_label| { - match (*var_type_op1, *var_type_op2) { - (VariableType::Constant, VariableType::Constant) => { + let handle_constant_operands = |cs: &ConstraintSystem::, must_be_satisfied: bool, cmp_result: Result<(), SynthesisError>, var_type_op1: &VariableType, var_type_op2: &VariableType, is_constant: bool, assertion_label| { + match (*var_type_op1, *var_type_op2, is_constant) { + (VariableType::Constant, VariableType::Constant, true) => { if must_be_satisfied { cmp_result.unwrap() } else { @@ -2407,12 +2407,12 @@ macro_rules! impl_uint_gadget { // test enforce_smaller_than let enforce_ret = a_var.enforce_smaller_than(cs.ns(|| "enforce a < b"), &b_var); - handle_constant_operands(&cs, is_smaller, enforce_ret, var_type_op1, var_type_op2, "enforce_smaller_than test"); + handle_constant_operands(&cs, is_smaller, enforce_ret, var_type_op1, var_type_op2, true, "enforce_smaller_than test"); // test equality let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); let enforce_ret = a_var.enforce_smaller_than(cs.ns(|| "enforce a < a"), &a_var); - handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, "enforce a < a test"); + handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, true, "enforce a < a test"); // test all comparisons @@ -2422,15 +2422,15 @@ macro_rules! impl_uint_gadget { match a.cmp(&b) { Ordering::Less => { let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, "enforce less test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce less test"); let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less equal"), &b_var, Ordering::Less, true); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, "enforce less equal test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce less equal test"); } Ordering::Greater => { let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater"), &b_var, Ordering::Greater, false); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, "enforce greater test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce greater test"); let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater equal"), &b_var, Ordering::Greater, true); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, "enforce greater equal test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce greater equal test"); } _ => {} } @@ -2440,16 +2440,16 @@ macro_rules! impl_uint_gadget { match b.cmp(&a) { Ordering::Less => { let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false); - handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, "enforce less negative test"); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce less negative test"); let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less equal"),&b_var, Ordering::Less, true); - handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, "enforce less equal negative test"); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce less equal negative test"); } Ordering::Greater => { let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater"),&b_var, Ordering::Greater, false); - handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, "enforce greater negative test"); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce greater negative test"); let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater equal"),&b_var, Ordering::Greater, true); - handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, "enforce greater equal negative test"); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce greater equal negative test"); } _ => {} } @@ -2459,9 +2459,60 @@ macro_rules! impl_uint_gadget { let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); let enforce_ret = a_var.enforce_cmp(cs.ns(|| "enforce a <= a"), &a_var, Ordering::Less, true); - handle_constant_operands(&cs, true, enforce_ret, var_type_op1, &VariableType::Constant, "enforce less equal on same variable test"); + handle_constant_operands(&cs, true, enforce_ret, var_type_op1, &VariableType::Constant, true, "enforce less equal on same variable test"); let enforce_ret = a_var.enforce_cmp(cs.ns(|| "enforce a < a"), &a_var, Ordering::Less, false); - handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, "enforce less on same variable test"); + handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, true, "enforce less on same variable test"); + + // test conditional_enforce_cmp + for condition in BOOLEAN_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a: $native_type = rng.gen(); + let b: $native_type = rng.gen(); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let b_var = alloc_fn(&mut cs, "alloc b", var_type_op2, b); + let cond = alloc_boolean_cond(&mut cs, "alloc cond", condition); + match a.cmp(&b) { + Ordering::Less => { + let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce less"), &b_var, &cond, Ordering::Less, false); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce less test"); + let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce less equal"), &b_var, &cond, Ordering::Less, true); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce less equal test"); + } + Ordering::Greater => { + let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce greater"), &b_var, &cond, Ordering::Greater, false); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce greater test"); + let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce greater equal"), &b_var, &cond, Ordering::Greater, true); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce greater equal test"); + } + _ => {} + } + // negative tests + match b.cmp(&a) { + Ordering::Less => { + let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce less"), &b_var, &cond, Ordering::Less, false); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce less negative test"); + let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce less equal"),&b_var, &cond, Ordering::Less, true); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce less equal negative test"); + + } + Ordering::Greater => { + let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce greater"),&b_var, &cond, Ordering::Greater, false); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce greater negative test"); + let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce greater equal"),&b_var, &cond, Ordering::Greater, true); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce greater equal negative test"); + } + _ => {} + } + // test with the same variable + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let cond = alloc_boolean_cond(&mut cs, "alloc cond", condition); + + let enforce_ret = a_var.conditional_enforce_cmp(cs.ns(|| "enforce a <= a"), &a_var, &cond, Ordering::Less, true); + handle_constant_operands(&cs, true, enforce_ret, var_type_op1, &VariableType::Constant, cond.is_constant(), "cond enforce less equal on same variable test"); + let enforce_ret = a_var.conditional_enforce_cmp(cs.ns(|| "enforce a < a"), &a_var, &cond, Ordering::Less, false); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_ret, var_type_op1, &VariableType::Constant, cond.is_constant(), "cond enforce less on same variable test"); + } } } } diff --git a/r1cs/gadgets/std/src/cmp.rs b/r1cs/gadgets/std/src/cmp.rs index 3a52a669c..6d690cd2a 100644 --- a/r1cs/gadgets/std/src/cmp.rs +++ b/r1cs/gadgets/std/src/cmp.rs @@ -44,17 +44,32 @@ pub trait ComparisonGadget: Sized + EqGadget /// Enforce the given order relationship between `self` and `other`. /// If `should_also_check_equality` is true, then the order relationship is not strict /// (e.g., `self <= other` is enforced rather than `self < other`). - // Default implementation calls `is_cmp` to get a Boolean which is true iff the order - // relationship holds, and then enforce this Boolean to be true fn enforce_cmp>( &self, mut cs: CS, other: &Self, ordering: Ordering, should_also_check_equality: bool, + ) -> Result<(), SynthesisError> { + self.conditional_enforce_cmp(&mut cs, other, &Boolean::constant(true), ordering, should_also_check_equality) + } + + /// Enforce the given order relationship between `self` and `other` if `should_enforce` is true, + /// enforce nothing otherwise. + /// If `should_also_check_equality` is true, then the order relationship is not strict + /// (e.g., `self <= other` is enforced rather than `self < other`). + // Default implementation calls `is_cmp` to get a Boolean which is true iff the order + // relationship holds, and then conditionally enforce this Boolean to be true + fn conditional_enforce_cmp>( + &self, + mut cs: CS, + other: &Self, + should_enforce: &Boolean, + ordering: Ordering, + should_also_check_equality: bool, ) -> Result<(), SynthesisError> { let is_cmp = self.is_cmp(cs.ns(|| "cmp outcome"), other, ordering, should_also_check_equality)?; - is_cmp.enforce_equal(cs.ns(|| "enforce cmp"), &Boolean::constant(true)) + is_cmp.conditional_enforce_equal(cs.ns(|| "enforce cmp"), &Boolean::constant(true), should_enforce) } } \ No newline at end of file diff --git a/r1cs/gadgets/std/src/fields/cmp.rs b/r1cs/gadgets/std/src/fields/cmp.rs index 0073f8403..1c5daf38d 100644 --- a/r1cs/gadgets/std/src/fields/cmp.rs +++ b/r1cs/gadgets/std/src/fields/cmp.rs @@ -130,15 +130,28 @@ impl FpGadget { /// Variant of `enforce_cmp` that assumes `self` and `other` are `<= (p-1)/2` and /// does not generate constraints to verify that. - fn enforce_cmp_unchecked>( + pub fn enforce_cmp_unchecked>( &self, mut cs: CS, other: &Self, ordering: Ordering, should_also_check_equality: bool, + ) -> Result<(), SynthesisError> { + self.conditional_enforce_cmp_unchecked(&mut cs, other, &Boolean::constant(true), ordering, should_also_check_equality) + } + + /// Variant of `conditional_enforce_cmp` that assumes `self` and `other` are `<= (p-1)/2` and + /// does not generate constraints to verify that. + pub fn conditional_enforce_cmp_unchecked>( + &self, + mut cs: CS, + other: &Self, + should_enforce: &Boolean, + ordering: Ordering, + should_also_check_equality: bool, ) -> Result<(), SynthesisError> { let is_cmp = self.is_cmp_unchecked(cs.ns(|| "is cmp unchecked"), other, ordering, should_also_check_equality)?; - is_cmp.enforce_equal(cs.ns(|| "enforce cmp"), &Boolean::constant(true)) + is_cmp.conditional_enforce_equal(cs.ns(|| "conditionally enforce cmp"), &Boolean::constant(true), should_enforce) } /// Variant of `is_cmp` that assumes `self` and `other` are `<= (p-1)/2` and does not generate @@ -190,8 +203,8 @@ mod test { use r1cs_core::{ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode}; use crate::{algebra::{UniformRand, PrimeField, fields::tweedle::Fr, Group, - }, fields::fp::FpGadget}; - use crate::{alloc::{AllocGadget, ConstantGadget}, cmp::ComparisonGadget}; + }, fields::{fp::FpGadget, FieldGadget}}; + use crate::{alloc::{AllocGadget, ConstantGadget}, cmp::ComparisonGadget, boolean::Boolean}; fn rand_in_range(rng: &mut R) -> Fr { let pminusonedivtwo: Fr = Fr::modulus_minus_one_div_two().into(); @@ -206,9 +219,10 @@ mod test { } macro_rules! test_cmp_function { - ($cmp_func: tt) => { + ($cmp_func: tt, $should_enforce: expr, $should_fail_with_invalid_operands: expr) => { let mut rng = &mut thread_rng(); - for i in 0..10 { + let should_enforce = Boolean::constant($should_enforce); + for _i in 0..10 { let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); let a = rand_in_range(&mut rng); @@ -218,19 +232,15 @@ mod test { match a.cmp(&b) { Ordering::Less => { - a_var.$cmp_func(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false).unwrap(); - a_var.$cmp_func(cs.ns(|| "enforce less equal"), &b_var, Ordering::Less, true).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less"), &b_var, &should_enforce, Ordering::Less, false).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less equal"), &b_var, &should_enforce, Ordering::Less, true).unwrap(); } Ordering::Greater => { - a_var.$cmp_func(cs.ns(|| "enforce greater"), &b_var, Ordering::Greater, false).unwrap(); - a_var.$cmp_func(cs.ns(|| "enforce greater equal"), &b_var, Ordering::Greater, true).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce greater"), &b_var, &should_enforce, Ordering::Greater, false).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce greater equal"), &b_var, &should_enforce, Ordering::Greater, true).unwrap(); } _ => {} } - - if i == 0 { - println!("number of constraints: {}", cs.num_constraints()); - } if !cs.is_satisfied(){ println!("{:?}", cs.which_is_unsatisfied()); } @@ -247,32 +257,32 @@ mod test { match b.cmp(&a) { Ordering::Less => { - a_var.$cmp_func(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false).unwrap(); - a_var.$cmp_func(cs.ns(|| "enforce less equal"),&b_var, Ordering::Less, true).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less"), &b_var, &should_enforce, Ordering::Less, false).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less equal"),&b_var, &should_enforce, Ordering::Less, true).unwrap(); } Ordering::Greater => { - a_var.$cmp_func(cs.ns(|| "enforce greater"),&b_var, Ordering::Greater, false).unwrap(); - a_var.$cmp_func(cs.ns(|| "enforce greater equal"),&b_var, Ordering::Greater, true).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce greater"),&b_var, &should_enforce, Ordering::Greater, false).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce greater equal"),&b_var, &should_enforce, Ordering::Greater, true).unwrap(); } _ => {} } - assert!(!cs.is_satisfied()); + assert!($should_enforce ^ cs.is_satisfied()); // check that constraints are satisfied iff should_enforce == false } for _i in 0..10 { let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); let a = rand_in_range(&mut rng); let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); - a_var.$cmp_func(cs.ns(|| "enforce less"),&a_var, Ordering::Less, false).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less"),&a_var, &should_enforce, Ordering::Less, false).unwrap(); - assert!(!cs.is_satisfied()); + assert!($should_enforce ^ cs.is_satisfied()); } for _i in 0..10 { let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); let a = rand_in_range(&mut rng); let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); - a_var.$cmp_func(cs.ns(|| "enforce less"),&a_var, Ordering::Less, true).unwrap(); + a_var.$cmp_func(cs.ns(|| "enforce less"),&a_var, &should_enforce, Ordering::Less, true).unwrap(); if !cs.is_satisfied(){ println!("{:?}", cs.which_is_unsatisfied()); } @@ -284,26 +294,28 @@ mod test { let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); let max_val: Fr = Fr::modulus_minus_one_div_two().into(); let max_var = FpGadget::::alloc(&mut cs.ns(|| "generate_max"), || Ok(max_val)).unwrap(); - let zero_var = FpGadget::::from_value(cs.ns(|| "alloc zero"), &Fr::zero()); - zero_var.$cmp_func(cs.ns(|| "enforce 0 <= (p-1) div 2"), &max_var, Ordering::Less, true).unwrap(); + let zero_var = FpGadget::::zero(cs.ns(|| "alloc zero")).unwrap(); + zero_var.$cmp_func(cs.ns(|| "enforce 0 <= (p-1) div 2"), &max_var, &should_enforce, Ordering::Less, true).unwrap(); assert!(cs.is_satisfied()); // test when one of the operands is beyond (p-1)/2 let out_range_var = FpGadget::::alloc(&mut cs.ns(|| "generate_out_range"), || Ok(max_val.double())).unwrap(); - zero_var.$cmp_func(cs.ns(|| "enforce 0 <= p-1"), &out_range_var, Ordering::Less, true).unwrap(); - assert!(!cs.is_satisfied()); + zero_var.$cmp_func(cs.ns(|| "enforce 0 <= p-1"), &out_range_var, &should_enforce, Ordering::Less, true).unwrap(); + assert!($should_fail_with_invalid_operands ^ cs.is_satisfied()); } } #[test] fn test_cmp() { - test_cmp_function!(enforce_cmp); + test_cmp_function!(conditional_enforce_cmp, true, true); + test_cmp_function!(conditional_enforce_cmp, false, true); } #[test] fn test_cmp_unchecked() { - test_cmp_function!(enforce_cmp_unchecked); + test_cmp_function!(conditional_enforce_cmp_unchecked, true, true); + test_cmp_function!(conditional_enforce_cmp_unchecked, false, false); } macro_rules! test_smaller_than_func { From 7a293f8b516336597a7dc12aa86ab055ee024099 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Thu, 10 Mar 2022 11:51:26 +0100 Subject: [PATCH 11/18] Address requested changes in PR --- r1cs/gadgets/std/Cargo.toml | 1 + r1cs/gadgets/std/src/bits/boolean.rs | 62 +++- r1cs/gadgets/std/src/bits/macros.rs | 383 ++++++++++++----------- r1cs/gadgets/std/src/bits/mod.rs | 12 +- r1cs/gadgets/std/src/cmp.rs | 14 +- r1cs/gadgets/std/src/eq.rs | 264 ++++++---------- r1cs/gadgets/std/src/fields/cmp.rs | 447 +++++++++++++++------------ r1cs/gadgets/std/src/lib.rs | 2 +- 8 files changed, 639 insertions(+), 546 deletions(-) diff --git a/r1cs/gadgets/std/Cargo.toml b/r1cs/gadgets/std/Cargo.toml index c8a1b9337..821c96988 100644 --- a/r1cs/gadgets/std/Cargo.toml +++ b/r1cs/gadgets/std/Cargo.toml @@ -35,6 +35,7 @@ num-traits = { version = "=0.2.14", default-features = false, optional = true } num-bigint = { version = "=0.4.3", default-features = false, optional = true } num-integer = { version = "=0.1.44", default-features = false, optional = true } hex = "=0.4.3" +paste = "=1.0.6" [features] llvm_asm = ["algebra/llvm_asm"] diff --git a/r1cs/gadgets/std/src/bits/boolean.rs b/r1cs/gadgets/std/src/bits/boolean.rs index d69de5c0c..b170186ab 100644 --- a/r1cs/gadgets/std/src/bits/boolean.rs +++ b/r1cs/gadgets/std/src/bits/boolean.rs @@ -638,6 +638,7 @@ impl Boolean { ConstraintF: PrimeField, CS: ConstraintSystemAbstract, { + assert!(bits.len() <= ConstraintF::Params::CAPACITY as usize); // this is done with a single constraint as follows: // - Compute a linear combination sum_lc which is the sum of all the bits // - enforce that the sum != 0 with a single constraint: sum*v = 1, where v can only be @@ -672,11 +673,11 @@ impl Boolean { match sum.inverse() { Some(val) => val, None => ConstraintF::one(), // if sum == 0, then inverse can be any value, the constraint should never be verified - }); + }) ; let inv_var = FpGadget::::alloc(cs.ns(|| "alloc inv"), || inv.ok_or(SynthesisError::AssignmentMissing))?; - cs.enforce(|| "enforce self != other", |_| sum_lc, |lc| &inv_var.get_variable() + lc, |_| (ConstraintF::one(), CS::one()).into()); + cs.enforce(|| "enforce or", |_| sum_lc, |lc| &inv_var.get_variable() + lc, |_| (ConstraintF::one(), CS::one()).into()); Ok(()) } @@ -712,10 +713,10 @@ impl Boolean { mut cs: CS, bits: &[Self], element: impl AsRef<[u64]>, - ) -> Result, SynthesisError> - where - ConstraintF: Field, - CS: ConstraintSystemAbstract, + ) -> Result, SynthesisError> + where + ConstraintF: Field, + CS: ConstraintSystemAbstract, { let b: &[u64] = element.as_ref(); @@ -733,13 +734,20 @@ impl Boolean { if bits.len() > element_num_bits { let mut or_result = Boolean::constant(false); for (i, should_be_zero) in bits[element_num_bits..].into_iter().enumerate() { - or_result = Boolean::or(cs.ns(|| format!("or {} {}", should_be_zero.get_value().unwrap(), i)), &or_result, should_be_zero)?; + or_result = Boolean::or( + cs.ns(|| format!("or {} {}", should_be_zero.get_value().unwrap(), i)), + &or_result, + should_be_zero, + )?; let _ = bits_iter.next().unwrap(); } or_result.enforce_equal(cs.ns(|| "enforce equal"), &Boolean::constant(false))?; } - for (i, (b, a)) in BitIterator::without_leading_zeros(b).zip(bits_iter.by_ref()).enumerate() { + for (i, (b, a)) in BitIterator::without_leading_zeros(b) + .zip(bits_iter.by_ref()) + .enumerate() + { if b { // This is part of a run of ones. current_run.push(a.clone()); @@ -759,13 +767,49 @@ impl Boolean { // If `last_run` is false, `a` can be true or false. // // Ergo, at least one of `last_run` and `a` must be false. - Self::enforce_nand(cs.ns(|| format!("enforce nand {}", i)), &[last_run.clone(), a.clone()])?; + Self::enforce_nand( + cs.ns(|| format!("enforce nand {}", i)), + &[last_run.clone(), a.clone()], + )?; } } assert!(bits_iter.next().is_none()); Ok(current_run) } + + /// Given a sequence `bits` of Booleans, constructs a `LinearCombination` + /// representing the field element whose little-endian bit representation corresponds to the + /// input `bits`. `one` represents the fixed variable of a constraint system employed to represent + /// constants in the linear combination. The function returns the constructed linear combination and + /// the field element represented by the linear combination, if all the `bits` in the sequence have values. + /// In addition, the function also returns a flag which specifies if there are no "real" variables + /// in the linear combination, that is if the sequence of Booleans comprises all constant values. + /// Assumes that `bits` can be packed in a single field element (i.e., bits.len() <= ConstraintF::Params::CAPACITY). + pub fn bits_to_linear_combination<'a, ConstraintF:Field>(bits: impl Iterator, one: Variable) -> (LinearCombination, Option, bool) + { + let mut lc = LinearCombination::zero(); + let mut coeff = ConstraintF::one(); + let mut lc_in_field = Some(ConstraintF::zero()); + let mut all_constants = true; + for bit in bits { + lc = lc + &bit.lc(one, coeff); + all_constants &= bit.is_constant(); + lc_in_field = match bit.get_value() { + Some(b) => lc_in_field.as_mut().map(|val| { + if b { + *val += coeff + } + *val + }), + None => None, + }; + + coeff.double_in_place(); + } + + (lc, lc_in_field, all_constants) + } } impl PartialEq for Boolean { diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index efe1347d1..b5a080bff 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -12,7 +12,6 @@ macro_rules! impl_uint_gadget { use std::{borrow::Borrow, ops::{Shl, Shr}, convert::TryInto, cmp::Ordering}; - //ToDo: remove public use of fields #[derive(Clone, Debug)] pub struct $type_name { // Least significant bit_gadget first @@ -61,7 +60,7 @@ macro_rules! impl_uint_gadget { /// `ConstraintF` elements, (thus reducing the number of input allocations), /// and then converts this list of `ConstraintF` gadgets back into /// bits and then packs chunks of such into `Self`. - pub fn alloc_input_vec( + fn alloc_input_vec_from_bytes( mut cs: CS, values: &[u8], ) -> Result, SynthesisError> @@ -118,17 +117,36 @@ macro_rules! impl_uint_gadget { .collect::>()?) } - /// Construct a constant vector of `Self` from a vector of `u8` - pub fn constant_vec(values: &[u8]) -> Vec { - const BYTES_PER_ELEMENT: usize = $bit_size/8; + /// Allocates a vector of `Self` from a slice of values of $native_type by serializing + /// them to sequence of bytes and then converting them to a sequence of `ConstraintF` + /// elements, (thus reducing the number of input allocations); + ///Then, this list of `ConstraintF` gadgets is converted back into + /// bits and then packs chunks of such into `Self`. + pub fn alloc_input_vec( + cs: CS, + values: &[T], + ) -> Result, SynthesisError> + where + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, + T: Into<$native_type> + Copy, + { + // convert values to vector of bytes + let mut values_to_bytes = Vec::new(); + for val in values { + let val_bytes = Into::<$native_type>::into(*val).to_le_bytes(); + values_to_bytes.extend_from_slice(&val_bytes[..]); + } + // alloc vector of `Self` from vector of bytes, minimizing the number of + // `ConstraintF` elements allocated + Self::alloc_input_vec_from_bytes(cs, values_to_bytes.as_slice()) + } + + /// Construct a constant vector of `Self` from a vector of `$native_type` + pub fn constant_vec + Copy>(values: &[T]) -> Vec { let mut result = Vec::new(); - for bytes in values.chunks(BYTES_PER_ELEMENT) { - let mut value: $native_type = 0; - for (i, byte) in bytes.iter().enumerate() { - let byte: $native_type = (*byte).into(); - value |= byte << (i*8); - } - result.push(Self::constant(value)); + for val in values { + result.push(Self::constant((*val).into())); } result } @@ -350,6 +368,7 @@ macro_rules! impl_uint_gadget { F: FnOnce() -> Result, T: Borrow<$native_type> { + assert!($bit_size <= ConstraintF::Params::CAPACITY); let mut value = None; let field_element = FpGadget::::alloc_input(cs.ns(|| "alloc_input as field element"), || { let val = value_gen().map(|val| *val.borrow())?; @@ -475,7 +494,7 @@ macro_rules! impl_uint_gadget { first: &Self, second: &Self, ) -> Result { - let bits = first.bits.iter().zip(second.bits.iter()).enumerate().map(|(i, (t, f))| Boolean::conditionally_select(&mut cs.ns(|| format!("cond select bit {}", i)), cond, t, f)).collect::, SynthesisError>>()?; + let bits = first.bits.iter().zip(second.bits.iter()).enumerate().map(|(i, (t, f))| Boolean::conditionally_select(cs.ns(|| format!("cond select bit {}", i)), cond, t, f)).collect::, SynthesisError>>()?; assert_eq!(bits.len(), $bit_size); // this assert should always be verified if first and second are built only with public methods @@ -500,7 +519,7 @@ macro_rules! impl_uint_gadget { fn shl(self, rhs: usize) -> Self::Output { let by = if rhs >= $bit_size { - $bit_size-1 + panic!("overflow due to left shift of {} bits for {}", rhs, stringify!($type_name)); } else { rhs }; @@ -524,7 +543,7 @@ macro_rules! impl_uint_gadget { fn shr(self, rhs: usize) -> Self::Output { let by = if rhs >= $bit_size { - $bit_size-1 + panic!("overflow due to right shift of {} bits for {}", rhs, stringify!($type_name)); } else { rhs }; @@ -591,7 +610,7 @@ macro_rules! impl_uint_gadget { let bits = self.bits.iter() .zip(other.bits.iter()) .enumerate() - .map(|(i , (b1, b2))| Boolean::$boolean_func(cs.ns(|| format!("xor bit {}", i)), &b1, &b2)) + .map(|(i , (b1, b2))| Boolean::$boolean_func(cs.ns(|| format!("apply binary operation to bit {}", i)), &b1, &b2)) .collect::, SynthesisError>>()?; let value = match other.value { @@ -614,20 +633,12 @@ macro_rules! impl_uint_gadget { // obtain the final outcome of the operation applied to all the operands macro_rules! handle_numoperands_opmany { ($opmany_func: tt, $cs: tt, $operands: tt, $max_num_operands: tt) => { - let num_operands = $operands.len(); // compute the aggregate result over batches of max_num_operands let mut result = $type_name::$opmany_func($cs.ns(|| "first batch of operands"), &$operands[..$max_num_operands])?; - let mut operands_processed = $max_num_operands; - while operands_processed < num_operands { - let last_op_to_process = if operands_processed + $max_num_operands - 1 > num_operands { - num_operands - } else { - operands_processed + $max_num_operands - 1 - }; - let mut next_operands = $operands[operands_processed..last_op_to_process].iter().cloned().collect::>(); - next_operands.push(result); - result = $type_name::$opmany_func($cs.ns(|| format!("operands from {} to {}", operands_processed, last_op_to_process)), &next_operands[..])?; - operands_processed += $max_num_operands - 1; + for (i, next_operands) in $operands[$max_num_operands..].chunks($max_num_operands-1).enumerate() { + let mut current_batch = vec![result]; + current_batch.extend_from_slice(next_operands); + result = $type_name::$opmany_func($cs.ns(|| format!("{}-th batch of operands", i+1)), current_batch.as_slice())?; } return Ok(result); } @@ -719,14 +730,9 @@ macro_rules! impl_uint_gadget { }, }; - let mut coeff = ConstraintF::one(); - for bit in &op.bits { - lc = lc + &bit.lc(CS::one(), coeff); - - all_constants &= bit.is_constant(); - - coeff.double_in_place(); - } + let (current_lc, _, is_op_constant) = Boolean::bits_to_linear_combination(op.bits.iter(), CS::one()); + lc = lc + current_lc; + all_constants &= is_op_constant; } if all_constants && result_value.is_some() { @@ -821,14 +827,9 @@ macro_rules! impl_uint_gadget { None => None, }; - let mut coeff = ConstraintF::one(); - for bit in &op.bits { - lc = lc + &bit.lc(CS::one(), coeff); - - all_constants &= bit.is_constant(); - - coeff.double_in_place(); - } + let (current_lc, _, is_op_constant) = Boolean::bits_to_linear_combination(op.bits.iter(), CS::one()); + lc = lc + current_lc; + all_constants &= is_op_constant; } if all_constants && result_value.is_some() { @@ -837,17 +838,9 @@ macro_rules! impl_uint_gadget { } return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &result_value.unwrap())); } - let result_var = $type_name::alloc(cs.ns(|| "alloc result"), || result_value.ok_or(SynthesisError::AssignmentMissing))?; - let mut coeff = ConstraintF::one(); - let mut result_lc = LinearCombination::zero(); - - for bit in result_var.bits.iter() { - result_lc = result_lc + &bit.lc(CS::one(), coeff); - - coeff.double_in_place(); - } + let (result_lc, _, _) = Boolean::bits_to_linear_combination(result_var.bits.iter(), CS::one()); cs.get_root().enforce_equal($bit_size, &lc, &result_lc); @@ -1031,16 +1024,10 @@ macro_rules! impl_uint_gadget { let max_value = ConstraintF::from($native_type::MAX) + ConstraintF::one(); // lc will be constructed as SUB(self,other)+2^$bit_size let mut lc = (max_value, CS::one()).into(); - let mut coeff = ConstraintF::one(); - let mut all_constants = true; - for (self_bit, other_bit) in self.bits.iter().zip(other.bits.iter()) { - lc = lc + &self_bit.lc(CS::one(), coeff); - lc = lc - &other_bit.lc(CS::one(), coeff); - - all_constants &= self_bit.is_constant() && other_bit.is_constant(); - - coeff.double_in_place(); - } + let (self_lc, _, is_self_constant) = Boolean::bits_to_linear_combination(self.bits.iter(), CS::one()); + let (other_lc, _, is_other_constant) = Boolean::bits_to_linear_combination(other.bits.iter(), CS::one()); + lc = lc + self_lc - other_lc; + let all_constants = is_self_constant && is_other_constant; // diff = self - other mod 2^$bit_size, // while diff_in_field = self - other + 2^$bit_size over the ConstraintF field @@ -1114,17 +1101,10 @@ macro_rules! impl_uint_gadget { */ // lc is constructed as SUB(self, other) - let mut lc = LinearCombination::zero(); - let mut coeff = ConstraintF::one(); - let mut all_constants = true; - for (self_bit, other_bit) in self.bits.iter().zip(other.bits.iter()) { - lc = lc + &self_bit.lc(CS::one(), coeff); - lc = lc - &other_bit.lc(CS::one(), coeff); - - all_constants &= self_bit.is_constant() && other_bit.is_constant(); - - coeff.double_in_place(); - } + let (self_lc, _, is_self_constant) = Boolean::bits_to_linear_combination(self.bits.iter(), CS::one()); + let (other_lc, _, is_other_constant) = Boolean::bits_to_linear_combination(other.bits.iter(), CS::one()); + let lc = self_lc - other_lc; + let all_constants = is_self_constant && is_other_constant; let (diff, is_underflowing) = match (self.value, other.value) { (Some(val1), Some(val2)) => { @@ -1145,13 +1125,7 @@ macro_rules! impl_uint_gadget { let diff_var = Self::alloc(cs.ns(|| "alloc diff"), || diff.ok_or(SynthesisError::AssignmentMissing))?; - let mut diff_lc = LinearCombination::zero(); - let mut coeff = ConstraintF::one(); - for diff_bit in diff_var.bits.iter() { - diff_lc = diff_lc + &diff_bit.lc(CS::one(), coeff); - - coeff.double_in_place(); - } + let (diff_lc, _, _) = Boolean::bits_to_linear_combination(diff_var.bits.iter(), CS::one()); cs.get_root().enforce_equal($bit_size, &lc, &diff_lc); @@ -1197,18 +1171,12 @@ macro_rules! impl_uint_gadget { self.sub(multi_eq.ns(|| "a - b mod 2^n"), other)? }; - let mut delta_lc = LinearCombination::zero(); - let mut coeff = ConstraintF::one(); - let mut all_constants = true; - for ((self_bit, other_bit), diff_bit) in self.bits.iter().zip(other.bits.iter()).zip(diff_var.bits.iter()) { - delta_lc = delta_lc + &self_bit.lc(CS::one(), coeff); - delta_lc = delta_lc - &other_bit.lc(CS::one(), coeff); - delta_lc = delta_lc - &diff_bit.lc(CS::one(), coeff); - - all_constants &= self_bit.is_constant() && other_bit.is_constant(); + let (self_lc, _, is_self_constant) = Boolean::bits_to_linear_combination(self.bits.iter(), CS::one()); + let (other_lc, _, is_other_constant) = Boolean::bits_to_linear_combination(other.bits.iter(), CS::one()); + let (diff_lc, _, is_diff_constant) = Boolean::bits_to_linear_combination(diff_var.bits.iter(), CS::one()); - coeff.double_in_place(); - } + let delta_lc = self_lc - other_lc - diff_lc; + let all_constants = is_self_constant && is_other_constant && is_diff_constant; let (diff_val, is_underflowing, delta) = match (self.get_value(), other.get_value()) { (Some(value1), Some(value2)) => { @@ -1228,7 +1196,7 @@ macro_rules! impl_uint_gadget { return Ok(Boolean::constant(is_underflowing.unwrap())) } - // ToDo: It should not be necessary to allocate it as a Boolean gadget, + //ToDo: It should not be necessary to allocate it as a Boolean gadget, // can be done when a Boolean::from(FieldGadget) will be implemented let is_smaller = Boolean::alloc(cs.ns(|| "alloc result"), || is_underflowing.ok_or(SynthesisError::AssignmentMissing))?; @@ -1243,7 +1211,7 @@ macro_rules! impl_uint_gadget { // enforce constraints: // (1 - is_smaller) * delta_lc = 0 enforces that is_smaller == 1 when delta != 0, i.e., when a < b - // inv * delta_lc = is_smaller enforces that is_smaller == 0 when delta == 0, i.e., when b >= a + // inv * delta_lc = is_smaller enforces that is_smaller == 0 when delta == 0, i.e., when a >= b cs.enforce(|| "enforce is smaller == true", |_| is_smaller.not().lc(CS::one(), ConstraintF::one()), |lc| lc + &delta_lc, |lc| lc); cs.enforce(|| "enforce is smaller == false", |lc| &inv_var.get_variable() + lc, |lc| lc + &delta_lc, |_| is_smaller.lc(CS::one(), ConstraintF::one())); @@ -1297,7 +1265,7 @@ macro_rules! impl_uint_gadget { ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode, SynthesisError, }; - use std::{ops::{Shl, Shr}, cmp::Ordering}; + use std::{ops::{Shl, Shr}, cmp::Ordering, cmp::max}; use crate::{alloc::{AllocGadget, ConstantGadget}, eq::{EqGadget, MultiEq}, boolean::Boolean, ToBitsGadget, FromBitsGadget, ToBytesGadget, RotateUInt, UIntGadget, select::CondSelectGadget, bits::UInt8, cmp::ComparisonGadget}; @@ -1408,6 +1376,7 @@ macro_rules! impl_uint_gadget { witness.enforce_equal(cs.ns(|| "enforce val == val+1"), &witness_ne).unwrap(); assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "enforce val == val+1/conditionally enforce equal for chunk 0"); } @@ -1501,32 +1470,29 @@ macro_rules! impl_uint_gadget { let vec_len: usize = rng.gen_range($bit_size..$bit_size*2); - // allocate input vector of VEC_LEN random bytes - let input_vec = (0..vec_len).map(|_| rng.gen()).collect::>(); + // allocate input vector of VEC_LEN $native_type elements + let input_vec = (0..vec_len).map(|_| rng.gen()).collect::>(); let alloc_vec = $type_name::alloc_input_vec(cs.ns(|| "alloc input vec"), &input_vec).unwrap(); - for (i, (input_bytes, alloc_el)) in input_vec.chunks_exact($bit_size/8).zip(alloc_vec.iter()).enumerate() { - let input_bytes_gadgets = UInt8::constant_vec(&input_bytes); - let input_el = $type_name::from_bytes(&input_bytes_gadgets).unwrap(); + for (i, (input_value, alloc_el)) in input_vec.iter().zip(alloc_vec.iter()).enumerate() { + let input_el = $type_name::constant(*input_value); input_el.enforce_equal(cs.ns(|| format!("eq for chunk {}", i)), &alloc_el).unwrap(); - assert_eq!(input_el.get_value().unwrap(), alloc_el.get_value().unwrap()); + assert_eq!(*input_value, alloc_el.get_value().unwrap()); } assert!(cs.is_satisfied()); - // test allocation of vector of constants from vector of bytes + // test allocation of vector of constants from vector of $native_type elements let constant_vec = $type_name::constant_vec(&input_vec); - for (i, (input_bytes, alloc_el)) in input_vec.chunks($bit_size/8).zip(constant_vec.iter()).enumerate() { - let input_bytes_gadgets = input_bytes.iter().enumerate() - .map(|(j, byte)| UInt8::from_value(cs.ns(|| format!("alloc byte {} in chunk {}", j, i)), byte)) - .collect::>(); - let input_el = $type_name::from_bytes(&input_bytes_gadgets).unwrap(); + for (i, (input_value, alloc_el)) in input_vec.iter().zip(constant_vec.iter()).enumerate() { + let input_el = $type_name::constant(*input_value); input_el.enforce_equal(cs.ns(|| format!("eq for chunk {} of constant vec", i)), &alloc_el).unwrap(); - assert_eq!(input_el.get_value().unwrap(), alloc_el.get_value().unwrap()); + assert_eq!(*input_value, alloc_el.get_value().unwrap()); } + assert!(cs.is_satisfied()); } @@ -1650,22 +1616,35 @@ macro_rules! impl_uint_gadget { test_uint_gadget_value((value >> by) << by, &alloc_var, format!("left shift by {} bits", by).as_str()); } + assert!(cs.is_satisfied()); + } + } - // check that shl(var, by) == shl(var, $bit_size-1) for by > $bit_size - let alloc_var = alloc_fn(&mut cs, "alloc var for invalid shl", var_type, value); - let by = $bit_size*2; - let shl_var = alloc_var.shl(by); - test_uint_gadget_value(value << $bit_size-1, &shl_var, "invalid left shift"); + #[test] + #[should_panic(expected="overflow due to left shift")] + fn test_invalid_left_shift() { + let rng = &mut thread_rng(); - // check that shr(var, by) == shr(var, $bit_size) for by > $bit_size - let alloc_var = alloc_fn(&mut cs, "alloc var for invalid shr", var_type, value); - let by = $bit_size*2; - let shr_var = alloc_var.shr(by); - test_uint_gadget_value(value >> $bit_size-1, &shr_var, "invalid right shift"); + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let var_type = &VARIABLE_TYPES[rng.gen::() % 3]; + let value: $native_type = rng.gen(); - assert!(cs.is_satisfied()); - } + let alloc_var = alloc_fn(&mut cs, "alloc var", var_type, value); + let _shl_var = alloc_var.shl($bit_size*2); + } + + #[test] + #[should_panic(expected="overflow due to right shift")] + fn test_invalid_right_shift() { + let rng = &mut thread_rng(); + + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let var_type = &VARIABLE_TYPES[rng.gen::() % 3]; + let value: $native_type = rng.gen(); + + let alloc_var = alloc_fn(&mut cs, "alloc var", var_type, value); + let _shl_var = alloc_var.shr($bit_size*2); } #[test] @@ -1722,17 +1701,6 @@ macro_rules! impl_uint_gadget { test_uint_gadget_value(res_and, &and_var, format!("and between {:?} {:?}", var_type_a, var_type_b).as_str()); test_uint_gadget_value(res_nand, &nand_var, format!("nand between {:?} {:?}", var_type_a, var_type_b).as_str()); - - let alloc_xor = alloc_fn(&mut cs, "alloc xor result", var_type_a, res_xor); - let alloc_or = alloc_fn(&mut cs, "alloc or result", var_type_b, res_or); - let alloc_and = alloc_fn(&mut cs, "alloc and result", var_type_a, res_and); - let alloc_nand = alloc_fn(&mut cs, "alloc nand result", var_type_b, res_nand); - - alloc_xor.enforce_equal(cs.ns(|| "check xor result"), &xor_var).unwrap(); - alloc_or.enforce_equal(cs.ns(|| "check or result"), &or_var).unwrap(); - alloc_and.enforce_equal(cs.ns(|| "check and result"), &and_var).unwrap(); - alloc_nand.enforce_equal(cs.ns(|| "check nand result"), &nand_var).unwrap(); - assert!(cs.is_satisfied()); } } @@ -1813,6 +1781,7 @@ macro_rules! impl_uint_gadget { cs.set(bit_gadget_path, Fr::zero()); } assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "multieq 0"); // test with all constants let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); @@ -1867,6 +1836,7 @@ macro_rules! impl_uint_gadget { cs.set(bit_gadget_path, Fr::zero()); } assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "mul operands/first batch of operands/unpack result field element/unpacking_constraint"); // set bit value back if cs.get(bit_gadget_path).is_zero() { @@ -1877,17 +1847,15 @@ macro_rules! impl_uint_gadget { assert!(cs.is_satisfied()); // negative test on allocated field element: skip if double and add must be used because the field is too small - let mut last_batch_start_operand = MAX_NUM_OPERANDS + (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1)*(MAX_NUM_OPERANDS-1); - if last_batch_start_operand == NUM_OPERANDS { - last_batch_start_operand -= MAX_NUM_OPERANDS-1; - } - let bit_gadget_path = format!("mul operands/operands from {} to {}/unpack result field element/bit 0/boolean", last_batch_start_operand, NUM_OPERANDS); + let num_batches = (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1); + let bit_gadget_path = format!("mul operands/{}-th batch of operands/unpack result field element/bit 0/boolean", num_batches); if cs.get(&bit_gadget_path).is_zero() { cs.set(&bit_gadget_path, Fr::one()); } else { cs.set(&bit_gadget_path, Fr::zero()); } assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), format!("mul operands/{}-th batch of operands/unpack result field element/unpacking_constraint", num_batches)); // set bit value back if cs.get(&bit_gadget_path).is_zero() { @@ -1912,7 +1880,7 @@ macro_rules! impl_uint_gadget { } #[test] - fn test_modular_arithmetic_operations() { + fn test_conditional_modular_arithmetic_operations() { let rng = &mut thread_rng(); for condition in BOOLEAN_TYPES.iter() { for var_type_op1 in VARIABLE_TYPES.iter() { @@ -1994,6 +1962,7 @@ macro_rules! impl_uint_gadget { cs.set(bit_gadget_path, Fr::zero()); } assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "multieq 0"); // set bit value back if cs.get(bit_gadget_path).is_zero() { @@ -2045,6 +2014,7 @@ macro_rules! impl_uint_gadget { // result should still be corrected, but constraints should not be verified test_uint_gadget_value(result_value, &result_var, "result of overflowing add correctness"); assert!(!cs.is_satisfied(), "checking overflow constraint"); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "multieq 0"); } #[test] @@ -2075,7 +2045,7 @@ macro_rules! impl_uint_gadget { test_uint_gadget_value(result_value, &result_var, "result correctness"); assert!(cs.is_satisfied()); - if MAX_NUM_OPERANDS >= 2 { // negative tests are skipped if if double and add must be used because the field is too small + if MAX_NUM_OPERANDS >= 2 { // negative tests are skipped if double and add must be used because the field is too small // negative test on first batch let bit_gadget_path = "mul operands/first batch of operands/unpack result field element/bit 0/boolean"; if cs.get(bit_gadget_path).is_zero() { @@ -2084,6 +2054,7 @@ macro_rules! impl_uint_gadget { cs.set(bit_gadget_path, Fr::zero()); } assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "mul operands/first batch of operands/unpack result field element/unpacking_constraint"); // set bit value back if cs.get(bit_gadget_path).is_zero() { @@ -2095,17 +2066,15 @@ macro_rules! impl_uint_gadget { // negative test on allocated field element - let mut last_batch_start_operand = MAX_NUM_OPERANDS + (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1)*(MAX_NUM_OPERANDS-1); - if last_batch_start_operand == NUM_OPERANDS { - last_batch_start_operand -= MAX_NUM_OPERANDS-1; - } - let bit_gadget_path = format!("mul operands/operands from {} to {}/unpack result field element/bit 0/boolean", last_batch_start_operand, NUM_OPERANDS); + let num_batches = (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1); + let bit_gadget_path = format!("mul operands/{}-th batch of operands/unpack result field element/bit 0/boolean", num_batches); if cs.get(&bit_gadget_path).is_zero() { cs.set(&bit_gadget_path, Fr::one()); } else { cs.set(&bit_gadget_path, Fr::zero()); } assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), format!("mul operands/{}-th batch of operands/unpack result field element/unpacking_constraint", num_batches)); // set bit value back @@ -2131,7 +2100,8 @@ macro_rules! impl_uint_gadget { // check that constraints are not satisfied in case of overflow let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let operand_values = (0..NUM_OPERANDS).map(|_| rng.gen_range(max_value..=$native_type::MAX)).collect::>(); + let min_value = 1 << ($bit_size/max(MAX_NUM_OPERANDS, 2)); // generate operands higher than this value to ensure that overflow occurs when processing the first batch of operands + let operand_values = (0..NUM_OPERANDS).map(|_| rng.gen_range(min_value..=$native_type::MAX)).collect::>(); let operands = operand_values.iter().enumerate().map(|(i, val)| { alloc_fn(&mut cs, format!("alloc operand {}", i).as_str(), &VARIABLE_TYPES[i % 3], *val) @@ -2151,10 +2121,16 @@ macro_rules! impl_uint_gadget { test_uint_gadget_value(result_value, &result_var, "result of overflowing mul correctness"); assert!(!cs.is_satisfied()); + if MAX_NUM_OPERANDS >= 2 { + assert_eq!(cs.which_is_unsatisfied().unwrap(), "mul overflowing operands/first batch of operands/unpack result field element/unpacking_constraint"); + } else { // double and add case + assert_eq!(cs.which_is_unsatisfied().unwrap(), "mul overflowing operands/double and add/double and add first operands/to bits for digit 0/unpacking_constraint"); + } + } #[test] - fn test_no_carry_arithmetic_operations() { + fn test_conditional_no_carry_arithmetic_operations() { const OPERATIONS: [&str; 2] = ["add", "mul"]; let rng = &mut thread_rng(); for condition in BOOLEAN_TYPES.iter() { @@ -2217,11 +2193,12 @@ macro_rules! impl_uint_gadget { let op1_var = alloc_fn(&mut cs, "alloc op1", &var_type_op1, op1); let op2_var = alloc_fn(&mut cs, "alloc op2", &var_type_op2, op2); + let cond_var = alloc_boolean_cond(&mut cs, "alloc conditional", condition); let result = if is_add { // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped let mut multi_eq = MultiEq::new(&mut cs); - op1_var.conditionally_add_nocarry(&mut multi_eq, &cond_var, &op2_var) + op1_var.conditionally_add_nocarry(multi_eq.ns(|| "conditionally add no carry"), &cond_var, &op2_var) } else { op1_var.conditionally_mul_nocarry(&mut cs, &cond_var, &op2_var) }; @@ -2247,7 +2224,16 @@ macro_rules! impl_uint_gadget { op1 }, &result_var, format!("{} correctness", op).as_str()); assert!(!cs.is_satisfied(), "checking overflow constraint for {:?} {:?} {}", var_type_op1, var_type_op2, is_add); - + if is_add { + assert_eq!(cs.which_is_unsatisfied().unwrap(), "multieq 0"); + } else { + let field_bits = (::Params::CAPACITY) as usize; + if field_bits < 2*$bit_size { // double and add case + assert_eq!(cs.which_is_unsatisfied().unwrap(), "mul values/double and add/double and add first operands/to bits for digit 0/unpacking_constraint"); + } else { + assert_eq!(cs.which_is_unsatisfied().unwrap(), "mul values/unpack result field element/unpacking_constraint"); + } + } } } } @@ -2317,6 +2303,8 @@ macro_rules! impl_uint_gadget { let left_op = alloc_fn(&mut cs, "alloc left op", var_type_left, left); let right_op = alloc_fn(&mut cs, "alloc right op", var_type_right, right); + let cond_var = alloc_boolean_cond(&mut cs, "alloc conditional", condition); + let result = { // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped let mut multi_eq = MultiEq::new(&mut cs); @@ -2332,7 +2320,7 @@ macro_rules! impl_uint_gadget { SynthesisError::Unsatisfiable => (), err => assert!(false, "invalid error returned by sub_noborrow: {}", err) }; - return; + continue; }, (_, _) => result.unwrap(), }; @@ -2343,6 +2331,7 @@ macro_rules! impl_uint_gadget { left }, &result_var, "sub with underflow correctness"); assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "multieq 0"); } } } @@ -2356,7 +2345,7 @@ macro_rules! impl_uint_gadget { // helper closure which is useful to deal with the error returned by enforce cmp // function if both the operands are constant and the comparison is // unsatisfiable on such constants - let handle_constant_operands = |cs: &ConstraintSystem::, must_be_satisfied: bool, cmp_result: Result<(), SynthesisError>, var_type_op1: &VariableType, var_type_op2: &VariableType, is_constant: bool, assertion_label| { + let handle_constant_operands = |cs: &ConstraintSystem::, must_be_satisfied: bool, cmp_result: Result<(), SynthesisError>, var_type_op1: &VariableType, var_type_op2: &VariableType, is_constant: bool, assertion_label, unsatisfied_constraint| { match (*var_type_op1, *var_type_op2, is_constant) { (VariableType::Constant, VariableType::Constant, true) => { if must_be_satisfied { @@ -2371,6 +2360,9 @@ macro_rules! impl_uint_gadget { _ => { cmp_result.unwrap(); assert!(!(cs.is_satisfied() ^ must_be_satisfied), "{} for {:?} {:?}", assertion_label, var_type_op1, var_type_op2); + if !must_be_satisfied { + assert_eq!(cs.which_is_unsatisfied().unwrap(), unsatisfied_constraint); + } } } }; @@ -2387,16 +2379,21 @@ macro_rules! impl_uint_gadget { let b_var = alloc_fn(&mut cs, "alloc b", var_type_op2, b); let is_smaller_var = a_var.is_smaller_than(cs.ns(|| "a < b"), &b_var).unwrap(); - let is_smaller = match a.cmp(&b) { + let (is_smaller, is_equal) = match a.cmp(&b) { Ordering::Less => { assert!(is_smaller_var.get_value().unwrap()); assert!(cs.is_satisfied(), "is smaller"); - true + (true, false) } - Ordering::Greater | Ordering::Equal => { + Ordering::Greater => { assert!(!is_smaller_var.get_value().unwrap()); assert!(cs.is_satisfied(), "is not smaller"); - false + (false, false) + } + Ordering::Equal => { + assert!(!is_smaller_var.get_value().unwrap()); + assert!(cs.is_satisfied(), "is not smaller"); + (false, true) } }; @@ -2407,13 +2404,13 @@ macro_rules! impl_uint_gadget { // test enforce_smaller_than let enforce_ret = a_var.enforce_smaller_than(cs.ns(|| "enforce a < b"), &b_var); - handle_constant_operands(&cs, is_smaller, enforce_ret, var_type_op1, var_type_op2, true, "enforce_smaller_than test"); + handle_constant_operands(&cs, is_smaller, enforce_ret, var_type_op1, var_type_op2, true, "enforce_smaller_than test", if is_equal {"enforce a < b/enforce self != other/enforce or"} else {"enforce a < b/multieq 0"}); // test equality let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); let enforce_ret = a_var.enforce_smaller_than(cs.ns(|| "enforce a < a"), &a_var); - handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, true, "enforce a < a test"); - + handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, true, "enforce a < a test", "enforce a < a/enforce self != other/enforce or"); // test all comparisons let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); @@ -2422,15 +2419,15 @@ macro_rules! impl_uint_gadget { match a.cmp(&b) { Ordering::Less => { let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce less test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce less test", ""); let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less equal"), &b_var, Ordering::Less, true); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce less equal test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce less equal test", ""); } Ordering::Greater => { let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater"), &b_var, Ordering::Greater, false); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce greater test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce greater test", ""); let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater equal"), &b_var, Ordering::Greater, true); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce greater equal test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, true, "enforce greater equal test", ""); } _ => {} } @@ -2440,16 +2437,24 @@ macro_rules! impl_uint_gadget { match b.cmp(&a) { Ordering::Less => { let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less"), &b_var, Ordering::Less, false); - handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce less negative test"); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce less negative test", "enforce less/enforce smaller than/multieq 0"); + // reinitialize cs to test also equality + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let b_var = alloc_fn(&mut cs, "alloc b", var_type_op2, b); let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce less equal"),&b_var, Ordering::Less, true); - handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce less equal negative test"); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce less equal negative test", "enforce less equal/enforce greater equal/multieq 0"); } Ordering::Greater => { let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater"),&b_var, Ordering::Greater, false); - handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce greater negative test"); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce greater negative test", "enforce greater/enforce smaller than/multieq 0"); + // reinitialize cs to test also equality + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let b_var = alloc_fn(&mut cs, "alloc b", var_type_op2, b); let enforce_res = a_var.enforce_cmp(cs.ns(|| "enforce greater equal"),&b_var, Ordering::Greater, true); - handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce greater equal negative test"); + handle_constant_operands(&cs, false, enforce_res, var_type_op1, var_type_op2, true, "enforce greater equal negative test", "enforce greater equal/enforce greater equal/multieq 0"); } _ => {} } @@ -2459,9 +2464,9 @@ macro_rules! impl_uint_gadget { let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); let enforce_ret = a_var.enforce_cmp(cs.ns(|| "enforce a <= a"), &a_var, Ordering::Less, true); - handle_constant_operands(&cs, true, enforce_ret, var_type_op1, &VariableType::Constant, true, "enforce less equal on same variable test"); + handle_constant_operands(&cs, true, enforce_ret, var_type_op1, &VariableType::Constant, true, "enforce less equal on same variable test", ""); let enforce_ret = a_var.enforce_cmp(cs.ns(|| "enforce a < a"), &a_var, Ordering::Less, false); - handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, true, "enforce less on same variable test"); + handle_constant_operands(&cs, false, enforce_ret, var_type_op1, &VariableType::Constant, true, "enforce less on same variable test", "enforce a < a/enforce smaller than/enforce self != other/enforce or"); // test conditional_enforce_cmp for condition in BOOLEAN_TYPES.iter() { @@ -2474,32 +2479,50 @@ macro_rules! impl_uint_gadget { match a.cmp(&b) { Ordering::Less => { let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce less"), &b_var, &cond, Ordering::Less, false); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce less test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce less test", ""); let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce less equal"), &b_var, &cond, Ordering::Less, true); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce less equal test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce less equal test", ""); + let enforce_res = a_var.conditional_enforce_smaller_than(cs.ns(|| "conditional enforce smaller than"), &b_var, &cond); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "conditional enforce smaller than positive test", ""); } Ordering::Greater => { let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce greater"), &b_var, &cond, Ordering::Greater, false); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce greater test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce greater test", ""); let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce greater equal"), &b_var, &cond, Ordering::Greater, true); - handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce greater equal test"); + handle_constant_operands(&cs, true, enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "enforce greater equal test", ""); + let enforce_res = a_var.conditional_enforce_smaller_than(cs.ns(|| "conditional enforce smaller than"), &b_var, &cond); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "conditional enforce smaller than negative test", "conditional enforce smaller than/conditional enforce is smaller/conditional_equals"); } _ => {} } // negative tests + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let b_var = alloc_fn(&mut cs, "alloc b", var_type_op2, b); + let cond = alloc_boolean_cond(&mut cs, "alloc cond", condition); match b.cmp(&a) { Ordering::Less => { let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce less"), &b_var, &cond, Ordering::Less, false); - handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce less negative test"); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce less negative test", "enforce less/conditional enforce cmp/conditional_equals"); + // reinitialize cs to test also equality + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let b_var = alloc_fn(&mut cs, "alloc b", var_type_op2, b); + let cond = alloc_boolean_cond(&mut cs, "alloc cond", condition); let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce less equal"),&b_var, &cond, Ordering::Less, true); - handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce less equal negative test"); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce less equal negative test", "enforce less equal/conditional enforce cmp/conditional_equals"); } Ordering::Greater => { let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce greater"),&b_var, &cond, Ordering::Greater, false); - handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce greater negative test"); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce greater negative test", "enforce greater/conditional enforce cmp/conditional_equals"); + // reinitialize cs to test also equality + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let b_var = alloc_fn(&mut cs, "alloc b", var_type_op2, b); + let cond = alloc_boolean_cond(&mut cs, "alloc cond", condition); let enforce_res = a_var.conditional_enforce_cmp(cs.ns(|| "enforce greater equal"),&b_var, &cond, Ordering::Greater, true); - handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce greater equal negative test"); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_res, var_type_op1, var_type_op2, cond.is_constant(), "cond enforce greater equal negative test", "enforce greater equal/conditional enforce cmp/conditional_equals"); } _ => {} } @@ -2509,9 +2532,15 @@ macro_rules! impl_uint_gadget { let cond = alloc_boolean_cond(&mut cs, "alloc cond", condition); let enforce_ret = a_var.conditional_enforce_cmp(cs.ns(|| "enforce a <= a"), &a_var, &cond, Ordering::Less, true); - handle_constant_operands(&cs, true, enforce_ret, var_type_op1, &VariableType::Constant, cond.is_constant(), "cond enforce less equal on same variable test"); - let enforce_ret = a_var.conditional_enforce_cmp(cs.ns(|| "enforce a < a"), &a_var, &cond, Ordering::Less, false); - handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_ret, var_type_op1, &VariableType::Constant, cond.is_constant(), "cond enforce less on same variable test"); + handle_constant_operands(&cs, true, enforce_ret, var_type_op1, &VariableType::Constant, cond.is_constant(), "cond enforce less equal on same variable test", ""); + let enforce_ret = a_var.conditional_enforce_cmp(cs.ns(|| "conditional enforce a > a"), &a_var, &cond, Ordering::Greater, false); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_ret, var_type_op1, &VariableType::Constant, cond.is_constant(), "cond enforce grater on same variable test", "conditional enforce a > a/conditional enforce cmp/conditional_equals"); + // reinitialize cs to test also enforce_smaller_than + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = alloc_fn(&mut cs, "alloc a", var_type_op1, a); + let cond = alloc_boolean_cond(&mut cs, "alloc cond", condition); + let enforce_ret = a_var.conditional_enforce_smaller_than(cs.ns(|| "conditional enforce a < a"), &a_var, &cond); + handle_constant_operands(&cs, !cond.get_value().unwrap(), enforce_ret, var_type_op1, &VariableType::Constant, cond.is_constant(), "cond enforce smaller than on same variable test", "conditional enforce a < a/conditional enforce is smaller/conditional_equals"); } } } diff --git a/r1cs/gadgets/std/src/bits/mod.rs b/r1cs/gadgets/std/src/bits/mod.rs index 1343e1f07..7c23efd01 100644 --- a/r1cs/gadgets/std/src/bits/mod.rs +++ b/r1cs/gadgets/std/src/bits/mod.rs @@ -6,10 +6,9 @@ use crate::eq::{EqGadget, MultiEq}; use crate::select::CondSelectGadget; use std::fmt::Debug; use std::ops::{Shl, Shr}; +//use rand::{Rng, thread_rng}; pub mod boolean; -//pub mod uint32; -//pub mod uint64; #[macro_use] pub mod macros; @@ -19,7 +18,12 @@ impl_uint_gadget!(UInt32, 32, u32, uint32); impl_uint_gadget!(UInt16, 16, u16, uint16); impl_uint_gadget!(UInt128, 128, u128, uint128); - +// This type alias allows to implement byte serialization/de-serialization functions inside the +// `impl_uint_gadget` macro. +// Indeed, the macro providing implementations of UIntGadget requires to refer to the u8 gadget +// type for byte serialization/de-serialization functions. The type alias allows to employ a type +// defined outside the macro in the interface of byte serialization functions, hence allowing to +// implement them inside the `impl_uint_gadget` macro pub type UInt8 = uint8::U8; pub trait ToBitsGadget { @@ -233,7 +237,7 @@ Sized M: ConstraintSystemAbstract> { let diff = self.sub_noborrow(cs.ns(|| "sub"), other)?; - Self::conditionally_select(cs.ns(|| "conditionally select result"), cond, &diff, self) + Self::conditionally_select(cs.ns(|| "conditionally select result"), cond, &diff, &self) } /// Perform modular multiplication of several `Self` objects. diff --git a/r1cs/gadgets/std/src/cmp.rs b/r1cs/gadgets/std/src/cmp.rs index 6d690cd2a..1fdb729b6 100644 --- a/r1cs/gadgets/std/src/cmp.rs +++ b/r1cs/gadgets/std/src/cmp.rs @@ -10,7 +10,17 @@ pub trait ComparisonGadget: Sized + EqGadget fn is_smaller_than>(&self, cs: CS, other: &Self) -> Result; /// Enforce in the constraint system `cs` that `self < other` - fn enforce_smaller_than>(&self, cs: CS, other: &Self) -> Result<(), SynthesisError>; + fn enforce_smaller_than>(&self, mut cs: CS, other: &Self) -> Result<(), SynthesisError> { + self.is_smaller_than(cs.ns(|| "is smaller"), other)?.enforce_equal(cs.ns(|| "enforce is smaller"), &Boolean::constant(true)) + } + + /// Enforce that `self` < `other` if `cond` is true, enforce nothing otherwise + fn conditional_enforce_smaller_than>(&self, mut cs: CS, other: &Self, cond: &Boolean) -> Result<(), SynthesisError> { + if cond.is_constant() && !cond.get_value().unwrap() { + return Ok(()) // no need to enforce anything as cond is false + } + self.is_smaller_than(cs.ns(|| "is smaller"), other)?.conditional_enforce_equal(cs.ns(|| "conditional enforce is smaller"), &Boolean::constant(true), cond) + } /// Output a `Boolean` gadget which is true iff the given order relationship between `self` /// and `other` holds. If `should_also_check_equality` is true, then the order relationship @@ -70,6 +80,6 @@ pub trait ComparisonGadget: Sized + EqGadget ) -> Result<(), SynthesisError> { let is_cmp = self.is_cmp(cs.ns(|| "cmp outcome"), other, ordering, should_also_check_equality)?; - is_cmp.conditional_enforce_equal(cs.ns(|| "enforce cmp"), &Boolean::constant(true), should_enforce) + is_cmp.conditional_enforce_equal(cs.ns(|| "conditional enforce cmp"), &Boolean::constant(true), should_enforce) } } \ No newline at end of file diff --git a/r1cs/gadgets/std/src/eq.rs b/r1cs/gadgets/std/src/eq.rs index e0a0704c4..8733d2a4a 100644 --- a/r1cs/gadgets/std/src/eq.rs +++ b/r1cs/gadgets/std/src/eq.rs @@ -131,135 +131,130 @@ impl, ConstraintF: Field> EqGadget for [T] } } -// wrapper type employed to implement helper functions for the implementation of EqGadget for -// Vec -struct BooleanVec<'a>(&'a [Boolean]); -impl BooleanVec<'_> { - #[inline] - // helper function that computes a linear combination of the bits of `self` and `other` which - // corresponds to the difference between two field elements a,b, where a (resp. b) is the field - // element whose little-endian bit representation is `self` (resp. other). - // The function returns also a-b over the field (wrapped in an Option) and a flag - // that specifies if all the bits in both `self` and `other` are constants. - fn compute_diff(&self, _cs: CS, other: &Self) -> (LinearCombination, Option, bool) - where - ConstraintF: PrimeField, - CS: ConstraintSystemAbstract, - { - let self_bits = self.0; - let other_bits = other.0; - let field_bits = ConstraintF::Params::CAPACITY as usize; - assert!(self_bits.len() <= field_bits); - assert!(other_bits.len() <= field_bits); - - let mut self_lc = LinearCombination::zero(); - let mut other_lc = LinearCombination::zero(); - let mut coeff = ConstraintF::one(); - let mut diff_in_field = Some(ConstraintF::zero()); - let mut all_constants = true; - for (self_bit, other_bit) in self_bits.iter().zip(other_bits.iter()) { - self_lc = self_lc + &self_bit.lc(CS::one(), coeff); - other_lc = other_lc + &other_bit.lc(CS::one(), coeff); - - all_constants &= self_bit.is_constant() && other_bit.is_constant(); - - diff_in_field = match (self_bit.get_value(), other_bit.get_value()) { - (Some(bit1), Some(bit2)) => diff_in_field.as_mut().map(|diff| { - let self_term = if bit1 { - coeff - } else { - ConstraintF::zero() - }; - let other_term = if bit2 { - coeff - } else { - ConstraintF::zero() - }; - *diff += self_term - other_term; - *diff - }), - _ => None, - }; - - coeff.double_in_place(); - } - (self_lc - other_lc, diff_in_field, all_constants) - } - // is_eq computes a Boolean which is true iff `self` == `other`. This function requires that - // `self` and `other` are bit sequences with length at most the capacity of the field ConstraintF. - fn is_eq(&self, mut cs: CS, other: &Self) -> Result +// helper function that computes a linear combination of the bits of `self` and `other` which +// corresponds to the difference between two field elements a,b, where a (resp. b) is the field +// element whose little-endian bit representation is `self` (resp. other). +// The function returns also a-b over the field (wrapped in an Option) and a flag +// that specifies if all the bits in both `self` and `other` are constants. +fn compute_diff(self_bits: &[Boolean], mut _cs: CS, other_bits: &[Boolean]) -> (LinearCombination, Option, bool) where - ConstraintF: PrimeField, - CS: ConstraintSystemAbstract, - { + ConstraintF: PrimeField, + CS: ConstraintSystemAbstract, +{ + let field_bits = ConstraintF::Params::CAPACITY as usize; + assert!(self_bits.len() <= field_bits); + assert!(other_bits.len() <= field_bits); - let (diff_lc, diff_in_field, all_constants) = self.compute_diff(&mut cs, other); + let (self_lc, self_in_field, is_self_constant) = Boolean::bits_to_linear_combination(self_bits.iter(), CS::one()); + let (other_lc, other_in_field, is_other_constant) = Boolean::bits_to_linear_combination(other_bits.iter(), CS::one()); - if all_constants && diff_in_field.is_some() { - return Ok(Boolean::constant(diff_in_field.unwrap().is_zero())); - } + let diff_in_field = match (self_in_field, other_in_field) { + (Some(self_val), Some(other_val)) => Some(self_val-other_val), + _ => None, + }; + let all_constants = is_self_constant && is_other_constant; - let is_eq = Boolean::alloc(cs.ns(|| "alloc result"), || { - let diff = diff_in_field.ok_or(SynthesisError::AssignmentMissing)?; - Ok(diff.is_zero()) - })?; + (self_lc - other_lc, diff_in_field, all_constants) +} - let inv = diff_in_field.map(|diff| { - match diff.inverse() { - Some(inv) => inv, - None => ConstraintF::one(), // in this case the value of inv does not matter for the constraint +impl EqGadget for Vec { + fn is_eq>( + &self, + mut cs: CS, + other: &Self, + ) -> Result { + assert_eq!(self.len(), other.len()); + //let len = self.len(); + let field_bits = ConstraintF::Params::CAPACITY as usize; + // Since `self` and `other` may not be packed in a single field element, + // we process them in chunks of size field_bits and then employ + // `self` == `other` iff each pair of chunks are equal + let mut chunk_eq_gadgets = Vec::new(); + for (i, (self_chunk, other_chunk)) in self.chunks(field_bits).zip(other.chunks(field_bits)).enumerate() { + + let (diff_lc, diff_in_field, all_constants) = compute_diff(self_chunk, cs.ns(|| format!("compute diff for chunk {}", i)), other_chunk); + + if all_constants && diff_in_field.is_some() { + return Ok(Boolean::constant(diff_in_field.unwrap().is_zero())); } - }); - let inv_var = FpGadget::::alloc(cs.ns(|| "alloc inv"), || {inv.ok_or(SynthesisError::AssignmentMissing)})?; + let is_chunk_eq = Boolean::alloc(cs.ns(|| format!("alloc is_eq flag for chunk {}", i)), || { + let diff = diff_in_field.ok_or(SynthesisError::AssignmentMissing)?; + Ok(diff.is_zero()) + })?; + + let inv = diff_in_field.map(|diff| { + match diff.inverse() { + Some(inv) => inv, + None => ConstraintF::one(), // in this case the value of inv does not matter for the constraint + } + }); - // enforce constraints: - // is_eq * diff_lc = 0 enforces that is_eq == 0 when diff_lc != 0, i.e., when self != other - // inv * diff_lc = 1 - is_eq enforces that is_eq == 1 when diff_lc == 0, i.e., when self == other - cs.enforce(|| "enforce is not eq", |_| is_eq.lc(CS::one(), ConstraintF::one()), |lc| lc + &diff_lc, |lc| lc); - cs.enforce(|| "enforce is eq", |lc| &inv_var.get_variable() + lc, |lc| lc + &diff_lc, |_| is_eq.not().lc(CS::one(), ConstraintF::one())); + let inv_var = FpGadget::::alloc(cs.ns(|| format!("alloc inv for chunk {}", i)), || {inv.ok_or(SynthesisError::AssignmentMissing)})?; - Ok(is_eq) + // enforce constraints: + // is_eq * diff_lc = 0 enforces that is_eq == 0 when diff_lc != 0, i.e., when self != other + // inv * diff_lc = 1 - is_eq enforces that is_eq == 1 when diff_lc == 0, i.e., when self == other + cs.enforce(|| format!("enforce is not eq for chunk {}", i), |_| is_chunk_eq.lc(CS::one(), ConstraintF::one()), |lc| lc + &diff_lc, |lc| lc); + cs.enforce(|| format!("enforce is eq for chunk {}", i), |lc| &inv_var.get_variable() + lc, |lc| lc + &diff_lc, |_| is_chunk_eq.not().lc(CS::one(), ConstraintF::one())); + + // let is_eq = BooleanVec(self_chunk).is_eq(cs.ns(|| format!("equality for chunk {}", i)), &BooleanVec(other_chunk))?; + chunk_eq_gadgets.push(is_chunk_eq); + } + + if chunk_eq_gadgets.len() == 0 { + return Ok(Boolean::constant(true)) + } + + Boolean::kary_and(cs.ns(|| "is eq"), chunk_eq_gadgets.as_slice()) } - // conditional_enforce_equal enforces that `self` == `other` if `should_enforce` is true, - // enforce nothing otherwise. This function requires that `self` and `other` are bit sequences - // with length at most the capacity of the field ConstraintF. - fn conditional_enforce_equal(&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> - where - ConstraintF: PrimeField, - CS: ConstraintSystemAbstract - { - let (diff_lc, diff_in_field, all_constants) = self.compute_diff(&mut cs, other); + fn conditional_enforce_equal> + (&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> { + assert_eq!(self.len(), other.len()); + // split `self` and `other` in chunks of size field_bits and enforce equality between each + // pair of chunks + let field_bits = ConstraintF::Params::CAPACITY as usize; + for (i, (self_chunk, other_chunk)) in self.chunks(field_bits).zip(other.chunks(field_bits)).enumerate() { - if all_constants && diff_in_field.is_some() && should_enforce.is_constant() { - if should_enforce.get_value().unwrap() && !diff_in_field.unwrap().is_zero() { - return Err(SynthesisError::Unsatisfiable) + let (diff_lc, diff_in_field, all_constants) = compute_diff(self_chunk, cs.ns(|| format!("compute diff for chunk {}", i)), other_chunk); + + if all_constants && diff_in_field.is_some() && should_enforce.is_constant() { + if should_enforce.get_value().unwrap() && !diff_in_field.unwrap().is_zero() { + return Err(SynthesisError::Unsatisfiable) + } + return Ok(()) } - return Ok(()) - } - // enforce that diff_lc*should_enforce = 0, which enforces that diff_lc = 0 if should_enforce=1, while it enforces nothing if should_enforce=0 - cs.enforce(|| "conditionally enforce equal", |lc| lc + &diff_lc, |_| should_enforce.lc(CS::one(), ConstraintF::one()), |lc| lc); + // enforce that diff_lc*should_enforce = 0, which enforces that diff_lc = 0 if should_enforce=1, while it enforces nothing if should_enforce=0 + cs.enforce(|| format!("conditionally enforce equal for chunk {}", i), |lc| lc + &diff_lc, |_| should_enforce.lc(CS::one(), ConstraintF::one()), |lc| lc); + } Ok(()) } - // conditional_enforce_not_equal enforces that `self` != `other` if `should_enforce` is true, - // enforce nothing otherwise. This function requires that `self` and `other` are bit sequences - // with length at most the capacity of the field ConstraintF. - fn conditional_enforce_not_equal(&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> - where - ConstraintF: PrimeField, - CS: ConstraintSystemAbstract - { - let (diff_lc, diff_in_field, all_constants) = self.compute_diff(&mut cs, other); + fn conditional_enforce_not_equal> + (&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> { + assert_eq!(self.len(), other.len()); + let field_bits = ConstraintF::Params::CAPACITY as usize; + let len = self.len(); + if field_bits < len { + // In this case we cannot split in chunks here, as it's not true that if two bit vectors + // are not equal, then they are not equal chunkwise too. Therefore, we + // compute a Boolean which is true iff `self != `other` and we conditionally + // enforce it to be true + let is_neq = self.is_neq(cs.ns(|| "is not equal"), other)?; + return is_neq.conditional_enforce_equal(cs, &Boolean::constant(true), should_enforce) + } + // instead, if `self` and `other` can be packed in a single field element, we can + // conditionally enforce their inequality, which is more efficient that calling is_neq + let (diff_lc, diff_in_field, all_constants) = compute_diff(self.as_slice(), &mut cs, other.as_slice()); if all_constants && diff_in_field.is_some() && should_enforce.is_constant() { if should_enforce.get_value().unwrap() && diff_in_field.unwrap().is_zero() { - return Err(SynthesisError::Unsatisfiable); + return Err(SynthesisError::Unsatisfiable); } return Ok(()) } @@ -287,63 +282,6 @@ impl BooleanVec<'_> { Ok(()) } -} - -impl EqGadget for Vec { - fn is_eq>( - &self, - mut cs: CS, - other: &Self, - ) -> Result { - assert_eq!(self.len(), other.len()); - let len = self.len(); - let field_bits = ConstraintF::Params::CAPACITY as usize; - if field_bits < len { - // if `self` and `other` cannot be packed in a single field element, - // then we split them in chunks of size field_bits and then leverage - // `self` == `other` iff each pair of chunks are equal - let mut chunk_eq_gadgets = Vec::new(); - for (i, (self_chunk, other_chunk)) in self.chunks(field_bits).zip(other.chunks(field_bits)).enumerate() { - let is_eq = BooleanVec(self_chunk).is_eq(cs.ns(|| format!("equality for chunk {}", i)), &BooleanVec(other_chunk))?; - chunk_eq_gadgets.push(is_eq); - } - return Boolean::kary_and(cs.ns(|| "is eq"), chunk_eq_gadgets.as_slice()) - } - - BooleanVec(self).is_eq(cs, &BooleanVec(other)) - } - - fn conditional_enforce_equal> - (&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> { - assert_eq!(self.len(), other.len()); - // split `self` and `other` in chunks of size field_bits and enforce equality between each - // pair of chunks - let field_bits = ConstraintF::Params::CAPACITY as usize; - for (i, (self_chunk, other_chunk)) in self.chunks(field_bits).zip(other.chunks(field_bits)).enumerate() { - BooleanVec(self_chunk).conditional_enforce_equal(cs.ns(|| format!("enforce equal for chunk {}", i)), &BooleanVec(other_chunk), should_enforce)?; - } - - Ok(()) - } - - fn conditional_enforce_not_equal> - (&self, mut cs: CS, other: &Self, should_enforce: &Boolean) -> Result<(), SynthesisError> { - assert_eq!(self.len(), other.len()); - let field_bits = ConstraintF::Params::CAPACITY as usize; - let len = self.len(); - if field_bits < len { - // in this case, it is not useful to split `self` and `other` in chunks, - // as `self` != `other` iff at least one pair of chunks are different, but we do not - // know on which pair we should enforce inequality. Therefore, we - // compute a Boolean which is true iff `self != `other` and we conditionally - // enforce it to be true - let is_neq = self.is_neq(cs.ns(|| "is not equal"), other)?; - return is_neq.conditional_enforce_equal(cs, &Boolean::constant(true), should_enforce) - } - // instead, if `self` and `other` can be packed in a single field element, we can - // conditionally enforce their inequality, which is more efficient that calling is_neq - BooleanVec(self).conditional_enforce_not_equal(cs, &BooleanVec(other), should_enforce) - } } diff --git a/r1cs/gadgets/std/src/fields/cmp.rs b/r1cs/gadgets/std/src/fields/cmp.rs index 1c5daf38d..9c4bcdb30 100644 --- a/r1cs/gadgets/std/src/fields/cmp.rs +++ b/r1cs/gadgets/std/src/fields/cmp.rs @@ -5,6 +5,60 @@ use crate::{boolean::Boolean, bits::{ToBitsGadget, FromBitsGadget}, eq::EqGadget use crate::cmp::ComparisonGadget; use crate::fields::{fp::FpGadget, FieldGadget}; +// this macro allows to implement the `unchecked` and `restricted` variants of the `enforce_cmp`, +// `conditional_enforce_cmp` and `is_cmp` functions. The macro is useful as the implementations +// are the same except for the call to the correspondent `is_smaller_than_restricted` or +// `is_smaller_than_unchecked` function. +macro_rules! implement_cmp_functions_variants { + ($variant: tt) => { + paste::item! { + pub fn []>( + &self, + mut cs: CS, + other: &Self, + ordering: Ordering, + should_also_check_equality: bool, + ) -> Result<(), SynthesisError> { + self.[](&mut cs, other, &Boolean::constant(true), ordering, should_also_check_equality) + } + + pub fn []>( + &self, + mut cs: CS, + other: &Self, + should_enforce: &Boolean, + ordering: Ordering, + should_also_check_equality: bool, + ) -> Result<(), SynthesisError> { + let is_cmp = self.[](cs.ns(|| "is cmp"), other, ordering, should_also_check_equality)?; + is_cmp.conditional_enforce_equal(cs.ns(|| "conditional enforce cmp"), &Boolean::constant(true), should_enforce) + } + + pub fn []>( + &self, + mut cs: CS, + other: &Self, + ordering: Ordering, + should_also_check_equality: bool, + ) -> Result { + let (left, right) = match (ordering, should_also_check_equality) { + (Ordering::Less, false) | (Ordering::Greater, true) => (self, other), + (Ordering::Greater, false) | (Ordering::Less, true) => (other, self), + (Ordering::Equal, _) => return self.is_eq(cs.ns(|| "is equal is is_cmp"), other), + }; + + let is_smaller = left.[](cs, right)?; + + if should_also_check_equality { + return Ok(is_smaller.not()); + } + + Ok(is_smaller) + } + } + }; + } + // implement functions for FpGadget that are useful to implement the ComparisonGadget impl FpGadget { @@ -16,13 +70,34 @@ impl FpGadget { assert_eq!(first.len(), 2); assert_eq!(second.len(), 2); - let a = first[0]; - let b = first[1]; - let c = second[0]; - let d = second[1]; + let a = first[0]; // a = msb(first) + let b = first[1]; // b = lsb(first) + let c = second[0]; // c = msb(second) + let d = second[1]; // d = lsb(second) + + // is_less corresponds to the Boolean function: !a*(c+!b*d)+(!b*c*d) + // which is true iff first < second, where + is Boolean OR and * is Boolean AND. Indeed: + // | first | second | a | b | c | d | is_less | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 1 | 0 | 0 | 0 | 1 | 1 | + // | 0 | 2 | 0 | 0 | 1 | 0 | 1 | + // | 0 | 3 | 0 | 0 | 1 | 1 | 1 | + // | 1 | 0 | 0 | 1 | 0 | 0 | 0 | + // | 1 | 1 | 0 | 1 | 0 | 1 | 0 | + // | 1 | 2 | 0 | 1 | 1 | 0 | 1 | + // | 1 | 3 | 0 | 1 | 1 | 1 | 1 | + // | 2 | 0 | 1 | 0 | 0 | 0 | 0 | + // | 2 | 1 | 1 | 0 | 0 | 1 | 0 | + // | 2 | 2 | 1 | 0 | 1 | 0 | 0 | + // | 2 | 3 | 1 | 0 | 1 | 1 | 1 | + // | 3 | 0 | 1 | 1 | 0 | 0 | 0 | + // | 3 | 1 | 1 | 1 | 0 | 1 | 0 | + // | 3 | 2 | 1 | 1 | 1 | 0 | 0 | + // | 3 | 3 | 1 | 1 | 1 | 1 | 0 | + + // To reduce the number of constraints, the Boolean function is computed as follows: + // is_less = !a + !b*d if c=1, !a*!b*d if c=0 - // is_less corresponds to the Boolean function: !a*(c+!b*d)+(!b*c*d), - // which is true iff first < second, where + is Boolean OR and * is Boolean AND let bd = Boolean::and(cs.ns(|| "!bd"), &b.not(), &d)?; let first_tmp = Boolean::or(cs.ns(|| "!a + !bd"), &a.not(), &bd)?; let second_tmp = Boolean::and(cs.ns(|| "!a!bd"), &a.not(), &bd)?; @@ -37,42 +112,6 @@ impl FpGadget { Ok((is_less, is_eq)) } - /// Output a Boolean that is true iff `self` < `other`. Here `self` and `other` - /// can be arbitrary field elements, they are not constrained to be at most (p-1)/2 - pub fn is_smaller_than_unrestricted>( - &self, - mut cs: CS, - other: &Self, - ) -> Result { - let self_bits = self.to_bits_strict(cs.ns(|| "first op to bits"))?; - let other_bits = other.to_bits_strict(cs.ns(|| "second op to bits"))?; - // extract the least significant MODULUS_BITS-2 bits and convert them to a field element, - // which is necessarily lower than (p-1)/2 - let fp_for_self_lsbs = FpGadget::::from_bits(cs.ns(|| "pack second op MSBs"), &self_bits[2..])?; - let fp_for_other_lsbs = FpGadget::::from_bits(cs.ns(|| "pack second op LSBs"), &other_bits[2..])?; - - // since the field elements are lower than (p-1)/2, we can compare it with the efficient approach - let is_less_lsbs = fp_for_self_lsbs.is_smaller_than_unchecked(cs.ns(|| "compare LSBs"), &fp_for_other_lsbs)?; - - - // obtain two Booleans: the former (resp. the latter) one is true iff the integer - // represented by the 2 MSBs of self is smaller (resp. is equal) than the integer - // represented by the 2 MSBs of other - let (is_less_msbs, is_eq_msbs) = Self::compare_msbs(cs.ns(|| "compare MSBs"), &self_bits[..2], &other_bits[..2])?; - - // Equivalent to is_less_msbs OR is_eq_msbs AND is_less_msbs, given that is_less_msbs and - // is_eq_msbs cannot be true at the same time - Boolean::conditionally_select(cs, &is_eq_msbs, &is_less_lsbs, &is_less_msbs) - } - - /// Enforce than `self` < `other`. Here `self` and `other` they are arbitrary field elements, - /// they are not constrained to be at most (p-1)/2 - pub fn enforce_smaller_than_unrestricted>(&self, mut cs: CS, other: &Self) -> Result<(), SynthesisError> { - let is_smaller = self.is_smaller_than_unrestricted(cs.ns(|| "is smaller unchecked"), other)?; - is_smaller.enforce_equal(cs.ns(|| "enforce smaller than"), &Boolean::constant(true)) - } - - /// Helper function to enforce that `self <= (p-1)/2`. pub fn enforce_smaller_or_equal_than_mod_minus_one_div_two>( &self, @@ -90,6 +129,25 @@ impl FpGadget { Ok(()) } + /// Helper function to check `self < other` and output a result bit. This + /// function requires that `self` and `other` are `<= (p-1)/2` and imposes + /// constraints to verify that. + pub fn is_smaller_than_restricted>(&self, mut cs: CS, other: &Self) -> Result { + self.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "self smaller or equal mod"))?; + other.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "other smaller or equal mod"))?; + self.is_smaller_than_unchecked(cs.ns(|| "is smaller unchecked"), other) + } + + + /// Helper function to enforce that `self < other`. This + /// function requires that `self` and `other` are `<= (p-1)/2` and imposes + /// constraints to verify that. + pub fn enforce_smaller_than_restricted>(&self, mut cs: CS, other: &Self) -> Result<(), SynthesisError> { + self.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "self smaller or equal mod"))?; + other.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "other smaller or equal mod"))?; + self.enforce_smaller_than_unchecked(cs.ns(|| "enforce smaller than unchecked"), other) + } + /// Helper function to check `self < other` and output a result bit. This /// function assumes `self` and `other` are `<= (p-1)/2` and does not /// generate constraints to verify that. @@ -128,71 +186,47 @@ impl FpGadget { is_smaller.enforce_equal(cs.ns(|| "enforce smaller than"), &Boolean::constant(true)) } - /// Variant of `enforce_cmp` that assumes `self` and `other` are `<= (p-1)/2` and - /// does not generate constraints to verify that. - pub fn enforce_cmp_unchecked>( - &self, - mut cs: CS, - other: &Self, - ordering: Ordering, - should_also_check_equality: bool, - ) -> Result<(), SynthesisError> { - self.conditional_enforce_cmp_unchecked(&mut cs, other, &Boolean::constant(true), ordering, should_also_check_equality) - } + // Variants of cmp functions that assume `self` and `other` are `<= (p-1)/2` and + // do not generate constraints to verify that. + implement_cmp_functions_variants!(unchecked); + // Variants of cmp functions that require `self` and `other` are `<= (p-1)/2` and + // impose constraints to verify that. + implement_cmp_functions_variants!(restricted); - /// Variant of `conditional_enforce_cmp` that assumes `self` and `other` are `<= (p-1)/2` and - /// does not generate constraints to verify that. - pub fn conditional_enforce_cmp_unchecked>( - &self, - mut cs: CS, - other: &Self, - should_enforce: &Boolean, - ordering: Ordering, - should_also_check_equality: bool, - ) -> Result<(), SynthesisError> { - let is_cmp = self.is_cmp_unchecked(cs.ns(|| "is cmp unchecked"), other, ordering, should_also_check_equality)?; - is_cmp.conditional_enforce_equal(cs.ns(|| "conditionally enforce cmp"), &Boolean::constant(true), should_enforce) - } +} - /// Variant of `is_cmp` that assumes `self` and `other` are `<= (p-1)/2` and does not generate - /// constraints to verify that. - // It differs from the default implementation of `is_cmp` only by - // calling `is_smaller_than_unchecked` in place of `is_smaller_than` for efficiency given that - // there is no need to verify that `self` and `other` are `<= (p-1)/2` - fn is_cmp_unchecked>( +impl ComparisonGadget for FpGadget { + /// Output a Boolean that is true iff `self` < `other`. Here `self` and `other` + /// can be arbitrary field elements, they are not constrained to be at most (p-1)/2 + fn is_smaller_than>( &self, mut cs: CS, other: &Self, - ordering: Ordering, - should_also_check_equality: bool, ) -> Result { - let (left, right) = match (ordering, should_also_check_equality) { - (Ordering::Less, false) | (Ordering::Greater, true) => (self, other), - (Ordering::Greater, false) | (Ordering::Less, true) => (other, self), - (Ordering::Equal, _) => return self.is_eq(cs, other), - }; - - let is_smaller = left.is_smaller_than_unchecked(cs.ns(|| "is smaller"), right)?; + let self_bits = self.to_bits_strict(cs.ns(|| "first op to bits"))?; + let other_bits = other.to_bits_strict(cs.ns(|| "second op to bits"))?; + // extract the least significant MODULUS_BITS-2 bits and convert them to a field element, + // which is necessarily lower than (p-1)/2 + let fp_for_self_lsbs = FpGadget::::from_bits(cs.ns(|| "pack second op MSBs"), &self_bits[2..])?; + let fp_for_other_lsbs = FpGadget::::from_bits(cs.ns(|| "pack second op LSBs"), &other_bits[2..])?; - if should_also_check_equality { - return Ok(is_smaller.not()); - } + // since the field elements are lower than (p-1)/2, we can compare it with the efficient approach + let is_less_lsbs = fp_for_self_lsbs.is_smaller_than_unchecked(cs.ns(|| "compare LSBs"), &fp_for_other_lsbs)?; - Ok(is_smaller) - } -} -impl ComparisonGadget for FpGadget { - fn is_smaller_than>(&self, mut cs: CS, other: &Self) -> Result { - self.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "self smaller or equal mod"))?; - other.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "other smaller or equal mod"))?; - self.is_smaller_than_unchecked(cs.ns(|| "is smaller unchecked"), other) - } + // obtain two Booleans: + // - `is_less_msbs` is true iff the integer represented by the 2 MSBs of self is smaller + // than the integer represented by the 2 MSBs of other + // - `is_eq_msbs` is true iff the integer represented by the 2 MSBs of self is equal + // to the integer represented by the 2 MSBs of other + let (is_less_msbs, is_eq_msbs) = Self::compare_msbs(cs.ns(|| "compare MSBs"), &self_bits[..2], &other_bits[..2])?; - fn enforce_smaller_than>(&self, mut cs: CS, other: &Self) -> Result<(), SynthesisError> { - self.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "self smaller or equal mod"))?; - other.enforce_smaller_or_equal_than_mod_minus_one_div_two(cs.ns(|| "other smaller or equal mod"))?; - self.enforce_smaller_than_unchecked(cs.ns(|| "enforce smaller than unchecked"), other) + // `self < other` iff `is_less_msbs OR is_eq_msbs AND is_less_lsbs` + // Given that `is_less_msbs` and `is_eq_msbs` cannot be true at the same time, + // the formula is equivalent to the following conditionally select; indeed: + // - if `is_eq_msbs = true`, then `is_less_msbs = false`, thus `self < other` iff `is_less_lsbs = true` + // - if `is_eq_msbs = false`, then `self < other` iff `is_less_msbs = true` + Boolean::conditionally_select(cs, &is_eq_msbs, &is_less_lsbs, &is_less_msbs) } } @@ -218,8 +252,24 @@ mod test { r } + fn rand_higher(rng: &mut R) -> Fr { + let pminusonedivtwo: Fr = Fr::modulus_minus_one_div_two().into(); + let mut r; + loop { + r = Fr::rand(rng); + if r > pminusonedivtwo { + break; + } + } + r + } + + fn field_uniform_rand(rng: &mut R) -> Fr { + Fr::rand(rng) + } + macro_rules! test_cmp_function { - ($cmp_func: tt, $should_enforce: expr, $should_fail_with_invalid_operands: expr) => { + ($cmp_func: tt, $should_enforce: expr) => { let mut rng = &mut thread_rng(); let should_enforce = Boolean::constant($should_enforce); for _i in 0..10 { @@ -258,15 +308,37 @@ mod test { match b.cmp(&a) { Ordering::Less => { a_var.$cmp_func(cs.ns(|| "enforce less"), &b_var, &should_enforce, Ordering::Less, false).unwrap(); + assert!($should_enforce ^ cs.is_satisfied()); // check that constraints are satisfied iff should_enforce == false + if $should_enforce { + assert_eq!(cs.which_is_unsatisfied().unwrap(), "enforce less/conditional enforce cmp/conditional_equals"); + } + // reinitialize cs to check for equality + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); a_var.$cmp_func(cs.ns(|| "enforce less equal"),&b_var, &should_enforce, Ordering::Less, true).unwrap(); + assert!($should_enforce ^ cs.is_satisfied()); // check that constraints are satisfied iff should_enforce == false + if $should_enforce { + assert_eq!(cs.which_is_unsatisfied().unwrap(), "enforce less equal/conditional enforce cmp/conditional_equals"); + } } Ordering::Greater => { a_var.$cmp_func(cs.ns(|| "enforce greater"),&b_var, &should_enforce, Ordering::Greater, false).unwrap(); + if $should_enforce { + assert_eq!(cs.which_is_unsatisfied().unwrap(), "enforce greater/conditional enforce cmp/conditional_equals"); + } + // reinitialize cs to check for equality + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); + let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); a_var.$cmp_func(cs.ns(|| "enforce greater equal"),&b_var, &should_enforce, Ordering::Greater, true).unwrap(); + if $should_enforce { + assert_eq!(cs.which_is_unsatisfied().unwrap(), "enforce greater equal/conditional enforce cmp/conditional_equals"); + } + } _ => {} } - assert!($should_enforce ^ cs.is_satisfied()); // check that constraints are satisfied iff should_enforce == false } for _i in 0..10 { @@ -276,6 +348,9 @@ mod test { a_var.$cmp_func(cs.ns(|| "enforce less"),&a_var, &should_enforce, Ordering::Less, false).unwrap(); assert!($should_enforce ^ cs.is_satisfied()); + if $should_enforce { + assert_eq!(cs.which_is_unsatisfied().unwrap(), "enforce less/conditional enforce cmp/conditional_equals"); + } } for _i in 0..10 { @@ -288,45 +363,78 @@ mod test { } assert!(cs.is_satisfied()); } + } + } + + fn test_corner_cases_cmp(should_enforce_flag: bool) { + // test corner case where the operands are extreme values of range [0, p-1] of + // admissible values + let should_enforce = Boolean::constant(should_enforce_flag); + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let max_val: Fr = Fr::modulus_minus_one_div_two().into(); + let max_val = max_val.double(); + let max_var = FpGadget::::alloc(&mut cs.ns(|| "generate_max"), || Ok(max_val)).unwrap(); + let zero_var = FpGadget::::from_value(cs.ns(|| "alloc zero"), &Fr::zero()); + max_var.conditional_enforce_cmp(cs.ns(|| "enforce p-1 > 0"), &zero_var, &should_enforce, Ordering::Greater, false).unwrap(); + assert!(cs.is_satisfied()); + } + macro_rules! test_corner_case_restricted_cmp { + ($conditional_enforce_cmp_func: tt, $should_enforce_flag: expr, $should_fail_with_invalid_operands: expr, $unsatisfied_constraint: expr) => { // test corner case when operands are extreme values of range [0, (p-1)/2] of // admissible values + let should_enforce = Boolean::constant($should_enforce_flag); let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); let max_val: Fr = Fr::modulus_minus_one_div_two().into(); let max_var = FpGadget::::alloc(&mut cs.ns(|| "generate_max"), || Ok(max_val)).unwrap(); let zero_var = FpGadget::::zero(cs.ns(|| "alloc zero")).unwrap(); - zero_var.$cmp_func(cs.ns(|| "enforce 0 <= (p-1) div 2"), &max_var, &should_enforce, Ordering::Less, true).unwrap(); + zero_var.$conditional_enforce_cmp_func(cs.ns(|| "enforce 0 <= (p-1) div 2"), &max_var, &should_enforce, Ordering::Less, true).unwrap(); assert!(cs.is_satisfied()); // test when one of the operands is beyond (p-1)/2 let out_range_var = FpGadget::::alloc(&mut cs.ns(|| "generate_out_range"), || Ok(max_val.double())).unwrap(); - zero_var.$cmp_func(cs.ns(|| "enforce 0 <= p-1"), &out_range_var, &should_enforce, Ordering::Less, true).unwrap(); + zero_var.$conditional_enforce_cmp_func(cs.ns(|| "enforce 0 <= p-1"), &out_range_var, &should_enforce, Ordering::Less, true).unwrap(); assert!($should_fail_with_invalid_operands ^ cs.is_satisfied()); + if $should_fail_with_invalid_operands { + assert_eq!(cs.which_is_unsatisfied().unwrap(), $unsatisfied_constraint); + } } } #[test] fn test_cmp() { - test_cmp_function!(conditional_enforce_cmp, true, true); - test_cmp_function!(conditional_enforce_cmp, false, true); + test_cmp_function!(conditional_enforce_cmp, true); + test_corner_cases_cmp(true); + test_cmp_function!(conditional_enforce_cmp, false); + test_corner_cases_cmp(false); } #[test] fn test_cmp_unchecked() { - test_cmp_function!(conditional_enforce_cmp_unchecked, true, true); - test_cmp_function!(conditional_enforce_cmp_unchecked, false, false); + test_cmp_function!(conditional_enforce_cmp_unchecked, true); + test_corner_case_restricted_cmp!(conditional_enforce_cmp_unchecked, true, true, "enforce 0 <= p-1/conditional enforce cmp/conditional_equals"); + test_cmp_function!(conditional_enforce_cmp_unchecked, false); + test_corner_case_restricted_cmp!(conditional_enforce_cmp_unchecked, false, false, "enforce 0 <= p-1/conditional enforce cmp/conditional_equals"); + } + + #[test] + fn test_cmp_restricted() { + test_cmp_function!(conditional_enforce_cmp_restricted, true); + test_corner_case_restricted_cmp!(conditional_enforce_cmp_restricted, true, true, "enforce 0 <= p-1/is cmp/self smaller or equal mod/enforce smaller or equal/enforce equal/conditional_equals"); + test_cmp_function!(conditional_enforce_cmp_restricted, false); + test_corner_case_restricted_cmp!(conditional_enforce_cmp_restricted, false, true, "enforce 0 <= p-1/is cmp/self smaller or equal mod/enforce smaller or equal/enforce equal/conditional_equals"); } macro_rules! test_smaller_than_func { - ($is_smaller_func: tt, $enforce_smaller_func: tt) => { + ($is_smaller_func: tt, $enforce_smaller_func: tt, $rand_func: tt, $unsatisfied_constraint: expr) => { let mut rng = &mut thread_rng(); for _ in 0..10 { let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let a = rand_in_range(&mut rng); + let a = $rand_func(&mut rng); let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); - let b = rand_in_range(&mut rng); + let b = $rand_func(&mut rng); let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); let is_smaller = a_var.$is_smaller_func(cs.ns(|| "is smaller"), &b_var).unwrap(); @@ -340,22 +448,54 @@ mod test { } Ordering::Greater | Ordering::Equal => { assert!(!is_smaller.get_value().unwrap()); - assert!(!cs.is_satisfied()) + assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), $unsatisfied_constraint); } } } for _ in 0..10 { let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let a = rand_in_range(&mut rng); + let a = $rand_func(&mut rng); let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); let is_smaller = a_var.$is_smaller_func(cs.ns(|| "is smaller"),&a_var).unwrap(); // check that a.is_smaller(a) == false assert!(!is_smaller.get_value().unwrap()); - a_var.$enforce_smaller_func(cs.ns(|| "enforce is smaller"), &a_var).unwrap(); + a_var.$enforce_smaller_func(cs.ns(|| "enforce smaller"), &a_var).unwrap(); assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), $unsatisfied_constraint); } + } + } + + fn test_corner_cases_smaller_than() { + // test corner case where the operands are extreme values of range [0, p-1] of + // admissible values + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let max_val: Fr = Fr::modulus_minus_one_div_two().into(); + let max_val = max_val.double(); + let max_var = FpGadget::::alloc(&mut cs.ns(|| "generate_max"), || Ok(max_val)).unwrap(); + let zero_var = FpGadget::::from_value(cs.ns(|| "alloc zero"), &Fr::zero()); + let is_smaller = zero_var.is_smaller_than(cs.ns(|| "0 is smaller than p-1"), &max_var).unwrap(); + assert!(is_smaller.get_value().unwrap()); + zero_var.enforce_smaller_than(cs.ns(|| "enforce 0 < p-1"), &max_var).unwrap(); + assert!(cs.is_satisfied()); + } + + #[test] + fn test_smaller_than() { + // test with random field elements >(p-1)/2 + test_smaller_than_func!(is_smaller_than, enforce_smaller_than, rand_higher, "enforce smaller/enforce is smaller/conditional_equals"); + // test with random field elements <=(p-1)/2 + test_smaller_than_func!(is_smaller_than, enforce_smaller_than, rand_in_range, "enforce smaller/enforce is smaller/conditional_equals"); + // test with arbitrary field elements + test_smaller_than_func!(is_smaller_than, enforce_smaller_than, field_uniform_rand, "enforce smaller/enforce is smaller/conditional_equals"); + // test corner case + test_corner_cases_smaller_than(); + } + macro_rules! test_corner_case_smaller_than_restricted { + ($is_smaller_func: tt, $enforce_smaller_func: tt, $unsatisfied_constraint: expr) => { // test corner case when operands are extreme values of range [0, (p-1)/2] of // admissible values let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); @@ -371,92 +511,19 @@ mod test { let out_range_var = FpGadget::::alloc(&mut cs.ns(|| "generate_out_range"), || Ok(max_val.double())).unwrap(); zero_var.$enforce_smaller_func(cs.ns(|| "enforce 0 <= p-1"), &out_range_var).unwrap(); assert!(!cs.is_satisfied()); + assert_eq!(cs.which_is_unsatisfied().unwrap(), $unsatisfied_constraint); } } #[test] - fn test_smaller_than() { - test_smaller_than_func!(is_smaller_than, enforce_smaller_than); + fn test_smaller_than_restricted() { + test_smaller_than_func!(is_smaller_than_restricted, enforce_smaller_than_restricted, rand_in_range, "enforce smaller/enforce smaller than unchecked/enforce smaller than/conditional_equals"); + test_corner_case_smaller_than_restricted!(is_smaller_than_restricted, enforce_smaller_than_restricted, "enforce 0 <= p-1/other smaller or equal mod/enforce smaller or equal/enforce equal/conditional_equals"); } #[test] fn test_smaller_than_unchecked() { - test_smaller_than_func!(is_smaller_than_unchecked, enforce_smaller_than_unchecked); - } - - macro_rules! test_smaller_than_unrestricted { - ($rand_func: tt) => { - let mut rng = &mut thread_rng(); - - for _ in 0..10 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let a = $rand_func(&mut rng); - let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); - let b = $rand_func(&mut rng); - let b_var = FpGadget::::alloc(&mut cs.ns(|| "generate_b"), || Ok(b)).unwrap(); - let is_smaller = a_var.is_smaller_than_unrestricted(cs.ns(|| "is smaller"), &b_var).unwrap(); - a_var.enforce_smaller_than_unrestricted(cs.ns(|| "enforce is smaller"), &b_var).unwrap(); - - match a.cmp(&b) { - Ordering::Less => { - assert!(is_smaller.get_value().unwrap()); - assert!(cs.is_satisfied()); - } - Ordering::Greater | Ordering::Equal => { - assert!(!is_smaller.get_value().unwrap()); - assert!(!cs.is_satisfied()) - } - } - } - - for _ in 0..10 { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let a = $rand_func(&mut rng); - let a_var = FpGadget::::alloc(&mut cs.ns(|| "generate_a"), || Ok(a)).unwrap(); - let is_smaller = a_var.is_smaller_than_unrestricted(cs.ns(|| "is smaller"),&a_var).unwrap(); - // check that a.is_smaller(a) == false - assert!(!is_smaller.get_value().unwrap()); - a_var.enforce_smaller_than_unrestricted(cs.ns(|| "enforce is smaller"), &a_var).unwrap(); - assert!(!cs.is_satisfied()); - } - - // test corner case where the operands are extreme values of range [0, p-1] of - // admissible values - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let max_val: Fr = Fr::modulus_minus_one_div_two().into(); - let max_val = max_val.double(); - let max_var = FpGadget::::alloc(&mut cs.ns(|| "generate_max"), || Ok(max_val)).unwrap(); - let zero_var = FpGadget::::from_value(cs.ns(|| "alloc zero"), &Fr::zero()); - let is_smaller = zero_var.is_smaller_than_unrestricted(cs.ns(|| "0 is smaller than p-1"), &max_var).unwrap(); - assert!(is_smaller.get_value().unwrap()); - zero_var.enforce_smaller_than_unrestricted(cs.ns(|| "enforce 0 <= (p-1) div 2"), &max_var).unwrap(); - assert!(cs.is_satisfied()); - } - } - - #[test] - fn test_smaller_than_unrestricted() { - fn rand_higher(rng: &mut R) -> Fr { - let pminusonedivtwo: Fr = Fr::modulus_minus_one_div_two().into(); - let mut r; - loop { - r = Fr::rand(rng); - if r > pminusonedivtwo { - break; - } - } - r - } - - fn field_uniform_rand(rng: &mut R) -> Fr { - Fr::rand(rng) - } - // test with random field elements >(p-1)/2 - test_smaller_than_unrestricted!(rand_higher); - // test with random field elements <=(p-1)/2 - test_smaller_than_unrestricted!(rand_in_range); - // test with arbitrary field elements - test_smaller_than_unrestricted!(field_uniform_rand); + test_smaller_than_func!(is_smaller_than_unchecked, enforce_smaller_than_unchecked, rand_in_range, "enforce smaller/enforce smaller than/conditional_equals"); + test_corner_case_smaller_than_restricted!(is_smaller_than_unchecked, enforce_smaller_than_unchecked, "enforce 0 <= p-1/enforce smaller than/conditional_equals"); } } \ No newline at end of file diff --git a/r1cs/gadgets/std/src/lib.rs b/r1cs/gadgets/std/src/lib.rs index e35c5645b..db85f4613 100644 --- a/r1cs/gadgets/std/src/lib.rs +++ b/r1cs/gadgets/std/src/lib.rs @@ -74,7 +74,7 @@ pub mod prelude { alloc::*, bits::{ boolean::Boolean, uint32::UInt32, UInt8, FromBitsGadget, ToBitsGadget, - ToBytesGadget, + ToBytesGadget, UIntGadget, }, eq::*, fields::{cubic_extension::*, quadratic_extension::*, FieldGadget}, From d60b85b58ab0534356f11022a6c1311c48cd8a8a Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Thu, 10 Mar 2022 18:17:50 +0100 Subject: [PATCH 12/18] merge to little-endian bits functions --- .../crypto/src/commitment/blake2s/mod.rs | 4 +- .../crypto/src/commitment/pedersen/mod.rs | 7 +- .../crypto/src/crh/bowe_hopwood/mod.rs | 6 +- r1cs/gadgets/crypto/src/crh/pedersen/mod.rs | 7 +- r1cs/gadgets/crypto/src/prf/blake2s/mod.rs | 10 +- r1cs/gadgets/crypto/src/prf/ripemd160.rs | 12 +- r1cs/gadgets/crypto/src/prf/sha256.rs | 27 +- .../crypto/src/signature/schnorr/mod.rs | 5 +- r1cs/gadgets/std/src/bits/macros.rs | 320 ++++++++++++------ r1cs/gadgets/std/src/bits/mod.rs | 8 +- r1cs/gadgets/std/src/fields/fp.rs | 12 +- r1cs/gadgets/std/src/lib.rs | 2 +- 12 files changed, 250 insertions(+), 170 deletions(-) diff --git a/r1cs/gadgets/crypto/src/commitment/blake2s/mod.rs b/r1cs/gadgets/crypto/src/commitment/blake2s/mod.rs index 99b7f2dd0..0f39cfb55 100644 --- a/r1cs/gadgets/crypto/src/commitment/blake2s/mod.rs +++ b/r1cs/gadgets/crypto/src/commitment/blake2s/mod.rs @@ -31,8 +31,8 @@ impl CommitmentGadget r: &Self::RandomnessGadget, ) -> Result { let mut input_bits = Vec::with_capacity(512); - for byte in input.iter().chain(r.0.iter()) { - input_bits.extend_from_slice(&byte.into_bits_le()); + for (i, byte) in input.iter().chain(r.0.iter()).enumerate() { + input_bits.extend_from_slice(&byte.to_bits_le(cs.ns(|| format!("convert byte {} to bits", i)))?); } let mut result = Vec::new(); for (i, int) in blake2s_gadget(cs.ns(|| "Blake2s Eval"), &input_bits)? diff --git a/r1cs/gadgets/crypto/src/commitment/pedersen/mod.rs b/r1cs/gadgets/crypto/src/commitment/pedersen/mod.rs index b00ce3875..a7237b968 100644 --- a/r1cs/gadgets/crypto/src/commitment/pedersen/mod.rs +++ b/r1cs/gadgets/crypto/src/commitment/pedersen/mod.rs @@ -79,10 +79,7 @@ where } // Allocate new variable for commitment output. - let input_in_bits: Vec<_> = padded_input - .iter() - .flat_map(|byte| byte.into_bits_le()) - .collect(); + let input_in_bits: Vec<_> = padded_input.to_bits_le(cs.ns(|| "padded input to bits"))?; let input_in_bits = input_in_bits.chunks(W::WINDOW_SIZE); let mut result = GG::precomputed_base_multiscalar_mul( cs.ns(|| "multiexp"), @@ -91,7 +88,7 @@ where )?; // Compute h^r - let rand_bits: Vec<_> = r.0.iter().flat_map(|byte| byte.into_bits_le()).collect(); + let rand_bits: Vec<_> = r.0.to_bits_le(cs.ns(|| "pedersen randomness to bits"))?; result.precomputed_base_scalar_mul( cs.ns(|| "Randomizer"), rand_bits diff --git a/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs b/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs index 6e98029b0..51447dd69 100644 --- a/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs +++ b/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs @@ -8,7 +8,7 @@ use primitives::{ crh::pedersen::PedersenWindow, }; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; -use r1cs_std::{alloc::AllocGadget, groups::GroupGadget, UInt8}; +use r1cs_std::{alloc::AllocGadget, groups::GroupGadget, ToBitsGadget, UInt8}; use r1cs_std::bits::boolean::Boolean; use std::{borrow::Borrow, marker::PhantomData}; @@ -51,12 +51,12 @@ where type ParametersGadget = BoweHopwoodPedersenCRHGadgetParameters; fn check_evaluation_gadget>( - cs: CS, + mut cs: CS, parameters: &Self::ParametersGadget, input: &[UInt8], ) -> Result { // Pad the input if it is not the current length. - let mut input_in_bits: Vec<_> = input.iter().flat_map(|byte| byte.into_bits_le()).collect(); + let mut input_in_bits: Vec<_> = input.to_bits_le(cs.ns(|| "input to bits"))?; if (input_in_bits.len()) % CHUNK_SIZE != 0 { let current_length = input_in_bits.len(); for _ in 0..(CHUNK_SIZE - current_length % CHUNK_SIZE) { diff --git a/r1cs/gadgets/crypto/src/crh/pedersen/mod.rs b/r1cs/gadgets/crypto/src/crh/pedersen/mod.rs index 1a9cefb74..5e1705221 100644 --- a/r1cs/gadgets/crypto/src/crh/pedersen/mod.rs +++ b/r1cs/gadgets/crypto/src/crh/pedersen/mod.rs @@ -42,7 +42,7 @@ where type ParametersGadget = PedersenCRHGadgetParameters; fn check_evaluation_gadget>( - cs: CS, + mut cs: CS, parameters: &Self::ParametersGadget, input: &[UInt8], ) -> Result { @@ -70,10 +70,7 @@ where } // Allocate new variable for the result. - let input_in_bits: Vec<_> = padded_input - .iter() - .flat_map(|byte| byte.into_bits_le()) - .collect(); + let input_in_bits: Vec<_> = padded_input.to_bits_le(cs.ns(|| "padded input to bits"))?; let input_in_bits = input_in_bits.chunks(W::WINDOW_SIZE); let result = GG::precomputed_base_multiscalar_mul(cs, ¶meters.params.generators, input_in_bits)?; diff --git a/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs b/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs index 4e3f42af3..b78581d52 100644 --- a/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs +++ b/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs @@ -339,7 +339,7 @@ pub fn blake2s_gadget PRFGadget for Blake2sGadget assert_eq!(seed.len(), 32); // assert_eq!(input.len(), 32); let mut gadget_input = Vec::with_capacity(512); - for byte in seed.iter().chain(input) { - gadget_input.extend_from_slice(&byte.into_bits_le()); + for (i, byte) in seed.iter().chain(input).enumerate() { + gadget_input.extend_from_slice(&byte.to_bits_le(cs.ns(|| format!("covert byte {} to bits", i)))?); } let mut result = Vec::new(); for (i, int) in blake2s_gadget(cs.ns(|| "Blake2s Eval"), &gadget_input)? @@ -652,8 +652,8 @@ mod test { .iter() .flat_map(|&byte| (0..8).map(move |i| (byte >> i) & 1u8 == 1u8)); - for chunk in r { - for b in chunk.to_bits_le() { + for (i, chunk) in r.iter().enumerate() { + for b in chunk.to_bits_le(cs.ns(|| format!("chunk {} to bits", i))).unwrap() { match b { Boolean::Is(b) => { assert!(s.next().unwrap() == b.get_value().unwrap()); diff --git a/r1cs/gadgets/crypto/src/prf/ripemd160.rs b/r1cs/gadgets/crypto/src/prf/ripemd160.rs index 963dce446..145a2d13d 100644 --- a/r1cs/gadgets/crypto/src/prf/ripemd160.rs +++ b/r1cs/gadgets/crypto/src/prf/ripemd160.rs @@ -6,7 +6,8 @@ use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; use r1cs_std::boolean::Boolean; use r1cs_std::eq::MultiEq; use r1cs_std::uint32::UInt32; -use r1cs_std::uint8::UInt8; +use r1cs_std::UInt8; +use r1cs_std::{UIntGadget, RotateUInt, ToBitsGadget, FromBitsGadget}; use crate::sha256::{sha256_ch_boolean, triop}; @@ -125,12 +126,7 @@ where { assert_eq!(input.len(), 512); - Ok( - ripemd160_compression_function(&mut cs, &input, &get_ripemd160_iv())? - .into_iter() - .flat_map(|e| e.to_bits_le()) - .collect(), - ) + ripemd160_compression_function(&mut cs, &input, &get_ripemd160_iv())?.to_bits_le(cs) } /// The full domain RIPEMD160 hash function. @@ -166,7 +162,7 @@ where cur = ripemd160_compression_function(cs.ns(|| format!("block {}", i)), block, &cur)?; } - Ok(cur.into_iter().flat_map(|e| e.to_bits_le()).collect()) + cur.to_bits_le(cs) } fn get_ripemd160_iv() -> Vec { diff --git a/r1cs/gadgets/crypto/src/prf/sha256.rs b/r1cs/gadgets/crypto/src/prf/sha256.rs index 7e9e90116..b20933eab 100644 --- a/r1cs/gadgets/crypto/src/prf/sha256.rs +++ b/r1cs/gadgets/crypto/src/prf/sha256.rs @@ -2,12 +2,13 @@ //! function. //! This is a port from the implementation in [Bellman](https://docs.rs/bellman/0.8.0/src/bellman/gadgets/sha256.rs.html#47-74) +use std::ops::Shr; use algebra::PrimeField; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; use r1cs_std::boolean::AllocatedBit; use r1cs_std::eq::MultiEq; use r1cs_std::uint32::UInt32; -use r1cs_std::{boolean::Boolean, Assignment}; +use r1cs_std::{boolean::Boolean, Assignment, UIntGadget, RotateUInt, ToBitsGadget, FromBitsGadget}; #[allow(clippy::unreadable_literal)] const ROUND_CONSTANTS: [u32; 64] = [ @@ -73,7 +74,8 @@ where cur = sha256_compression_function(cs.ns(|| format!("block {}", i)), block, &cur)?; } - Ok(cur.into_iter().flat_map(|e| e.into_bits_be()).collect()) + cur.reverse(); + cur.to_bits(cs) } fn get_sha256_iv() -> Vec { @@ -81,7 +83,7 @@ fn get_sha256_iv() -> Vec { } fn sha256_compression_function( - cs: CS, + mut cs: CS, input: &[Boolean], current_hash_value: &[UInt32], ) -> Result, SynthesisError> @@ -99,8 +101,8 @@ where // Initialize the first 16 words in the array w let mut w = input .chunks(32) - .map(|e| UInt32::from_bits_be(e)) - .collect::>(); + .map(|e| UInt32::from_bits(cs.ns(|| format!("pack input word {}", i)),e)) + .collect::, SynthesisError>>()?; let mut cs = MultiEq::new(cs); @@ -112,13 +114,13 @@ where // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) let mut s0 = UInt32::rotr(&w[i - 15], 7); s0 = s0.xor(cs.ns(|| "first xor for s0"), &UInt32::rotr(&w[i - 15], 18))?; - s0 = s0.xor(cs.ns(|| "second xor for s0"), &UInt32::shr(&w[i - 15], 3))?; + s0 = s0.xor(cs.ns(|| "second xor for s0"), &w[i - 15].clone()(3))?; // Compute SHA256_sigma1(w[i-2]) // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) let mut s1 = UInt32::rotr(&w[i - 2], 17); s1 = s1.xor(cs.ns(|| "first xor for s1"), &UInt32::rotr(&w[i - 2], 19))?; - s1 = s1.xor(cs.ns(|| "second xor for s1"), &UInt32::shr(&w[i - 2], 10))?; + s1 = s1.xor(cs.ns(|| "second xor for s1"), &w[i - 2].clone().shr(10))?; // Compute W[i] = SHA256_sigma1(W[i-2]) + W[i-7] + // SHA256_sigma0(W[i-15]) + W[i-16] mod 2^32 @@ -330,18 +332,15 @@ where }; let bits = a - .bits + .to_bits_le(cs.ns(|| "a to bits"))? .iter() - .zip(b.bits.iter()) - .zip(c.bits.iter()) + .zip(b.to_bits_le(cs.ns(|| "b to bits"))?.iter()) + .zip(c.to_bits_le(cs.ns(|| "c to bits"))?.iter()) .enumerate() .map(|(i, ((a, b), c))| circuit_fn(&mut cs, i, a, b, c)) .collect::>()?; - Ok(UInt32 { - bits, - value: new_value, - }) + Ok(UInt32::new(bits, new_value)) } /// Computes (a and b) xor ((not a) and c) diff --git a/r1cs/gadgets/crypto/src/signature/schnorr/mod.rs b/r1cs/gadgets/crypto/src/signature/schnorr/mod.rs index 99061345a..c9ddca96d 100644 --- a/r1cs/gadgets/crypto/src/signature/schnorr/mod.rs +++ b/r1cs/gadgets/crypto/src/signature/schnorr/mod.rs @@ -74,10 +74,7 @@ where randomness: &[UInt8], ) -> Result { let base = parameters.generator.clone(); - let randomness = randomness - .iter() - .flat_map(|b| b.into_bits_le()) - .collect::>(); + let randomness = randomness.to_bits_le(cs.ns(|| "randomness to bits"))?; let rand_pk = { let base_pow_rand = base.mul_bits(&mut cs.ns(|| "Compute randomizer"), randomness.iter())?; diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index b5a080bff..3b34ad867 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -7,7 +7,7 @@ macro_rules! impl_uint_gadget { use r1cs_core::{ConstraintSystemAbstract, SynthesisError, LinearCombination}; use crate::alloc::{AllocGadget, ConstantGadget}; - use algebra::{fields::{PrimeField, FpParameters}, ToConstraintField}; + use algebra::{fields::{PrimeField, FpParameters, Field}, ToConstraintField}; use std::{borrow::Borrow, ops::{Shl, Shr}, convert::TryInto, cmp::Ordering}; @@ -167,21 +167,39 @@ macro_rules! impl_uint_gadget { Ok(()) } - // Return little endian representation of self. Will be removed when to_bits_le and - // from_bits_le will be merged. - pub fn into_bits_le(&self) -> Vec { - self.bits.to_vec() - } - - // Construct self from its little endian bit representation. Will be removed when - // to_bits_le and from_bits_le will be merged. - pub fn from_bits_le(cs: CS, bits: &[Boolean]) -> Result + // Helper function that constructs self from an iterator over Booleans. + // It is employed as a building block by the public functions of the `FromBitsGadget` + fn from_bit_iterator<'a, ConstraintF, CS>(_cs: CS, bits_iter: impl Iterator) -> Result where ConstraintF: PrimeField, CS: ConstraintSystemAbstract, { - let be_bits = bits.iter().rev().map(|el| *el).collect::>(); - Self::from_bits(cs, &be_bits) + let mut bits = Vec::with_capacity($bit_size); + let mut value: Option<$native_type> = Some(0); + for (i, el) in bits_iter.enumerate() { + bits.push(*el); + value = match el.get_value() { + Some(bit) => value.as_mut().map(|v| {*v |= + if bit { + let mask: $native_type = 1; + mask << i + } else { + 0 + }; *v}), + None => None, + }; + } + + if bits.len() != $bit_size { + let mut error_msg = String::from(concat!("error: building ", stringify!($type_name))); + error_msg.push_str(format!("from slice of {} bits", bits.len()).as_str()); + return Err(SynthesisError::Other(error_msg)) + } + + Ok(Self { + bits, + value, + }) } // Construct Self from a little endian byte representation, provided in the form of @@ -198,7 +216,7 @@ macro_rules! impl_uint_gadget { *v}), None => None, }; - bits.append(&mut byte.into_bits_le()); + bits.extend_from_slice(byte.bits.as_slice()); } // pad with 0 bits to get to $bit_size @@ -254,7 +272,7 @@ macro_rules! impl_uint_gadget { result_bits .iter() .skip(max_overflow_bits) - .map(|el| *el) + .cloned() .collect::>() }; // addend is equal to coeff*digit mod 2^$bit_size @@ -394,53 +412,83 @@ macro_rules! impl_uint_gadget { } impl ToBitsGadget for $type_name { - fn to_bits>( - &self, - _cs: CS, - ) -> Result, SynthesisError> { - //Need to reverse bits since to_bits must return a big-endian representation - let le_bits = self.bits.iter().rev().map(|el| *el).collect::>(); - Ok(le_bits) - } + fn to_bits>( + &self, + _cs: CS, + ) -> Result, SynthesisError> { + //Need to reverse bits since to_bits must return a big-endian representation + let be_bits = self.bits.iter().rev().cloned().collect::>(); + Ok(be_bits) + } - fn to_bits_strict>( - &self, - cs: CS, - ) -> Result, SynthesisError> { - self.to_bits(cs) - } + fn to_bits_strict>( + &self, + cs: CS, + ) -> Result, SynthesisError> { + self.to_bits(cs) + } + + fn to_bits_le>( + &self, + _cs: CS, + ) -> Result, SynthesisError> { + Ok(self.bits.clone()) + } + + fn to_bits_strict_le>( + &self, + cs: CS, + ) -> Result, SynthesisError> { + self.to_bits_le(cs) + } + } + + impl ToBitsGadget for Vec<$type_name> { + fn to_bits>( + &self, + cs: CS, + ) -> Result, SynthesisError> { + let mut le_bits = self.to_bits_le(cs)?; + //Need to reverse bits since to_bits must return a big-endian representation + le_bits.reverse(); + Ok(le_bits) + } + + fn to_bits_strict>( + &self, + cs: CS, + ) -> Result, SynthesisError> { + self.to_bits(cs) + } + + fn to_bits_le>( + &self, + _cs: CS, + ) -> Result, SynthesisError> { + Ok(self.iter().flat_map(|el| el.bits.clone()).collect::>()) + } + + fn to_bits_strict_le>( + &self, + cs: CS, + ) -> Result, SynthesisError> { + self.to_bits_le(cs) + } } impl FromBitsGadget for $type_name { fn from_bits>( - _cs: CS, + cs: CS, bits: &[Boolean], ) -> Result { - if bits.len() != $bit_size { - let mut error_msg = String::from(concat!("error: building ", stringify!($type_name))); - error_msg.push_str(format!("from slice of {} bits", bits.len()).as_str()); - return Err(SynthesisError::Other(error_msg)) - } - let mut le_bits = Vec::with_capacity($bit_size); - let mut value: Option<$native_type> = Some(0); - for (i, el) in bits.iter().rev().enumerate() { - le_bits.push(*el); - value = match el.get_value() { - Some(bit) => value.as_mut().map(|v| {*v |= - if bit { - let mask: $native_type = 1; - mask << i - } else { - 0 - }; *v}), - None => None, - }; - } + Self::from_bit_iterator(cs, bits.iter().rev()) + } - Ok(Self { - bits: le_bits, - value, - }) + fn from_bits_le>( + cs: CS, + bits: &[Boolean], + ) -> Result { + Self::from_bit_iterator(cs, bits.iter()) } } @@ -528,7 +576,7 @@ macro_rules! impl_uint_gadget { .iter() // append rhs zeros as least significant bits .chain(self.bits.iter()) // Chain existing bits as most significant bits starting from least significant ones .take($bit_size) // Crop after $bit_size bits - .map(|el| *el) + .cloned() .collect(); Self { @@ -553,7 +601,7 @@ macro_rules! impl_uint_gadget { .iter() .skip(by) // skip least significant bits which are removed by the shift .chain(vec![Boolean::constant(false); by].iter()) // append zeros as most significant bits - .map(|el| *el) + .cloned() .collect(); Self { @@ -574,7 +622,7 @@ macro_rules! impl_uint_gadget { .skip($bit_size - by) .chain(self.bits.iter()) .take($bit_size) - .map(|el| *el) + .cloned() .collect(); Self { @@ -592,7 +640,7 @@ macro_rules! impl_uint_gadget { .skip(by) .chain(self.bits.iter()) .take($bit_size) - .map(|el| *el) + .cloned() .collect(); Self { @@ -650,6 +698,7 @@ macro_rules! impl_uint_gadget { impl_binary_bitwise_operation!(or, |, or); impl_binary_bitwise_operation!(and, &, and); + fn not>(&self, _cs: CS) -> Self { let bits = self.bits.iter().map(|el| el.not()).collect::>(); @@ -914,7 +963,7 @@ macro_rules! impl_uint_gadget { let result_lsbs = result_bits .iter() .skip((num_operands-1)*$bit_size) - .map(|el| *el) + .cloned() .collect::>(); $type_name::from_bits(cs.ns(|| "packing result"), &result_lsbs[..]) @@ -1507,12 +1556,19 @@ macro_rules! impl_uint_gadget { let val: $native_type = rng.gen(); let alloc_var = alloc_fn(&mut cs, "alloc var", var_type, val); + // test big endian serialization + let bits = alloc_var.to_bits(cs.ns(|| "unpack variable to big-endian bits")).unwrap(); + assert_eq!(bits.len(), $bit_size, "unpacking value to big endian"); + + let reconstructed_var = $type_name::from_bits(cs.ns(|| "pack big-endian bits"), &bits).unwrap(); + test_uint_gadget_value(val, &reconstructed_var, "packing big-endian bits"); - let bits = alloc_var.to_bits(cs.ns(|| "unpack variable")).unwrap(); - assert_eq!(bits.len(), $bit_size, "unpacking value"); + // test little endian serialization + let bits = alloc_var.to_bits_le(cs.ns(|| "unpack variable to little-endian bits")).unwrap(); + assert_eq!(bits.len(), $bit_size, "unpacking value to little-endian"); - let reconstructed_var = $type_name::from_bits(cs.ns(|| "pack bits"), &bits).unwrap(); - test_uint_gadget_value(val, &reconstructed_var, "packing bits"); + let reconstructed_var = $type_name::from_bits_le(cs.ns(|| "pack little-endian bits"), &bits).unwrap(); + test_uint_gadget_value(val, &reconstructed_var, "packing little-endian bits"); } } @@ -1545,57 +1601,105 @@ macro_rules! impl_uint_gadget { #[test] fn test_from_bits() { let rng = &mut thread_rng(); - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + const NUM_RUNS: usize = 5; + for _ in 0..NUM_RUNS { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - let mut bits = Vec::with_capacity($bit_size); // vector of Booleans - let mut bit_values = Vec::with_capacity($bit_size); // vector of the actual values wrapped by Booleans found in bits vector - for i in 0..$bit_size { - let bit_value: bool = rng.gen(); - // we test all types of Booleans - match i % 3 { - 0 => { - bit_values.push(bit_value); - bits.push(Boolean::Constant(bit_value)) - }, - 1 => { - bit_values.push(bit_value); - let bit = Boolean::alloc(cs.ns(|| format!("alloc bit {}", i)), || Ok(bit_value)).unwrap(); - bits.push(bit) - }, - 2 => { - bit_values.push(!bit_value); - let bit = Boolean::alloc(cs.ns(|| format!("alloc bit {}", i)), || Ok(bit_value)).unwrap(); - bits.push(bit.not()) - }, - _ => {}, - } - } + let mut bits = Vec::with_capacity($bit_size); // vector of Booleans + let mut bit_values = Vec::with_capacity($bit_size); // vector of the actual values wrapped by Booleans found in bits vector + for i in 0..$bit_size { + let bit_value: bool = rng.gen(); + // we test all types of Booleans + match i % 3 { + 0 => { + bit_values.push(bit_value); + bits.push(Boolean::Constant(bit_value)) + }, + 1 => { + bit_values.push(bit_value); + let bit = Boolean::alloc(cs.ns(|| format!("alloc bit {}", i)), || Ok(bit_value)).unwrap(); + bits.push(bit) + }, + 2 => { + bit_values.push(!bit_value); + let bit = Boolean::alloc(cs.ns(|| format!("alloc bit {}", i)), || Ok(bit_value)).unwrap(); + bits.push(bit.not()) + }, + _ => {}, + } + } + let little_endian: bool = rng.gen(); - let uint_gadget = $type_name::from_bits(cs.ns(|| "pack random bits"), &bits).unwrap(); - let value = uint_gadget.get_value().unwrap(); + // construct $type_name from bits and check correctness + let uint_gadget = if little_endian { + $type_name::from_bits_le(cs.ns(|| "pack random bits"), &bits).unwrap() + } else { + $type_name::from_bits(cs.ns(|| "pack random bits"), &bits).unwrap() + }; + let value = uint_gadget.get_value().unwrap(); + + for (i, el) in uint_gadget.bits.iter().enumerate() { + let bit = el.get_value().unwrap(); + let bit_index = if little_endian {i} else {$bit_size-1-i}; + assert_eq!(bit, bits[bit_index].get_value().unwrap()); + assert_eq!(bit, bit_values[bit_index]); + assert_eq!(bit, (value >> i) & 1 == 1); + } - for (i, el) in uint_gadget.bits.iter().enumerate() { - let bit = el.get_value().unwrap(); - assert_eq!(bit, bits[$bit_size-1-i].get_value().unwrap()); - assert_eq!(bit, bit_values[$bit_size-1-i]); - assert_eq!(bit, (value >> i) & 1 == 1); - } + // check that to_bits(from_bits(bits)) == bits + let unpacked_bits = if little_endian { + uint_gadget.to_bits_le(cs.ns(|| "unpack bits")).unwrap() + } else { + uint_gadget.to_bits(cs.ns(|| "unpack bits")).unwrap() + }; + for (bit1, bit2) in bits.iter().zip(unpacked_bits.iter()) { + assert_eq!(bit1, bit2); + } - // check that to_bits(from_bits(bits)) == bits - let unpacked_bits = uint_gadget.to_bits(cs.ns(|| "unpack bits")).unwrap(); + //check that an error is returned if more than $bit_size bits are unpacked + let mut bits = Vec::with_capacity($bit_size+1); + for _ in 0..$bit_size+1 { + bits.push(Boolean::constant(false)); + } + + if little_endian { + $type_name::from_bits_le(cs.ns(|| "unpacking too many bits"), &bits).unwrap_err(); + } else { + $type_name::from_bits(cs.ns(|| "unpacking too many bits"), &bits).unwrap_err(); + } - for (bit1, bit2) in bits.iter().zip(unpacked_bits.iter()) { - assert_eq!(bit1, bit2); } + } - //check that an error is returned if more than $bit_size bits are unpacked - let mut bits = Vec::with_capacity($bit_size+1); - for _ in 0..$bit_size+1 { - bits.push(Boolean::constant(false)); + #[test] + fn test_vec_serialization() { + let rng = &mut thread_rng(); + for var_type in VARIABLE_TYPES.iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + let vec_len = rng.gen_range(0..20); + + let values = (0..vec_len).map(|_| rng.gen()).collect::>(); + + let alloc_vec = match var_type { + VariableType::Allocated => $type_name::alloc_vec(cs.ns(|| "alloc vec"), &values).unwrap(), + VariableType::PublicInput => $type_name::alloc_input_vec(cs.ns(|| "alloc vec"), &values).unwrap(), + VariableType::Constant => $type_name::constant_vec(&values), + }; + + let bits = alloc_vec.to_bits_le(cs.ns(|| "vec to little-endian bits")).unwrap(); + for (i, (val, bit_chunk)) in values.iter().zip(bits.chunks($bit_size)).enumerate() { + let chunk_val = $type_name::from_bits_le(cs.ns(|| format!("pack chunk {} of little-endian bits", i)), bit_chunk).unwrap(); + test_uint_gadget_value(*val, &chunk_val, format!("check chunk {} of little-endian bits", i).as_str()); + } + + let bits = alloc_vec.to_bits(cs.ns(|| "vec to big-endian bits")).unwrap(); + for (i, (val, bit_chunk)) in values.iter().rev().zip(bits.chunks($bit_size)).enumerate() { + let chunk_val = $type_name::from_bits(cs.ns(|| format!("pack chunk {} of big-endian bits", i)), bit_chunk).unwrap(); + test_uint_gadget_value(*val, &chunk_val, format!("check chunk {} of big-endian bits", i).as_str()); + } } - let _ = $type_name::from_bits(cs.ns(|| "unpacking too many bits"), &bits).unwrap_err(); } #[test] @@ -1762,7 +1866,7 @@ macro_rules! impl_uint_gadget { alloc_fn(&mut cs, format!("alloc operand {}", i).as_str(), &VARIABLE_TYPES[i % 3], *val) }).collect::>(); - let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a,b| a.overflowing_add(b).0).unwrap(); + let result_value: $native_type = operand_values.iter().cloned().reduce(|a,b| a.overflowing_add(b).0).unwrap(); let result_var = { // add a scope for multi_eq CS as the constraints are enforced when the variable is dropped @@ -1819,7 +1923,7 @@ macro_rules! impl_uint_gadget { alloc_fn(&mut cs, format!("alloc operand {}", i).as_str(), &VARIABLE_TYPES[i % 3], *val) }).collect::>(); - let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a,b| a.overflowing_mul(b).0).unwrap(); + let result_value: $native_type = operand_values.iter().cloned().reduce(|a,b| a.overflowing_mul(b).0).unwrap(); let result_var = $type_name::mulmany(cs.ns(|| "mul operands"), &operands).unwrap(); @@ -1997,7 +2101,7 @@ macro_rules! impl_uint_gadget { }).collect::>(); let mut is_overflowing = false; - let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a,b| { + let result_value: $native_type = operand_values.iter().cloned().reduce(|a,b| { let (updated_sum, overflow) = a.overflowing_add(b); is_overflowing |= overflow; updated_sum @@ -2038,7 +2142,7 @@ macro_rules! impl_uint_gadget { // computation of result_value will panic in case of addition overflows, but it // should never happen given how we generate operand_values - let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a, b| a*b).unwrap(); + let result_value: $native_type = operand_values.iter().cloned().reduce(|a, b| a*b).unwrap(); let result_var = $type_name::mulmany_nocarry(cs.ns(|| "mul operands"), &operands).unwrap(); @@ -2109,7 +2213,7 @@ macro_rules! impl_uint_gadget { let mut is_overflowing = false; - let result_value: $native_type = operand_values.iter().map(|el| *el).reduce(|a,b| { + let result_value: $native_type = operand_values.iter().cloned().reduce(|a,b| { let (updated_sum, overflow) = a.overflowing_mul(b); is_overflowing |= overflow; updated_sum diff --git a/r1cs/gadgets/std/src/bits/mod.rs b/r1cs/gadgets/std/src/bits/mod.rs index 7c23efd01..41a8080a6 100644 --- a/r1cs/gadgets/std/src/bits/mod.rs +++ b/r1cs/gadgets/std/src/bits/mod.rs @@ -332,13 +332,9 @@ impl ToBitsGadget for Vec { impl ToBitsGadget for [UInt8] { fn to_bits>( &self, - _cs: CS, + cs: CS, ) -> Result, SynthesisError> { - let mut result = Vec::with_capacity(&self.len() * 8); - for byte in self { - result.extend_from_slice(&byte.into_bits_le()); - } - Ok(result) + self.to_vec().to_bits(cs) } fn to_bits_strict>( diff --git a/r1cs/gadgets/std/src/fields/fp.rs b/r1cs/gadgets/std/src/fields/fp.rs index 13a907471..669648cd6 100644 --- a/r1cs/gadgets/std/src/fields/fp.rs +++ b/r1cs/gadgets/std/src/fields/fp.rs @@ -104,9 +104,7 @@ impl FpGadget { let mut lc = LinearCombination::zero(); let mut coeff = F::one(); - for bit in bytes - .iter() - .flat_map(|byte_gadget| byte_gadget.bits.clone()) + for bit in bytes.to_bits_le(cs.ns(|| "convert allocated bytes to bits"))? { match bit { Boolean::Is(bit) => { @@ -553,14 +551,10 @@ impl ToBytesGadget for FpGadget { mut cs: CS, ) -> Result, SynthesisError> { let bytes = self.to_bytes(&mut cs)?; + let be_bits = bytes.to_bits(cs.ns(|| "convert to bits"))?; Boolean::enforce_in_field::<_, _, F>( &mut cs, - &bytes - .iter() - .flat_map(|byte_gadget| byte_gadget.into_bits_le()) - // This reverse maps the bits into big-endian form, as required by `enforce_in_field`. - .rev() - .collect::>(), + be_bits.as_slice(), )?; Ok(bytes) diff --git a/r1cs/gadgets/std/src/lib.rs b/r1cs/gadgets/std/src/lib.rs index db85f4613..6d0f7b402 100644 --- a/r1cs/gadgets/std/src/lib.rs +++ b/r1cs/gadgets/std/src/lib.rs @@ -74,7 +74,7 @@ pub mod prelude { alloc::*, bits::{ boolean::Boolean, uint32::UInt32, UInt8, FromBitsGadget, ToBitsGadget, - ToBytesGadget, UIntGadget, + ToBytesGadget, UIntGadget, RotateUInt, }, eq::*, fields::{cubic_extension::*, quadratic_extension::*, FieldGadget}, From a4eab00f759698055eb798743b91d64448a61f15 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 11 Mar 2022 12:24:52 +0100 Subject: [PATCH 13/18] Move rotate functions to UIntGadget --- r1cs/gadgets/crypto/src/prf/blake2s/mod.rs | 8 ++-- r1cs/gadgets/crypto/src/prf/ripemd160.rs | 10 ++--- r1cs/gadgets/crypto/src/prf/sha256.rs | 34 +++++++++------ r1cs/gadgets/std/src/bits/macros.rs | 50 ++++++++++++++++++---- r1cs/gadgets/std/src/bits/mod.rs | 22 +++++----- r1cs/gadgets/std/src/lib.rs | 2 +- 6 files changed, 82 insertions(+), 44 deletions(-) diff --git a/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs b/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs index b78581d52..88cb40540 100644 --- a/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs +++ b/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs @@ -92,16 +92,16 @@ fn mixing_g>( cs.ns(|| "mixing step 1"), &[v[a].clone(), v[b].clone(), x.clone()], )?; - v[d] = v[d].xor(cs.ns(|| "mixing step 2"), &v[a])?.rotr(R1); + v[d] = v[d].xor(cs.ns(|| "mixing step 2"), &v[a])?.rotr(R1, &mut cs); v[c] = UInt32::addmany(cs.ns(|| "mixing step 3"), &[v[c].clone(), v[d].clone()])?; - v[b] = v[b].xor(cs.ns(|| "mixing step 4"), &v[c])?.rotr(R2); + v[b] = v[b].xor(cs.ns(|| "mixing step 4"), &v[c])?.rotr(R2, &mut cs); v[a] = UInt32::addmany( cs.ns(|| "mixing step 5"), &[v[a].clone(), v[b].clone(), y.clone()], )?; - v[d] = v[d].xor(cs.ns(|| "mixing step 6"), &v[a])?.rotr(R3); + v[d] = v[d].xor(cs.ns(|| "mixing step 6"), &v[a])?.rotr(R3, &mut cs); v[c] = UInt32::addmany(cs.ns(|| "mixing step 7"), &[v[c].clone(), v[d].clone()])?; - v[b] = v[b].xor(cs.ns(|| "mixing step 8"), &v[c])?.rotr(R4); + v[b] = v[b].xor(cs.ns(|| "mixing step 8"), &v[c])?.rotr(R4, &mut cs); Ok(()) } diff --git a/r1cs/gadgets/crypto/src/prf/ripemd160.rs b/r1cs/gadgets/crypto/src/prf/ripemd160.rs index 145a2d13d..1c9508c95 100644 --- a/r1cs/gadgets/crypto/src/prf/ripemd160.rs +++ b/r1cs/gadgets/crypto/src/prf/ripemd160.rs @@ -7,7 +7,7 @@ use r1cs_std::boolean::Boolean; use r1cs_std::eq::MultiEq; use r1cs_std::uint32::UInt32; use r1cs_std::UInt8; -use r1cs_std::{UIntGadget, RotateUInt, ToBitsGadget, FromBitsGadget}; +use r1cs_std::{UIntGadget, ToBitsGadget, FromBitsGadget}; use crate::sha256::{sha256_ch_boolean, triop}; @@ -217,7 +217,7 @@ where cs.ns(|| format!("first rotl(a + f + x + k) {}", i)), &[a, f, selected_input_word, get_round_constants(i).0], )? - .rotl(S[i]); + .rotl(S[i], &mut *cs); UInt32::addmany( cs.ns(|| format!("compute first T {}", i)), &[result, e.clone()], @@ -226,7 +226,7 @@ where a = e; e = d; - d = c.rotl(10); + d = c.rotl(10, &mut *cs); c = b; b = t; @@ -245,7 +245,7 @@ where cs.ns(|| format!("second rotl(a + f + x + k) {}", i)), &[a_prime, f, selected_input_word, get_round_constants(i).1], )? - .rotl(S_PRIME[i]); + .rotl(S_PRIME[i], &mut *cs); UInt32::addmany( cs.ns(|| format!("compute second T {}", i)), &[result, e_prime.clone()], @@ -254,7 +254,7 @@ where a_prime = e_prime; e_prime = d_prime; - d_prime = c_prime.rotl(10); + d_prime = c_prime.rotl(10, &mut *cs); c_prime = b_prime; b_prime = t; } diff --git a/r1cs/gadgets/crypto/src/prf/sha256.rs b/r1cs/gadgets/crypto/src/prf/sha256.rs index b20933eab..71a3604ea 100644 --- a/r1cs/gadgets/crypto/src/prf/sha256.rs +++ b/r1cs/gadgets/crypto/src/prf/sha256.rs @@ -7,8 +7,8 @@ use algebra::PrimeField; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; use r1cs_std::boolean::AllocatedBit; use r1cs_std::eq::MultiEq; -use r1cs_std::uint32::UInt32; -use r1cs_std::{boolean::Boolean, Assignment, UIntGadget, RotateUInt, ToBitsGadget, FromBitsGadget}; +use r1cs_std::{uint32::UInt32, UIntGadget}; +use r1cs_std::{boolean::Boolean, Assignment, ToBitsGadget, FromBitsGadget}; #[allow(clippy::unreadable_literal)] const ROUND_CONSTANTS: [u32; 64] = [ @@ -112,15 +112,17 @@ where // Compute SHA256_sigma0(w[i-15]) // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) - let mut s0 = UInt32::rotr(&w[i - 15], 7); - s0 = s0.xor(cs.ns(|| "first xor for s0"), &UInt32::rotr(&w[i - 15], 18))?; - s0 = s0.xor(cs.ns(|| "second xor for s0"), &w[i - 15].clone()(3))?; + let mut s0 = UInt32::rotr(&w[i - 15], 7, &mut *cs); + let rotated_word = UInt32::rotr(&w[i - 15], 18, &mut *cs); + s0 = s0.xor(cs.ns(|| "first xor for s0"), &rotated_word)?; + s0 = s0.xor(cs.ns(|| "second xor for s0"), &w[i-15].clone().shr(3))?; // Compute SHA256_sigma1(w[i-2]) // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) - let mut s1 = UInt32::rotr(&w[i - 2], 17); - s1 = s1.xor(cs.ns(|| "first xor for s1"), &UInt32::rotr(&w[i - 2], 19))?; - s1 = s1.xor(cs.ns(|| "second xor for s1"), &w[i - 2].clone().shr(10))?; + let mut s1 = UInt32::rotr(&w[i - 2], 17, &mut *cs); + let rotated_word = UInt32::rotr(&w[i - 2], 19, &mut *cs); + s1 = s1.xor(cs.ns(|| "first xor for s1"), &rotated_word)?; + s1 = s1.xor(cs.ns(|| "second xor for s1"), &w[i-2].clone().shr(10))?; // Compute W[i] = SHA256_sigma1(W[i-2]) + W[i-7] + // SHA256_sigma0(W[i-15]) + W[i-16] mod 2^32 @@ -176,9 +178,11 @@ where // S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25) let new_e = e.compute(cs.ns(|| "deferred e computation"), &[])?; - let mut s1 = new_e.rotr(6); - s1 = s1.xor(cs.ns(|| "first xor for s1"), &new_e.rotr(11))?; - s1 = s1.xor(cs.ns(|| "second xor for s1"), &new_e.rotr(25))?; + let mut s1 = new_e.rotr(6, &mut *cs); + let new_e_rotated = new_e.rotr(11, &mut *cs); + s1 = s1.xor(cs.ns(|| "first xor for s1"), &new_e_rotated)?; + let new_e_rotated = new_e.rotr(25, &mut *cs); + s1 = s1.xor(cs.ns(|| "second xor for s1"), &new_e_rotated)?; // ch := (e and f) xor ((not e) and g) let ch = sha256_ch_uint32(cs.ns(|| "ch"), &new_e, &f, &g)?; @@ -194,9 +198,11 @@ where // S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22) let new_a = a.compute(cs.ns(|| "deferred a computation"), &[])?; - let mut s0 = new_a.rotr(2); - s0 = s0.xor(cs.ns(|| "first xor for s0"), &new_a.rotr(13))?; - s0 = s0.xor(cs.ns(|| "second xor for s0"), &new_a.rotr(22))?; + let mut s0 = new_a.rotr(2, &mut *cs); + let new_a_rotated = new_a.rotr(13, &mut *cs); + s0 = s0.xor(cs.ns(|| "first xor for s0"), &new_a_rotated)?; + let new_a_rotated = new_a.rotr(22, &mut *cs); + s0 = s0.xor(cs.ns(|| "second xor for s0"), &new_a_rotated)?; // maj := (a and b) xor (a and c) xor (b and c) let maj = sha256_maj_uint32(cs.ns(|| "maj"), &new_a, &b, &c)?; diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index 3b34ad867..0e71ac317 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -2,7 +2,7 @@ macro_rules! impl_uint_gadget { ($type_name: ident, $bit_size: expr, $native_type: ident, $mod_name: ident) => { pub mod $mod_name { - use crate::{boolean::{Boolean, AllocatedBit}, fields::{fp::FpGadget, FieldGadget}, eq::{EqGadget, MultiEq}, ToBitsGadget, FromBitsGadget, ToBytesGadget, RotateUInt, UIntGadget, select::CondSelectGadget, bits::UInt8, Assignment, cmp::ComparisonGadget}; + use crate::{boolean::{Boolean, AllocatedBit}, fields::{fp::FpGadget, FieldGadget}, eq::{EqGadget, MultiEq}, ToBitsGadget, FromBitsGadget, ToBytesGadget, UIntGadget, select::CondSelectGadget, bits::UInt8, Assignment, cmp::ComparisonGadget}; use r1cs_core::{ConstraintSystemAbstract, SynthesisError, LinearCombination}; use crate::alloc::{AllocGadget, ConstantGadget}; @@ -612,7 +612,7 @@ macro_rules! impl_uint_gadget { } - impl RotateUInt for $type_name { + /*impl RotateUInt for $type_name { fn rotl(&self, by: usize) -> Self { let by = by % $bit_size; @@ -648,7 +648,7 @@ macro_rules! impl_uint_gadget { value: self.value.map(|v| v.rotate_right(by as u32)), } } - } + }*/ //this macro allows to implement the binary bitwise operations already available for Booleans (i.e., XOR, OR, AND) macro_rules! impl_binary_bitwise_operation { @@ -693,7 +693,41 @@ macro_rules! impl_uint_gadget { } impl UIntGadget for $type_name { + fn rotl>(&self, by: usize, _cs: CS) -> Self { + let by = by % $bit_size; + + let bits = self + .bits + .iter() + .skip($bit_size - by) + .chain(self.bits.iter()) + .take($bit_size) + .cloned() + .collect(); + + Self { + bits, + value: self.value.map(|v| v.rotate_left(by as u32)), + } + } + fn rotr>(&self, by: usize, _cs: CS) -> Self { + let by = by % $bit_size; + + let bits = self + .bits + .iter() + .skip(by) + .chain(self.bits.iter()) + .take($bit_size) + .cloned() + .collect(); + + Self { + bits, + value: self.value.map(|v| v.rotate_right(by as u32)), + } + } impl_binary_bitwise_operation!(xor, ^, xor); impl_binary_bitwise_operation!(or, |, or); impl_binary_bitwise_operation!(and, &, and); @@ -1316,7 +1350,7 @@ macro_rules! impl_uint_gadget { use std::{ops::{Shl, Shr}, cmp::Ordering, cmp::max}; - use crate::{alloc::{AllocGadget, ConstantGadget}, eq::{EqGadget, MultiEq}, boolean::Boolean, ToBitsGadget, FromBitsGadget, ToBytesGadget, RotateUInt, UIntGadget, select::CondSelectGadget, bits::UInt8, cmp::ComparisonGadget}; + use crate::{alloc::{AllocGadget, ConstantGadget}, eq::{EqGadget, MultiEq}, boolean::Boolean, ToBitsGadget, FromBitsGadget, ToBytesGadget, UIntGadget, select::CondSelectGadget, bits::UInt8, cmp::ComparisonGadget}; fn test_uint_gadget_value(val: $native_type, alloc_val: &$type_name, check_name: &str) { @@ -1762,18 +1796,18 @@ macro_rules! impl_uint_gadget { let alloc_var = alloc_fn(&mut cs, "alloc var", var_type, value); for i in 0..$bit_size { - let rotl_var = alloc_var.rotl(i); + let rotl_var = alloc_var.rotl(i, &mut cs); test_uint_gadget_value(value.rotate_left(i as u32), &rotl_var, format!("left rotation by {}", i).as_str()); - let rotr_var = rotl_var.rotr(i); + let rotr_var = rotl_var.rotr(i, &mut cs); test_uint_gadget_value(value, &rotr_var, format!("right rotation by {}", i).as_str()); } //check rotations are ok even if by > $bit_size let by = $bit_size*2; - let rotl_var = alloc_var.rotl(by); + let rotl_var = alloc_var.rotl(by, &mut cs); test_uint_gadget_value(value.rotate_left(by as u32), &rotl_var, format!("left rotation by {}", by).as_str()); - let rotr_var = alloc_var.rotl(by); + let rotr_var = alloc_var.rotl(by, &mut cs); test_uint_gadget_value(value.rotate_right(by as u32), &rotr_var, format!("right rotation by {}", by).as_str()); } } diff --git a/r1cs/gadgets/std/src/bits/mod.rs b/r1cs/gadgets/std/src/bits/mod.rs index 41a8080a6..886c53450 100644 --- a/r1cs/gadgets/std/src/bits/mod.rs +++ b/r1cs/gadgets/std/src/bits/mod.rs @@ -82,17 +82,6 @@ where } } -// this trait allows to move out rotl and rotr from UIntGadget, in turn allowing to avoid specifying -// for the compiler a field ConstraintF every time these methods are called, which requires a -// verbose syntax (e.g., UIntGadget::::rotl(&gadget_variable, i) -pub trait RotateUInt { - /// Rotate left `self` by `by` bits. - fn rotl(&self, by: usize) -> Self; - - /// Rotate right `self` by `by` bits. - fn rotr(&self, by: usize) -> Self; -} - pub trait UIntGadget: Sized + Clone @@ -108,8 +97,17 @@ Sized + ConstantGadget + Shr + Shl -+ RotateUInt { + /// Rotate left `self` by `by` bits. + fn rotl(&self, by: usize, cs: CS) -> Self + where + CS: ConstraintSystemAbstract; + + /// Rotate right `self` by `by` bits. + fn rotr(&self, by: usize, cs: CS) -> Self + where + CS: ConstraintSystemAbstract; + /// XOR `self` with `other` fn xor(&self, cs: CS, other: &Self) -> Result where diff --git a/r1cs/gadgets/std/src/lib.rs b/r1cs/gadgets/std/src/lib.rs index 6d0f7b402..db85f4613 100644 --- a/r1cs/gadgets/std/src/lib.rs +++ b/r1cs/gadgets/std/src/lib.rs @@ -74,7 +74,7 @@ pub mod prelude { alloc::*, bits::{ boolean::Boolean, uint32::UInt32, UInt8, FromBitsGadget, ToBitsGadget, - ToBytesGadget, UIntGadget, RotateUInt, + ToBytesGadget, UIntGadget, }, eq::*, fields::{cubic_extension::*, quadratic_extension::*, FieldGadget}, From 2cfea88dd47147241914721c64a3529d14507acb Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Thu, 24 Mar 2022 12:22:47 +0100 Subject: [PATCH 14/18] remove modules for each uint implementation --- r1cs/gadgets/crypto/src/prf/ripemd160.rs | 2 +- r1cs/gadgets/crypto/src/prf/sha256.rs | 2 +- r1cs/gadgets/std/src/bits/boolean.rs | 4 +- r1cs/gadgets/std/src/bits/macros.rs | 78 +++++------------------- r1cs/gadgets/std/src/bits/mod.rs | 23 +++++-- r1cs/gadgets/std/src/eq.rs | 6 +- r1cs/gadgets/std/src/lib.rs | 2 +- 7 files changed, 39 insertions(+), 78 deletions(-) diff --git a/r1cs/gadgets/crypto/src/prf/ripemd160.rs b/r1cs/gadgets/crypto/src/prf/ripemd160.rs index 1c9508c95..21dc275c6 100644 --- a/r1cs/gadgets/crypto/src/prf/ripemd160.rs +++ b/r1cs/gadgets/crypto/src/prf/ripemd160.rs @@ -5,7 +5,7 @@ use algebra::PrimeField; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; use r1cs_std::boolean::Boolean; use r1cs_std::eq::MultiEq; -use r1cs_std::uint32::UInt32; +use r1cs_std::uint::UInt32; use r1cs_std::UInt8; use r1cs_std::{UIntGadget, ToBitsGadget, FromBitsGadget}; diff --git a/r1cs/gadgets/crypto/src/prf/sha256.rs b/r1cs/gadgets/crypto/src/prf/sha256.rs index 71a3604ea..16915ae58 100644 --- a/r1cs/gadgets/crypto/src/prf/sha256.rs +++ b/r1cs/gadgets/crypto/src/prf/sha256.rs @@ -7,7 +7,7 @@ use algebra::PrimeField; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; use r1cs_std::boolean::AllocatedBit; use r1cs_std::eq::MultiEq; -use r1cs_std::{uint32::UInt32, UIntGadget}; +use r1cs_std::{uint::UInt32, UIntGadget}; use r1cs_std::{boolean::Boolean, Assignment, ToBitsGadget, FromBitsGadget}; #[allow(clippy::unreadable_literal)] diff --git a/r1cs/gadgets/std/src/bits/boolean.rs b/r1cs/gadgets/std/src/bits/boolean.rs index b170186ab..3c4304c5e 100644 --- a/r1cs/gadgets/std/src/bits/boolean.rs +++ b/r1cs/gadgets/std/src/bits/boolean.rs @@ -786,14 +786,14 @@ impl Boolean { /// In addition, the function also returns a flag which specifies if there are no "real" variables /// in the linear combination, that is if the sequence of Booleans comprises all constant values. /// Assumes that `bits` can be packed in a single field element (i.e., bits.len() <= ConstraintF::Params::CAPACITY). - pub fn bits_to_linear_combination<'a, ConstraintF:Field>(bits: impl Iterator, one: Variable) -> (LinearCombination, Option, bool) + pub fn bits_to_linear_combination<'a, ConstraintF:Field, CS: ConstraintSystemAbstract>(_cs: CS, bits: impl Iterator) -> (LinearCombination, Option, bool) { let mut lc = LinearCombination::zero(); let mut coeff = ConstraintF::one(); let mut lc_in_field = Some(ConstraintF::zero()); let mut all_constants = true; for bit in bits { - lc = lc + &bit.lc(one, coeff); + lc = lc + &bit.lc(CS::one(), coeff); all_constants &= bit.is_constant(); lc_in_field = match bit.get_value() { Some(b) => lc_in_field.as_mut().map(|val| { diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index 0e71ac317..6c5cc4d38 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -1,16 +1,5 @@ macro_rules! impl_uint_gadget { - ($type_name: ident, $bit_size: expr, $native_type: ident, $mod_name: ident) => { - pub mod $mod_name { - - use crate::{boolean::{Boolean, AllocatedBit}, fields::{fp::FpGadget, FieldGadget}, eq::{EqGadget, MultiEq}, ToBitsGadget, FromBitsGadget, ToBytesGadget, UIntGadget, select::CondSelectGadget, bits::UInt8, Assignment, cmp::ComparisonGadget}; - - use r1cs_core::{ConstraintSystemAbstract, SynthesisError, LinearCombination}; - use crate::alloc::{AllocGadget, ConstantGadget}; - - use algebra::{fields::{PrimeField, FpParameters, Field}, ToConstraintField}; - - use std::{borrow::Borrow, ops::{Shl, Shr}, convert::TryInto, cmp::Ordering}; - + ($type_name: ident, $bit_size: expr, $native_type: ident) => { #[derive(Clone, Debug)] pub struct $type_name { @@ -611,45 +600,6 @@ macro_rules! impl_uint_gadget { } } - - /*impl RotateUInt for $type_name { - fn rotl(&self, by: usize) -> Self { - let by = by % $bit_size; - - let bits = self - .bits - .iter() - .skip($bit_size - by) - .chain(self.bits.iter()) - .take($bit_size) - .cloned() - .collect(); - - Self { - bits, - value: self.value.map(|v| v.rotate_left(by as u32)), - } - } - - fn rotr(&self, by: usize) -> Self { - let by = by % $bit_size; - - let bits = self - .bits - .iter() - .skip(by) - .chain(self.bits.iter()) - .take($bit_size) - .cloned() - .collect(); - - Self { - bits, - value: self.value.map(|v| v.rotate_right(by as u32)), - } - } - }*/ - //this macro allows to implement the binary bitwise operations already available for Booleans (i.e., XOR, OR, AND) macro_rules! impl_binary_bitwise_operation { ($func_name: ident, $op: tt, $boolean_func: tt) => { @@ -813,7 +763,7 @@ macro_rules! impl_uint_gadget { }, }; - let (current_lc, _, is_op_constant) = Boolean::bits_to_linear_combination(op.bits.iter(), CS::one()); + let (current_lc, _, is_op_constant) = Boolean::bits_to_linear_combination(&mut cs, op.bits.iter()); lc = lc + current_lc; all_constants &= is_op_constant; } @@ -910,7 +860,7 @@ macro_rules! impl_uint_gadget { None => None, }; - let (current_lc, _, is_op_constant) = Boolean::bits_to_linear_combination(op.bits.iter(), CS::one()); + let (current_lc, _, is_op_constant) = Boolean::bits_to_linear_combination(&mut cs, op.bits.iter()); lc = lc + current_lc; all_constants &= is_op_constant; } @@ -923,7 +873,7 @@ macro_rules! impl_uint_gadget { } let result_var = $type_name::alloc(cs.ns(|| "alloc result"), || result_value.ok_or(SynthesisError::AssignmentMissing))?; - let (result_lc, _, _) = Boolean::bits_to_linear_combination(result_var.bits.iter(), CS::one()); + let (result_lc, _, _) = Boolean::bits_to_linear_combination(&mut cs, result_var.bits.iter()); cs.get_root().enforce_equal($bit_size, &lc, &result_lc); @@ -1107,8 +1057,8 @@ macro_rules! impl_uint_gadget { let max_value = ConstraintF::from($native_type::MAX) + ConstraintF::one(); // lc will be constructed as SUB(self,other)+2^$bit_size let mut lc = (max_value, CS::one()).into(); - let (self_lc, _, is_self_constant) = Boolean::bits_to_linear_combination(self.bits.iter(), CS::one()); - let (other_lc, _, is_other_constant) = Boolean::bits_to_linear_combination(other.bits.iter(), CS::one()); + let (self_lc, _, is_self_constant) = Boolean::bits_to_linear_combination(&mut cs, self.bits.iter()); + let (other_lc, _, is_other_constant) = Boolean::bits_to_linear_combination(&mut cs, other.bits.iter()); lc = lc + self_lc - other_lc; let all_constants = is_self_constant && is_other_constant; @@ -1184,8 +1134,8 @@ macro_rules! impl_uint_gadget { */ // lc is constructed as SUB(self, other) - let (self_lc, _, is_self_constant) = Boolean::bits_to_linear_combination(self.bits.iter(), CS::one()); - let (other_lc, _, is_other_constant) = Boolean::bits_to_linear_combination(other.bits.iter(), CS::one()); + let (self_lc, _, is_self_constant) = Boolean::bits_to_linear_combination(&mut cs, self.bits.iter()); + let (other_lc, _, is_other_constant) = Boolean::bits_to_linear_combination(&mut cs, other.bits.iter()); let lc = self_lc - other_lc; let all_constants = is_self_constant && is_other_constant; @@ -1208,7 +1158,7 @@ macro_rules! impl_uint_gadget { let diff_var = Self::alloc(cs.ns(|| "alloc diff"), || diff.ok_or(SynthesisError::AssignmentMissing))?; - let (diff_lc, _, _) = Boolean::bits_to_linear_combination(diff_var.bits.iter(), CS::one()); + let (diff_lc, _, _) = Boolean::bits_to_linear_combination(&mut cs, diff_var.bits.iter()); cs.get_root().enforce_equal($bit_size, &lc, &diff_lc); @@ -1254,9 +1204,9 @@ macro_rules! impl_uint_gadget { self.sub(multi_eq.ns(|| "a - b mod 2^n"), other)? }; - let (self_lc, _, is_self_constant) = Boolean::bits_to_linear_combination(self.bits.iter(), CS::one()); - let (other_lc, _, is_other_constant) = Boolean::bits_to_linear_combination(other.bits.iter(), CS::one()); - let (diff_lc, _, is_diff_constant) = Boolean::bits_to_linear_combination(diff_var.bits.iter(), CS::one()); + let (self_lc, _, is_self_constant) = Boolean::bits_to_linear_combination(&mut cs, self.bits.iter()); + let (other_lc, _, is_other_constant) = Boolean::bits_to_linear_combination(&mut cs, other.bits.iter()); + let (diff_lc, _, is_diff_constant) = Boolean::bits_to_linear_combination(&mut cs, diff_var.bits.iter()); let delta_lc = self_lc - other_lc - diff_lc; let all_constants = is_self_constant && is_other_constant && is_diff_constant; @@ -1340,7 +1290,8 @@ macro_rules! impl_uint_gadget { #[cfg(test)] - mod test { + paste::item! { + mod [] { use super::$type_name; use rand::{Rng, thread_rng}; use algebra::{fields::tweedle::Fr, Group, Field, FpParameters, PrimeField}; @@ -2686,7 +2637,6 @@ macro_rules! impl_uint_gadget { } } - } } } diff --git a/r1cs/gadgets/std/src/bits/mod.rs b/r1cs/gadgets/std/src/bits/mod.rs index 886c53450..c36863279 100644 --- a/r1cs/gadgets/std/src/bits/mod.rs +++ b/r1cs/gadgets/std/src/bits/mod.rs @@ -12,11 +12,22 @@ pub mod boolean; #[macro_use] pub mod macros; -impl_uint_gadget!(U8, 8, u8, uint8); -impl_uint_gadget!(UInt64, 64, u64, uint64); -impl_uint_gadget!(UInt32, 32, u32, uint32); -impl_uint_gadget!(UInt16, 16, u16, uint16); -impl_uint_gadget!(UInt128, 128, u128, uint128); +pub mod uint { + use crate::{boolean::{Boolean, AllocatedBit}, fields::{fp::FpGadget, FieldGadget}, eq::{EqGadget, MultiEq}, ToBitsGadget, FromBitsGadget, ToBytesGadget, UIntGadget, select::CondSelectGadget, bits::UInt8, Assignment, cmp::ComparisonGadget}; + + use r1cs_core::{ConstraintSystemAbstract, SynthesisError, LinearCombination}; + use crate::alloc::{AllocGadget, ConstantGadget}; + + use algebra::{fields::{PrimeField, FpParameters, Field}, ToConstraintField}; + + use std::{borrow::Borrow, ops::{Shl, Shr}, convert::TryInto, cmp::Ordering}; + + impl_uint_gadget!(U8, 8, u8); + impl_uint_gadget!(UInt64, 64, u64); + impl_uint_gadget!(UInt32, 32, u32); + impl_uint_gadget!(UInt16, 16, u16); + impl_uint_gadget!(UInt128, 128, u128); +} // This type alias allows to implement byte serialization/de-serialization functions inside the // `impl_uint_gadget` macro. @@ -24,7 +35,7 @@ impl_uint_gadget!(UInt128, 128, u128, uint128); // type for byte serialization/de-serialization functions. The type alias allows to employ a type // defined outside the macro in the interface of byte serialization functions, hence allowing to // implement them inside the `impl_uint_gadget` macro -pub type UInt8 = uint8::U8; +pub type UInt8 = uint::U8; pub trait ToBitsGadget { fn to_bits>( diff --git a/r1cs/gadgets/std/src/eq.rs b/r1cs/gadgets/std/src/eq.rs index 8733d2a4a..28d5b459d 100644 --- a/r1cs/gadgets/std/src/eq.rs +++ b/r1cs/gadgets/std/src/eq.rs @@ -138,7 +138,7 @@ impl, ConstraintF: Field> EqGadget for [T] // element whose little-endian bit representation is `self` (resp. other). // The function returns also a-b over the field (wrapped in an Option) and a flag // that specifies if all the bits in both `self` and `other` are constants. -fn compute_diff(self_bits: &[Boolean], mut _cs: CS, other_bits: &[Boolean]) -> (LinearCombination, Option, bool) +fn compute_diff(self_bits: &[Boolean], mut cs: CS, other_bits: &[Boolean]) -> (LinearCombination, Option, bool) where ConstraintF: PrimeField, CS: ConstraintSystemAbstract, @@ -147,8 +147,8 @@ fn compute_diff(self_bits: &[Boolean], mut _cs: CS, other_bits: assert!(self_bits.len() <= field_bits); assert!(other_bits.len() <= field_bits); - let (self_lc, self_in_field, is_self_constant) = Boolean::bits_to_linear_combination(self_bits.iter(), CS::one()); - let (other_lc, other_in_field, is_other_constant) = Boolean::bits_to_linear_combination(other_bits.iter(), CS::one()); + let (self_lc, self_in_field, is_self_constant) = Boolean::bits_to_linear_combination(&mut cs, self_bits.iter()); + let (other_lc, other_in_field, is_other_constant) = Boolean::bits_to_linear_combination(&mut cs, other_bits.iter()); let diff_in_field = match (self_in_field, other_in_field) { (Some(self_val), Some(other_val)) => Some(self_val-other_val), diff --git a/r1cs/gadgets/std/src/lib.rs b/r1cs/gadgets/std/src/lib.rs index db85f4613..7853a233c 100644 --- a/r1cs/gadgets/std/src/lib.rs +++ b/r1cs/gadgets/std/src/lib.rs @@ -73,7 +73,7 @@ pub mod prelude { pub use crate::{ alloc::*, bits::{ - boolean::Boolean, uint32::UInt32, UInt8, FromBitsGadget, ToBitsGadget, + boolean::Boolean, uint::*, UInt8, FromBitsGadget, ToBitsGadget, ToBytesGadget, UIntGadget, }, eq::*, From b1649ea12cc9a68ec33cfb85968fd8b8a1e9cc6f Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Thu, 21 Apr 2022 14:52:27 +0200 Subject: [PATCH 15/18] Fix conditional_enforce_not_equal for boolean + improve unit test --- r1cs/core/src/constraint_system.rs | 75 +++++++++-- r1cs/gadgets/std/src/bits/boolean.rs | 186 +++++++++++++++++++-------- r1cs/gadgets/std/src/eq.rs | 4 + 3 files changed, 199 insertions(+), 66 deletions(-) diff --git a/r1cs/core/src/constraint_system.rs b/r1cs/core/src/constraint_system.rs index a30aaacf3..33bb72225 100644 --- a/r1cs/core/src/constraint_system.rs +++ b/r1cs/core/src/constraint_system.rs @@ -73,6 +73,12 @@ pub trait ConstraintSystemAbstract: Sized { /// Output the number of constraints in the system. fn num_constraints(&self) -> usize; + + /// Evaluate linear combination lc with the values assigned to variables of lc + /// in the constraint system `self`. Returns an error if one variable of lc is not found among + /// allocated variables in `self`. The result is None if variables have not been assigned to a + /// value yet + fn eval_lc(&self, lc: &LinearCombination) -> Result, SynthesisError>; } /// Defines debugging functionalities for a constraint system, which allow to verify which @@ -346,14 +352,48 @@ impl ConstraintSystemAbstract for ConstraintSystem { fn num_constraints(&self) -> usize { self.num_constraints } + + fn eval_lc(&self, lc: &LinearCombination) -> Result, SynthesisError> { + let mut acc = F::zero(); + + if self.is_in_setup_mode() { + // in setup mode there are for sure no assignment to variables, so we can return None + return Ok(None) + } + + for &(ref var, coeff) in lc.as_ref() { + let mut tmp = match var.0 { + Index::Input(index) => *self.input_assignment.get(index).ok_or( + SynthesisError::Other( + format!( + "no public input variable with index {} found in the constraint system" + , index)))?, + Index::Aux(index) => *self.aux_assignment.get(index).ok_or( + SynthesisError::Other( + format!( + "no private variable with index {} found in the constraint system" + , index)))?, + }; + + tmp.mul_assign(coeff); + acc.add_assign(tmp); + } + + Ok(Some(acc)) + } } impl ConstraintSystemDebugger for ConstraintSystem { fn which_is_unsatisfied(&self) -> Option<&str> { for i in 0..self.num_constraints { - let mut a = Self::eval_lc(&self.at[i], &self.input_assignment, &self.aux_assignment); - let b = Self::eval_lc(&self.bt[i], &self.input_assignment, &self.aux_assignment); - let c = Self::eval_lc(&self.ct[i], &self.input_assignment, &self.aux_assignment); + // Note that the following `eval_lc` calls cannot return an error, as `get_constraint` + // constructs a LinearCombination whose variables are necessarily in `self`. + // Furthermore, we assume that this function is never called in setup mode, therefore + // `eval_lc` should never return None. + // Thus, we can safely unwrap the return values of `eval_lc` invocations + let mut a = self.eval_lc(&Self::get_constraint(&self.at, i)).unwrap().unwrap(); + let b = self.eval_lc(&Self::get_constraint(&self.bt, i)).unwrap().unwrap(); + let c = self.eval_lc(&Self::get_constraint(&self.ct, i)).unwrap().unwrap(); a.mul_assign(&b); if a != c { @@ -455,20 +495,19 @@ impl ConstraintSystem { } } } - fn eval_lc(terms: &[(F, Index)], inputs: &[F], aux: &[F]) -> F { - let mut acc = F::zero(); - - for &(ref coeff, idx) in terms { - let mut tmp = match idx { - Index::Input(index) => inputs[index], - Index::Aux(index) => aux[index], - }; - tmp.mul_assign(coeff); - acc.add_assign(tmp); + fn get_constraint( + constraints: &[Vec<(F, Index)>], + this_constraint: usize, + ) -> LinearCombination { + let constraint = &constraints[this_constraint]; + // build a linear combination representing the constraint + let mut lc = LinearCombination::zero(); + for (coeff, idx) in constraint { + lc += (*coeff, Variable(idx.clone())); } - acc + lc } } @@ -557,6 +596,10 @@ impl> ConstraintSystemAbstract fn num_constraints(&self) -> usize { self.0.num_constraints() } + + fn eval_lc(&self, lc: &LinearCombination) -> Result, SynthesisError> { + self.0.eval_lc(lc) + } } impl + ConstraintSystemDebugger> @@ -650,6 +693,10 @@ impl> ConstraintSystemAbstract for fn num_constraints(&self) -> usize { (**self).num_constraints() } + + fn eval_lc(&self, lc: &LinearCombination) -> Result, SynthesisError> { + (**self).eval_lc(lc) + } } impl + ConstraintSystemDebugger> diff --git a/r1cs/gadgets/std/src/bits/boolean.rs b/r1cs/gadgets/std/src/bits/boolean.rs index 3c4304c5e..5525a5eba 100644 --- a/r1cs/gadgets/std/src/bits/boolean.rs +++ b/r1cs/gadgets/std/src/bits/boolean.rs @@ -946,7 +946,16 @@ impl EqGadget for Boolean { // 1 != 0; 0 != 1 (Constant(true), Constant(false)) | (Constant(false), Constant(true)) => return Ok(()), // false == false and true == true - (Constant(_), Constant(_)) => return Err(SynthesisError::AssignmentMissing), + (Constant(_), Constant(_)) => { + if should_enforce.is_constant() { + return if should_enforce.get_value().unwrap() { + Err(SynthesisError::AssignmentMissing) + } else { + Ok(()) + }; + } + LinearCombination::zero() + }, // 1 - a (Constant(true), Is(a)) | (Is(a), Constant(true)) => { LinearCombination::zero() + one - a.get_variable() @@ -976,10 +985,33 @@ impl EqGadget for Boolean { if let Constant(false) = should_enforce { Ok(()) } else { + // we need to enforce that difference != 0 if should_enforce is true. + // we let the prover allocate a variable d and we enforce that + // difference * d = should_enforce + // in this way, if difference = 0 and should_enforce = 1, the constraint is + // unsatisfiable for any value of d; otherwise, the prover can always find a value for + // d that satisfies the constraint + let diff_value = cs.eval_lc(&difference)?; + let d = cs.alloc( + || "alloc d", + || { + Ok( + if should_enforce.get_value().ok_or(SynthesisError::AssignmentMissing)? { + // if should_enforce is true, then d must be the inverse of difference + diff_value.ok_or(SynthesisError::AssignmentMissing)? + .inverse().unwrap_or(ConstraintF::zero()) + } else { + // otherwise we set d to zero to trivially satisfy the constraint + ConstraintF::zero() + } + ) + } + )?; + cs.enforce( - || "conditional_equals", - |lc| difference + &lc, - |lc| should_enforce.lc(one, ConstraintF::one()) + &lc, + || "conditional_not_equal", + |lc| &difference + &lc, + |lc| lc + (ConstraintF::one(), d), |lc| should_enforce.lc(one, ConstraintF::one()) + &lc, ); Ok(()) @@ -1319,59 +1351,109 @@ mod test { } #[test] - fn test_conditional_enforce_equal() { + fn test_conditional_enforce_equal_functions() { + enum TestedFunction { + ConditionalEnforceEqual, + ConditionalEnforceNotEqual + } for a_bool in [false, true].iter().cloned() { for b_bool in [false, true].iter().cloned() { for a_neg in [false, true].iter().cloned() { for b_neg in [false, true].iter().cloned() { - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - // First test if constraint system is satisfied - // when we do want to enforce the condition. - let mut a: Boolean = AllocatedBit::alloc(cs.ns(|| "a"), || Ok(a_bool)) - .unwrap() - .into(); - let mut b: Boolean = AllocatedBit::alloc(cs.ns(|| "b"), || Ok(b_bool)) - .unwrap() - .into(); - - if a_neg { - a = a.not(); - } - if b_neg { - b = b.not(); - } - - a.conditional_enforce_equal(&mut cs, &b, &Boolean::constant(true)) - .unwrap(); - - assert_eq!(cs.is_satisfied(), (a_bool ^ a_neg) == (b_bool ^ b_neg)); - - // Now test if constraint system is satisfied even - // when we don't want to enforce the condition. - let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); - - let mut a: Boolean = AllocatedBit::alloc(cs.ns(|| "a"), || Ok(a_bool)) - .unwrap() - .into(); - let mut b: Boolean = AllocatedBit::alloc(cs.ns(|| "b"), || Ok(b_bool)) - .unwrap() - .into(); - - if a_neg { - a = a.not(); - } - if b_neg { - b = b.not(); + for a_const in [false, true].iter().cloned() { + for b_const in [false, true].iter().cloned() { + for cond_const in [false, true].iter().cloned() { + for tested_function in [TestedFunction::ConditionalEnforceEqual, TestedFunction::ConditionalEnforceNotEqual].iter() { + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let alloc_variables = |cs: &mut ConstraintSystem| { + let a: Boolean = + if a_const { + Boolean::constant(a_bool) + } else { + AllocatedBit::alloc(cs.ns(|| "a"), || Ok(a_bool)) + .unwrap() + .into() + }; + let b: Boolean = + if b_const { + Boolean::constant(b_bool) + } else { + AllocatedBit::alloc(cs.ns(|| "b"), || Ok(b_bool)) + .unwrap() + .into() + }; + + let true_cond = + if cond_const { + Boolean::constant(true) + } else { + AllocatedBit::alloc(cs.ns(|| "should_enforce"), || Ok(true)) + .unwrap() + .into() + }; + (a, b, true_cond) + }; + + // First test if constraint system is satisfied + // when we do want to enforce the condition. + let (mut a, mut b, true_cond) = alloc_variables(&mut cs); + + if a_neg { + a = a.not(); + } + if b_neg { + b = b.not(); + } + + let (must_be_satisfied, enforce_ret) = match tested_function { + TestedFunction::ConditionalEnforceEqual => + ( + (a_bool ^ a_neg) == (b_bool ^ b_neg), + a.conditional_enforce_equal(&mut cs, &b, &true_cond) + ), + TestedFunction::ConditionalEnforceNotEqual => ( + (a_bool ^ a_neg) != (b_bool ^ b_neg), + a.conditional_enforce_not_equal(&mut cs, &b, &true_cond) + ), + }; + if a_const && b_const && cond_const && !must_be_satisfied { + // in this case the function returns an error rather than + // unsatisfied constraints + enforce_ret.unwrap_err(); + } else { + enforce_ret + .unwrap(); + assert_eq!(cs.is_satisfied(), must_be_satisfied); + } + + // Now test if constraint system is satisfied even + // when we don't want to enforce the condition. + let mut cs = ConstraintSystem::::new(SynthesisMode::Debug); + + let (mut a, mut b, true_cond) = alloc_variables(&mut cs); + + if a_neg { + a = a.not(); + } + if b_neg { + b = b.not(); + } + + let false_cond = true_cond.not(); + + match tested_function { + TestedFunction::ConditionalEnforceEqual => a.conditional_enforce_equal(&mut cs, &b, &false_cond) + .unwrap(), + TestedFunction::ConditionalEnforceNotEqual => a.conditional_enforce_not_equal(&mut cs, &b, &false_cond) + .unwrap(), + } + + assert!(cs.is_satisfied()); + } + } + } } - - let false_cond = AllocatedBit::alloc(cs.ns(|| "cond"), || Ok(false)) - .unwrap() - .into(); - a.conditional_enforce_equal(&mut cs, &b, &false_cond) - .unwrap(); - - assert!(cs.is_satisfied()); } } } diff --git a/r1cs/gadgets/std/src/eq.rs b/r1cs/gadgets/std/src/eq.rs index 28d5b459d..539940681 100644 --- a/r1cs/gadgets/std/src/eq.rs +++ b/r1cs/gadgets/std/src/eq.rs @@ -423,6 +423,10 @@ impl> fn num_constraints(&self) -> usize { self.cs.num_constraints() } + + fn eval_lc(&self, lc: &LinearCombination) -> Result, SynthesisError> { + self.cs.eval_lc(lc) + } } #[cfg(test)] From e50670f50eb440fc988dcbfd0be5cd112a2b9ce0 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 22 Apr 2022 10:50:51 +0200 Subject: [PATCH 16/18] Addressing changes --- r1cs/gadgets/std/src/bits/boolean.rs | 35 +++++++++--- r1cs/gadgets/std/src/bits/macros.rs | 79 ++++++++++++++++------------ r1cs/gadgets/std/src/eq.rs | 5 +- 3 files changed, 77 insertions(+), 42 deletions(-) diff --git a/r1cs/gadgets/std/src/bits/boolean.rs b/r1cs/gadgets/std/src/bits/boolean.rs index 5525a5eba..7e39a7de3 100644 --- a/r1cs/gadgets/std/src/bits/boolean.rs +++ b/r1cs/gadgets/std/src/bits/boolean.rs @@ -581,6 +581,7 @@ impl Boolean { } } + /// Compute the Boolean AND of `bits` pub fn kary_and(mut cs: CS, bits: &[Self]) -> Result where ConstraintF: Field, @@ -722,15 +723,18 @@ impl Boolean { let mut bits_iter = bits.iter().rev(); // Iterate in big-endian - // Runs of ones in r + // Runs of ones in b let mut last_run = Boolean::constant(true); let mut current_run = vec![]; + // compute number of bits necessary to represent element let mut element_num_bits = 0; for _ in BitIterator::without_leading_zeros(b) { element_num_bits += 1; } + // check that the most significant bits of `bits` exceeding `element_num_bits` are all zero, + // computing the Boolean OR of such bits and enforce the result to be 0 if bits.len() > element_num_bits { let mut or_result = Boolean::constant(false); for (i, should_be_zero) in bits[element_num_bits..].into_iter().enumerate() { @@ -743,15 +747,29 @@ impl Boolean { } or_result.enforce_equal(cs.ns(|| "enforce equal"), &Boolean::constant(false))?; } - + // compare the least significant `element_num_bits` bits of `bits` with the bit + // representation of `element` for (i, (b, a)) in BitIterator::without_leading_zeros(b) .zip(bits_iter.by_ref()) .enumerate() { if b { - // This is part of a run of ones. + // This is part of a run of ones. Save in `current_run` the bits of `bits` + // corresponding to such bit 1 in `b` current_run.push(a.clone()); } else { + // The bit of `element` is 0. Therefore, in order for `bits` to be smaller than `b`, + // either some bits of `bits` corresponding to the last run of ones were 0 + // (and to check this we compute the boolean AND of all such bits, saving it + // in `last_run`) or the current bit `a` of `bits` must be 0. Thus, we enforce + // that either `a` or `last_run` must be 0 with an `enforce_nand`. + // Note that when `last_run` becomes 0, which happens as soon as there is a bit 0 + // in `bits` whose corresponding bit in `element` is 1, `last_run`` will always + // remain 0 for the rest of the loop; thus, in this case the `enforce_nand` + // will hold independently from the remaining bits of `bits`, which is correct as + // once a bit difference is spot between the 2 bit representations, then the lesser + // significant bits do not affect the outcome of the comparison + if !current_run.is_empty() { // This is the start of a run of zeros, but we need // to k-ary AND against `last_run` first. @@ -886,6 +904,7 @@ impl EqGadget for Boolean { (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => return Ok(()), // false != true (Constant(_), Constant(_)) => { + // in this case the enforcement should fail, unless `should_enforce` is false if should_enforce.is_constant() { return if should_enforce.get_value().unwrap() { Err(SynthesisError::AssignmentMissing) @@ -893,9 +912,10 @@ impl EqGadget for Boolean { Ok(()) } } - LinearCombination::zero() + CS::one() // set difference != 0 - }, - // 1 - a + // set difference != 0 to ensure the constraint enforced later on is violated if + // `should_enforce` is true + LinearCombination::zero() + CS::one() + }// 1 - a (Constant(true), Is(a)) | (Is(a), Constant(true)) => { LinearCombination::zero() + one - a.get_variable() } @@ -948,12 +968,15 @@ impl EqGadget for Boolean { // false == false and true == true (Constant(_), Constant(_)) => { if should_enforce.is_constant() { + // in this case the enforcement should fail, unless `should_enforce` is false return if should_enforce.get_value().unwrap() { Err(SynthesisError::AssignmentMissing) } else { Ok(()) }; } + // set difference = 0 to ensure the constraint enforced later on is violated if + // `should_enforce` is true LinearCombination::zero() }, // 1 - a diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index 6c5cc4d38..e1d528c44 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -156,7 +156,9 @@ macro_rules! impl_uint_gadget { Ok(()) } - // Helper function that constructs self from an iterator over Booleans. + // Helper function that constructs `Self` from an iterator over Booleans. + // The Booleans provided by the iterator are considered as the little-endian bit + // representation of `Self`. // It is employed as a building block by the public functions of the `FromBitsGadget` fn from_bit_iterator<'a, ConstraintF, CS>(_cs: CS, bits_iter: impl Iterator) -> Result where @@ -233,11 +235,11 @@ macro_rules! impl_uint_gadget { CS: ConstraintSystemAbstract, { - let field_bits = ConstraintF::Params::CAPACITY as usize; + let capacity_bits = ConstraintF::Params::CAPACITY as usize; // max_overflow_bits are the maximum number of non-zero bits a `Self` element // can have to be multiplied to another `Self` without overflowing the field - let max_overflow_bits = field_bits - $bit_size; - // given a base b = 2^m, where m=2^max_overflow_bits, the `other` operand is + let max_overflow_bits = capacity_bits - $bit_size; + // given a base b = 2^m, where m=max_overflow_bits, the `other` operand is // represented in base b with digits of m bits. Then, the product self*other // is computed by the following summation: // sum_{i from 0 to h-1} ((self*b^i) % 2^$bit_size * digit_i), where h is the @@ -255,7 +257,7 @@ macro_rules! impl_uint_gadget { let result_bits = if no_carry { // ensure that tmp_result can be represented with $bit_size bits to // ensure that no native type overflow has happened in the multiplication - tmp_result.to_bits_with_length_restriction(cs.ns(|| format!("to bits for digit {}", i)), field_bits + 1 - $bit_size)? + tmp_result.to_bits_with_length_restriction(cs.ns(|| format!("to bits for digit {}", i)), capacity_bits + 1 - $bit_size)? } else { let result_bits = tmp_result.to_bits_with_length_restriction(cs.ns(|| format!("to bits for digit {}", i)), 1)?; result_bits @@ -531,9 +533,17 @@ macro_rules! impl_uint_gadget { first: &Self, second: &Self, ) -> Result { - let bits = first.bits.iter().zip(second.bits.iter()).enumerate().map(|(i, (t, f))| Boolean::conditionally_select(cs.ns(|| format!("cond select bit {}", i)), cond, t, f)).collect::, SynthesisError>>()?; - - assert_eq!(bits.len(), $bit_size); // this assert should always be verified if first and second are built only with public methods + let bits = first.bits.iter() + .zip(second.bits.iter()) + .enumerate() + .map(|(i, (t, f))| + Boolean::conditionally_select( + cs.ns(|| format!("cond select bit {}", i)), cond, t, f) + ).collect::, SynthesisError>>()?; + + // this assert should always be verified if first and second are built only + // with public methods + assert_eq!(bits.len(), $bit_size); let value = match cond.get_value() { Some(cond_bit) => if cond_bit {first.get_value()} else {second.get_value()}, @@ -698,14 +708,14 @@ macro_rules! impl_uint_gadget { M: ConstraintSystemAbstract> { let num_operands = operands.len(); - let field_bits = (ConstraintF::Params::CAPACITY) as usize; + let capacity_bits = (ConstraintF::Params::CAPACITY) as usize; // in this case it is not possible to enforce the correctness of the addition // of at least 2 elements for the field ConstraintF - assert!(field_bits > $bit_size); + assert!(capacity_bits > $bit_size); assert!(num_operands >= 2); // Weird trivial cases that should never happen let overflow_bits = (num_operands as f64).log2().ceil() as usize; - if field_bits < $bit_size + overflow_bits { + if capacity_bits < $bit_size + overflow_bits { // in this case addition of num_operands elements over field would overflow, // thus it would not be possible to ensure the correctness of the result. // Therefore, the operands are split in smaller slices, and the sum is @@ -713,7 +723,7 @@ macro_rules! impl_uint_gadget { // given the field ConstraintF and the $bit_size, compute the maximum number // of operands for which we can enforce correctness of the result - let max_overflow_bits = field_bits - $bit_size; + let max_overflow_bits = capacity_bits - $bit_size; let max_num_operands = 1usize << max_overflow_bits; handle_numoperands_opmany!(addmany, cs, operands, max_num_operands); } @@ -806,14 +816,14 @@ macro_rules! impl_uint_gadget { CS: ConstraintSystemAbstract, M: ConstraintSystemAbstract> { let num_operands = operands.len(); - let field_bits = (ConstraintF::Params::CAPACITY) as usize; + let capacity_bits = (ConstraintF::Params::CAPACITY) as usize; // in this case it is not possible to enforce the correctness of the addition // of at least 2 elements for the field ConstraintF - assert!(field_bits > $bit_size); + assert!(capacity_bits > $bit_size); assert!(num_operands >= 2); // Weird trivial cases that should never happen let overflow_bits = (num_operands as f64).log2().ceil() as usize; - if field_bits < $bit_size + overflow_bits { + if capacity_bits < $bit_size + overflow_bits { // in this case addition of num_operands elements over field would overflow, // thus it would not be possible to ensure the correctness of the result. // Therefore, the operands are split in smaller slices, and the sum is @@ -821,7 +831,7 @@ macro_rules! impl_uint_gadget { // given the field ConstraintF and the $bit_size, compute the maximum number // of operands for which we can enforce correctness of the result - let max_overflow_bits = field_bits - $bit_size; + let max_overflow_bits = capacity_bits - $bit_size; let max_num_operands = 1usize << max_overflow_bits; handle_numoperands_opmany!(addmany_nocarry, cs, operands, max_num_operands); } @@ -883,7 +893,7 @@ macro_rules! impl_uint_gadget { fn mulmany(mut cs: CS, operands: &[Self]) -> Result where CS: ConstraintSystemAbstract { let num_operands = operands.len(); - let field_bits = (ConstraintF::Params::CAPACITY) as usize; + let capacity_bits = (ConstraintF::Params::CAPACITY) as usize; assert!(num_operands >= 2); // corner case: check if all operands are constants before allocating any variable @@ -904,14 +914,14 @@ macro_rules! impl_uint_gadget { return Ok($type_name::from_value(cs.ns(|| "alloc constant result"), &result_value.unwrap())); } - assert!(field_bits > $bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field + assert!(capacity_bits > $bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field - if field_bits < 2*$bit_size { + if capacity_bits < 2*$bit_size { return $type_name::mulmany_with_double_and_add(cs.ns(|| "double and add"), operands, false); } - if field_bits < num_operands*$bit_size { - let max_num_operands = field_bits/$bit_size; + if capacity_bits < num_operands*$bit_size { + let max_num_operands = capacity_bits/$bit_size; handle_numoperands_opmany!(mulmany, cs, operands, max_num_operands); } @@ -942,7 +952,7 @@ macro_rules! impl_uint_gadget { result = result.mul(cs.ns(|| format!("mul op {}", i)), &field_op)?; } - let skip_leading_bits = field_bits + 1 - num_operands*$bit_size; + let skip_leading_bits = capacity_bits + 1 - num_operands*$bit_size; let result_bits = result.to_bits_with_length_restriction(cs.ns(|| "unpack result field element"), skip_leading_bits)?; let result_lsbs = result_bits .iter() @@ -956,7 +966,7 @@ macro_rules! impl_uint_gadget { fn mulmany_nocarry(mut cs: CS, operands: &[Self]) -> Result where CS: ConstraintSystemAbstract { let num_operands = operands.len(); - let field_bits = (ConstraintF::Params::CAPACITY) as usize; + let capacity_bits = (ConstraintF::Params::CAPACITY) as usize; assert!(num_operands >= 2); // corner case: check if all operands are constants before allocating any variable @@ -986,14 +996,14 @@ macro_rules! impl_uint_gadget { } } - assert!(field_bits > $bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field + assert!(capacity_bits > $bit_size); // minimum requirement on field size to compute multiplication of at least 2 elements without overflowing the field - if field_bits < 2*$bit_size { + if capacity_bits < 2*$bit_size { return $type_name::mulmany_with_double_and_add(cs.ns(|| "double and add"), operands, true); } - if field_bits < num_operands*$bit_size { - let max_num_operands = field_bits/$bit_size; + if capacity_bits < num_operands*$bit_size { + let max_num_operands = capacity_bits/$bit_size; handle_numoperands_opmany!(mulmany_nocarry, cs, operands, max_num_operands); } @@ -1024,7 +1034,7 @@ macro_rules! impl_uint_gadget { result = result.mul(cs.ns(|| format!("mul op {}", i)), &field_op)?; } - let skip_leading_bits = field_bits + 1 - $bit_size; // we want to verify that the field element for the product of operands can be represented with $bit_size bits to ensure that there is no overflow + let skip_leading_bits = capacity_bits + 1 - $bit_size; // we want to verify that the field element for the product of operands can be represented with $bit_size bits to ensure that there is no overflow let result_bits = result.to_bits_with_length_restriction(cs.ns(|| "unpack result field element"), skip_leading_bits)?; assert_eq!(result_bits.len(), $bit_size); $type_name::from_bits(cs.ns(|| "packing result"), &result_bits[..]) @@ -1862,7 +1872,7 @@ macro_rules! impl_uint_gadget { test_uint_gadget_value(result_value, &result_var, "result correctness"); assert!(cs.is_satisfied()); - // negative test + // negative test: change a bit of the result to verify that constraints no longer hold let bit_gadget_path = "add operands/alloc result bit 0/boolean"; if cs.get(bit_gadget_path).is_zero() { cs.set(bit_gadget_path, Fr::one()); @@ -1916,8 +1926,9 @@ macro_rules! impl_uint_gadget { assert!(cs.is_satisfied()); - if MAX_NUM_OPERANDS >= 2 { // negative tests are skipped if if double and add must be used because the field is too small - // negative test on first batch + if MAX_NUM_OPERANDS >= 2 { + // negative tests are skipped if `mul_with_double_and_add` must be used because the field is too small, which happens when MAX_NUM_OPERANDS < 2 + // negative test on first batch: change a bit of the result to verify that the constraints no longer hold let bit_gadget_path = "mul operands/first batch of operands/unpack result field element/bit 0/boolean"; if cs.get(bit_gadget_path).is_zero() { cs.set(bit_gadget_path, Fr::one()); @@ -1935,7 +1946,7 @@ macro_rules! impl_uint_gadget { } assert!(cs.is_satisfied()); - // negative test on allocated field element: skip if double and add must be used because the field is too small + // negative test on allocated field element: skipped if `mul_with_double_and_add` must be used because the field is too small let num_batches = (NUM_OPERANDS-MAX_NUM_OPERANDS)/(MAX_NUM_OPERANDS-1); let bit_gadget_path = format!("mul operands/{}-th batch of operands/unpack result field element/bit 0/boolean", num_batches); if cs.get(&bit_gadget_path).is_zero() { @@ -2316,8 +2327,8 @@ macro_rules! impl_uint_gadget { if is_add { assert_eq!(cs.which_is_unsatisfied().unwrap(), "multieq 0"); } else { - let field_bits = (::Params::CAPACITY) as usize; - if field_bits < 2*$bit_size { // double and add case + let capacity_bits = (::Params::CAPACITY) as usize; + if capacity_bits < 2*$bit_size { // double and add case assert_eq!(cs.which_is_unsatisfied().unwrap(), "mul values/double and add/double and add first operands/to bits for digit 0/unpacking_constraint"); } else { assert_eq!(cs.which_is_unsatisfied().unwrap(), "mul values/unpack result field element/unpacking_constraint"); diff --git a/r1cs/gadgets/std/src/eq.rs b/r1cs/gadgets/std/src/eq.rs index 539940681..68e6057ea 100644 --- a/r1cs/gadgets/std/src/eq.rs +++ b/r1cs/gadgets/std/src/eq.rs @@ -241,8 +241,9 @@ impl EqGadget for Vec { let field_bits = ConstraintF::Params::CAPACITY as usize; let len = self.len(); if field_bits < len { - // In this case we cannot split in chunks here, as it's not true that if two bit vectors - // are not equal, then they are not equal chunkwise too. Therefore, we + // In this case we cannot split `self` and `other` in chunks here and enforce that each + // pair of chunks differ, as it's not true that if two bit vectors + // are not equal, then all their corresponding chunks differ. Therefore, we // compute a Boolean which is true iff `self != `other` and we conditionally // enforce it to be true let is_neq = self.is_neq(cs.ns(|| "is not equal"), other)?; From 4661c2051f741903d82302bdb7ff009618acc681 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 22 Apr 2022 16:32:28 +0200 Subject: [PATCH 17/18] Optimize comparison for FpGadget --- r1cs/gadgets/std/src/fields/cmp.rs | 95 +++++++++--------------------- 1 file changed, 29 insertions(+), 66 deletions(-) diff --git a/r1cs/gadgets/std/src/fields/cmp.rs b/r1cs/gadgets/std/src/fields/cmp.rs index 9c4bcdb30..7f53bc016 100644 --- a/r1cs/gadgets/std/src/fields/cmp.rs +++ b/r1cs/gadgets/std/src/fields/cmp.rs @@ -1,9 +1,14 @@ use std::cmp::Ordering; -use algebra::PrimeField; -use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; -use crate::{boolean::Boolean, bits::{ToBitsGadget, FromBitsGadget}, eq::EqGadget, select::CondSelectGadget}; use crate::cmp::ComparisonGadget; use crate::fields::{fp::FpGadget, FieldGadget}; +use crate::{ + bits::{FromBitsGadget, ToBitsGadget}, + boolean::Boolean, + eq::EqGadget, + select::CondSelectGadget, +}; +use algebra::{FpParameters, PrimeField}; +use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; // this macro allows to implement the `unchecked` and `restricted` variants of the `enforce_cmp`, // `conditional_enforce_cmp` and `is_cmp` functions. The macro is useful as the implementations @@ -61,57 +66,6 @@ macro_rules! implement_cmp_functions_variants { // implement functions for FpGadget that are useful to implement the ComparisonGadget impl FpGadget { - - /// Helper function that allows to compare 2 slices of 2 bits, outputting 2 Booleans: - /// the former (resp. the latter) one is true iff the big-endian integer represented by the - /// first slice is smaller (resp. is equal) than the big-endian integer represented by the second slice - fn compare_msbs>(mut cs: CS, first: &[Boolean], second: &[Boolean]) - -> Result<(Boolean, Boolean), SynthesisError> { - assert_eq!(first.len(), 2); - assert_eq!(second.len(), 2); - - let a = first[0]; // a = msb(first) - let b = first[1]; // b = lsb(first) - let c = second[0]; // c = msb(second) - let d = second[1]; // d = lsb(second) - - // is_less corresponds to the Boolean function: !a*(c+!b*d)+(!b*c*d) - // which is true iff first < second, where + is Boolean OR and * is Boolean AND. Indeed: - // | first | second | a | b | c | d | is_less | - // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | - // | 0 | 1 | 0 | 0 | 0 | 1 | 1 | - // | 0 | 2 | 0 | 0 | 1 | 0 | 1 | - // | 0 | 3 | 0 | 0 | 1 | 1 | 1 | - // | 1 | 0 | 0 | 1 | 0 | 0 | 0 | - // | 1 | 1 | 0 | 1 | 0 | 1 | 0 | - // | 1 | 2 | 0 | 1 | 1 | 0 | 1 | - // | 1 | 3 | 0 | 1 | 1 | 1 | 1 | - // | 2 | 0 | 1 | 0 | 0 | 0 | 0 | - // | 2 | 1 | 1 | 0 | 0 | 1 | 0 | - // | 2 | 2 | 1 | 0 | 1 | 0 | 0 | - // | 2 | 3 | 1 | 0 | 1 | 1 | 1 | - // | 3 | 0 | 1 | 1 | 0 | 0 | 0 | - // | 3 | 1 | 1 | 1 | 0 | 1 | 0 | - // | 3 | 2 | 1 | 1 | 1 | 0 | 0 | - // | 3 | 3 | 1 | 1 | 1 | 1 | 0 | - - // To reduce the number of constraints, the Boolean function is computed as follows: - // is_less = !a + !b*d if c=1, !a*!b*d if c=0 - - let bd = Boolean::and(cs.ns(|| "!bd"), &b.not(), &d)?; - let first_tmp = Boolean::or(cs.ns(|| "!a + !bd"), &a.not(), &bd)?; - let second_tmp = Boolean::and(cs.ns(|| "!a!bd"), &a.not(), &bd)?; - let is_less = Boolean::conditionally_select(cs.ns(|| "is less"), &c, &first_tmp, &second_tmp)?; - - // is_eq corresponds to the Boolean function: !((a xor c) + (b xor d)), - // which is true iff first == second - let first_tmp = Boolean::xor(cs.ns(|| "a xor c"), &a, &c)?; - let second_tmp = Boolean::xor(cs.ns(|| "b xor d"), &b, &d)?; - let is_eq = Boolean::or(cs.ns(|| "is eq"), &first_tmp, &second_tmp)?.not(); - - Ok((is_less, is_eq)) - } - /// Helper function to enforce that `self <= (p-1)/2`. pub fn enforce_smaller_or_equal_than_mod_minus_one_div_two>( &self, @@ -205,28 +159,37 @@ impl ComparisonGadget for FpGadget Result { let self_bits = self.to_bits_strict(cs.ns(|| "first op to bits"))?; let other_bits = other.to_bits_strict(cs.ns(|| "second op to bits"))?; - // extract the least significant MODULUS_BITS-2 bits and convert them to a field element, + + let num_bits = ConstraintF::Params::MODULUS_BITS as usize; + + // For both operands, extract the most significant MODULUS_BITS-1 bits and convert them to a field element, // which is necessarily lower than (p-1)/2 - let fp_for_self_lsbs = FpGadget::::from_bits(cs.ns(|| "pack second op MSBs"), &self_bits[2..])?; - let fp_for_other_lsbs = FpGadget::::from_bits(cs.ns(|| "pack second op LSBs"), &other_bits[2..])?; + let fp_for_self_msbs = + FpGadget::::from_bits(cs.ns(|| "pack first op MSBs"), &self_bits[..num_bits-1])?; + let fp_for_other_msbs = + FpGadget::::from_bits(cs.ns(|| "pack second op MSBs"), &other_bits[..num_bits-1])?; // since the field elements are lower than (p-1)/2, we can compare it with the efficient approach - let is_less_lsbs = fp_for_self_lsbs.is_smaller_than_unchecked(cs.ns(|| "compare LSBs"), &fp_for_other_lsbs)?; - + let is_less_msbs = fp_for_self_msbs + .is_smaller_than_unchecked(cs.ns(|| "compare MSBs"), &fp_for_other_msbs)?; + // check is the field elements represented by the MSBs of `self` and `other` are equal + let is_eq_msbs = fp_for_self_msbs.is_eq(cs.ns(|| "eq of MSBs"), + &fp_for_other_msbs, + )?; - // obtain two Booleans: - // - `is_less_msbs` is true iff the integer represented by the 2 MSBs of self is smaller - // than the integer represented by the 2 MSBs of other - // - `is_eq_msbs` is true iff the integer represented by the 2 MSBs of self is equal - // to the integer represented by the 2 MSBs of other - let (is_less_msbs, is_eq_msbs) = Self::compare_msbs(cs.ns(|| "compare MSBs"), &self_bits[..2], &other_bits[..2])?; + // compute a Boolean `is_less_lsb` which is true iff the least significant bit of `self` + // is 0 and the least significant bit of `other` is 1 + let is_less_lsb = Boolean::and(cs.ns(|| "compare lsb"), + &self_bits[num_bits-1].not(), + &other_bits[num_bits-1], + )?; // `self < other` iff `is_less_msbs OR is_eq_msbs AND is_less_lsbs` // Given that `is_less_msbs` and `is_eq_msbs` cannot be true at the same time, // the formula is equivalent to the following conditionally select; indeed: // - if `is_eq_msbs = true`, then `is_less_msbs = false`, thus `self < other` iff `is_less_lsbs = true` // - if `is_eq_msbs = false`, then `self < other` iff `is_less_msbs = true` - Boolean::conditionally_select(cs, &is_eq_msbs, &is_less_lsbs, &is_less_msbs) + Boolean::conditionally_select(cs, &is_eq_msbs, &is_less_lsb, &is_less_msbs) } } From 8ea676bdb4c6bf9d620e7eab82d118583dfb662f Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Mon, 9 May 2022 16:58:10 +0200 Subject: [PATCH 18/18] Fix issues after rebase --- .../src/commitment/injective_map/mod.rs | 2 +- .../crypto/src/crh/bowe_hopwood/mod.rs | 2 +- r1cs/gadgets/crypto/src/prf/blake2s/mod.rs | 6 ++--- r1cs/gadgets/crypto/src/prf/ripemd160.rs | 9 ++++---- r1cs/gadgets/crypto/src/prf/sha256.rs | 22 ++++++++++-------- r1cs/gadgets/std/src/bits/macros.rs | 23 ++++++++++++++----- r1cs/gadgets/std/src/fields/cmp.rs | 2 +- .../curves/short_weierstrass/mnt/mnt4/mod.rs | 2 +- .../curves/short_weierstrass/mnt/mnt6/mod.rs | 2 +- 9 files changed, 43 insertions(+), 27 deletions(-) diff --git a/r1cs/gadgets/crypto/src/commitment/injective_map/mod.rs b/r1cs/gadgets/crypto/src/commitment/injective_map/mod.rs index 618adaf22..da2c56bcf 100644 --- a/r1cs/gadgets/crypto/src/commitment/injective_map/mod.rs +++ b/r1cs/gadgets/crypto/src/commitment/injective_map/mod.rs @@ -14,7 +14,7 @@ use crate::commitment::{ pub use crate::crh::injective_map::InjectiveMapGadget; use algebra::groups::Group; use r1cs_core::{ConstraintSystemAbstract, SynthesisError}; -use r1cs_std::{groups::GroupGadget, uint8::UInt8}; +use r1cs_std::{groups::GroupGadget, UInt8}; use std::marker::PhantomData; diff --git a/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs b/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs index 51447dd69..cc1cf0955 100644 --- a/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs +++ b/r1cs/gadgets/crypto/src/crh/bowe_hopwood/mod.rs @@ -154,7 +154,7 @@ mod test { use r1cs_core::{ ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode, }; - use r1cs_std::{alloc::AllocGadget, instantiated::edwards_sw6::EdwardsSWGadget, uint8::UInt8}; + use r1cs_std::{alloc::AllocGadget, instantiated::edwards_sw6::EdwardsSWGadget, UInt8}; use rand::{thread_rng, Rng}; type TestCRH = BoweHopwoodPedersenCRH; diff --git a/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs b/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs index 88cb40540..5c6e44485 100644 --- a/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs +++ b/r1cs/gadgets/crypto/src/prf/blake2s/mod.rs @@ -332,14 +332,14 @@ pub fn blake2s_gadget> = vec![]; - for block in input.chunks(512) { + for (i, block) in input.chunks(512).enumerate() { let mut this_block = Vec::with_capacity(16); - for word in block.chunks(32) { + for (j, word) in block.chunks(32).enumerate() { let mut tmp = word.to_vec(); while tmp.len() < 32 { tmp.push(Boolean::constant(false)); } - this_block.push(UInt32::from_bits_le(&tmp)?); + this_block.push(UInt32::from_bits_le(cs.ns(|| format!("convert {}-th chunk of {}-th block to uint32", j, i)), &tmp)?); } while this_block.len() < 16 { this_block.push(UInt32::constant(0)); diff --git a/r1cs/gadgets/crypto/src/prf/ripemd160.rs b/r1cs/gadgets/crypto/src/prf/ripemd160.rs index 21dc275c6..f2b08e093 100644 --- a/r1cs/gadgets/crypto/src/prf/ripemd160.rs +++ b/r1cs/gadgets/crypto/src/prf/ripemd160.rs @@ -145,7 +145,7 @@ where let mut padded = input.to_vec(); let plen = padded.len() as u64; // append bit "1", and alread seven 0 bits (recall that our input length is a mult. of 8) - padded.append(&mut UInt8::constant(1).into_bits_be()); + padded.append(&mut UInt8::constant(1).to_bits(cs.ns(|| "append bit 1"))?); // append remaining K '0' bits such that L + 7 + K is 64 bit shy of being a multiple of 512 while (padded.len() + 64) % 512 != 0 { @@ -171,7 +171,7 @@ fn get_ripemd160_iv() -> Vec { /// The RIPEMD160 block compression function fn ripemd160_compression_function( - cs: CS, + mut cs: CS, input: &[Boolean], current_hash_value: &[UInt32], ) -> Result, SynthesisError> @@ -195,8 +195,9 @@ where let x = input .chunks(32) - .map(|e| UInt32::from_bits_le(e)) - .collect::>(); + .enumerate() + .map(|(i, e)| UInt32::from_bits_le(cs.ns(|| format!("pack input chunk {}", i)),e)) + .collect::, SynthesisError>>()?; let mut cs = MultiEq::new(cs); diff --git a/r1cs/gadgets/crypto/src/prf/sha256.rs b/r1cs/gadgets/crypto/src/prf/sha256.rs index 16915ae58..9c5f6c042 100644 --- a/r1cs/gadgets/crypto/src/prf/sha256.rs +++ b/r1cs/gadgets/crypto/src/prf/sha256.rs @@ -37,12 +37,13 @@ where { assert_eq!(input.len(), 512); - Ok( - sha256_compression_function(&mut cs, &input, &get_sha256_iv())? - .into_iter() - .flat_map(|e| e.into_bits_be()) - .collect(), - ) + let mut digest = sha256_compression_function(&mut cs, &input, &get_sha256_iv())?; + + // we need to convert each word of digest to its big-endian bit representation. + // We first need to reverse digest since to_bits also reverts the vector before converting + // each word to big-endian bit representation + digest.reverse(); + digest.to_bits(cs.ns(|| "sha256 block to bits")) } pub fn sha256( @@ -73,7 +74,9 @@ where for (i, block) in padded.chunks(512).enumerate() { cur = sha256_compression_function(cs.ns(|| format!("block {}", i)), block, &cur)?; } - + // we need to convert each word of cur to its big-endian bit representation. + // We first need to reverse cur since to_bits also reverts the vector before converting + // each word to big-endian bit representation cur.reverse(); cur.to_bits(cs) } @@ -101,7 +104,8 @@ where // Initialize the first 16 words in the array w let mut w = input .chunks(32) - .map(|e| UInt32::from_bits(cs.ns(|| format!("pack input word {}", i)),e)) + .enumerate() + .map(|(i,e)| UInt32::from_bits(cs.ns(|| format!("pack input word {}", i)),e)) .collect::, SynthesisError>>()?; let mut cs = MultiEq::new(cs); @@ -332,7 +336,7 @@ where F: Fn(u32, u32, u32) -> u32, U: Fn(&mut CS, usize, &Boolean, &Boolean, &Boolean) -> Result, { - let new_value = match (a.value, b.value, c.value) { + let new_value = match (a.get_value(), b.get_value(), c.get_value()) { (Some(a), Some(b), Some(c)) => Some(tri_fn(a, b, c)), _ => None, }; diff --git a/r1cs/gadgets/std/src/bits/macros.rs b/r1cs/gadgets/std/src/bits/macros.rs index e1d528c44..df9ce9954 100644 --- a/r1cs/gadgets/std/src/bits/macros.rs +++ b/r1cs/gadgets/std/src/bits/macros.rs @@ -9,6 +9,13 @@ macro_rules! impl_uint_gadget { } impl $type_name { + pub fn new(bits: Vec, value: Option<$native_type>) -> Self { + Self{ + bits, + value, + } + } + pub fn get_value(&self) -> Option<$native_type> { self.value } @@ -435,14 +442,16 @@ macro_rules! impl_uint_gadget { } impl ToBitsGadget for Vec<$type_name> { + // Compute the concatenation of the big-endian bit representations of the elements + // in the reversed vector. Such definition of to_bits mandates to reverse the + // vector to ensure that self.to_bits(cs) == self.to_bits_le(cs).reverse() fn to_bits>( &self, - cs: CS, + _cs: CS, ) -> Result, SynthesisError> { - let mut le_bits = self.to_bits_le(cs)?; - //Need to reverse bits since to_bits must return a big-endian representation - le_bits.reverse(); - Ok(le_bits) + Ok(self.iter().rev().flat_map(|el| + el.bits.iter().rev().cloned().collect::>() + ).collect::>()) } fn to_bits_strict>( @@ -452,6 +461,8 @@ macro_rules! impl_uint_gadget { self.to_bits(cs) } + // Compute the concatenation of the little-endian bit representations of the elements + // in the vector fn to_bits_le>( &self, _cs: CS, @@ -1304,7 +1315,7 @@ macro_rules! impl_uint_gadget { mod [] { use super::$type_name; use rand::{Rng, thread_rng}; - use algebra::{fields::tweedle::Fr, Group, Field, FpParameters, PrimeField}; + use algebra::{fields::tweedle::Fr, Field, FpParameters, PrimeField}; use r1cs_core::{ ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode, SynthesisError, }; diff --git a/r1cs/gadgets/std/src/fields/cmp.rs b/r1cs/gadgets/std/src/fields/cmp.rs index 7f53bc016..99f0545b2 100644 --- a/r1cs/gadgets/std/src/fields/cmp.rs +++ b/r1cs/gadgets/std/src/fields/cmp.rs @@ -199,7 +199,7 @@ mod test { use rand::{Rng, thread_rng}; use r1cs_core::{ConstraintSystem, ConstraintSystemAbstract, ConstraintSystemDebugger, SynthesisMode}; use crate::{algebra::{UniformRand, PrimeField, - fields::tweedle::Fr, Group, + fields::tweedle::Fr, Field, }, fields::{fp::FpGadget, FieldGadget}}; use crate::{alloc::{AllocGadget, ConstantGadget}, cmp::ComparisonGadget, boolean::Boolean}; diff --git a/r1cs/gadgets/std/src/groups/curves/short_weierstrass/mnt/mnt4/mod.rs b/r1cs/gadgets/std/src/groups/curves/short_weierstrass/mnt/mnt4/mod.rs index 58b5c9206..018b7d24a 100644 --- a/r1cs/gadgets/std/src/groups/curves/short_weierstrass/mnt/mnt4/mod.rs +++ b/r1cs/gadgets/std/src/groups/curves/short_weierstrass/mnt/mnt4/mod.rs @@ -2,7 +2,7 @@ use algebra::Field; use crate::{ alloc::AllocGadget, - bits::uint8::UInt8, + bits::UInt8, fields::{fp::FpGadget, fp2::Fp2Gadget, FieldGadget}, groups::curves::short_weierstrass::short_weierstrass_projective::AffineGadget, Assignment, ToBytesGadget, diff --git a/r1cs/gadgets/std/src/groups/curves/short_weierstrass/mnt/mnt6/mod.rs b/r1cs/gadgets/std/src/groups/curves/short_weierstrass/mnt/mnt6/mod.rs index d6ed62358..1d08fb114 100644 --- a/r1cs/gadgets/std/src/groups/curves/short_weierstrass/mnt/mnt6/mod.rs +++ b/r1cs/gadgets/std/src/groups/curves/short_weierstrass/mnt/mnt6/mod.rs @@ -2,7 +2,7 @@ use algebra::Field; use crate::{ alloc::AllocGadget, - bits::uint8::UInt8, + bits::UInt8, bits::ToBytesGadget, fields::{fp::FpGadget, fp3::Fp3Gadget, FieldGadget}, groups::curves::short_weierstrass::short_weierstrass_projective::AffineGadget,