diff --git a/.circleci/config.yml b/.circleci/config.yml index 0bfef41ee..98c18e57e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -96,7 +96,7 @@ jobs: zokrates_js_build: docker: - image: zokrates/env:latest - resource_class: large + resource_class: xlarge working_directory: ~/project/zokrates_js steps: - checkout: @@ -111,7 +111,7 @@ jobs: zokrates_js_test: docker: - image: zokrates/env:latest - resource_class: large + resource_class: xlarge working_directory: ~/project/zokrates_js steps: - checkout: diff --git a/Cargo.lock b/Cargo.lock index f250f1fbb..33389a5d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2109,6 +2109,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "reduce" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d2dc47b68ac15ea328cd7ebe01d7d512ed29787f7d534ad2a3c341328b35d7" + [[package]] name = "regex" version = "0.2.11" @@ -2991,6 +2997,29 @@ dependencies = [ "zokrates_field", ] +[[package]] +name = "zokrates_analysis" +version = "0.1.0" +dependencies = [ + "cfg-if 0.1.10", + "csv", + "lazy_static", + "log", + "num 0.1.42", + "num-bigint 0.2.6", + "pretty_assertions 0.6.1", + "reduce", + "serde", + "serde_json", + "typed-arena", + "zokrates_ast", + "zokrates_common", + "zokrates_embed", + "zokrates_field", + "zokrates_fs_resolver", + "zokrates_pest_ast", +] + [[package]] name = "zokrates_ark" version = "0.1.1" @@ -3107,9 +3136,23 @@ dependencies = [ "zokrates_solidity_test", ] +[[package]] +name = "zokrates_codegen" +version = "0.1.0" +dependencies = [ + "zokrates_ast", + "zokrates_common", + "zokrates_embed", + "zokrates_field", + "zokrates_interpreter", +] + [[package]] name = "zokrates_common" version = "0.1.1" +dependencies = [ + "serde", +] [[package]] name = "zokrates_core" @@ -3125,7 +3168,9 @@ dependencies = [ "serde", "serde_json", "typed-arena", + "zokrates_analysis", "zokrates_ast", + "zokrates_codegen", "zokrates_common", "zokrates_embed", "zokrates_field", @@ -3202,6 +3247,7 @@ dependencies = [ "pairing_ce", "serde", "zokrates_abi", + "zokrates_analysis", "zokrates_ast", "zokrates_embed", "zokrates_field", diff --git a/Cargo.toml b/Cargo.toml index 92c7358e2..8c5b23352 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,8 @@ members = [ "zokrates_cli", "zokrates_fs_resolver", "zokrates_stdlib", + "zokrates_codegen", + "zokrates_analysis", "zokrates_embed", "zokrates_abi", "zokrates_test", diff --git a/changelogs/unreleased/1246-dark64 b/changelogs/unreleased/1246-dark64 new file mode 100644 index 000000000..6c45d9a10 --- /dev/null +++ b/changelogs/unreleased/1246-dark64 @@ -0,0 +1 @@ +Introduce constraint generation through assembly blocks \ No newline at end of file diff --git a/zokrates_analysis/Cargo.toml b/zokrates_analysis/Cargo.toml new file mode 100644 index 000000000..e347abcdf --- /dev/null +++ b/zokrates_analysis/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "zokrates_analysis" +version = "0.1.0" +edition = "2021" + +[features] +default = ["ark", "bellman"] +ark = ["zokrates_ast/ark", "zokrates_embed/ark", "zokrates_common/ark"] +bellman = ["zokrates_ast/bellman", "zokrates_embed/bellman", "zokrates_common/bellman"] + +[dependencies] +log = "0.4" +cfg-if = "0.1" +num = { version = "0.1.36", default-features = false } +num-bigint = { version = "0.2", default-features = false } +lazy_static = "1.4" +typed-arena = "1.4.1" +reduce = "0.1.1" +# serialization and deserialization +serde = { version = "1.0", features = ["derive"] } +serde_json = { version = "1.0", features = ["preserve_order"] } +zokrates_field = { version = "0.5.0", path = "../zokrates_field", default-features = false } +zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } +zokrates_common = { version = "0.1", path = "../zokrates_common", default-features = false } +zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false } +zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } +csv = "1" + +[dev-dependencies] +pretty_assertions = "0.6.1" +zokrates_fs_resolver = { version = "0.5", path = "../zokrates_fs_resolver"} \ No newline at end of file diff --git a/zokrates_analysis/src/assembly_transformer.rs b/zokrates_analysis/src/assembly_transformer.rs new file mode 100644 index 000000000..7c8e856e0 --- /dev/null +++ b/zokrates_analysis/src/assembly_transformer.rs @@ -0,0 +1,412 @@ +// A static analyser pass to transform user-defined constraints to the form `lin_comb === quad_comb` +// This pass can fail if a non-quadratic constraint is found which cannot be transformed to the expected form + +use crate::ZirPropagator; +use std::fmt; +use zokrates_ast::zir::lqc::LinQuadComb; +use zokrates_ast::zir::result_folder::ResultFolder; +use zokrates_ast::zir::{FieldElementExpression, Id, Identifier, ZirAssemblyStatement, ZirProgram}; +use zokrates_field::Field; + +#[derive(Debug)] +pub struct Error(String); + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +pub struct AssemblyTransformer; + +impl AssemblyTransformer { + pub fn transform(p: ZirProgram) -> Result, Error> { + AssemblyTransformer.fold_program(p) + } +} + +impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer { + type Error = Error; + + fn fold_assembly_statement( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + match s { + ZirAssemblyStatement::Assignment(_, _) => Ok(vec![s]), + ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { + let lhs = self.fold_field_expression(lhs)?; + let rhs = self.fold_field_expression(rhs)?; + + let (is_quadratic, lhs, rhs) = match (lhs, rhs) { + ( + lhs @ FieldElementExpression::Identifier(..), + rhs @ FieldElementExpression::Identifier(..), + ) => (true, lhs, rhs), + (FieldElementExpression::Mult(x, y), other) + | (other, FieldElementExpression::Mult(x, y)) + if other.is_linear() => + { + ( + x.is_linear() && y.is_linear(), + other, + FieldElementExpression::Mult(x, y), + ) + } + (lhs, rhs) => (false, lhs, rhs), + }; + + match is_quadratic { + true => Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)]), + false => { + let sub = FieldElementExpression::Sub(box lhs, box rhs); + let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| { + Error("Non-quadratic constraints are not allowed".to_string()) + })?; + + let linear = lqc + .linear + .into_iter() + .map(|(c, i)| { + FieldElementExpression::Mult( + box FieldElementExpression::Number(c), + box FieldElementExpression::identifier(i), + ) + }) + .fold(FieldElementExpression::Number(T::from(0)), |acc, e| { + FieldElementExpression::Add(box acc, box e) + }); + + let lhs = FieldElementExpression::Add( + box FieldElementExpression::Number(lqc.constant), + box linear, + ); + + let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 { + let common_factor = lqc + .quadratic + .iter() + .scan(None, |state: &mut Option>, (_, a, b)| { + // short circuit if we do not have any common factors anymore + if *state == Some(vec![]) { + None + } else { + match state { + // only keep factors found in this term + Some(factors) => { + factors.retain(|&x| x == a || x == b); + } + // initialisation step, start with the two factors in the first term + None => { + *state = Some(vec![a, b]); + } + }; + state.clone() + } + }) + .last() + .and_then(|mut v| v.pop().cloned()); + + match common_factor { + Some(factor) => Ok(FieldElementExpression::Mult( + box lqc + .quadratic + .into_iter() + .map(|(c, i0, i1)| { + let c = T::zero() - c; + let e = match (i0, i1) { + (i0, i1) if factor.eq(&i0) => { + FieldElementExpression::identifier(i1) + } + (i0, i1) if factor.eq(&i1) => { + FieldElementExpression::identifier(i0) + } + _ => unreachable!(), + }; + FieldElementExpression::Mult( + box FieldElementExpression::Number(c), + box e, + ) + }) + .fold( + FieldElementExpression::Number(T::from(0)), + |acc, e| FieldElementExpression::Add(box acc, box e), + ), + box FieldElementExpression::identifier(factor), + )), + None => Err(Error( + "Non-quadratic constraints are not allowed".to_string(), + )), + }? + } else { + lqc.quadratic + .pop() + .map(|(c, i0, i1)| { + FieldElementExpression::Mult( + box FieldElementExpression::Mult( + box FieldElementExpression::Number(T::zero() - c), + box FieldElementExpression::identifier(i0), + ), + box FieldElementExpression::identifier(i1), + ) + }) + .unwrap_or_else(|| FieldElementExpression::Number(T::from(0))) + }; + + let mut propagator = ZirPropagator::default(); + let lhs = propagator + .fold_field_expression(lhs) + .map_err(|e| Error(e.to_string()))?; + + let rhs = propagator + .fold_field_expression(rhs) + .map_err(|e| Error(e.to_string()))?; + + Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)]) + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use zokrates_ast::common::SourceMetadata; + use zokrates_field::Bn128Field; + + #[test] + fn quadratic() { + // x === a * b; + let lhs = FieldElementExpression::::identifier("x".into()); + let rhs = FieldElementExpression::Mult( + box FieldElementExpression::identifier("a".into()), + box FieldElementExpression::identifier("b".into()), + ); + + let expected = vec![ZirAssemblyStatement::Constraint( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::Mult( + box FieldElementExpression::identifier("a".into()), + box FieldElementExpression::identifier("b".into()), + ), + SourceMetadata::default(), + )]; + let result = AssemblyTransformer + .fold_assembly_statement(ZirAssemblyStatement::Constraint( + lhs, + rhs, + SourceMetadata::default(), + )) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn non_quadratic() { + // x === ((a * b) * c); + let lhs = FieldElementExpression::::identifier("x".into()); + let rhs = FieldElementExpression::Mult( + box FieldElementExpression::Mult( + box FieldElementExpression::identifier("a".into()), + box FieldElementExpression::identifier("b".into()), + ), + box FieldElementExpression::identifier("c".into()), + ); + + let result = AssemblyTransformer.fold_assembly_statement(ZirAssemblyStatement::Constraint( + lhs, + rhs, + SourceMetadata::default(), + )); + + assert!(result.is_err()); + } + + #[test] + fn transform() { + // x === 1 - a * b; --> (-1) + x === (((-1) * a) * b); + let lhs = FieldElementExpression::identifier("x".into()); + let rhs = FieldElementExpression::Sub( + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Mult( + box FieldElementExpression::identifier("a".into()), + box FieldElementExpression::identifier("b".into()), + ), + ); + + let expected = vec![ZirAssemblyStatement::Constraint( + FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(-1)), + box FieldElementExpression::identifier("x".into()), + ), + FieldElementExpression::Mult( + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(-1)), + box FieldElementExpression::identifier("a".into()), + ), + box FieldElementExpression::identifier("b".into()), + ), + SourceMetadata::default(), + )]; + + let result = AssemblyTransformer + .fold_assembly_statement(ZirAssemblyStatement::Constraint( + lhs, + rhs, + SourceMetadata::default(), + )) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn factorize() { + // x === (a * b) + (b * c); --> x === ((a + c) * b); + let lhs = FieldElementExpression::::identifier("x".into()); + let rhs = FieldElementExpression::Add( + box FieldElementExpression::Mult( + box FieldElementExpression::identifier("a".into()), + box FieldElementExpression::identifier("b".into()), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::identifier("b".into()), + box FieldElementExpression::identifier("c".into()), + ), + ); + + let expected = vec![ZirAssemblyStatement::Constraint( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::Mult( + box FieldElementExpression::Add( + box FieldElementExpression::identifier("a".into()), + box FieldElementExpression::identifier("c".into()), + ), + box FieldElementExpression::identifier("b".into()), + ), + SourceMetadata::default(), + )]; + let result = AssemblyTransformer + .fold_assembly_statement(ZirAssemblyStatement::Constraint( + lhs, + rhs, + SourceMetadata::default(), + )) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn transform_complex() { + // mid = b*c; + // x === a+b+c - 2*a*b - 2*a*c - 2*mid + 4*a*mid; // x === a ^ b ^ c + // --> + // ((((x + ((-1)*a)) + ((-1)*b)) + ((-1)*c)) + (2*mid)) === (((((-2)*b) + ((-2)*c)) + (4*mid)) * a); + let lhs = FieldElementExpression::::identifier("x".into()); + let rhs = FieldElementExpression::Add( + box FieldElementExpression::Sub( + box FieldElementExpression::Sub( + box FieldElementExpression::Sub( + box FieldElementExpression::Add( + box FieldElementExpression::Add( + box FieldElementExpression::identifier("a".into()), + box FieldElementExpression::identifier("b".into()), + ), + box FieldElementExpression::identifier("c".into()), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("a".into()), + ), + box FieldElementExpression::identifier("b".into()), + ), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("a".into()), + ), + box FieldElementExpression::identifier("c".into()), + ), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("mid".into()), + ), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(4)), + box FieldElementExpression::identifier("a".into()), + ), + box FieldElementExpression::identifier("mid".into()), + ), + ); + + let lhs_expected = FieldElementExpression::Add( + box FieldElementExpression::Add( + box FieldElementExpression::Add( + box FieldElementExpression::Add( + box FieldElementExpression::identifier("x".into()), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(-1)), + box FieldElementExpression::identifier("a".into()), + ), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(-1)), + box FieldElementExpression::identifier("b".into()), + ), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(-1)), + box FieldElementExpression::identifier("c".into()), + ), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("mid".into()), + ), + ); + + let rhs_expected = FieldElementExpression::Mult( + box FieldElementExpression::Add( + box FieldElementExpression::Add( + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(-2)), + box FieldElementExpression::identifier("b".into()), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(-2)), + box FieldElementExpression::identifier("c".into()), + ), + ), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(4)), + box FieldElementExpression::identifier("mid".into()), + ), + ), + box FieldElementExpression::identifier("a".into()), + ); + + let expected = vec![ZirAssemblyStatement::Constraint( + lhs_expected, + rhs_expected, + SourceMetadata::default(), + )]; + let result = AssemblyTransformer + .fold_assembly_statement(ZirAssemblyStatement::Constraint( + lhs, + rhs, + SourceMetadata::default(), + )) + .unwrap(); + + assert_eq!(result, expected); + } +} diff --git a/zokrates_core/src/static_analysis/boolean_array_comparator.rs b/zokrates_analysis/src/boolean_array_comparator.rs similarity index 100% rename from zokrates_core/src/static_analysis/boolean_array_comparator.rs rename to zokrates_analysis/src/boolean_array_comparator.rs diff --git a/zokrates_core/src/static_analysis/branch_isolator.rs b/zokrates_analysis/src/branch_isolator.rs similarity index 100% rename from zokrates_core/src/static_analysis/branch_isolator.rs rename to zokrates_analysis/src/branch_isolator.rs diff --git a/zokrates_core/src/static_analysis/condition_redefiner.rs b/zokrates_analysis/src/condition_redefiner.rs similarity index 100% rename from zokrates_core/src/static_analysis/condition_redefiner.rs rename to zokrates_analysis/src/condition_redefiner.rs diff --git a/zokrates_core/src/static_analysis/constant_argument_checker.rs b/zokrates_analysis/src/constant_argument_checker.rs similarity index 59% rename from zokrates_core/src/static_analysis/constant_argument_checker.rs rename to zokrates_analysis/src/constant_argument_checker.rs index c7ee629a3..485cf9181 100644 --- a/zokrates_core/src/static_analysis/constant_argument_checker.rs +++ b/zokrates_analysis/src/constant_argument_checker.rs @@ -1,9 +1,7 @@ use std::fmt; use zokrates_ast::common::FlatEmbed; use zokrates_ast::typed::{ - result_folder::ResultFolder, - result_folder::{fold_statement, fold_uint_expression_inner}, - Constant, EmbedCall, TypedStatement, UBitwidth, UExpressionInner, + result_folder::fold_statement, result_folder::ResultFolder, Constant, EmbedCall, TypedStatement, }; use zokrates_ast::typed::{DefinitionRhs, TypedProgram}; use zokrates_field::Field; @@ -71,40 +69,4 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { s => fold_statement(self, s), } } - - fn fold_uint_expression_inner( - &mut self, - bitwidth: UBitwidth, - e: UExpressionInner<'ast, T>, - ) -> Result, Error> { - match e { - UExpressionInner::LeftShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - - match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), - by => Err(Error(format!( - "Cannot shift by a variable value, found `{} << {}`", - e, - by.clone().annotate(UBitwidth::B32) - ))), - } - } - UExpressionInner::RightShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - - match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), - by => Err(Error(format!( - "Cannot shift by a variable value, found `{} >> {}`", - e, - by.clone().annotate(UBitwidth::B32) - ))), - } - } - e => fold_uint_expression_inner(self, bitwidth, e), - } - } } diff --git a/zokrates_core/src/static_analysis/constant_resolver.rs b/zokrates_analysis/src/constant_resolver.rs similarity index 100% rename from zokrates_core/src/static_analysis/constant_resolver.rs rename to zokrates_analysis/src/constant_resolver.rs diff --git a/zokrates_core/src/static_analysis/dead_code.rs b/zokrates_analysis/src/dead_code.rs similarity index 100% rename from zokrates_core/src/static_analysis/dead_code.rs rename to zokrates_analysis/src/dead_code.rs diff --git a/zokrates_analysis/src/expression_validator.rs b/zokrates_analysis/src/expression_validator.rs new file mode 100644 index 000000000..d3a0a96ef --- /dev/null +++ b/zokrates_analysis/src/expression_validator.rs @@ -0,0 +1,107 @@ +use std::fmt; +use zokrates_ast::typed::result_folder::{ + fold_assembly_statement, fold_field_expression, fold_uint_expression_inner, ResultFolder, +}; +use zokrates_ast::typed::{ + FieldElementExpression, TypedAssemblyStatement, TypedProgram, UBitwidth, UExpressionInner, +}; +use zokrates_field::Field; + +#[derive(Debug, PartialEq, Eq)] +pub struct Error(String); + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +pub struct ExpressionValidator; + +impl ExpressionValidator { + pub fn validate(p: TypedProgram) -> Result, Error> { + ExpressionValidator.fold_program(p) + } +} + +impl<'ast, T: Field> ResultFolder<'ast, T> for ExpressionValidator { + type Error = Error; + + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + match s { + // we allow more dynamic expressions in witness generation + TypedAssemblyStatement::Assignment(_, _) => Ok(vec![s]), + s => fold_assembly_statement(self, s), + } + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + // these should have been propagated away + FieldElementExpression::And(_, _) + | FieldElementExpression::Or(_, _) + | FieldElementExpression::Xor(_, _) + | FieldElementExpression::LeftShift(_, _) + | FieldElementExpression::RightShift(_, _) => Err(Error(format!( + "Found non-constant bitwise operation in field element expression `{}`", + e + ))), + FieldElementExpression::Pow(box e, box exp) => { + let e = self.fold_field_expression(e)?; + let exp = self.fold_uint_expression(exp)?; + + match exp.as_inner() { + UExpressionInner::Value(_) => Ok(FieldElementExpression::Pow(box e, box exp)), + exp => Err(Error(format!( + "Found non-constant exponent in power expression `{}**{}`", + e, + exp.clone().annotate(UBitwidth::B32) + ))), + } + } + e => fold_field_expression(self, e), + } + } + + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Error> { + match e { + UExpressionInner::LeftShift(box e, box by) => { + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; + + match by.as_inner() { + UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), + by => Err(Error(format!( + "Cannot shift by a variable value, found `{} << {}`", + e, + by.clone().annotate(UBitwidth::B32) + ))), + } + } + UExpressionInner::RightShift(box e, box by) => { + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; + + match by.as_inner() { + UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), + by => Err(Error(format!( + "Cannot shift by a variable value, found `{} >> {}`", + e, + by.clone().annotate(UBitwidth::B32) + ))), + } + } + e => fold_uint_expression_inner(self, bitwidth, e), + } + } +} diff --git a/zokrates_core/src/static_analysis/flat_propagation.rs b/zokrates_analysis/src/flat_propagation.rs similarity index 96% rename from zokrates_core/src/static_analysis/flat_propagation.rs rename to zokrates_analysis/src/flat_propagation.rs index f69b93137..155d803c8 100644 --- a/zokrates_core/src/static_analysis/flat_propagation.rs +++ b/zokrates_analysis/src/flat_propagation.rs @@ -14,8 +14,8 @@ struct Propagator { constants: HashMap, } -impl Folder for Propagator { - fn fold_statement(&mut self, s: FlatStatement) -> Vec> { +impl<'ast, T: Field> Folder<'ast, T> for Propagator { + fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec> { match s { FlatStatement::Definition(var, expr) => match self.fold_expression(expr) { FlatExpression::Number(n) => { diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs similarity index 88% rename from zokrates_core/src/static_analysis/flatten_complex_types.rs rename to zokrates_analysis/src/flatten_complex_types.rs index b7571f9df..f4b81d8e1 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -1,11 +1,12 @@ +use std::collections::HashMap; +use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth}; use zokrates_ast::typed::{self, Expr, Typed}; -use zokrates_ast::zir::{self, Id, Select}; +use zokrates_ast::zir::IntoType as ZirIntoType; +use zokrates_ast::zir::{self, Folder, Id, Select}; use zokrates_field::Field; -use std::convert::{TryFrom, TryInto}; - #[derive(Default)] pub struct Flattener { phantom: PhantomData, @@ -272,6 +273,14 @@ impl<'ast, T: Field> Flattener { } } + fn fold_assembly_statement( + &mut self, + statements_buffer: &mut Vec>, + s: typed::TypedAssemblyStatement<'ast, T>, + ) -> zir::ZirAssemblyStatement<'ast, T> { + fold_assembly_statement(self, statements_buffer, s) + } + fn fold_statement( &mut self, statements_buffer: &mut Vec>, @@ -449,12 +458,126 @@ impl<'ast, T: Field> Flattener { } } +// This finder looks for identifiers that were not defined in some block of statements +// These identifiers are used as function arguments when moving witness assignment expression +// to a zir function. +// +// Example: +// def main(field a, field mut b) -> field { +// asm { +// b <== a * a; +// } +// return b; +// } +// is turned into +// def main(field a, field mut b) -> field { +// asm { +// b <-- (field a) -> field { +// return a * a; +// } +// b == a * a; +// } +// return b; +// } +#[derive(Default)] +pub struct ArgumentFinder<'ast, T> { + pub identifiers: HashMap, zir::Type>, + _phantom: PhantomData, +} + +impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> { + fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec> { + match s { + zir::ZirStatement::Definition(assignee, expr) => { + let assignee = self.fold_assignee(assignee); + let expr = self.fold_expression(expr); + self.identifiers.remove(&assignee.id); + vec![zir::ZirStatement::Definition(assignee, expr)] + } + zir::ZirStatement::MultipleDefinition(assignees, list) => { + let assignees: Vec> = assignees + .into_iter() + .map(|v| self.fold_assignee(v)) + .collect(); + let list = self.fold_expression_list(list); + for a in &assignees { + self.identifiers.remove(&a.id); + } + vec![zir::ZirStatement::MultipleDefinition(assignees, list)] + } + s => zir::folder::fold_statement(self, s), + } + } + + fn fold_identifier_expression + Id<'ast, T>>( + &mut self, + ty: &E::Ty, + e: zir::IdentifierExpression<'ast, E>, + ) -> zir::IdentifierOrExpression<'ast, T, E> { + self.identifiers + .insert(e.id.clone(), ty.clone().into_type()); + zir::IdentifierOrExpression::Identifier(e) + } +} + +fn fold_assembly_statement<'ast, T: Field>( + f: &mut Flattener, + statements_buffer: &mut Vec>, + s: typed::TypedAssemblyStatement<'ast, T>, +) -> zir::ZirAssemblyStatement<'ast, T> { + match s { + typed::TypedAssemblyStatement::Assignment(a, e) => { + let mut statements_buffer: Vec> = vec![]; + let a = f.fold_assignee(a); + let e = f.fold_expression(&mut statements_buffer, e); + statements_buffer.push(zir::ZirStatement::Return(e)); + + let mut finder = ArgumentFinder::default(); + let mut statements_buffer: Vec> = statements_buffer + .into_iter() + .rev() + .flat_map(|s| finder.fold_statement(s)) + .collect(); + statements_buffer.reverse(); + + let function = zir::ZirFunction { + signature: zir::types::Signature::default() + .inputs(finder.identifiers.values().cloned().collect()) + .outputs(a.iter().map(|a| a.get_type()).collect()), + arguments: finder + .identifiers + .into_iter() + .map(|(id, ty)| zir::Parameter { + id: zir::Variable::with_id_and_type(id, ty), + private: true, + }) + .collect(), + statements: statements_buffer, + }; + + zir::ZirAssemblyStatement::Assignment(a, function) + } + typed::TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { + let lhs = f.fold_field_expression(statements_buffer, lhs); + let rhs = f.fold_field_expression(statements_buffer, rhs); + zir::ZirAssemblyStatement::Constraint(lhs, rhs, metadata) + } + } +} + fn fold_statement<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, s: typed::TypedStatement<'ast, T>, ) { let res = match s { + typed::TypedStatement::Assembly(statements) => { + let statements = statements + .into_iter() + .map(|s| f.fold_assembly_statement(statements_buffer, s)) + .collect(); + vec![zir::ZirStatement::Assembly(statements)] + } typed::TypedStatement::Return(expression) => vec![zir::ZirStatement::Return( f.fold_expression(statements_buffer, expression), )], @@ -471,7 +594,7 @@ fn fold_statement<'ast, T: Field>( let e = f.fold_boolean_expression(statements_buffer, e); let error = match error { typed::RuntimeError::SourceAssertion(metadata) => { - zir::RuntimeError::SourceAssertion(metadata.to_string()) + zir::RuntimeError::SourceAssertion(metadata) } typed::RuntimeError::SelectRangeCheck => zir::RuntimeError::SelectRangeCheck, typed::RuntimeError::DivisionByZero => zir::RuntimeError::DivisionByZero, @@ -896,6 +1019,36 @@ fn fold_field_expression<'ast, T: Field>( ) } typed::FieldElementExpression::Pos(box e) => f.fold_field_expression(statements_buffer, e), + typed::FieldElementExpression::Xor(box left, box right) => { + let left = f.fold_field_expression(statements_buffer, left); + let right = f.fold_field_expression(statements_buffer, right); + + zir::FieldElementExpression::Xor(box left, box right) + } + typed::FieldElementExpression::And(box left, box right) => { + let left = f.fold_field_expression(statements_buffer, left); + let right = f.fold_field_expression(statements_buffer, right); + + zir::FieldElementExpression::And(box left, box right) + } + typed::FieldElementExpression::Or(box left, box right) => { + let left = f.fold_field_expression(statements_buffer, left); + let right = f.fold_field_expression(statements_buffer, right); + + zir::FieldElementExpression::Or(box left, box right) + } + typed::FieldElementExpression::LeftShift(box e, box by) => { + let e = f.fold_field_expression(statements_buffer, e); + let by = f.fold_uint_expression(statements_buffer, by); + + zir::FieldElementExpression::LeftShift(box e, box by) + } + typed::FieldElementExpression::RightShift(box e, box by) => { + let e = f.fold_field_expression(statements_buffer, e); + let by = f.fold_uint_expression(statements_buffer, by); + + zir::FieldElementExpression::RightShift(box e, box by) + } typed::FieldElementExpression::Conditional(c) => f .fold_conditional_expression(statements_buffer, c) .pop() diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_analysis/src/lib.rs similarity index 78% rename from zokrates_core/src/static_analysis/mod.rs rename to zokrates_analysis/src/lib.rs index bf394dc29..c628e7283 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_analysis/src/lib.rs @@ -1,15 +1,19 @@ +#![feature(box_patterns, box_syntax)] + //! Module containing static analysis //! //! @file mod.rs //! @author Thibaut Schaeffer //! @date 2018 +mod assembly_transformer; mod boolean_array_comparator; mod branch_isolator; mod condition_redefiner; mod constant_argument_checker; mod constant_resolver; mod dead_code; +mod expression_validator; mod flat_propagation; mod flatten_complex_types; mod log_ignorer; @@ -34,14 +38,16 @@ use self::reducer::reduce_program; use self::struct_concretizer::StructConcretizer; use self::uint_optimizer::UintOptimizer; use self::variable_write_remover::VariableWriteRemover; -use crate::compile::CompileConfig; -use crate::static_analysis::constant_resolver::ConstantResolver; -use crate::static_analysis::dead_code::DeadCodeEliminator; -use crate::static_analysis::panic_extractor::PanicExtractor; -use crate::static_analysis::zir_propagation::ZirPropagator; +use crate::assembly_transformer::AssemblyTransformer; +use crate::constant_resolver::ConstantResolver; +use crate::dead_code::DeadCodeEliminator; +use crate::expression_validator::ExpressionValidator; +use crate::panic_extractor::PanicExtractor; +pub use crate::zir_propagation::ZirPropagator; use std::fmt; use zokrates_ast::typed::{abi::Abi, TypedProgram}; use zokrates_ast::zir::ZirProgram; +use zokrates_common::CompileConfig; use zokrates_field::Field; #[derive(Debug)] @@ -51,6 +57,9 @@ pub enum Error { ZirPropagation(self::zir_propagation::Error), NonConstantArgument(self::constant_argument_checker::Error), OutOfBounds(self::out_of_bounds::Error), + Assembly(self::assembly_transformer::Error), + VariableIndex(self::variable_write_remover::Error), + InvalidExpression(self::expression_validator::Error), } impl From for Error { @@ -83,6 +92,24 @@ impl From for Error { } } +impl From for Error { + fn from(e: assembly_transformer::Error) -> Self { + Error::Assembly(e) + } +} + +impl From for Error { + fn from(e: variable_write_remover::Error) -> Self { + Error::VariableIndex(e) + } +} + +impl From for Error { + fn from(e: expression_validator::Error) -> Self { + Error::InvalidExpression(e) + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -91,6 +118,9 @@ impl fmt::Display for Error { Error::ZirPropagation(e) => write!(f, "{}", e), Error::NonConstantArgument(e) => write!(f, "{}", e), Error::OutOfBounds(e) => write!(f, "{}", e), + Error::Assembly(e) => write!(f, "{}", e), + Error::VariableIndex(e) => write!(f, "{}", e), + Error::InvalidExpression(e) => write!(f, "{}", e), } } } @@ -139,6 +169,10 @@ pub fn analyse<'ast, T: Field>( let r = StructConcretizer::concretize(r); log::trace!("\n{}", r); + // validate expressions + log::debug!("Static analyser: Validate expressions"); + let r = ExpressionValidator::validate(r).map_err(Error::from)?; + // generate abi log::debug!("Static analyser: Generate abi"); let abi = r.abi(); @@ -155,7 +189,7 @@ pub fn analyse<'ast, T: Field>( // remove assignment to variable index log::debug!("Static analyser: Remove variable index"); - let r = VariableWriteRemover::apply(r); + let r = VariableWriteRemover::apply(r).map_err(Error::from)?; log::trace!("\n{}", r); // detect non constant shifts and constant lt bounds @@ -196,5 +230,9 @@ pub fn analyse<'ast, T: Field>( let zir = UintOptimizer::optimize(zir); log::trace!("\n{}", zir); + log::debug!("Static analyser: Apply constraint transformations in assembly"); + let zir = AssemblyTransformer::transform(zir).map_err(Error::from)?; + log::trace!("\n{}", zir); + Ok((zir, abi)) } diff --git a/zokrates_core/src/static_analysis/log_ignorer.rs b/zokrates_analysis/src/log_ignorer.rs similarity index 100% rename from zokrates_core/src/static_analysis/log_ignorer.rs rename to zokrates_analysis/src/log_ignorer.rs diff --git a/zokrates_core/src/static_analysis/out_of_bounds.rs b/zokrates_analysis/src/out_of_bounds.rs similarity index 100% rename from zokrates_core/src/static_analysis/out_of_bounds.rs rename to zokrates_analysis/src/out_of_bounds.rs diff --git a/zokrates_core/src/static_analysis/panic_extractor.rs b/zokrates_analysis/src/panic_extractor.rs similarity index 100% rename from zokrates_core/src/static_analysis/panic_extractor.rs rename to zokrates_analysis/src/panic_extractor.rs diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_analysis/src/propagation.rs similarity index 83% rename from zokrates_core/src/static_analysis/propagation.rs rename to zokrates_analysis/src/propagation.rs index 737e29b9b..b7e5c0a17 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -7,9 +7,12 @@ //! @author Thibaut Schaeffer //! @date 2018 +use num::traits::Pow; +use num_bigint::BigUint; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::fmt; +use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub}; use zokrates_ast::common::FlatEmbed; use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::types::Type; @@ -21,28 +24,22 @@ pub type Constants<'ast, T> = HashMap, TypedExpression<'ast, T> #[derive(Debug, PartialEq, Eq)] pub enum Error { Type(String), - AssertionFailed(String), - ValueTooLarge(String), + AssertionFailed(RuntimeError), + InvalidValue(String), OutOfBounds(u128, u128), - NonConstantExponent(String), } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Error::Type(s) => write!(f, "{}", s), - Error::AssertionFailed(s) => write!(f, "{}", s), - Error::ValueTooLarge(s) => write!(f, "{}", s), + Error::AssertionFailed(err) => write!(f, "Assertion failed ({})", err), + Error::InvalidValue(s) => write!(f, "{}", s), Error::OutOfBounds(index, size) => write!( f, "Out of bounds index ({} >= {}) found during static analysis", index, size ), - Error::NonConstantExponent(s) => write!( - f, - "Non-constant exponent `{}` detected during static analysis", - s - ), } } } @@ -179,13 +176,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } } - fn fold_function( - &mut self, - f: TypedFunction<'ast, T>, - ) -> Result, Error> { - fold_function(self, f) - } - fn fold_conditional_expression< E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq + ResultFold<'ast, T>, >( @@ -215,11 +205,101 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { ) } + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + match s { + TypedAssemblyStatement::Assignment(assignee, expr) => { + let assignee = self.fold_assignee(assignee)?; + let expr = self.fold_expression(expr)?; + + if expr.is_constant() { + match assignee { + TypedAssignee::Identifier(var) => { + let expr = expr.into_canonical_constant(); + + assert!(self.constants.insert(var.id, expr).is_none()); + + Ok(vec![]) + } + assignee => match self.try_get_constant_mut(&assignee) { + Ok((_, c)) => { + *c = expr.into_canonical_constant(); + Ok(vec![]) + } + Err(v) => match self.constants.remove(&v.id) { + // invalidate the cache for this identifier, and define the latest + // version of the constant in the program, if any + Some(c) => Ok(vec![ + TypedAssemblyStatement::Assignment(v.clone().into(), c), + TypedAssemblyStatement::Assignment(assignee, expr), + ]), + None => { + Ok(vec![TypedAssemblyStatement::Assignment(assignee, expr)]) + } + }, + }, + } + } else { + // the expression being assigned is not constant, invalidate the cache + let v = self + .try_get_constant_mut(&assignee) + .map(|(v, _)| v) + .unwrap_or_else(|v| v); + + match self.constants.remove(&v.id) { + Some(c) => Ok(vec![ + TypedAssemblyStatement::Assignment(v.clone().into(), c), + TypedAssemblyStatement::Assignment(assignee, expr), + ]), + None => Ok(vec![TypedAssemblyStatement::Assignment(assignee, expr)]), + } + } + } + TypedAssemblyStatement::Constraint(left, right, metadata) => { + let left = self.fold_field_expression(left)?; + let right = self.fold_field_expression(right)?; + + // a bit hacky, but we use a fake boolean expression to check this + let is_equal = + BooleanExpression::FieldEq(EqExpression::new(left.clone(), right.clone())); + let is_equal = self.fold_boolean_expression(is_equal)?; + + match is_equal { + BooleanExpression::Value(true) => Ok(vec![]), + BooleanExpression::Value(false) => { + Err(Error::AssertionFailed(RuntimeError::SourceAssertion( + metadata + .message(Some(format!("In asm block: `{} !== {}`", left, right))), + ))) + } + _ => Ok(vec![TypedAssemblyStatement::Constraint( + left, right, metadata, + )]), + } + } + } + } + fn fold_statement( &mut self, s: TypedStatement<'ast, T>, ) -> Result>, Error> { match s { + TypedStatement::Assembly(statements) => { + let statements: Vec<_> = statements + .into_iter() + .map(|s| self.fold_assembly_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + match statements.len() { + 0 => Ok(vec![]), + _ => Ok(vec![TypedStatement::Assembly(statements)]), + } + } // propagation to the defined variable if rhs is a constant TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => { let assignee = self.fold_assignee(assignee)?; @@ -373,6 +453,26 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { true => { let r: Option> = match embed_call.embed { FlatEmbed::BitArrayLe => Ok(None), // todo + FlatEmbed::FieldToBoolUnsafe => { + match FieldElementExpression::try_from_typed( + embed_call.arguments[0].clone(), + ) { + Ok(FieldElementExpression::Number(n)) if n == T::from(0) => { + Ok(Some(BooleanExpression::Value(false).into())) + } + Ok(FieldElementExpression::Number(n)) if n == T::from(1) => { + Ok(Some(BooleanExpression::Value(true).into())) + } + Ok(FieldElementExpression::Number(n)) => { + Err(Error::InvalidValue(format!( + "Cannot call `{}` with value `{}`: should be 0 or 1", + embed_call.embed.id(), + n + ))) + } + _ => Ok(None), + } + } FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits( &embed_call.arguments, UBitwidth::B64, @@ -430,7 +530,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } if acc != T::zero() { - Err(Error::ValueTooLarge(format!( + Err(Error::InvalidValue(format!( "Cannot unpack `{}` to `{}`: value is too large", num, assignee.get_type() @@ -521,15 +621,12 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } } } - TypedStatement::Assertion(e, ty) => { - let e_str = e.to_string(); + TypedStatement::Assertion(e, err) => { let expr = self.fold_boolean_expression(e)?; match expr { - BooleanExpression::Value(false) => { - Err(Error::AssertionFailed(format!("{}: ({})", ty, e_str))) - } + BooleanExpression::Value(false) => Err(Error::AssertionFailed(err)), BooleanExpression::Value(true) => Ok(vec![]), - _ => Ok(vec![TypedStatement::Assertion(expr, ty)]), + _ => Ok(vec![TypedStatement::Assertion(expr, err)]), } } s @ TypedStatement::PushCallLog(..) => Ok(vec![s]), @@ -827,11 +924,140 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { box e1, box UExpressionInner::Value(n2).annotate(UBitwidth::B32), )), - (_, e2) => Err(Error::NonConstantExponent( - e2.annotate(UBitwidth::B32).to_string(), + (e1, e2) => Ok(FieldElementExpression::Pow( + box e1, + box e2.annotate(UBitwidth::B32), )), } } + FieldElementExpression::Xor(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; + + match (e1, e2) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + Ok(FieldElementExpression::Number( + T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(), + )) + } + (FieldElementExpression::Number(n), e) + | (e, FieldElementExpression::Number(n)) + if n == T::from(0) => + { + Ok(e) + } + (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))), + (e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)), + } + } + + FieldElementExpression::And(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; + + match (e1, e2) { + (_, FieldElementExpression::Number(n)) + | (FieldElementExpression::Number(n), _) + if n == T::from(0) => + { + Ok(FieldElementExpression::Number(n)) + } + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + Ok(FieldElementExpression::Number( + T::try_from(n1.to_biguint().bitand(n2.to_biguint())).unwrap(), + )) + } + (e1, e2) => Ok(FieldElementExpression::And(box e1, box e2)), + } + } + FieldElementExpression::Or(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; + + match (e1, e2) { + (e, FieldElementExpression::Number(n)) + | (FieldElementExpression::Number(n), e) + if n == T::from(0) => + { + Ok(e) + } + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + Ok(FieldElementExpression::Number( + T::try_from(n1.to_biguint().bitor(n2.to_biguint())).unwrap(), + )) + } + (e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)), + } + } + FieldElementExpression::LeftShift(box e, box by) => { + let e = self.fold_field_expression(e)?; + let by = self.fold_uint_expression(by)?; + match (e, by) { + ( + e, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by == 0 => Ok(e), + ( + _, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::Number(T::from(0))) + } + ( + FieldElementExpression::Number(n), + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) => { + let two = BigUint::from(2usize); + let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize); + + Ok(FieldElementExpression::Number( + T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(), + )) + } + (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), + } + } + FieldElementExpression::RightShift(box e, box by) => { + let e = self.fold_field_expression(e)?; + let by = self.fold_uint_expression(by)?; + match (e, by) { + ( + e, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by == 0 => Ok(e), + ( + _, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::Number(T::from(0))) + } + ( + FieldElementExpression::Number(n), + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) => Ok(FieldElementExpression::Number( + T::try_from(n.to_biguint().shr(by as usize)).unwrap(), + )), + (e, by) => Ok(FieldElementExpression::RightShift(box e, box by)), + } + } e => fold_field_expression(self, e), } } @@ -1333,6 +1559,113 @@ mod tests { ); } + #[test] + fn left_shift() { + let mut constants = Constants::new(); + let mut propagator = Propagator::with_constants(&mut constants); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::identifier("a".into()), + box 0u32.into(), + )), + Ok(FieldElementExpression::identifier("a".into())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(2)), + box 2u32.into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(8))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(1)), + box ((Bn128Field::get_required_bits() - 1) as u32).into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(3)), + box ((Bn128Field::get_required_bits() - 3) as u32).into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(1)), + box (Bn128Field::get_required_bits() as u32).into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + } + + #[test] + fn right_shift() { + let mut constants = Constants::new(); + let mut propagator = Propagator::with_constants(&mut constants); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::identifier("a".into()), + box 0u32.into(), + )), + Ok(FieldElementExpression::identifier("a".into())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::identifier("a".into()), + box (Bn128Field::get_required_bits() as u32).into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::from(3)), + box 1u32.into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(1))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::from(2)), + box 2u32.into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::from(2)), + box 4u32.into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::max_value()), + box ((Bn128Field::get_required_bits() - 1) as u32).into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(1))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::max_value()), + box (Bn128Field::get_required_bits() as u32).into(), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + } + #[test] fn if_else_true() { let e = FieldElementExpression::conditional( diff --git a/zokrates_core/src/static_analysis/reducer/constants_reader.rs b/zokrates_analysis/src/reducer/constants_reader.rs similarity index 99% rename from zokrates_core/src/static_analysis/reducer/constants_reader.rs rename to zokrates_analysis/src/reducer/constants_reader.rs index b04ed4064..4ee0d1359 100644 --- a/zokrates_core/src/static_analysis/reducer/constants_reader.rs +++ b/zokrates_analysis/src/reducer/constants_reader.rs @@ -1,6 +1,6 @@ // given a (partial) map of values for program constants, replace where applicable constants by their value -use crate::static_analysis::reducer::ConstantDefinitions; +use crate::reducer::ConstantDefinitions; use zokrates_ast::typed::{ folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier, DeclarationConstant, Expr, FieldElementExpression, Id, Identifier, IdentifierExpression, diff --git a/zokrates_core/src/static_analysis/reducer/constants_writer.rs b/zokrates_analysis/src/reducer/constants_writer.rs similarity index 99% rename from zokrates_core/src/static_analysis/reducer/constants_writer.rs rename to zokrates_analysis/src/reducer/constants_writer.rs index 4917a6b9a..d4e03d3d4 100644 --- a/zokrates_core/src/static_analysis/reducer/constants_writer.rs +++ b/zokrates_analysis/src/reducer/constants_writer.rs @@ -1,6 +1,6 @@ // A folder to inline all constant definitions down to a single literal and register them in the state for later use. -use crate::static_analysis::reducer::{ +use crate::reducer::{ constants_reader::ConstantsReader, reduce_function, ConstantDefinitions, Error, }; use std::collections::{BTreeMap, HashSet}; diff --git a/zokrates_core/src/static_analysis/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs similarity index 97% rename from zokrates_core/src/static_analysis/reducer/inline.rs rename to zokrates_analysis/src/reducer/inline.rs index 09aa6c932..31f237e82 100644 --- a/zokrates_core/src/static_analysis/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -26,9 +26,9 @@ // - The body of the function is in SSA form // - The return value(s) are assigned to internal variables -use crate::static_analysis::reducer::Output; -use crate::static_analysis::reducer::ShallowTransformer; -use crate::static_analysis::reducer::Versions; +use crate::reducer::Output; +use crate::reducer::ShallowTransformer; +use crate::reducer::Versions; use zokrates_ast::common::FlatEmbed; use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType}; diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs similarity index 99% rename from zokrates_core/src/static_analysis/reducer/mod.rs rename to zokrates_analysis/src/reducer/mod.rs index 0585f9663..ea6feb536 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -36,7 +36,7 @@ use zokrates_field::Field; use self::constants_writer::ConstantsWriter; use self::shallow_ssa::ShallowTransformer; -use crate::static_analysis::propagation::{Constants, Propagator}; +use crate::propagation::{Constants, Propagator}; use std::fmt; diff --git a/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs similarity index 97% rename from zokrates_core/src/static_analysis/reducer/shallow_ssa.rs rename to zokrates_analysis/src/reducer/shallow_ssa.rs index f0e667870..a071a0446 100644 --- a/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -121,33 +121,42 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { fold_function(self, f) } + + fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { + match a { + TypedAssignee::Identifier(v) => { + let v = self.issue_next_ssa_variable(v); + TypedAssignee::Identifier(self.fold_variable(v)) + } + a => fold_assignee(self, a), + } + } } impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Vec> { + match s { + TypedAssemblyStatement::Assignment(a, e) => { + let e = self.fold_expression(e); + let a = self.fold_assignee(a); + vec![TypedAssemblyStatement::Assignment(a, e)] + } + s => fold_assembly_statement(self, s), + } + } fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { match s { TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => { let e = self.fold_expression(e); - - let a = match a { - TypedAssignee::Identifier(v) => { - let v = self.issue_next_ssa_variable(v); - TypedAssignee::Identifier(self.fold_variable(v)) - } - a => fold_assignee(self, a), - }; - + let a = self.fold_assignee(a); vec![TypedStatement::definition(a, e)] } TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => { - let assignee = match assignee { - TypedAssignee::Identifier(v) => { - let v = self.issue_next_ssa_variable(v); - TypedAssignee::Identifier(self.fold_variable(v)) - } - a => fold_assignee(self, a), - }; let embed_call = self.fold_embed_call(embed_call); + let assignee = self.fold_assignee(assignee); vec![TypedStatement::embed_call_definition(assignee, embed_call)] } TypedStatement::For(v, from, to, stats) => { diff --git a/zokrates_core/src/static_analysis/struct_concretizer.rs b/zokrates_analysis/src/struct_concretizer.rs similarity index 100% rename from zokrates_core/src/static_analysis/struct_concretizer.rs rename to zokrates_analysis/src/struct_concretizer.rs diff --git a/zokrates_core/src/static_analysis/uint_optimizer.rs b/zokrates_analysis/src/uint_optimizer.rs similarity index 100% rename from zokrates_core/src/static_analysis/uint_optimizer.rs rename to zokrates_analysis/src/uint_optimizer.rs diff --git a/zokrates_core/src/static_analysis/variable_write_remover.rs b/zokrates_analysis/src/variable_write_remover.rs similarity index 94% rename from zokrates_core/src/static_analysis/variable_write_remover.rs rename to zokrates_analysis/src/variable_write_remover.rs index 7d88336cc..b218acc33 100644 --- a/zokrates_core/src/static_analysis/variable_write_remover.rs +++ b/zokrates_analysis/src/variable_write_remover.rs @@ -5,11 +5,22 @@ //! @date 2018 use std::collections::HashSet; -use zokrates_ast::typed::folder::*; +use std::fmt; +use zokrates_ast::typed::result_folder::ResultFolder; +use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::types::{MemberId, Type}; use zokrates_ast::typed::*; use zokrates_field::Field; +#[derive(Debug)] +pub struct Error(String); + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + pub struct VariableWriteRemover; impl<'ast> VariableWriteRemover { @@ -17,7 +28,7 @@ impl<'ast> VariableWriteRemover { VariableWriteRemover } - pub fn apply(p: TypedProgram) -> TypedProgram { + pub fn apply(p: TypedProgram) -> Result, Error> { let mut remover = VariableWriteRemover::new(); remover.fold_program(p) } @@ -452,14 +463,35 @@ fn is_constant(assignee: &TypedAssignee) -> bool { } } -impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover { - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { +impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover { + type Error = Error; + + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + match s { + TypedAssemblyStatement::Assignment(a, e) if is_constant(&a) => { + Ok(vec![TypedAssemblyStatement::Assignment(a, e)]) + } + TypedAssemblyStatement::Assignment(a, _) => Err(Error(format!( + "Cannot assign to an assignee with a variable index `{}`", + a + ))), + s => Ok(vec![s]), + } + } + + fn fold_statement( + &mut self, + s: TypedStatement<'ast, T>, + ) -> Result>, Self::Error> { match s { TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => { - let expr = self.fold_expression(expr); + let expr = self.fold_expression(expr)?; if is_constant(&assignee) { - vec![TypedStatement::definition(assignee, expr)] + Ok(vec![TypedStatement::definition(assignee, expr)]) } else { // Note: here we redefine the whole object, ideally we would only redefine some of it // Example: `a[0][i] = 42` we redefine `a` but we could redefine just `a[0]` @@ -486,28 +518,28 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover { .into(), }; - let base = self.fold_expression(base); + let base = self.fold_expression(base)?; let indices = indices .into_iter() .map(|a| match a { Access::Select(box i) => { - Access::Select(box self.fold_uint_expression(i)) + Ok(Access::Select(box self.fold_uint_expression(i)?)) } - a => a, + a => Ok(a), }) - .collect(); + .collect::>()?; let mut range_checks = HashSet::new(); let e = Self::choose_many(base, indices, expr, &mut range_checks); - range_checks + Ok(range_checks .into_iter() .chain(std::iter::once(TypedStatement::definition( TypedAssignee::Identifier(variable), e, ))) - .collect() + .collect()) } } s => fold_statement(self, s), diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_analysis/src/zir_propagation.rs similarity index 80% rename from zokrates_core/src/static_analysis/zir_propagation.rs rename to zokrates_analysis/src/zir_propagation.rs index 403555584..4a08a7144 100644 --- a/zokrates_core/src/static_analysis/zir_propagation.rs +++ b/zokrates_analysis/src/zir_propagation.rs @@ -1,9 +1,13 @@ +use num::traits::Pow; +use num_bigint::BigUint; use std::collections::HashMap; use std::fmt; +use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub}; use zokrates_ast::zir::types::UBitwidth; use zokrates_ast::zir::{ - result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id, - IdentifierExpression, IdentifierOrExpression, SelectExpression, SelectOrExpression, + result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Constant, Expr, + Id, IdentifierExpression, IdentifierOrExpression, SelectExpression, SelectOrExpression, + ZirAssemblyStatement, }; use zokrates_ast::zir::{ BooleanExpression, FieldElementExpression, Identifier, RuntimeError, UExpression, @@ -31,7 +35,7 @@ impl fmt::Display for Error { Error::DivisionByZero => { write!(f, "Division by zero detected in zir during static analysis",) } - Error::AssertionFailed(err) => write!(f, "{}", err), + Error::AssertionFailed(err) => write!(f, "Assertion failed ({})", err), } } } @@ -42,6 +46,9 @@ pub struct ZirPropagator<'ast, T> { } impl<'ast, T: Field> ZirPropagator<'ast, T> { + pub fn with_constants(constants: Constants<'ast, T>) -> Self { + Self { constants } + } pub fn propagate(p: ZirProgram) -> Result, Error> { ZirPropagator::default().fold_program(p) } @@ -50,6 +57,68 @@ impl<'ast, T: Field> ZirPropagator<'ast, T> { impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { type Error = Error; + fn fold_assembly_statement( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + match s { + ZirAssemblyStatement::Assignment(assignees, function) => { + let assignees: Vec<_> = assignees + .into_iter() + .map(|a| self.fold_assignee(a)) + .collect::>()?; + + let function = self.fold_function(function)?; + + match &function.statements.last().unwrap() { + ZirStatement::Return(values) => { + if values.iter().all(|v| v.is_constant()) { + self.constants.extend( + assignees + .into_iter() + .zip(values.iter()) + .map(|(a, v)| (a.id, v.clone())), + ); + Ok(vec![]) + } else { + assignees.iter().for_each(|a| { + self.constants.remove(&a.id); + }); + Ok(vec![ZirAssemblyStatement::Assignment(assignees, function)]) + } + } + _ => { + assignees.iter().for_each(|a| { + self.constants.remove(&a.id); + }); + Ok(vec![ZirAssemblyStatement::Assignment(assignees, function)]) + } + } + } + ZirAssemblyStatement::Constraint(left, right, metadata) => { + let left = self.fold_field_expression(left)?; + let right = self.fold_field_expression(right)?; + + // a bit hacky, but we use a fake boolean expression to check this + let is_equal = BooleanExpression::FieldEq(box left.clone(), box right.clone()); + let is_equal = self.fold_boolean_expression(is_equal)?; + + match is_equal { + BooleanExpression::Value(true) => Ok(vec![]), + BooleanExpression::Value(false) => { + Err(Error::AssertionFailed(RuntimeError::SourceAssertion( + metadata + .message(Some(format!("In asm block: `{} !== {}`", left, right))), + ))) + } + _ => Ok(vec![ZirAssemblyStatement::Constraint( + left, right, metadata, + )]), + } + } + } + } + fn fold_statement( &mut self, s: ZirStatement<'ast, T>, @@ -122,6 +191,19 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { self.fold_expression_list(list)?, )]) } + ZirStatement::Assembly(statements) => { + let statements: Vec<_> = statements + .into_iter() + .map(|s| self.fold_assembly_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + match statements.len() { + 0 => Ok(vec![]), + _ => Ok(vec![ZirStatement::Assembly(statements)]), + } + } _ => fold_statement(self, s), } } @@ -226,6 +308,127 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { )), } } + FieldElementExpression::Xor(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; + + match (e1, e2) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + Ok(FieldElementExpression::Number( + T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(), + )) + } + (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))), + (e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)), + } + } + FieldElementExpression::And(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; + + match (e1, e2) { + (_, FieldElementExpression::Number(n)) + | (FieldElementExpression::Number(n), _) + if n == T::from(0) => + { + Ok(FieldElementExpression::Number(n)) + } + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + Ok(FieldElementExpression::Number( + T::try_from(n1.to_biguint().bitand(n2.to_biguint())).unwrap(), + )) + } + (e1, e2) => Ok(FieldElementExpression::And(box e1, box e2)), + } + } + FieldElementExpression::Or(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; + + match (e1, e2) { + (e, FieldElementExpression::Number(n)) + | (FieldElementExpression::Number(n), e) + if n == T::from(0) => + { + Ok(e) + } + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + Ok(FieldElementExpression::Number( + T::try_from(n1.to_biguint().bitor(n2.to_biguint())).unwrap(), + )) + } + (e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)), + } + } + FieldElementExpression::LeftShift(box e, box by) => { + let e = self.fold_field_expression(e)?; + let by = self.fold_uint_expression(by)?; + match (e, by) { + ( + e, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by == 0 => Ok(e), + ( + _, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::Number(T::from(0))) + } + ( + FieldElementExpression::Number(n), + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) => { + let two = BigUint::from(2usize); + let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize); + + Ok(FieldElementExpression::Number( + T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(), + )) + } + (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), + } + } + FieldElementExpression::RightShift(box e, box by) => { + let e = self.fold_field_expression(e)?; + let by = self.fold_uint_expression(by)?; + match (e, by) { + ( + e, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by == 0 => Ok(e), + ( + _, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::Number(T::from(0))) + } + ( + FieldElementExpression::Number(n), + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) => Ok(FieldElementExpression::Number( + T::try_from(n.to_biguint().shr(by as usize)).unwrap(), + )), + (e, by) => Ok(FieldElementExpression::RightShift(box e, box by)), + } + } e => fold_field_expression(self, e), } } @@ -587,22 +790,28 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { e: ConditionalExpression<'ast, T, E>, ) -> Result, Self::Error> { let condition = self.fold_boolean_expression(*e.condition)?; - let consequence = e.consequence.fold(self)?; - let alternative = e.alternative.fold(self)?; - match (condition, consequence, alternative) { - (_, consequence, alternative) if consequence == alternative => Ok( - ConditionalOrExpression::Expression(consequence.into_inner()), - ), - (BooleanExpression::Value(true), consequence, _) => Ok( - ConditionalOrExpression::Expression(consequence.into_inner()), - ), - (BooleanExpression::Value(false), _, alternative) => Ok( - ConditionalOrExpression::Expression(alternative.into_inner()), - ), - (condition, consequence, alternative) => Ok(ConditionalOrExpression::Conditional( - ConditionalExpression::new(condition, consequence, alternative), + match condition { + BooleanExpression::Value(true) => Ok(ConditionalOrExpression::Expression( + e.consequence.fold(self)?.into_inner(), + )), + BooleanExpression::Value(false) => Ok(ConditionalOrExpression::Expression( + e.alternative.fold(self)?.into_inner(), )), + condition => { + let consequence = e.consequence.fold(self)?; + let alternative = e.alternative.fold(self)?; + + if consequence == alternative { + Ok(ConditionalOrExpression::Expression( + consequence.into_inner(), + )) + } else { + Ok(ConditionalOrExpression::Conditional( + ConditionalExpression::new(condition, consequence, alternative), + )) + } + } } } } @@ -821,6 +1030,115 @@ mod tests { ); } + #[test] + fn left_shift() { + let mut propagator = ZirPropagator::::default(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::identifier("a".into()), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::identifier("a".into())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(2)), + box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(8))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(1)), + box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(3)), + box UExpressionInner::Value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(1)), + box UExpressionInner::Value((Bn128Field::get_required_bits()) as u128) + .annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + } + + #[test] + fn right_shift() { + let mut propagator = ZirPropagator::::default(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::identifier("a".into()), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::identifier("a".into())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::identifier("a".into()), + box UExpressionInner::Value(Bn128Field::get_required_bits() as u128) + .annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::from(3)), + box UExpressionInner::Value(1 as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(1))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::from(2)), + box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::from(2)), + box UExpressionInner::Value(4 as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::max_value()), + box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128) + .annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(1))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::max_value()), + box UExpressionInner::Value(Bn128Field::get_required_bits() as u128) + .annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + } + #[test] fn if_else() { let mut propagator = ZirPropagator::default(); diff --git a/zokrates_ark/src/gm17.rs b/zokrates_ark/src/gm17.rs index 209a1abdd..0e119eee8 100644 --- a/zokrates_ark/src/gm17.rs +++ b/zokrates_ark/src/gm17.rs @@ -16,8 +16,8 @@ use zokrates_proof_systems::Scheme; use zokrates_proof_systems::{Backend, NonUniversalBackend, Proof, SetupKeypair}; impl NonUniversalBackend for Ark { - fn setup>>( - program: ProgIterator, + fn setup<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, ) -> SetupKeypair { let computation = Computation::without_witness(program); @@ -41,8 +41,8 @@ impl NonUniversalBackend for Ark { } impl Backend for Ark { - fn generate_proof>>( - program: ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, witness: Witness, proving_key: Vec, ) -> Proof { diff --git a/zokrates_ark/src/groth16.rs b/zokrates_ark/src/groth16.rs index 617d34ede..0de188453 100644 --- a/zokrates_ark/src/groth16.rs +++ b/zokrates_ark/src/groth16.rs @@ -19,8 +19,8 @@ use zokrates_proof_systems::Scheme; const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications."; impl Backend for Ark { - fn generate_proof>>( - program: ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, witness: Witness, proving_key: Vec, ) -> Proof { @@ -86,8 +86,8 @@ impl Backend for Ark { } impl NonUniversalBackend for Ark { - fn setup>>( - program: ProgIterator, + fn setup<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, ) -> SetupKeypair { println!("{}", G16_WARNING); diff --git a/zokrates_ark/src/lib.rs b/zokrates_ark/src/lib.rs index f5c3b320b..425be3a8d 100644 --- a/zokrates_ark/src/lib.rs +++ b/zokrates_ark/src/lib.rs @@ -17,20 +17,20 @@ pub use self::parse::*; pub struct Ark; #[derive(Clone)] -pub struct Computation>> { - program: ProgIterator, +pub struct Computation<'a, T, I: IntoIterator>> { + program: ProgIterator<'a, T, I>, witness: Option>, } -impl>> Computation { - pub fn with_witness(program: ProgIterator, witness: Witness) -> Self { +impl<'a, T, I: IntoIterator>> Computation<'a, T, I> { + pub fn with_witness(program: ProgIterator<'a, T, I>, witness: Witness) -> Self { Computation { program, witness: Some(witness), } } - pub fn without_witness(program: ProgIterator) -> Self { + pub fn without_witness(program: ProgIterator<'a, T, I>) -> Self { Computation { program, witness: None, @@ -72,9 +72,9 @@ fn ark_combination( .fold(LinearCombination::zero(), |acc, e| acc + e) } -impl>> +impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator>> ConstraintSynthesizer<<::ArkEngine as PairingEngine>::Fr> - for Computation + for Computation<'a, T, I> { fn generate_constraints( self, @@ -143,7 +143,9 @@ impl>> } } -impl>> Computation { +impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator>> + Computation<'a, T, I> +{ pub fn public_inputs_values(&self) -> Vec<::Fr> { self.program .public_inputs_values(self.witness.as_ref().unwrap()) diff --git a/zokrates_ark/src/marlin.rs b/zokrates_ark/src/marlin.rs index cc85f6a24..24204a6cd 100644 --- a/zokrates_ark/src/marlin.rs +++ b/zokrates_ark/src/marlin.rs @@ -134,9 +134,9 @@ impl UniversalBackend for Ark res } - fn setup>>( + fn setup<'a, I: IntoIterator>>( srs: Vec, - program: ProgIterator, + program: ProgIterator<'a, T, I>, ) -> Result, String> { let program = program.collect(); @@ -210,8 +210,8 @@ impl UniversalBackend for Ark } impl Backend for Ark { - fn generate_proof>>( - program: ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, witness: Witness, proving_key: Vec, ) -> Proof { diff --git a/zokrates_ast/src/common/embed.rs b/zokrates_ast/src/common/embed.rs index 4133c5c8a..58294142f 100644 --- a/zokrates_ast/src/common/embed.rs +++ b/zokrates_ast/src/common/embed.rs @@ -9,6 +9,7 @@ use crate::untyped::{ types::{UnresolvedSignature, UnresolvedType}, ConstantGenericNode, Expression, }; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use zokrates_field::Field; @@ -28,8 +29,9 @@ cfg_if::cfg_if! { /// A low level function that contains non-deterministic introduction of variables. It is carried out as is until /// the flattening step when it can be inlined. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Serialize, Deserialize)] pub enum FlatEmbed { + FieldToBoolUnsafe, BitArrayLe, Unpack, U8ToBits, @@ -49,6 +51,9 @@ pub enum FlatEmbed { impl FlatEmbed { pub fn signature(&self) -> UnresolvedSignature { match self { + FlatEmbed::FieldToBoolUnsafe => UnresolvedSignature::new() + .inputs(vec![UnresolvedType::FieldElement.into()]) + .output(UnresolvedType::Boolean.into()), FlatEmbed::BitArrayLe => UnresolvedSignature::new() .generics(vec![ConstantGenericNode::mock("N")]) .inputs(vec![ @@ -185,6 +190,9 @@ impl FlatEmbed { pub fn typed_signature(&self) -> DeclarationSignature<'static, T> { match self { + FlatEmbed::FieldToBoolUnsafe => DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .output(DeclarationType::Boolean), FlatEmbed::BitArrayLe => DeclarationSignature::new() .generics(vec![Some(DeclarationConstant::Generic( GenericIdentifier::with_name("N").with_index(0), @@ -291,6 +299,7 @@ impl FlatEmbed { pub fn id(&self) -> &'static str { match self { + FlatEmbed::FieldToBoolUnsafe => "_FIELD_TO_BOOL_UNSAFE", FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT", FlatEmbed::Unpack => "_UNPACK", FlatEmbed::U8ToBits => "_U8_TO_BITS", @@ -317,8 +326,8 @@ impl FlatEmbed { /// - constraint system variables /// - arguments #[cfg(feature = "bellman")] -pub fn sha256_round( -) -> FlatFunctionIterator>> { +pub fn sha256_round<'ast, T: Field>( +) -> FlatFunctionIterator<'ast, T, impl IntoIterator>> { use zokrates_field::Bn128Field; assert_eq!(T::id(), Bn128Field::id()); @@ -420,9 +429,9 @@ pub fn sha256_round( } #[cfg(feature = "ark")] -pub fn snark_verify_bls12_377( +pub fn snark_verify_bls12_377<'ast, T: Field>( n: usize, -) -> FlatFunctionIterator>> { +) -> FlatFunctionIterator<'ast, T, impl IntoIterator>> { use zokrates_field::Bw6_761Field; assert_eq!(T::id(), Bw6_761Field::id()); @@ -546,9 +555,9 @@ fn use_variable( /// # Remarks /// * the return value of the `FlatFunction` is not deterministic if `bit_width >= T::get_required_bits()` /// as some elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)` -pub fn unpack_to_bitwidth( +pub fn unpack_to_bitwidth<'ast, T: Field>( bit_width: usize, -) -> FlatFunctionIterator>> { +) -> FlatFunctionIterator<'ast, T, impl IntoIterator>> { let mut counter = 0; let mut layout = HashMap::new(); diff --git a/zokrates_ast/src/common/error.rs b/zokrates_ast/src/common/error.rs index 45ef04223..82a201881 100644 --- a/zokrates_ast/src/common/error.rs +++ b/zokrates_ast/src/common/error.rs @@ -1,5 +1,7 @@ +use crate::common::SourceMetadata; use serde::{Deserialize, Serialize}; use std::fmt; +use std::fmt::Write; #[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)] pub enum RuntimeError { @@ -25,7 +27,8 @@ pub enum RuntimeError { Euclidean, ShaXor, Division, - SourceAssertion(String), + SourceAssertion(SourceMetadata), + SourceAssemblyConstraint(SourceMetadata), ArgumentBitness, SelectRangeCheck, } @@ -33,7 +36,9 @@ pub enum RuntimeError { impl From for RuntimeError { fn from(error: crate::zir::RuntimeError) -> Self { match error { - crate::zir::RuntimeError::SourceAssertion(s) => RuntimeError::SourceAssertion(s), + crate::zir::RuntimeError::SourceAssertion(metadata) => { + RuntimeError::SourceAssertion(metadata) + } crate::zir::RuntimeError::SelectRangeCheck => RuntimeError::SelectRangeCheck, crate::zir::RuntimeError::DivisionByZero => RuntimeError::Inverse, crate::zir::RuntimeError::IncompleteDynamicRange => { @@ -49,7 +54,8 @@ impl RuntimeError { !matches!( self, - SourceAssertion(_) + SourceAssemblyConstraint(_) + | SourceAssertion(_) | Inverse | SelectRangeCheck | ArgumentBitness @@ -62,6 +68,7 @@ impl fmt::Display for RuntimeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use RuntimeError::*; + let mut buf = String::new(); let msg = match self { BellmanConstraint => "Bellman constraint is unsatisfied", BellmanOneBinding => "Bellman ~one binding is unsatisfied", @@ -87,7 +94,14 @@ impl fmt::Display for RuntimeError { Euclidean => "Euclidean check failed", ShaXor => "Internal Sha check failed", Division => "Division check failed", - SourceAssertion(m) => m.as_str(), + SourceAssertion(m) => { + write!(&mut buf, "Assertion failed at {}", m).unwrap(); + buf.as_str() + } + SourceAssemblyConstraint(m) => { + write!(&mut buf, "Unsatisfied constraint at {}", m).unwrap(); + buf.as_str() + } ArgumentBitness => "Argument bitness check failed", SelectRangeCheck => "Out of bounds array access", }; diff --git a/zokrates_ast/src/common/metadata.rs b/zokrates_ast/src/common/metadata.rs new file mode 100644 index 000000000..efe9d235e --- /dev/null +++ b/zokrates_ast/src/common/metadata.rs @@ -0,0 +1,34 @@ +use crate::untyped::Position; +use serde::{Deserialize, Serialize}; +use std::fmt; + +#[derive(Clone, Debug, PartialEq, Hash, Eq, Default, PartialOrd, Ord, Serialize, Deserialize)] +pub struct SourceMetadata { + pub file: String, + pub position: Position, + pub message: Option, +} + +impl SourceMetadata { + pub fn new(file: String, position: Position) -> Self { + Self { + file, + position, + message: None, + } + } + pub fn message(mut self, message: Option) -> Self { + self.message = message; + self + } +} + +impl fmt::Display for SourceMetadata { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}:{}", self.file, self.position)?; + match &self.message { + Some(m) => write!(f, ": \"{}\"", m), + None => write!(f, ""), + } + } +} diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index 95a3245fd..13d23bfdb 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -1,12 +1,14 @@ pub mod embed; mod error; mod format_string; +mod metadata; mod parameter; mod solvers; mod variable; pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; +pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; pub use self::solvers::Solver; pub use self::variable::Variable; diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index d8387f26d..9b4f5c900 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -1,8 +1,9 @@ +use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] -pub enum Solver { +pub enum Solver<'ast, T> { ConditionEq, Bits(usize), Div, @@ -11,19 +12,35 @@ pub enum Solver { ShaAndXorAndXorAnd, ShaCh, EuclideanDiv, + #[serde(borrow)] + Zir(ZirFunction<'ast, T>), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] SnarkVerifyBls12377(usize), } -impl fmt::Display for Solver { +impl<'ast, T> fmt::Display for Solver<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?}", self) + match self { + Solver::ConditionEq => write!(f, "ConditionEq"), + Solver::Bits(n) => write!(f, "Bits({})", n), + Solver::Div => write!(f, "Div"), + Solver::Xor => write!(f, "Xor"), + Solver::Or => write!(f, "Or"), + Solver::ShaAndXorAndXorAnd => write!(f, "ShaAndXorAndXorAnd"), + Solver::ShaCh => write!(f, "ShaCh"), + Solver::EuclideanDiv => write!(f, "EuclideanDiv"), + Solver::Zir(_) => write!(f, "Zir(..)"), + #[cfg(feature = "bellman")] + Solver::Sha256Round => write!(f, "Sha256Round"), + #[cfg(feature = "ark")] + Solver::SnarkVerifyBls12377(n) => write!(f, "SnarkVerifyBls12377({})", n), + } } } -impl Solver { +impl<'ast, T> Solver<'ast, T> { pub fn get_signature(&self) -> (usize, usize) { match self { Solver::ConditionEq => (1, 2), @@ -34,6 +51,7 @@ impl Solver { Solver::ShaAndXorAndXorAnd => (3, 1), Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), + Solver::Zir(f) => (f.arguments.len(), 1), #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] @@ -42,7 +60,7 @@ impl Solver { } } -impl Solver { +impl<'ast, T> Solver<'ast, T> { pub fn bits(width: usize) -> Self { Solver::Bits(width) } diff --git a/zokrates_ast/src/flat/folder.rs b/zokrates_ast/src/flat/folder.rs index 4c9baeb00..ce50d7dac 100644 --- a/zokrates_ast/src/flat/folder.rs +++ b/zokrates_ast/src/flat/folder.rs @@ -4,8 +4,8 @@ use super::*; use crate::common::Variable; use zokrates_field::Field; -pub trait Folder: Sized { - fn fold_program(&mut self, p: FlatProg) -> FlatProg { +pub trait Folder<'ast, T: Field>: Sized { + fn fold_program(&mut self, p: FlatProg<'ast, T>) -> FlatProg<'ast, T> { fold_program(self, p) } @@ -17,7 +17,7 @@ pub trait Folder: Sized { fold_variable(self, v) } - fn fold_statement(&mut self, s: FlatStatement) -> Vec> { + fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec> { fold_statement(self, s) } @@ -25,12 +25,15 @@ pub trait Folder: Sized { fold_expression(self, e) } - fn fold_directive(&mut self, d: FlatDirective) -> FlatDirective { + fn fold_directive(&mut self, d: FlatDirective<'ast, T>) -> FlatDirective<'ast, T> { fold_directive(self, d) } } -pub fn fold_program>(f: &mut F, p: FlatProg) -> FlatProg { +pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + p: FlatProg<'ast, T>, +) -> FlatProg<'ast, T> { FlatProg { arguments: p .arguments @@ -46,11 +49,17 @@ pub fn fold_program>(f: &mut F, p: FlatProg) -> FlatPr } } -pub fn fold_statement>( +pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - s: FlatStatement, -) -> Vec> { + s: FlatStatement<'ast, T>, +) -> Vec> { match s { + FlatStatement::Block(statements) => vec![FlatStatement::Block( + statements + .into_iter() + .flat_map(|s| f.fold_statement(s)) + .collect(), + )], FlatStatement::Condition(left, right, error) => vec![FlatStatement::Condition( f.fold_expression(left), f.fold_expression(right), @@ -70,7 +79,7 @@ pub fn fold_statement>( } } -pub fn fold_expression>( +pub fn fold_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: FlatExpression, ) -> FlatExpression { @@ -89,7 +98,10 @@ pub fn fold_expression>( } } -pub fn fold_directive>(f: &mut F, ds: FlatDirective) -> FlatDirective { +pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ds: FlatDirective<'ast, T>, +) -> FlatDirective<'ast, T> { FlatDirective { inputs: ds .inputs @@ -101,13 +113,13 @@ pub fn fold_directive>(f: &mut F, ds: FlatDirective) - } } -pub fn fold_argument>(f: &mut F, a: Parameter) -> Parameter { +pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter { Parameter { id: f.fold_variable(a.id), private: a.private, } } -pub fn fold_variable>(_f: &mut F, v: Variable) -> Variable { +pub fn fold_variable<'ast, T: Field, F: Folder<'ast, T>>(_f: &mut F, v: Variable) -> Variable { v } diff --git a/zokrates_ast/src/flat/mod.rs b/zokrates_ast/src/flat/mod.rs index 903fff2c4..015e670a5 100644 --- a/zokrates_ast/src/flat/mod.rs +++ b/zokrates_ast/src/flat/mod.rs @@ -24,14 +24,14 @@ use std::collections::HashMap; use std::fmt; use zokrates_field::Field; -pub type FlatProg = FlatFunction; +pub type FlatProg<'ast, T> = FlatFunction<'ast, T>; -pub type FlatFunction = FlatFunctionIterator>>; +pub type FlatFunction<'ast, T> = FlatFunctionIterator<'ast, T, Vec>>; -pub type FlatProgIterator = FlatFunctionIterator; +pub type FlatProgIterator<'ast, T, I> = FlatFunctionIterator<'ast, T, I>; #[derive(Clone, PartialEq, Eq, Debug)] -pub struct FlatFunctionIterator>> { +pub struct FlatFunctionIterator<'ast, T, I: IntoIterator>> { /// Arguments of the function pub arguments: Vec, /// Vector of statements that are executed when running the function @@ -40,8 +40,8 @@ pub struct FlatFunctionIterator>> { pub return_count: usize, } -impl>> FlatFunctionIterator { - pub fn collect(self) -> FlatFunction { +impl<'ast, T, I: IntoIterator>> FlatFunctionIterator<'ast, T, I> { + pub fn collect(self) -> FlatFunction<'ast, T> { FlatFunction { statements: self.statements.into_iter().collect(), arguments: self.arguments, @@ -50,7 +50,7 @@ impl>> FlatFunctionIterator { } } -impl fmt::Display for FlatFunction { +impl<'ast, T: Field> fmt::Display for FlatFunction<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -81,16 +81,24 @@ impl fmt::Display for FlatFunction { /// * r1cs - R1CS in standard JSON data format #[derive(Clone, PartialEq, Eq, Debug)] -pub enum FlatStatement { +pub enum FlatStatement<'ast, T> { + Block(Vec>), Condition(FlatExpression, FlatExpression, RuntimeError), Definition(Variable, FlatExpression), - Directive(FlatDirective), + Directive(FlatDirective<'ast, T>), Log(FormatString, Vec<(ConcreteType, Vec>)>), } -impl fmt::Display for FlatStatement { +impl<'ast, T: Field> fmt::Display for FlatStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { + FlatStatement::Block(ref statements) => { + writeln!(f, "{{")?; + for s in statements { + writeln!(f, "{}", s)?; + } + writeln!(f, "}}") + } FlatStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs), FlatStatement::Condition(ref lhs, ref rhs, ref message) => { write!(f, "{} == {} // {}", lhs, rhs, message) @@ -116,12 +124,18 @@ impl fmt::Display for FlatStatement { } } -impl FlatStatement { +impl<'ast, T: Field> FlatStatement<'ast, T> { pub fn apply_substitution( self, - substitution: &HashMap, + substitution: &'ast HashMap, ) -> FlatStatement { match self { + FlatStatement::Block(statements) => FlatStatement::Block( + statements + .into_iter() + .map(|s| s.apply_substitution(substitution)) + .collect(), + ), FlatStatement::Definition(id, x) => FlatStatement::Definition( *id.apply_substitution(substitution), x.apply_substitution(substitution), @@ -167,16 +181,16 @@ impl FlatStatement { } #[derive(Clone, Hash, Debug, PartialEq, Eq)] -pub struct FlatDirective { +pub struct FlatDirective<'ast, T> { pub inputs: Vec>, pub outputs: Vec, - pub solver: Solver, + pub solver: Solver<'ast, T>, } -impl FlatDirective { +impl<'ast, T> FlatDirective<'ast, T> { pub fn new>>( outputs: Vec, - solver: Solver, + solver: Solver<'ast, T>, inputs: Vec, ) -> Self { let (in_len, out_len) = solver.get_signature(); @@ -190,7 +204,7 @@ impl FlatDirective { } } -impl fmt::Display for FlatDirective { +impl<'ast, T: Field> fmt::Display for FlatDirective<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, diff --git a/zokrates_ast/src/ir/check.rs b/zokrates_ast/src/ir/check.rs index 11c5fd841..41cac7b0d 100644 --- a/zokrates_ast/src/ir/check.rs +++ b/zokrates_ast/src/ir/check.rs @@ -13,7 +13,9 @@ pub struct UnconstrainedVariableDetector { } impl UnconstrainedVariableDetector { - pub fn new>>(p: &ProgIterator) -> Self { + pub fn new<'ast, T: Field, I: IntoIterator>>( + p: &ProgIterator<'ast, T, I>, + ) -> Self { UnconstrainedVariableDetector { variables: p .arguments @@ -32,7 +34,7 @@ impl UnconstrainedVariableDetector { } } -impl Folder for UnconstrainedVariableDetector { +impl<'ast, T: Field> Folder<'ast, T> for UnconstrainedVariableDetector { fn fold_argument(&mut self, p: Parameter) -> Parameter { p } @@ -40,7 +42,7 @@ impl Folder for UnconstrainedVariableDetector { self.variables.remove(&v); v } - fn fold_directive(&mut self, d: Directive) -> Directive { + fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { self.variables.extend(d.outputs.iter()); d } diff --git a/zokrates_ast/src/ir/clean.rs b/zokrates_ast/src/ir/clean.rs new file mode 100644 index 000000000..b4fb8f445 --- /dev/null +++ b/zokrates_ast/src/ir/clean.rs @@ -0,0 +1,31 @@ +use super::folder::Folder; +use super::{ProgIterator, Statement}; +use zokrates_field::Field; + +#[derive(Default)] +pub struct Cleaner; + +impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { + pub fn clean(self) -> ProgIterator<'ast, T, impl IntoIterator>> { + ProgIterator { + arguments: self.arguments, + return_count: self.return_count, + statements: self + .statements + .into_iter() + .flat_map(|s| Cleaner::default().fold_statement(s)), + } + } +} + +impl<'ast, T: Field> Folder<'ast, T> for Cleaner { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { + match s { + Statement::Block(statements) => statements + .into_iter() + .flat_map(|s| self.fold_statement(s)) + .collect(), + s => vec![s], + } + } +} diff --git a/zokrates_ast/src/ir/folder.rs b/zokrates_ast/src/ir/folder.rs index 22245252b..6e67c15de 100644 --- a/zokrates_ast/src/ir/folder.rs +++ b/zokrates_ast/src/ir/folder.rs @@ -4,8 +4,8 @@ use super::*; use crate::common::Variable; use zokrates_field::Field; -pub trait Folder: Sized { - fn fold_program(&mut self, p: Prog) -> Prog { +pub trait Folder<'ast, T: Field>: Sized { + fn fold_program(&mut self, p: Prog<'ast, T>) -> Prog<'ast, T> { fold_program(self, p) } @@ -17,7 +17,7 @@ pub trait Folder: Sized { fold_variable(self, v) } - fn fold_statement(&mut self, s: Statement) -> Vec> { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { fold_statement(self, s) } @@ -29,12 +29,15 @@ pub trait Folder: Sized { fold_quadratic_combination(self, es) } - fn fold_directive(&mut self, d: Directive) -> Directive { + fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { fold_directive(self, d) } } -pub fn fold_program>(f: &mut F, p: Prog) -> Prog { +pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + p: Prog<'ast, T>, +) -> Prog<'ast, T> { Prog { arguments: p .arguments @@ -50,8 +53,17 @@ pub fn fold_program>(f: &mut F, p: Prog) -> Prog { } } -pub fn fold_statement>(f: &mut F, s: Statement) -> Vec> { +pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: Statement<'ast, T>, +) -> Vec> { match s { + Statement::Block(statements) => vec![Statement::Block( + statements + .into_iter() + .flat_map(|s| f.fold_statement(s)) + .collect(), + )], Statement::Constraint(quad, lin, message) => vec![Statement::Constraint( f.fold_quadratic_combination(quad), f.fold_linear_combination(lin), @@ -74,7 +86,10 @@ pub fn fold_statement>(f: &mut F, s: Statement) -> Vec } } -pub fn fold_linear_combination>(f: &mut F, e: LinComb) -> LinComb { +pub fn fold_linear_combination<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: LinComb, +) -> LinComb { LinComb( e.0.into_iter() .map(|(variable, coefficient)| (f.fold_variable(variable), coefficient)) @@ -82,7 +97,7 @@ pub fn fold_linear_combination>(f: &mut F, e: LinComb) ) } -pub fn fold_quadratic_combination>( +pub fn fold_quadratic_combination<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: QuadComb, ) -> QuadComb { @@ -92,7 +107,10 @@ pub fn fold_quadratic_combination>( } } -pub fn fold_directive>(f: &mut F, ds: Directive) -> Directive { +pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ds: Directive<'ast, T>, +) -> Directive<'ast, T> { Directive { inputs: ds .inputs @@ -104,13 +122,13 @@ pub fn fold_directive>(f: &mut F, ds: Directive) -> Di } } -pub fn fold_argument>(f: &mut F, a: Parameter) -> Parameter { +pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter { Parameter { id: f.fold_variable(a.id), private: a.private, } } -pub fn fold_variable>(_f: &mut F, v: Variable) -> Variable { +pub fn fold_variable<'ast, T: Field, F: Folder<'ast, T>>(_f: &mut F, v: Variable) -> Variable { v } diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index e5e3cb733..fc961cd85 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -17,9 +17,9 @@ impl QuadComb { } } -pub fn from_flat>>( - flat_prog_iterator: FlatProgIterator, -) -> ProgIterator>> { +pub fn from_flat<'ast, T: Field, I: IntoIterator>>( + flat_prog_iterator: FlatProgIterator<'ast, T, I>, +) -> ProgIterator>> { ProgIterator { statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, @@ -52,9 +52,12 @@ impl From> for LinComb { } } -impl From> for Statement { - fn from(flat_statement: FlatStatement) -> Statement { +impl<'ast, T: Field> From> for Statement<'ast, T> { + fn from(flat_statement: FlatStatement<'ast, T>) -> Statement<'ast, T> { match flat_statement { + FlatStatement::Block(statements) => { + Statement::Block(statements.into_iter().map(Statement::from).collect()) + } FlatStatement::Condition(linear, quadratic, message) => match quadratic { FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint( QuadComb::from_linear_combinations(lhs.into(), rhs.into()), @@ -83,8 +86,8 @@ impl From> for Statement { } } -impl From> for Directive { - fn from(ds: FlatDirective) -> Directive { +impl<'ast, T: Field> From> for Directive<'ast, T> { + fn from(ds: FlatDirective<'ast, T>) -> Directive { Directive { inputs: ds .inputs diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 36d01d758..78b48f808 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -8,6 +8,7 @@ use std::hash::Hash; use zokrates_field::Field; mod check; +mod clean; mod expression; pub mod folder; pub mod from_flat; @@ -28,19 +29,22 @@ pub use self::witness::Witness; #[derive(Debug, Serialize, Deserialize, Clone, Derivative)] #[derivative(Hash, PartialEq, Eq)] -pub enum Statement { +pub enum Statement<'ast, T> { + #[serde(skip)] + Block(Vec>), Constraint( QuadComb, LinComb, #[derivative(Hash = "ignore")] Option, ), - Directive(Directive), + #[serde(borrow)] + Directive(Directive<'ast, T>), Log(FormatString, Vec<(ConcreteType, Vec>)>), } pub type PublicInputs = BTreeSet; -impl Statement { +impl<'ast, T: Field> Statement<'ast, T> { pub fn definition>>(v: Variable, e: U) -> Self { Statement::Constraint(e.into(), v.into(), None) } @@ -51,13 +55,14 @@ impl Statement { } #[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] -pub struct Directive { +pub struct Directive<'ast, T> { pub inputs: Vec>, pub outputs: Vec, - pub solver: Solver, + #[serde(borrow)] + pub solver: Solver<'ast, T>, } -impl fmt::Display for Directive { +impl<'ast, T: Field> fmt::Display for Directive<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -77,9 +82,16 @@ impl fmt::Display for Directive { } } -impl fmt::Display for Statement { +impl<'ast, T: Field> fmt::Display for Statement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { + Statement::Block(ref statements) => { + writeln!(f, "{{")?; + for s in statements { + writeln!(f, "{}", s)?; + } + write!(f, "}}") + } Statement::Constraint(ref quad, ref lin, ref error) => write!( f, "{} == {}{}", @@ -111,16 +123,16 @@ impl fmt::Display for Statement { } } -pub type Prog = ProgIterator>>; +pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec>>; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] -pub struct ProgIterator>> { +pub struct ProgIterator<'ast, T, I: IntoIterator>> { pub arguments: Vec, pub return_count: usize, pub statements: I, } -impl>> ProgIterator { +impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, I> { pub fn new(arguments: Vec, statements: I, return_count: usize) -> Self { Self { arguments, @@ -129,7 +141,7 @@ impl>> ProgIterator { } } - pub fn collect(self) -> ProgIterator>> { + pub fn collect(self) -> ProgIterator<'ast, T, Vec>> { ProgIterator { statements: self.statements.into_iter().collect::>(), arguments: self.arguments, @@ -154,7 +166,7 @@ impl>> ProgIterator { } } -impl>> ProgIterator { +impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { pub fn public_inputs_values(&self, witness: &Witness) -> Vec { self.arguments .iter() @@ -165,7 +177,7 @@ impl>> ProgIterator { } } -impl Prog { +impl<'ast, T> Prog<'ast, T> { pub fn constraint_count(&self) -> usize { self.statements .iter() @@ -173,7 +185,9 @@ impl Prog { .count() } - pub fn into_prog_iter(self) -> ProgIterator>> { + pub fn into_prog_iter( + self, + ) -> ProgIterator<'ast, T, impl IntoIterator>> { ProgIterator { statements: self.statements.into_iter(), arguments: self.arguments, @@ -182,7 +196,7 @@ impl Prog { } } -impl fmt::Display for Prog { +impl<'ast, T: Field> fmt::Display for Prog<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let returns = (0..self.return_count) .map(Variable::public) diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 39c746368..09d003900 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -12,32 +12,35 @@ const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2]; #[derive(PartialEq, Eq, Debug)] pub enum ProgEnum< - Bls12_381I: IntoIterator>, - Bn128I: IntoIterator>, - Bls12_377I: IntoIterator>, - Bw6_761I: IntoIterator>, + 'ast, + Bls12_381I: IntoIterator>, + Bn128I: IntoIterator>, + Bls12_377I: IntoIterator>, + Bw6_761I: IntoIterator>, > { - Bls12_381Program(ProgIterator), - Bn128Program(ProgIterator), - Bls12_377Program(ProgIterator), - Bw6_761Program(ProgIterator), + Bls12_381Program(ProgIterator<'ast, Bls12_381Field, Bls12_381I>), + Bn128Program(ProgIterator<'ast, Bn128Field, Bn128I>), + Bls12_377Program(ProgIterator<'ast, Bls12_377Field, Bls12_377I>), + Bw6_761Program(ProgIterator<'ast, Bw6_761Field, Bw6_761I>), } -type MemoryProgEnum = ProgEnum< - Vec>, - Vec>, - Vec>, - Vec>, +type MemoryProgEnum<'ast> = ProgEnum< + 'ast, + Vec>, + Vec>, + Vec>, + Vec>, >; impl< - Bls12_381I: IntoIterator>, - Bn128I: IntoIterator>, - Bls12_377I: IntoIterator>, - Bw6_761I: IntoIterator>, - > ProgEnum + 'ast, + Bls12_381I: IntoIterator>, + Bn128I: IntoIterator>, + Bls12_377I: IntoIterator>, + Bw6_761I: IntoIterator>, + > ProgEnum<'ast, Bls12_381I, Bn128I, Bls12_377I, Bw6_761I> { - pub fn collect(self) -> MemoryProgEnum { + pub fn collect(self) -> MemoryProgEnum<'ast> { match self { ProgEnum::Bls12_381Program(p) => ProgEnum::Bls12_381Program(p.collect()), ProgEnum::Bn128Program(p) => ProgEnum::Bn128Program(p.collect()), @@ -55,7 +58,7 @@ impl< } } -impl>> ProgIterator { +impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { /// serialize a program iterator, returning the number of constraints serialized /// Note that we only return constraints, not other statements such as directives pub fn serialize(self, mut w: W) -> Result { @@ -106,10 +109,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator impl<'de, R: Read> ProgEnum< - UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement>, - UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement>, - UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement>, - UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement>, + 'de, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bls12_381Field>>, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bn128Field>>, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bls12_377Field>>, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bw6_761Field>>, > { pub fn deserialize(mut r: R) -> Result { diff --git a/zokrates_ast/src/ir/smtlib2.rs b/zokrates_ast/src/ir/smtlib2.rs index 8bdd04d37..bc1188518 100644 --- a/zokrates_ast/src/ir/smtlib2.rs +++ b/zokrates_ast/src/ir/smtlib2.rs @@ -12,9 +12,9 @@ pub trait SMTLib2 { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result; } -pub struct SMTLib2Display<'a, T>(pub &'a Prog); +pub struct SMTLib2Display<'a, 'ast, T>(pub &'a Prog<'ast, T>); -impl fmt::Display for SMTLib2Display<'_, T> { +impl<'ast, T: Field> fmt::Display for SMTLib2Display<'_, 'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.0.to_smtlib2(f) } @@ -30,7 +30,7 @@ impl Visitor for VariableCollector { } } -impl SMTLib2 for Prog { +impl<'ast, T: Field> SMTLib2 for Prog<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut collector = VariableCollector { variables: BTreeSet::::new(), @@ -75,9 +75,10 @@ fn format_prefix_op_smtlib2( write!(f, ")") } -impl SMTLib2 for Statement { +impl<'ast, T: Field> SMTLib2 for Statement<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { + Statement::Block(..) => unreachable!(), Statement::Constraint(ref quad, ref lin, _) => { write!(f, "(= (mod ")?; quad.to_smtlib2(f)?; @@ -91,7 +92,7 @@ impl SMTLib2 for Statement { } } -impl SMTLib2 for Directive { +impl<'ast, T: Field> SMTLib2 for Directive<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "") } diff --git a/zokrates_ast/src/ir/visitor.rs b/zokrates_ast/src/ir/visitor.rs index 6d8ffc00f..d3894ca6b 100644 --- a/zokrates_ast/src/ir/visitor.rs +++ b/zokrates_ast/src/ir/visitor.rs @@ -53,6 +53,11 @@ pub fn visit_module>(f: &mut F, p: &Prog) { pub fn visit_statement>(f: &mut F, s: &Statement) { match s { + Statement::Block(statements) => { + for s in statements { + f.visit_statement(s); + } + } Statement::Constraint(quad, lin, error) => { f.visit_quadratic_combination(quad); f.visit_linear_combination(lin); diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index e687d33d1..d3e87fcd0 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -260,6 +260,13 @@ pub trait Folder<'ast, T: Field>: Sized { fold_assignee(self, a) } + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Vec> { + fold_assembly_statement(self, s) + } + fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { fold_statement(self, s) } @@ -515,6 +522,27 @@ pub fn fold_definition_rhs<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: TypedAssemblyStatement<'ast, T>, +) -> Vec> { + match s { + TypedAssemblyStatement::Assignment(a, e) => { + vec![TypedAssemblyStatement::Assignment( + f.fold_assignee(a), + f.fold_expression(e), + )] + } + TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { + vec![TypedAssemblyStatement::Constraint( + f.fold_field_expression(lhs), + f.fold_field_expression(rhs), + metadata, + )] + } + } +} + pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedStatement<'ast, T>, @@ -539,6 +567,12 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( TypedStatement::Log(s, e) => { TypedStatement::Log(s, e.into_iter().map(|e| f.fold_expression(e)).collect()) } + TypedStatement::Assembly(statements) => TypedStatement::Assembly( + statements + .into_iter() + .flat_map(|s| f.fold_assembly_statement(s)) + .collect(), + ), s => s, }; vec![res] @@ -761,6 +795,36 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( Pos(box e) } + And(box left, box right) => { + let left = f.fold_field_expression(left); + let right = f.fold_field_expression(right); + + And(box left, box right) + } + Or(box left, box right) => { + let left = f.fold_field_expression(left); + let right = f.fold_field_expression(right); + + Or(box left, box right) + } + Xor(box left, box right) => { + let left = f.fold_field_expression(left); + let right = f.fold_field_expression(right); + + Xor(box left, box right) + } + LeftShift(box e, box by) => { + let e = f.fold_field_expression(e); + let by = f.fold_uint_expression(by); + + LeftShift(box e, box by) + } + RightShift(box e, box by) => { + let e = f.fold_field_expression(e); + let by = f.fold_uint_expression(by); + + RightShift(box e, box by) + } Conditional(c) => match f.fold_conditional_expression(&Type::FieldElement, c) { ConditionalOrExpression::Conditional(s) => Conditional(s), ConditionalOrExpression::Expression(u) => u, diff --git a/zokrates_ast/src/typed/identifier.rs b/zokrates_ast/src/typed/identifier.rs index abcd2f400..772eb2bf2 100644 --- a/zokrates_ast/src/typed/identifier.rs +++ b/zokrates_ast/src/typed/identifier.rs @@ -1,10 +1,12 @@ use crate::typed::CanonicalConstantIdentifier; +use serde::{Deserialize, Serialize}; use std::fmt; -pub type SourceIdentifier<'ast> = &'ast str; +pub type SourceIdentifier<'ast> = std::borrow::Cow<'ast, str>; -#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum CoreIdentifier<'ast> { + #[serde(borrow)] Source(ShadowedIdentifier<'ast>), Call(usize), Constant(CanonicalConstantIdentifier<'ast>), @@ -29,16 +31,18 @@ impl<'ast> From> for CoreIdentifier<'ast> { } /// A identifier for a variable -#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Identifier<'ast> { /// the id of the variable + #[serde(borrow)] pub id: CoreIdentifier<'ast>, /// the version of the variable, used after SSA transformation pub version: usize, } -#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct ShadowedIdentifier<'ast> { + #[serde(borrow)] pub id: SourceIdentifier<'ast>, pub shadow: usize, } @@ -97,7 +101,7 @@ impl<'ast> Identifier<'ast> { // these two From implementations are only used in tests but somehow cfg(test) doesn't work impl<'ast> From<&'ast str> for CoreIdentifier<'ast> { fn from(s: &str) -> CoreIdentifier { - CoreIdentifier::Source(ShadowedIdentifier::shadow(s, 0)) + CoreIdentifier::Source(ShadowedIdentifier::shadow(std::borrow::Cow::Borrowed(s), 0)) } } diff --git a/zokrates_ast/src/typed/integer.rs b/zokrates_ast/src/typed/integer.rs index 4af5de121..507ce9174 100644 --- a/zokrates_ast/src/typed/integer.rs +++ b/zokrates_ast/src/typed/integer.rs @@ -446,6 +446,24 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { box Self::try_from_int(e1)?, box Self::try_from_int(e2)?, )), + IntExpression::And(box e1, box e2) => Ok(Self::And( + box Self::try_from_int(e1)?, + box Self::try_from_int(e2)?, + )), + IntExpression::Or(box e1, box e2) => Ok(Self::Or( + box Self::try_from_int(e1)?, + box Self::try_from_int(e2)?, + )), + IntExpression::Xor(box e1, box e2) => Ok(Self::Xor( + box Self::try_from_int(e1)?, + box Self::try_from_int(e2)?, + )), + IntExpression::LeftShift(box e1, box e2) => { + Ok(Self::LeftShift(box Self::try_from_int(e1)?, box e2)) + } + IntExpression::RightShift(box e1, box e2) => { + Ok(Self::RightShift(box Self::try_from_int(e1)?, box e2)) + } IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)), IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)), IntExpression::Conditional(c) => Ok(Self::Conditional(ConditionalExpression::new( @@ -843,11 +861,6 @@ mod tests { let should_error = vec![ BigUint::parse_bytes(b"99999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10).unwrap().into(), - IntExpression::xor(n.clone(), n.clone()), - IntExpression::or(n.clone(), n.clone()), - IntExpression::and(n.clone(), n.clone()), - IntExpression::left_shift(n.clone(), i.clone()), - IntExpression::right_shift(n.clone(), i.clone()), IntExpression::not(n.clone()), ]; diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index 9d89718e1..003e43c47 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -27,9 +27,7 @@ pub use self::types::{ UBitwidth, }; use self::types::{ConcreteArrayType, ConcreteStructType}; - use crate::typed::types::{ConcreteGenericsAssignment, IntoType}; -use crate::untyped::Position; pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable}; use std::marker::PhantomData; @@ -38,7 +36,7 @@ use std::path::{Path, PathBuf}; pub use crate::typed::integer::IntExpression; pub use crate::typed::uint::{bitwidth, UExpression, UExpressionInner, UMetadata}; -use crate::common::{FlatEmbed, FormatString}; +use crate::common::{FlatEmbed, FormatString, SourceMetadata}; use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; @@ -569,26 +567,9 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedAssignee<'ast, T> { } } -#[derive(Clone, Debug, PartialEq, Hash, Eq, Default, PartialOrd, Ord)] -pub struct AssertionMetadata { - pub file: String, - pub position: Position, - pub message: Option, -} - -impl fmt::Display for AssertionMetadata { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Assertion failed at {}:{}", self.file, self.position)?; - match &self.message { - Some(m) => write!(f, ": \"{}\"", m), - None => write!(f, ""), - } - } -} - #[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] pub enum RuntimeError { - SourceAssertion(AssertionMetadata), + SourceAssertion(SourceMetadata), SelectRangeCheck, DivisionByZero, } @@ -677,6 +658,29 @@ impl<'ast, T: fmt::Display> fmt::Display for DefinitionRhs<'ast, T> { } } +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub enum TypedAssemblyStatement<'ast, T> { + Assignment(TypedAssignee<'ast, T>, TypedExpression<'ast, T>), + Constraint( + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + SourceMetadata, + ), +} + +impl<'ast, T: fmt::Display> fmt::Display for TypedAssemblyStatement<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + TypedAssemblyStatement::Assignment(ref lhs, ref rhs) => { + write!(f, "{} <-- {};", lhs, rhs) + } + TypedAssemblyStatement::Constraint(ref lhs, ref rhs, _) => { + write!(f, "{} === {};", lhs, rhs) + } + } + } +} + /// A statement in a `TypedFunction` #[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] @@ -697,6 +701,7 @@ pub enum TypedStatement<'ast, T> { ConcreteGenericsAssignment<'ast>, ), PopCallLog, + Assembly(Vec>), } impl<'ast, T> TypedStatement<'ast, T> { @@ -721,6 +726,14 @@ impl<'ast, T: fmt::Display> TypedStatement<'ast, T> { } write!(f, "{}}}", "\t".repeat(depth)) } + TypedStatement::Assembly(statements) => { + write!(f, "{}", "\t".repeat(depth))?; + writeln!(f, "asm {{")?; + for s in statements { + writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; + } + write!(f, "{}}}", "\t".repeat(depth)) + } s => write!(f, "{}{}", "\t".repeat(depth), s), } } @@ -768,6 +781,13 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { generics, ), TypedStatement::PopCallLog => write!(f, "// POP CALL",), + TypedStatement::Assembly(ref statements) => { + writeln!(f, "asm {{")?; + for s in statements { + writeln!(f, "\t\t{}", s)?; + } + write!(f, "\t}}") + } } } } @@ -1188,6 +1208,26 @@ pub enum FieldElementExpression<'ast, T> { Box>, Box>, ), + And( + Box>, + Box>, + ), + Or( + Box>, + Box>, + ), + Xor( + Box>, + Box>, + ), + LeftShift( + Box>, + Box>, + ), + RightShift( + Box>, + Box>, + ), Conditional(ConditionalExpression<'ast, T, Self>), Neg(Box>), Pos(Box>), @@ -1196,6 +1236,73 @@ pub enum FieldElementExpression<'ast, T> { Select(SelectExpression<'ast, T, Self>), Element(ElementExpression<'ast, T, Self>), } + +impl<'ast, T: Field> From> for TupleExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => { + let inner = TupleExpression::identifier(v.id); + match v._type { + GType::Tuple(tuple_ty) => inner.annotate(tuple_ty), + _ => unreachable!(), + } + } + TypedAssignee::Select(box a, box index) => TupleExpression::select(a.into(), index), + TypedAssignee::Member(box a, id) => TupleExpression::member(a.into(), id), + TypedAssignee::Element(box a, index) => TupleExpression::element(a.into(), index), + } + } +} + +impl<'ast, T: Field> From> for StructExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => { + let inner = StructExpression::identifier(v.id); + match v._type { + GType::Struct(struct_ty) => inner.annotate(struct_ty), + _ => unreachable!(), + } + } + TypedAssignee::Select(box a, box index) => StructExpression::select(a.into(), index), + TypedAssignee::Member(box a, id) => StructExpression::member(a.into(), id), + TypedAssignee::Element(box a, index) => StructExpression::element(a.into(), index), + } + } +} + +impl<'ast, T: Field> From> for ArrayExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => { + let inner = ArrayExpression::identifier(v.id); + match v._type { + GType::Array(array_ty) => inner.annotate(*array_ty.ty, *array_ty.size), + _ => unreachable!(), + } + } + TypedAssignee::Select(box a, box index) => ArrayExpression::select(a.into(), index), + TypedAssignee::Member(box a, id) => ArrayExpression::member(a.into(), id), + TypedAssignee::Element(box a, index) => ArrayExpression::element(a.into(), index), + } + } +} + +impl<'ast, T: Field> From> for FieldElementExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => FieldElementExpression::identifier(v.id), + TypedAssignee::Element(box a, index) => { + FieldElementExpression::element(a.into(), index) + } + TypedAssignee::Member(box a, id) => FieldElementExpression::member(a.into(), id), + TypedAssignee::Select(box a, box index) => { + FieldElementExpression::select(a.into(), index) + } + } + } +} + impl<'ast, T> Add for FieldElementExpression<'ast, T> { type Output = Self; @@ -1676,6 +1783,11 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs), FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e), FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e), + FieldElementExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), + FieldElementExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), + FieldElementExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), + FieldElementExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), + FieldElementExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), FieldElementExpression::Conditional(ref c) => write!(f, "{}", c), FieldElementExpression::FunctionCall(ref function_call) => { write!(f, "{}", function_call) diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index 16f3f68cb..e4146c504 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -386,6 +386,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_assignee(self, a) } + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_statement(self, s) + } + fn fold_statement( &mut self, s: TypedStatement<'ast, T>, @@ -516,6 +523,27 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } } +pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedAssemblyStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(match s { + TypedAssemblyStatement::Assignment(a, e) => { + vec![TypedAssemblyStatement::Assignment( + f.fold_assignee(a)?, + f.fold_expression(e)?, + )] + } + TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { + vec![TypedAssemblyStatement::Constraint( + f.fold_field_expression(lhs)?, + f.fold_field_expression(rhs)?, + metadata, + )] + } + }) +} + pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: TypedStatement<'ast, T>, @@ -546,6 +574,15 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( .map(|e| f.fold_expression(e)) .collect::, _>>()?, ), + TypedStatement::Assembly(statements) => TypedStatement::Assembly( + statements + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ), s => s, }; Ok(vec![res]) @@ -780,6 +817,36 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( Pos(box e) } + And(box left, box right) => { + let left = f.fold_field_expression(left)?; + let right = f.fold_field_expression(right)?; + + And(box left, box right) + } + Or(box left, box right) => { + let left = f.fold_field_expression(left)?; + let right = f.fold_field_expression(right)?; + + Or(box left, box right) + } + Xor(box left, box right) => { + let left = f.fold_field_expression(left)?; + let right = f.fold_field_expression(right)?; + + Xor(box left, box right) + } + LeftShift(box e, box by) => { + let e = f.fold_field_expression(e)?; + let by = f.fold_uint_expression(by)?; + + LeftShift(box e, box by) + } + RightShift(box e, box by) => { + let e = f.fold_field_expression(e)?; + let by = f.fold_uint_expression(by)?; + + RightShift(box e, box by) + } Conditional(c) => match f.fold_conditional_expression(&Type::FieldElement, c)? { ConditionalOrExpression::Conditional(c) => Conditional(c), ConditionalOrExpression::Expression(u) => u, diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index f983ed4f2..a453fef3f 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -52,7 +52,10 @@ pub struct GenericIdentifier<'ast> { impl<'ast> From> for CoreIdentifier<'ast> { fn from(g: GenericIdentifier<'ast>) -> CoreIdentifier<'ast> { // generic identifiers are always declared in the function scope, which is shadow 0 - CoreIdentifier::Source(ShadowedIdentifier::shadow(g.name(), 0)) + CoreIdentifier::Source(ShadowedIdentifier::shadow( + std::borrow::Cow::Borrowed(g.name()), + 0, + )) } } @@ -120,9 +123,10 @@ pub struct SpecializationError; pub type ConstantIdentifier<'ast> = &'ast str; -#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)] +#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord, Serialize, Deserialize)] pub struct CanonicalConstantIdentifier<'ast> { pub module: OwnedTypedModuleId, + #[serde(borrow)] pub id: ConstantIdentifier<'ast>, } diff --git a/zokrates_ast/src/untyped/from_ast.rs b/zokrates_ast/src/untyped/from_ast.rs index 9dce23390..88c12d6ce 100644 --- a/zokrates_ast/src/untyped/from_ast.rs +++ b/zokrates_ast/src/untyped/from_ast.rs @@ -277,6 +277,7 @@ impl<'ast> From> for untyped::StatementNode<'ast> { pest::Statement::Assertion(s) => untyped::StatementNode::from(s), pest::Statement::Return(s) => untyped::StatementNode::from(s), pest::Statement::Log(s) => untyped::StatementNode::from(s), + pest::Statement::Assembly(s) => untyped::StatementNode::from(s), } } } @@ -340,6 +341,32 @@ impl<'ast> From> for untyped::StatementNode<'ast> } } +impl<'ast> From> for untyped::StatementNode<'ast> { + fn from(statement: pest::AssemblyStatement<'ast>) -> untyped::StatementNode<'ast> { + use crate::untyped::NodeValue; + + let statements = statement + .inner + .into_iter() + .map(|s| match s { + pest::AssemblyStatementInner::Assignment(a) => { + untyped::AssemblyStatement::Assignment( + a.assignee.into(), + a.expression.into(), + matches!(a.operator, pest::AssignmentOperator::AssignConstrain(_)), + ) + .span(a.span) + } + pest::AssemblyStatementInner::Constraint(c) => { + untyped::AssemblyStatement::Constraint(c.lhs.into(), c.rhs.into()).span(c.span) + } + }) + .collect(); + + untyped::Statement::Assembly(statements).span(statement.span) + } +} + impl<'ast> From> for untyped::ExpressionNode<'ast> { fn from(expression: pest::Expression<'ast>) -> untyped::ExpressionNode<'ast> { match expression { diff --git a/zokrates_ast/src/untyped/mod.rs b/zokrates_ast/src/untyped/mod.rs index 8fabc2ec5..07541edec 100644 --- a/zokrates_ast/src/untyped/mod.rs +++ b/zokrates_ast/src/untyped/mod.rs @@ -382,6 +382,33 @@ impl<'ast> fmt::Display for Assignee<'ast> { } } +#[derive(Debug, Clone, PartialEq)] +pub enum AssemblyStatement<'ast> { + Assignment(AssigneeNode<'ast>, ExpressionNode<'ast>, bool), + Constraint(ExpressionNode<'ast>, ExpressionNode<'ast>), +} + +pub type AssemblyStatementNode<'ast> = Node>; + +impl<'ast> fmt::Display for AssemblyStatement<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + AssemblyStatement::Assignment(ref lhs, ref rhs, ref constrained) => { + write!( + f, + "{} <{} {}", + lhs, + if *constrained { "==" } else { "--" }, + rhs + ) + } + AssemblyStatement::Constraint(ref lhs, ref rhs) => { + write!(f, "{} === {}", lhs, rhs) + } + } + } +} + /// A statement in a `Function` #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq)] @@ -397,6 +424,7 @@ pub enum Statement<'ast> { Vec>, ), Log(&'ast str, Vec>), + Assembly(Vec>), } pub type StatementNode<'ast> = Node>; @@ -431,7 +459,7 @@ impl<'ast> fmt::Display for Statement<'ast> { } Statement::Log(ref l, ref expressions) => write!( f, - "log({}, {})", + "log({}, {});", l, expressions .iter() @@ -439,6 +467,13 @@ impl<'ast> fmt::Display for Statement<'ast> { .collect::>() .join(", ") ), + Statement::Assembly(ref statements) => { + writeln!(f, "asm {{")?; + for s in statements { + writeln!(f, "\t\t{};", s)?; + } + write!(f, "\t}}") + } } } } diff --git a/zokrates_ast/src/untyped/node.rs b/zokrates_ast/src/untyped/node.rs index 44cda49f9..62ef299df 100644 --- a/zokrates_ast/src/untyped/node.rs +++ b/zokrates_ast/src/untyped/node.rs @@ -84,6 +84,7 @@ use super::*; impl<'ast> NodeValue for Expression<'ast> {} impl<'ast> NodeValue for Assignee<'ast> {} impl<'ast> NodeValue for Statement<'ast> {} +impl<'ast> NodeValue for AssemblyStatement<'ast> {} impl<'ast> NodeValue for SymbolDeclaration<'ast> {} impl<'ast> NodeValue for UnresolvedType<'ast> {} impl<'ast> NodeValue for StructDefinition<'ast> {} diff --git a/zokrates_ast/src/untyped/position.rs b/zokrates_ast/src/untyped/position.rs index 12394209a..786055551 100644 --- a/zokrates_ast/src/untyped/position.rs +++ b/zokrates_ast/src/untyped/position.rs @@ -1,6 +1,7 @@ +use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord)] +#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord, Serialize, Deserialize)] pub struct Position { pub line: usize, pub col: usize, diff --git a/zokrates_ast/src/zir/folder.rs b/zokrates_ast/src/zir/folder.rs index 917f8131d..770eb0f03 100644 --- a/zokrates_ast/src/zir/folder.rs +++ b/zokrates_ast/src/zir/folder.rs @@ -56,6 +56,13 @@ pub trait Folder<'ast, T: Field>: Sized { self.fold_variable(a) } + fn fold_assembly_statement( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Vec> { + fold_assembly_statement(self, s) + } + fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { fold_statement(self, s) } @@ -135,6 +142,24 @@ pub trait Folder<'ast, T: Field>: Sized { } } +pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ZirAssemblyStatement<'ast, T>, +) -> Vec> { + match s { + ZirAssemblyStatement::Assignment(assignees, function) => { + let assignees = assignees.into_iter().map(|a| f.fold_assignee(a)).collect(); + let function = f.fold_function(function); + vec![ZirAssemblyStatement::Assignment(assignees, function)] + } + ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { + let lhs = f.fold_field_expression(lhs); + let rhs = f.fold_field_expression(rhs); + vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)] + } + } +} + pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: ZirStatement<'ast, T>, @@ -173,6 +198,12 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( .map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect())) .collect(), ), + ZirStatement::Assembly(statements) => ZirStatement::Assembly( + statements + .into_iter() + .flat_map(|s| f.fold_assembly_statement(s)) + .collect(), + ), }; vec![res] } @@ -233,6 +264,36 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( let e2 = f.fold_uint_expression(e2); FieldElementExpression::Pow(box e1, box e2) } + FieldElementExpression::And(box left, box right) => { + let left = f.fold_field_expression(left); + let right = f.fold_field_expression(right); + + FieldElementExpression::And(box left, box right) + } + FieldElementExpression::Or(box left, box right) => { + let left = f.fold_field_expression(left); + let right = f.fold_field_expression(right); + + FieldElementExpression::Or(box left, box right) + } + FieldElementExpression::Xor(box left, box right) => { + let left = f.fold_field_expression(left); + let right = f.fold_field_expression(right); + + FieldElementExpression::Xor(box left, box right) + } + FieldElementExpression::LeftShift(box e, box by) => { + let e = f.fold_field_expression(e); + let by = f.fold_uint_expression(by); + + FieldElementExpression::LeftShift(box e, box by) + } + FieldElementExpression::RightShift(box e, box by) => { + let e = f.fold_field_expression(e); + let by = f.fold_uint_expression(by); + + FieldElementExpression::RightShift(box e, box by) + } FieldElementExpression::Conditional(c) => { match f.fold_conditional_expression(&Type::FieldElement, c) { ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s), diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index aae839d30..249b2630e 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -1,15 +1,18 @@ use crate::zir::types::MemberId; +use serde::{Deserialize, Serialize}; use std::fmt; use crate::typed::Identifier as CoreIdentifier; -#[derive(Debug, PartialEq, Clone, Hash, Eq)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] pub enum Identifier<'ast> { + #[serde(borrow)] Source(SourceIdentifier<'ast>), } -#[derive(Debug, PartialEq, Clone, Hash, Eq)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] pub enum SourceIdentifier<'ast> { + #[serde(borrow)] Basic(CoreIdentifier<'ast>), Select(Box>, u32), Member(Box>, MemberId), diff --git a/zokrates_ast/src/zir/lqc.rs b/zokrates_ast/src/zir/lqc.rs new file mode 100644 index 000000000..121b2a396 --- /dev/null +++ b/zokrates_ast/src/zir/lqc.rs @@ -0,0 +1,287 @@ +use crate::zir::{FieldElementExpression, Identifier}; +use zokrates_field::Field; + +pub type LinearTerm<'ast, T> = (T, Identifier<'ast>); +pub type QuadraticTerm<'ast, T> = (T, Identifier<'ast>, Identifier<'ast>); + +#[derive(Clone, PartialEq, Hash, Eq, Debug, Default)] +pub struct LinQuadComb<'ast, T> { + // the constant terms + pub constant: T, + // the linear terms + pub linear: Vec>, + // the quadratic terms + pub quadratic: Vec>, +} + +impl<'ast, T: Field> std::ops::Add for LinQuadComb<'ast, T> { + type Output = Self; + + fn add(self, mut other: Self) -> Self::Output { + Self { + constant: self.constant + other.constant, + linear: { + let mut l = self.linear; + l.append(&mut other.linear); + l + }, + quadratic: { + let mut q = self.quadratic; + q.append(&mut other.quadratic); + q + }, + } + } +} + +impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { + type Output = Self; + + fn sub(self, mut other: Self) -> Self::Output { + Self { + constant: self.constant - other.constant, + linear: { + let mut l = self.linear; + other.linear.iter_mut().for_each(|(c, _)| { + *c = T::zero() - &*c; + }); + l.append(&mut other.linear); + l + }, + quadratic: { + let mut q = self.quadratic; + other.quadratic.iter_mut().for_each(|(c, _, _)| { + *c = T::zero() - &*c; + }); + q.append(&mut other.quadratic); + q + }, + } + } +} + +impl<'ast, T: Field> LinQuadComb<'ast, T> { + fn try_mul(self, rhs: Self) -> Result { + // fail if the result has degree higher than 2 + if !(self.quadratic.is_empty() || rhs.quadratic.is_empty()) { + return Err(()); + } + + Ok(Self { + constant: self.constant.clone() * rhs.constant.clone(), + linear: { + // lin0 * const1 + lin1 * const0 + self.linear + .clone() + .into_iter() + .map(|(c, i)| (c * rhs.constant.clone(), i)) + .chain( + rhs.linear + .clone() + .into_iter() + .map(|(c, i)| (c * self.constant.clone(), i)), + ) + .collect() + }, + quadratic: { + // quad0 * const1 + quad1 * const0 + lin0 * lin1 + self.quadratic + .into_iter() + .map(|(c, i0, i1)| (c * rhs.constant.clone(), i0, i1)) + .chain( + rhs.quadratic + .into_iter() + .map(|(c, i0, i1)| (c * self.constant.clone(), i0, i1)), + ) + .chain(self.linear.iter().flat_map(|(cl, l)| { + rhs.linear + .iter() + .map(|(cr, r)| (cl.clone() * cr.clone(), l.clone(), r.clone())) + })) + .collect() + }, + }) + } +} + +impl<'ast, T: Field> TryFrom> for LinQuadComb<'ast, T> { + type Error = (); + + fn try_from(e: FieldElementExpression<'ast, T>) -> Result { + match e { + FieldElementExpression::Number(v) => Ok(Self { + constant: v, + ..Self::default() + }), + FieldElementExpression::Identifier(id) => Ok(Self { + linear: vec![(T::one(), id.id)], + ..Self::default() + }), + FieldElementExpression::Add(box left, box right) => { + Ok(Self::try_from(left)? + Self::try_from(right)?) + } + FieldElementExpression::Sub(box left, box right) => { + Ok(Self::try_from(left)? - Self::try_from(right)?) + } + FieldElementExpression::Mult(box left, box right) => { + let left = Self::try_from(left)?; + let right = Self::try_from(right)?; + + left.try_mul(right) + } + _ => Err(()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::zir::Id; + use zokrates_field::Bn128Field; + + #[test] + fn add() { + // (2 + 2*a) + let a = LinQuadComb::try_from(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("a".into()), + ), + )) + .unwrap(); + + // (2 + 2*a*b) + let b = LinQuadComb::try_from(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Mult( + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("a".into()), + ), + box FieldElementExpression::identifier("b".into()), + ), + )) + .unwrap(); + + // (2 + 2*a) + (2 + 2*a*b) => 4 + 2*a + 2*a*b + let c = a + b; + + assert_eq!(c.constant, Bn128Field::from(4)); + assert_eq!( + c.linear, + vec![ + (Bn128Field::from(2), "a".into()), + (Bn128Field::from(0), "a".into()), + (Bn128Field::from(0), "b".into()) + ] + ); + assert_eq!( + c.quadratic, + vec![(Bn128Field::from(2), "a".into(), "b".into())] + ); + } + + #[test] + fn sub() { + // (2 + 2*a) + let a = LinQuadComb::try_from(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("a".into()), + ), + )) + .unwrap(); + + // (2 + 2*a*b) + let b = LinQuadComb::try_from(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Mult( + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("a".into()), + ), + box FieldElementExpression::identifier("b".into()), + ), + )) + .unwrap(); + + // (2 + 2*a) - (2 + 2*a*b) => 0 + 2*a + (-2)*a*b + let c = a - b; + + assert_eq!(c.constant, Bn128Field::from(0)); + assert_eq!( + c.linear, + vec![ + (Bn128Field::from(2), "a".into()), + (Bn128Field::from(0), "a".into()), + (Bn128Field::from(0), "b".into()) + ] + ); + assert_eq!( + c.quadratic, + vec![(Bn128Field::from(-2), "a".into(), "b".into())] + ); + } + + #[test] + fn mult() { + // (2 + 2*a) + let a = LinQuadComb::try_from(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("a".into()), + ), + )) + .unwrap(); + + // (2 + 2*b) + let b = LinQuadComb::try_from(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::identifier("b".into()), + ), + )) + .unwrap(); + + // (2 + 2*a) * (2 + 2*b) => 4 + 4*b + 4*a + 4*a*b + let c = a.try_mul(b).unwrap(); + + assert_eq!(c.constant, Bn128Field::from(4)); + assert_eq!( + c.linear, + vec![ + (Bn128Field::from(4), "a".into()), + (Bn128Field::from(4), "b".into()), + ] + ); + assert_eq!( + c.quadratic, + vec![(Bn128Field::from(4), "a".into(), "b".into())] + ); + } + + #[test] + fn mult_degree_error() { + // 2*a*b + let a = LinQuadComb::try_from(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Mult( + box FieldElementExpression::identifier("a".into()), + box FieldElementExpression::identifier("b".into()), + ), + )) + .unwrap(); + + // 2*a*b + let b = a.clone(); + + // (2*a*b) * (2*a*b) would result in a higher degree than expected + let c = a.try_mul(b); + assert!(c.is_err()); + } +} diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index d8f325936..60dc1467f 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -1,6 +1,7 @@ pub mod folder; mod from_typed; mod identifier; +pub mod lqc; mod parameter; pub mod result_folder; pub mod types; @@ -10,7 +11,7 @@ mod variable; pub use self::parameter::Parameter; pub use self::types::{Type, UBitwidth}; pub use self::variable::Variable; -use crate::common::{FlatEmbed, FormatString}; +use crate::common::{FlatEmbed, FormatString, SourceMetadata}; use crate::typed::ConcreteType; pub use crate::zir::uint::{ShouldReduce, UExpression, UExpressionInner, UMetadata}; @@ -21,6 +22,7 @@ use zokrates_field::Field; pub use self::folder::Folder; pub use self::identifier::{Identifier, SourceIdentifier}; +use serde::{Deserialize, Serialize}; /// A typed program as a collection of modules, one of them being the main #[derive(PartialEq, Eq, Debug, Clone)] @@ -34,11 +36,13 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirProgram<'ast, T> { } } /// A typed function -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct ZirFunction<'ast, T> { /// Arguments of the function + #[serde(borrow)] pub arguments: Vec>, /// Vector of statements that are executed when running the function + #[serde(borrow)] pub statements: Vec>, /// function signature pub signature: Signature, @@ -67,7 +71,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> { writeln!(f)?; } - writeln!(f, "}}") + write!(f, "}}") } } @@ -88,9 +92,9 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirFunction<'ast, T> { pub type ZirAssignee<'ast> = Variable<'ast>; -#[derive(Debug, Clone, PartialEq, Hash, Eq)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub enum RuntimeError { - SourceAssertion(String), + SourceAssertion(SourceMetadata), SelectRangeCheck, DivisionByZero, IncompleteDynamicRange, @@ -99,7 +103,7 @@ pub enum RuntimeError { impl fmt::Display for RuntimeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - RuntimeError::SourceAssertion(message) => write!(f, "{}", message), + RuntimeError::SourceAssertion(metadata) => write!(f, "{}", metadata), RuntimeError::SelectRangeCheck => write!(f, "Range check on array access"), RuntimeError::DivisionByZero => write!(f, "Division by zero"), RuntimeError::IncompleteDynamicRange => write!(f, "Dynamic comparison is incomplete"), @@ -109,12 +113,46 @@ impl fmt::Display for RuntimeError { impl RuntimeError { pub fn mock() -> Self { - RuntimeError::SourceAssertion(String::default()) + RuntimeError::SourceAssertion(SourceMetadata::default()) + } +} + +#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] +pub enum ZirAssemblyStatement<'ast, T> { + Assignment( + #[serde(borrow)] Vec>, + ZirFunction<'ast, T>, + ), + Constraint( + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + SourceMetadata, + ), +} + +impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ZirAssemblyStatement::Assignment(ref lhs, ref rhs) => { + write!( + f, + "{} <-- {};", + lhs.iter() + .map(|a| a.to_string()) + .collect::>() + .join(", "), + rhs + ) + } + ZirAssemblyStatement::Constraint(ref lhs, ref rhs, _) => { + write!(f, "{} === {};", lhs, rhs) + } + } } } /// A statement in a `ZirFunction` -#[derive(Clone, PartialEq, Hash, Eq, Debug)] +#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] pub enum ZirStatement<'ast, T> { Return(Vec>), Definition(ZirAssignee<'ast>, ZirExpression<'ast, T>), @@ -129,6 +167,8 @@ pub enum ZirStatement<'ast, T> { FormatString, Vec<(ConcreteType, Vec>)>, ), + #[serde(borrow)] + Assembly(Vec>), } impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> { @@ -142,15 +182,19 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { write!(f, "{}", "\t".repeat(depth))?; match self { ZirStatement::Return(ref exprs) => { - write!( - f, - "return {};", - exprs - .iter() - .map(|e| e.to_string()) - .collect::>() - .join(", ") - ) + write!(f, "return")?; + if !exprs.is_empty() { + write!( + f, + " {}", + exprs + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", ") + )?; + } + write!(f, ";") } ZirStatement::Definition(ref lhs, ref rhs) => { write!(f, "{} = {};", lhs, rhs) @@ -166,7 +210,7 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { s.fmt_indented(f, depth + 1)?; writeln!(f)?; } - write!(f, "{}}};", "\t".repeat(depth)) + write!(f, "{}}}", "\t".repeat(depth)) } ZirStatement::Assertion(ref e, ref error) => { write!(f, "assert({}", e)?; @@ -200,6 +244,13 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { .collect::>() .join(", ") ), + ZirStatement::Assembly(statements) => { + writeln!(f, "asm {{")?; + for s in statements { + writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; + } + write!(f, "{}}}", "\t".repeat(depth)) + } } } } @@ -208,8 +259,9 @@ pub trait Typed { fn get_type(&self) -> Type; } -#[derive(Debug, Clone, PartialEq, Hash, Eq)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct IdentifierExpression<'ast, E> { + #[serde(borrow)] pub id: Identifier<'ast>, ty: PhantomData, } @@ -229,8 +281,9 @@ impl<'ast, E> IdentifierExpression<'ast, E> { } } -#[derive(Debug, Clone, PartialEq, Hash, Eq)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct ConditionalExpression<'ast, T, E> { + #[serde(borrow)] pub condition: Box>, pub consequence: Box, pub alternative: Box, @@ -256,9 +309,10 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpress } } -#[derive(Debug, Clone, PartialEq, Hash, Eq)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct SelectExpression<'ast, T, E> { pub array: Vec, + #[serde(borrow)] pub index: Box>, } @@ -287,11 +341,11 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for SelectExpression<' } /// A typed expression -#[derive(Clone, PartialEq, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub enum ZirExpression<'ast, T> { Boolean(BooleanExpression<'ast, T>), FieldElement(FieldElementExpression<'ast, T>), - Uint(UExpression<'ast, T>), + Uint(#[serde(borrow)] UExpression<'ast, T>), } impl<'ast, T: Field> From> for ZirExpression<'ast, T> { @@ -364,15 +418,20 @@ pub trait MultiTyped { fn get_types(&self) -> &Vec; } -#[derive(Clone, PartialEq, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub enum ZirExpressionList<'ast, T> { - EmbedCall(FlatEmbed, Vec, Vec>), + EmbedCall( + FlatEmbed, + Vec, + #[serde(borrow)] Vec>, + ), } /// An expression of type `field` -#[derive(Clone, PartialEq, Hash, Eq, Debug)] +#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] pub enum FieldElementExpression<'ast, T> { Number(T), + #[serde(borrow)] Identifier(IdentifierExpression<'ast, Self>), Select(SelectExpression<'ast, T, Self>), Add( @@ -392,16 +451,57 @@ pub enum FieldElementExpression<'ast, T> { Box>, ), Pow( + Box>, + #[serde(borrow)] Box>, + ), + And( + Box>, + Box>, + ), + Or( + Box>, + Box>, + ), + Xor( + Box>, + Box>, + ), + LeftShift( + Box>, + Box>, + ), + RightShift( Box>, Box>, ), Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>), } +impl<'ast, T> FieldElementExpression<'ast, T> { + pub fn is_linear(&self) -> bool { + match self { + FieldElementExpression::Number(_) => true, + FieldElementExpression::Identifier(_) => true, + FieldElementExpression::Add(box left, box right) => { + left.is_linear() && right.is_linear() + } + FieldElementExpression::Sub(box left, box right) => { + left.is_linear() && right.is_linear() + } + FieldElementExpression::Mult(box left, box right) => matches!( + (left, right), + (FieldElementExpression::Number(_), _) | (_, FieldElementExpression::Number(_)) + ), + _ => false, + } + } +} + /// An expression of type `bool` -#[derive(Clone, PartialEq, Hash, Eq, Debug)] +#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] pub enum BooleanExpression<'ast, T> { Value(bool), + #[serde(borrow)] Identifier(IdentifierExpression<'ast, Self>), Select(SelectExpression<'ast, T, Self>), FieldLt( @@ -501,6 +601,15 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs), + FieldElementExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), + FieldElementExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), + FieldElementExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), + FieldElementExpression::LeftShift(ref lhs, ref rhs) => { + write!(f, "({} << {})", lhs, rhs) + } + FieldElementExpression::RightShift(ref lhs, ref rhs) => { + write!(f, "({} >> {})", lhs, rhs) + } FieldElementExpression::Conditional(ref c) => { write!(f, "{}", c) } @@ -804,3 +913,36 @@ impl IntoType for UBitwidth { Type::Uint(self) } } + +pub trait Constant: Sized { + // return whether this is constant + fn is_constant(&self) -> bool; +} + +impl<'ast, T: Field> Constant for ZirExpression<'ast, T> { + fn is_constant(&self) -> bool { + match self { + ZirExpression::FieldElement(e) => e.is_constant(), + ZirExpression::Boolean(e) => e.is_constant(), + ZirExpression::Uint(e) => e.is_constant(), + } + } +} + +impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> { + fn is_constant(&self) -> bool { + matches!(self, FieldElementExpression::Number(..)) + } +} + +impl<'ast, T: Field> Constant for BooleanExpression<'ast, T> { + fn is_constant(&self) -> bool { + matches!(self, BooleanExpression::Value(..)) + } +} + +impl<'ast, T: Field> Constant for UExpression<'ast, T> { + fn is_constant(&self) -> bool { + matches!(self.as_inner(), UExpressionInner::Value(..)) + } +} diff --git a/zokrates_ast/src/zir/parameter.rs b/zokrates_ast/src/zir/parameter.rs index 08d26c935..203a291a7 100644 --- a/zokrates_ast/src/zir/parameter.rs +++ b/zokrates_ast/src/zir/parameter.rs @@ -1,8 +1,10 @@ use crate::zir::Variable; +use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct Parameter<'ast> { + #[serde(borrow)] pub id: Variable<'ast>, pub private: bool, } diff --git a/zokrates_ast/src/zir/result_folder.rs b/zokrates_ast/src/zir/result_folder.rs index 803e3ca61..6ea741cf6 100644 --- a/zokrates_ast/src/zir/result_folder.rs +++ b/zokrates_ast/src/zir/result_folder.rs @@ -61,6 +61,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { self.fold_variable(a) } + fn fold_assembly_statement( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_statement(self, s) + } + fn fold_statement( &mut self, s: ZirStatement<'ast, T>, @@ -152,6 +159,26 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_uint_expression_inner(self, bitwidth, e) } } +pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ZirAssemblyStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(match s { + ZirAssemblyStatement::Assignment(assignees, function) => { + let assignees = assignees + .into_iter() + .map(|a| f.fold_assignee(a)) + .collect::>()?; + let function = f.fold_function(function)?; + vec![ZirAssemblyStatement::Assignment(assignees, function)] + } + ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { + let lhs = f.fold_field_expression(lhs)?; + let rhs = f.fold_field_expression(rhs)?; + vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)] + } + }) +} pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, @@ -207,6 +234,16 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( ZirStatement::Log(l, e) } + ZirStatement::Assembly(statements) => { + let statements = statements + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + ZirStatement::Assembly(statements) + } }; Ok(vec![res]) } @@ -254,6 +291,36 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( let e2 = f.fold_uint_expression(e2)?; FieldElementExpression::Pow(box e1, box e2) } + FieldElementExpression::Xor(box left, box right) => { + let left = f.fold_field_expression(left)?; + let right = f.fold_field_expression(right)?; + + FieldElementExpression::Xor(box left, box right) + } + FieldElementExpression::And(box left, box right) => { + let left = f.fold_field_expression(left)?; + let right = f.fold_field_expression(right)?; + + FieldElementExpression::And(box left, box right) + } + FieldElementExpression::Or(box left, box right) => { + let left = f.fold_field_expression(left)?; + let right = f.fold_field_expression(right)?; + + FieldElementExpression::Or(box left, box right) + } + FieldElementExpression::LeftShift(box e, box by) => { + let e = f.fold_field_expression(e)?; + let by = f.fold_uint_expression(by)?; + + FieldElementExpression::LeftShift(box e, box by) + } + FieldElementExpression::RightShift(box e, box by) => { + let e = f.fold_field_expression(e)?; + let by = f.fold_uint_expression(by)?; + + FieldElementExpression::RightShift(box e, box by) + } FieldElementExpression::Conditional(c) => { match f.fold_conditional_expression(&Type::FieldElement, c)? { ConditionalOrExpression::Conditional(s) => FieldElementExpression::Conditional(s), diff --git a/zokrates_ast/src/zir/uint.rs b/zokrates_ast/src/zir/uint.rs index 6360372c3..9ae30d40e 100644 --- a/zokrates_ast/src/zir/uint.rs +++ b/zokrates_ast/src/zir/uint.rs @@ -1,5 +1,6 @@ use crate::zir::types::UBitwidth; use crate::zir::IdentifierExpression; +use serde::{Deserialize, Serialize}; use zokrates_field::Field; use super::{ConditionalExpression, SelectExpression}; @@ -91,7 +92,7 @@ impl<'ast, T> From for UExpression<'ast, T> { } } -#[derive(Debug, PartialEq, Eq, Clone, Hash)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub enum ShouldReduce { Unknown, True, @@ -135,7 +136,7 @@ impl ShouldReduce { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct UMetadata { pub max: T, pub should_reduce: ShouldReduce, @@ -162,16 +163,18 @@ impl UMetadata { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct UExpression<'ast, T> { pub bitwidth: UBitwidth, pub metadata: Option>, + #[serde(borrow)] pub inner: UExpressionInner<'ast, T>, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum UExpressionInner<'ast, T> { Value(u128), + #[serde(borrow)] Identifier(IdentifierExpression<'ast, UExpression<'ast, T>>), Select(SelectExpression<'ast, T, UExpression<'ast, T>>), Add(Box>, Box>), diff --git a/zokrates_ast/src/zir/variable.rs b/zokrates_ast/src/zir/variable.rs index 91fda8c72..14d329727 100644 --- a/zokrates_ast/src/zir/variable.rs +++ b/zokrates_ast/src/zir/variable.rs @@ -1,9 +1,11 @@ use crate::zir::types::{Type, UBitwidth}; use crate::zir::Identifier; +use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Clone, PartialEq, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] pub struct Variable<'ast> { + #[serde(borrow)] pub id: Identifier<'ast>, pub _type: Type, } diff --git a/zokrates_bellman/src/groth16.rs b/zokrates_bellman/src/groth16.rs index c918ccf5d..79d699ee4 100644 --- a/zokrates_bellman/src/groth16.rs +++ b/zokrates_bellman/src/groth16.rs @@ -21,8 +21,8 @@ use zokrates_proof_systems::Scheme; const G16_WARNING: &str = "WARNING: You are using the G16 scheme which is subject to malleability. See zokrates.github.io/toolbox/proving_schemes.html#g16-malleability for implications."; impl Backend for Bellman { - fn generate_proof>>( - program: ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, witness: Witness, proving_key: Vec, ) -> Proof { @@ -84,8 +84,8 @@ impl Backend for Bellman { } impl NonUniversalBackend for Bellman { - fn setup>>( - program: ProgIterator, + fn setup<'a, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, ) -> SetupKeypair { println!("{}", G16_WARNING); @@ -99,8 +99,8 @@ impl NonUniversalBackend for Bellman } impl MpcBackend for Bellman { - fn initialize>>( - program: ProgIterator, + fn initialize<'a, R: Read, W: Write, I: IntoIterator>>( + program: ProgIterator<'a, T, I>, phase1_radix: &mut R, output: &mut W, ) -> Result<(), String> { @@ -124,9 +124,9 @@ impl MpcBackend for Bellman { Ok(hash) } - fn verify>>( + fn verify<'a, P: Read, R: Read, I: IntoIterator>>( params: &mut P, - program: ProgIterator, + program: ProgIterator<'a, T, I>, phase1_radix: &mut R, ) -> Result, String> { let params = diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index 4bf396245..26bcf392e 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -22,20 +22,20 @@ pub use self::parse::*; pub struct Bellman; #[derive(Clone)] -pub struct Computation>> { - program: ProgIterator, +pub struct Computation<'a, T, I: IntoIterator>> { + program: ProgIterator<'a, T, I>, witness: Option>, } -impl>> Computation { - pub fn with_witness(program: ProgIterator, witness: Witness) -> Self { +impl<'a, T: Field, I: IntoIterator>> Computation<'a, T, I> { + pub fn with_witness(program: ProgIterator<'a, T, I>, witness: Witness) -> Self { Computation { program, witness: Some(witness), } } - pub fn without_witness(program: ProgIterator) -> Self { + pub fn without_witness(program: ProgIterator<'a, T, I>) -> Self { Computation { program, witness: None, @@ -83,8 +83,8 @@ fn bellman_combination>> - Circuit for Computation +impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator>> + Circuit for Computation<'a, T, I> { fn synthesize>( self, @@ -148,7 +148,9 @@ impl>> } } -impl>> Computation { +impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator>> + Computation<'a, T, I> +{ fn get_random_seed(&self) -> Result<[u32; 8], getrandom::Error> { let mut seed = [0u8; 32]; getrandom::getrandom(&mut seed)?; diff --git a/zokrates_book/src/SUMMARY.md b/zokrates_book/src/SUMMARY.md index 298f90438..5d2830cbf 100644 --- a/zokrates_book/src/SUMMARY.md +++ b/zokrates_book/src/SUMMARY.md @@ -16,6 +16,7 @@ - [Comments](language/comments.md) - [Macros](language/macros.md) - [Logging](language/logging.md) + - [Assembly](language/assembly.md) - [Toolbox](toolbox/index.md) - [CLI](toolbox/cli.md) diff --git a/zokrates_book/src/language/assembly.md b/zokrates_book/src/language/assembly.md new file mode 100644 index 000000000..aa37ddab7 --- /dev/null +++ b/zokrates_book/src/language/assembly.md @@ -0,0 +1,75 @@ +## Assembly + +ZoKrates allows developers to define constraints through assembly blocks. Assembly blocks are considered **unsafe**, as safety and correctness of the resulting arithmetic circuit is in the hands of the developer. Usage of assembly is recommended only in optimization efforts for the experienced developers to minimize constraint count of an arithmetic circuit. + +## Writing assembly + +All constraints must be enclosed within an `asm` block. In an assembly block we can do the following: + +1. Assign to a witness variable using `<--` +2. Define a constraint using `===` + +Assigning a value, in general, should be combined with adding a constraint: + +```zok +{{#include ../../../zokrates_cli/examples/book/assembly/division.zok}} +``` + +> The operator `<--` can be sometimes misused, as this operator does not generate any constraints, resulting in unconstrained variables in the constraint system. + +In some cases we can combine the witness assignment and constraint generation with the `<==` operator: + +```zok +asm { + c <== 1 - a*b; +} +``` + +which is equivalent to: + +```zok +asm { + c <-- 1 - a*b; + c === 1 - a*b; +} +``` + +A constraint can contain arithmetic expressions that are built using multiplication, addition, and other variables or `field` values. Only quadratic expressions are allowed to be included in constraints. Non-quadratic expressions or usage of other arithmetic operators like division or power are not allowed as constraints, but can be used in the witness assignment expression. + +The following code is not allowed: + +```zok +asm { + d === a*b*c; +} +``` + +as the constraint `d === a*b*c` is not quadratic. + +In some cases, ZoKrates will apply minor transformations on the defined constraints in order to meet the correct format: + +```zok +asm { + x * (x - 1) === 0; +} +``` + +will be transformed to: + +```zok +asm { + x === x * x; +} +``` + +## Type casting + +Assembly is a low-level part of the compiler which does not have type safety. In some cases we might want to do zero-cost conversions between `field` and `bool` type. + +### field_to_bool_unsafe + +This call is unsafe because it is the responsibility of the user to constrain the field input: + +```zok +{{#include ../../../zokrates_cli/examples/book/assembly/field_to_bool.zok}} +``` \ No newline at end of file diff --git a/zokrates_circom/src/r1cs.rs b/zokrates_circom/src/r1cs.rs index 8bdab0ac0..854bc0eab 100644 --- a/zokrates_circom/src/r1cs.rs +++ b/zokrates_circom/src/r1cs.rs @@ -72,6 +72,7 @@ pub fn r1cs_program(prog: Prog) -> (Vec, usize, Vec Some((quad, lin)), Statement::Directive(..) => None, + Statement::Block(..) => unreachable!(), Statement::Log(..) => None, }) { for (k, _) in &quad.left.0 { @@ -95,6 +96,7 @@ pub fn r1cs_program(prog: Prog) -> (Vec, usize, Vec Some((quad, lin)), + Statement::Block(..) => unreachable!(), Statement::Directive(..) => None, Statement::Log(..) => None, }) { diff --git a/zokrates_cli/examples/book/assembly/division.zok b/zokrates_cli/examples/book/assembly/division.zok new file mode 100644 index 000000000..b9f65af22 --- /dev/null +++ b/zokrates_cli/examples/book/assembly/division.zok @@ -0,0 +1,11 @@ +def main(field a, field b) -> field { + field mut c = 0; + field mut invb = 0; + asm { + invb <-- b == 0 ? 0 : 1 / b; + invb * b === 1; + c <-- invb * a; + a === b * c; + } + return c; +} \ No newline at end of file diff --git a/zokrates_cli/examples/book/assembly/field_to_bool.zok b/zokrates_cli/examples/book/assembly/field_to_bool.zok new file mode 100644 index 000000000..2e9f05523 --- /dev/null +++ b/zokrates_cli/examples/book/assembly/field_to_bool.zok @@ -0,0 +1,13 @@ +from "EMBED" import field_to_bool_unsafe; + +def main(field x) -> bool { + // we constrain `x` to be 0 or 1 + asm { + x * (x - 1) === 0; + } + // we can convert `x` to `bool` afterwards, as we constrained it properly + // if we failed to constrain `x` to `0` or `1`, the call to `field_to_bool_unsafe` introduces undefined behavior + // `field_to_bool_unsafe` call does not produce any extra constraints + bool out = field_to_bool_unsafe(x); + return out; +} \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/assembly/bitwise_op_in_constraint.zok b/zokrates_cli/examples/compile_errors/assembly/bitwise_op_in_constraint.zok new file mode 100644 index 000000000..3b54fa1ff --- /dev/null +++ b/zokrates_cli/examples/compile_errors/assembly/bitwise_op_in_constraint.zok @@ -0,0 +1,6 @@ +def main(field mut a, u32 i) { + asm { + a <-- a << i; // bitwise operations are allowed in witness generation + a === a << i; // but not in constraints + } +} \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/assembly/variable_index_assignment.zok b/zokrates_cli/examples/compile_errors/assembly/variable_index_assignment.zok new file mode 100644 index 000000000..ceed47f65 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/assembly/variable_index_assignment.zok @@ -0,0 +1,6 @@ +def main(field[2] mut a, u32 i) -> field[2] { + asm { + a[i] <== 42; // assigning to a variable index is not allowed in assembly + } + return a; +} \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/variable_exponent_in_pow.zok b/zokrates_cli/examples/compile_errors/variable_exponent_in_pow.zok new file mode 100644 index 000000000..14586fadd --- /dev/null +++ b/zokrates_cli/examples/compile_errors/variable_exponent_in_pow.zok @@ -0,0 +1,3 @@ +def main(field a, u32 b) -> field { + return a**b; +} \ No newline at end of file diff --git a/zokrates_cli/src/bin.rs b/zokrates_cli/src/bin.rs index 99a1b0838..a5f8a31f8 100644 --- a/zokrates_cli/src/bin.rs +++ b/zokrates_cli/src/bin.rs @@ -121,7 +121,8 @@ mod tests { use std::io::{BufReader, Read}; use std::string::String; use typed_arena::Arena; - use zokrates_core::compile::{compile, CompilationArtifacts, CompileConfig}; + use zokrates_common::CompileConfig; + use zokrates_core::compile::{compile, CompilationArtifacts}; use zokrates_field::Bn128Field; use zokrates_fs_resolver::FileSystemResolver; @@ -219,7 +220,7 @@ mod tests { let interpreter = zokrates_interpreter::Interpreter::default(); let _ = interpreter - .execute(artifacts.prog(), &[Bn128Field::from(0)]) + .execute(artifacts.prog(), &[Bn128Field::from(0u32)]) .unwrap(); } } diff --git a/zokrates_cli/src/ops/check.rs b/zokrates_cli/src/ops/check.rs index a32be7f6e..97d697efd 100644 --- a/zokrates_cli/src/ops/check.rs +++ b/zokrates_cli/src/ops/check.rs @@ -5,8 +5,8 @@ use std::fs::File; use std::io::{BufReader, Read}; use std::path::{Path, PathBuf}; use zokrates_common::constants::BN128; -use zokrates_common::helpers::CurveParameter; -use zokrates_core::compile::{check, CompileConfig, CompileError}; +use zokrates_common::{helpers::CurveParameter, CompileConfig}; +use zokrates_core::compile::{check, CompileError}; use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field}; use zokrates_fs_resolver::FileSystemResolver; diff --git a/zokrates_cli/src/ops/compile.rs b/zokrates_cli/src/ops/compile.rs index b00debe3d..74d26b6f2 100644 --- a/zokrates_cli/src/ops/compile.rs +++ b/zokrates_cli/src/ops/compile.rs @@ -8,8 +8,8 @@ use std::path::{Path, PathBuf}; use typed_arena::Arena; use zokrates_circom::write_r1cs; use zokrates_common::constants::BN128; -use zokrates_common::helpers::CurveParameter; -use zokrates_core::compile::{compile, CompileConfig, CompileError}; +use zokrates_common::{helpers::CurveParameter, CompileConfig}; +use zokrates_core::compile::{compile, CompileError}; use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field}; use zokrates_fs_resolver::FileSystemResolver; diff --git a/zokrates_cli/src/ops/compute_witness.rs b/zokrates_cli/src/ops/compute_witness.rs index 865f3c9b4..ab2a959bd 100644 --- a/zokrates_cli/src/ops/compute_witness.rs +++ b/zokrates_cli/src/ops/compute_witness.rs @@ -85,8 +85,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } } -fn cli_compute>>( - ir_prog: ir::ProgIterator, +fn cli_compute<'a, T: Field, I: Iterator>>( + ir_prog: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Computing witness..."); diff --git a/zokrates_cli/src/ops/generate_proof.rs b/zokrates_cli/src/ops/generate_proof.rs index 2a62042a4..319cf561c 100644 --- a/zokrates_cli/src/ops/generate_proof.rs +++ b/zokrates_cli/src/ops/generate_proof.rs @@ -136,12 +136,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } fn cli_generate_proof< + 'a, T: Field, - I: Iterator>, + I: Iterator>, S: Scheme, B: Backend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Generating proof..."); diff --git a/zokrates_cli/src/ops/generate_smtlib2.rs b/zokrates_cli/src/ops/generate_smtlib2.rs index b1bf6f6a8..ac58f8e56 100644 --- a/zokrates_cli/src/ops/generate_smtlib2.rs +++ b/zokrates_cli/src/ops/generate_smtlib2.rs @@ -47,8 +47,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } } -fn cli_smtlib2>>( - ir_prog: ir::ProgIterator, +fn cli_smtlib2<'a, T: Field, I: Iterator>>( + ir_prog: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Generating SMTLib2..."); diff --git a/zokrates_cli/src/ops/inspect.rs b/zokrates_cli/src/ops/inspect.rs index 523d664a4..b8ca37545 100644 --- a/zokrates_cli/src/ops/inspect.rs +++ b/zokrates_cli/src/ops/inspect.rs @@ -43,8 +43,8 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } } -fn cli_inspect>>( - ir_prog: ir::ProgIterator, +fn cli_inspect<'a, T: Field, I: Iterator>>( + ir_prog: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { let ir_prog: ir::Prog = ir_prog.collect(); @@ -52,6 +52,9 @@ fn cli_inspect>>( let curve = format!("{:<17} {}", "curve:", T::name()); let constraint_count = format!("{:<17} {}", "constraint_count:", ir_prog.constraint_count()); + println!("{}", curve); + println!("{}", constraint_count); + if sub_matches.is_present("ztf") { let output_path = PathBuf::from(sub_matches.value_of("input").unwrap()).with_extension("ztf"); diff --git a/zokrates_cli/src/ops/mpc/init.rs b/zokrates_cli/src/ops/mpc/init.rs index 71136366c..eb7ba16e4 100644 --- a/zokrates_cli/src/ops/mpc/init.rs +++ b/zokrates_cli/src/ops/mpc/init.rs @@ -58,12 +58,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } fn cli_mpc_init< + 'a, T: Field + BellmanFieldExtensions, - I: Iterator>, + I: Iterator>, S: MpcScheme, B: MpcBackend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Initializing MPC..."); diff --git a/zokrates_cli/src/ops/mpc/verify.rs b/zokrates_cli/src/ops/mpc/verify.rs index 6beca03db..fa014bd06 100644 --- a/zokrates_cli/src/ops/mpc/verify.rs +++ b/zokrates_cli/src/ops/mpc/verify.rs @@ -58,12 +58,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } fn cli_mpc_verify< + 'a, T: Field + BellmanFieldExtensions, - I: Iterator>, + I: Iterator>, S: MpcScheme, B: MpcBackend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Verifying contributions..."); diff --git a/zokrates_cli/src/ops/setup.rs b/zokrates_cli/src/ops/setup.rs index e9e8a2166..0fc569538 100644 --- a/zokrates_cli/src/ops/setup.rs +++ b/zokrates_cli/src/ops/setup.rs @@ -167,12 +167,13 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { } fn cli_setup_non_universal< + 'a, T: Field, - I: Iterator>, + I: Iterator>, S: NonUniversalScheme, B: NonUniversalBackend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, sub_matches: &ArgMatches, ) -> Result<(), String> { println!("Performing setup..."); @@ -211,12 +212,13 @@ fn cli_setup_non_universal< } fn cli_setup_universal< + 'a, T: Field, - I: Iterator>, + I: Iterator>, S: UniversalScheme, B: UniversalBackend, >( - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, srs: Vec, sub_matches: &ArgMatches, ) -> Result<(), String> { diff --git a/zokrates_codegen/Cargo.toml b/zokrates_codegen/Cargo.toml new file mode 100644 index 000000000..6bb566f4c --- /dev/null +++ b/zokrates_codegen/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "zokrates_codegen" +version = "0.1.0" +edition = "2021" + +[features] +default = ["ark", "bellman"] +ark = ["zokrates_ast/ark", "zokrates_embed/ark", "zokrates_common/ark", "zokrates_interpreter/ark"] +bellman = ["zokrates_ast/bellman", "zokrates_embed/bellman", "zokrates_common/bellman", "zokrates_interpreter/bellman"] + +[dependencies] +zokrates_field = { version = "0.5.0", path = "../zokrates_field", default-features = false } +zokrates_common = { version = "0.1.0", path = "../zokrates_common", default-features = false } +zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false } +zokrates_interpreter = { version = "0.1", path = "../zokrates_interpreter", default-features = false } +zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } \ No newline at end of file diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_codegen/src/lib.rs similarity index 97% rename from zokrates_core/src/flatten/mod.rs rename to zokrates_codegen/src/lib.rs index 2e3e95b8b..cf4e8cbb0 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_codegen/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(box_patterns, box_syntax)] + //! Module containing the `Flattener` to process a program that is R1CS-able. //! //! @file flatten.rs @@ -9,11 +11,11 @@ mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::zir::{ - ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirExpressionList, + ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, + ZirExpressionList, }; use zokrates_interpreter::Interpreter; -use crate::compile::CompileConfig; use std::collections::{ hash_map::{Entry, HashMap}, VecDeque, @@ -29,9 +31,10 @@ use zokrates_ast::zir::{ UExpression, UExpressionInner, Variable as ZirVariable, ZirExpression, ZirFunction, ZirStatement, }; +use zokrates_common::CompileConfig; use zokrates_field::Field; -type FlatStatements = VecDeque>; +type FlatStatements<'ast, T> = VecDeque>; /// Flattens a function /// @@ -63,14 +66,14 @@ pub fn from_function_and_config( pub struct FlattenerIteratorInner<'ast, T> { pub statements: VecDeque>, - pub statements_flattened: FlatStatements, + pub statements_flattened: FlatStatements<'ast, T>, pub flattener: Flattener<'ast, T>, } -pub type FlattenerIterator<'ast, T> = FlatProgIterator>; +pub type FlattenerIterator<'ast, T> = FlatProgIterator<'ast, T, FlattenerIteratorInner<'ast, T>>; impl<'ast, T: Field> Iterator for FlattenerIteratorInner<'ast, T> { - type Item = FlatStatement; + type Item = FlatStatement<'ast, T>; fn next(&mut self) -> Option { while self.statements_flattened.is_empty() { @@ -124,7 +127,7 @@ trait Flatten<'ast, T: Field>: From> + Conditional<'ast, fn flatten( self, flattener: &mut Flattener<'ast, T>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output; } @@ -134,7 +137,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> { fn flatten( self, flattener: &mut Flattener<'ast, T>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output { flattener.flatten_field_expression(statements_flattened, self) } @@ -146,7 +149,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for UExpression<'ast, T> { fn flatten( self, flattener: &mut Flattener<'ast, T>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output { flattener.flatten_uint_expression(statements_flattened, self) } @@ -158,7 +161,7 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> { fn flatten( self, flattener: &mut Flattener<'ast, T>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output { flattener.flatten_boolean_expression(statements_flattened, self) } @@ -224,7 +227,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn define( &mut self, e: FlatExpression, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Variable { match e { FlatExpression::Identifier(id) => id, @@ -273,7 +276,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { #[must_use] fn constant_le_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, a: &[FlatExpression], b: &[bool], ) -> Vec> { @@ -378,7 +381,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise fn eq_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, left: FlatExpression, right: FlatExpression, ) -> FlatExpression { @@ -431,7 +434,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `b` - the big-endian bit decomposition of the upper bound of the range fn enforce_constant_le_check_bits( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, a: &[FlatExpression], c: &[bool], error: RuntimeError, @@ -461,7 +464,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `c` - the constant upper bound of the range fn enforce_constant_le_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, error: RuntimeError, @@ -497,7 +500,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `c` - the constant upper bound of the range fn enforce_constant_lt_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, error: RuntimeError, @@ -516,9 +519,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn make_conditional( &mut self, - statements: FlatStatements, + statements: FlatStatements<'ast, T>, condition: FlatExpression, - ) -> FlatStatements { + ) -> FlatStatements<'ast, T> { statements .into_iter() .flat_map(|s| match s { @@ -579,7 +582,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * U is the type of the expression fn flatten_conditional_expression>( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: ConditionalExpression<'ast, T, U>, ) -> FlatUExpression { let condition = *e.condition; @@ -677,7 +680,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise fn constant_lt_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, ) -> FlatExpression { @@ -701,7 +704,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise fn constant_field_le_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, ) -> FlatExpression { @@ -742,7 +745,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { #[must_use] fn le_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, lhs_flattened: FlatExpression, rhs_flattened: FlatExpression, bit_width: usize, @@ -765,7 +768,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { #[must_use] fn lt_check( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, lhs_flattened: FlatExpression, rhs_flattened: FlatExpression, bit_width: usize, @@ -824,7 +827,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * in order to preserve composability. fn flatten_boolean_expression( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, expression: BooleanExpression<'ast, T>, ) -> FlatExpression { match expression { @@ -1030,7 +1033,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `param_expressions` - Arguments of this call fn flatten_embed_call( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, embed: FlatEmbed, generics: Vec, param_expressions: Vec>, @@ -1046,6 +1049,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .collect(); match embed { + FlatEmbed::FieldToBoolUnsafe => vec![params.pop().unwrap()], FlatEmbed::U8ToBits => self.u_to_bits(params.pop().unwrap(), 8.into()), FlatEmbed::U16ToBits => self.u_to_bits(params.pop().unwrap(), 16.into()), FlatEmbed::U32ToBits => self.u_to_bits(params.pop().unwrap(), 32.into()), @@ -1131,9 +1135,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn flatten_embed_call_aux( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, params: Vec>, - funct: FlatFunctionIterator>>, + funct: FlatFunctionIterator<'ast, T, impl IntoIterator>>, ) -> Vec> { let mut replacement_map = HashMap::new(); @@ -1152,6 +1156,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // add all flattened statements, adapt return statements let statements = funct.statements.into_iter().map(|stat| match stat { + FlatStatement::Block(..) => unreachable!(), FlatStatement::Definition(var, rhs) => { let new_var = self.use_sym(); replacement_map.insert(var, new_var); @@ -1216,7 +1221,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `expr` - `ZirExpression` that will be flattened. fn flatten_expression( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, expr: ZirExpression<'ast, T>, ) -> FlatUExpression { match expr { @@ -1232,7 +1237,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn default_xor( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, left: UExpression<'ast, T>, right: UExpression<'ast, T>, ) -> FlatUExpression { @@ -1293,7 +1298,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn euclidean_division( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, target_bitwidth: UBitwidth, left: UExpression<'ast, T>, right: UExpression<'ast, T>, @@ -1379,7 +1384,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `expr` - `UExpression` that will be flattened. fn flatten_uint_expression( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, expr: UExpression<'ast, T>, ) -> FlatUExpression { // the bitwidth for this type of uint (8, 16 or 32) @@ -1872,7 +1877,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { e: &FlatUExpression, from: usize, to: usize, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, error: RuntimeError, ) -> Vec> { assert!(from <= T::get_required_bits()); @@ -1966,7 +1971,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn flatten_select_expression>( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, e: SelectExpression<'ast, T, U>, ) -> FlatUExpression { let array = e.array; @@ -2030,7 +2035,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `expr` - `FieldElementExpression` that will be flattened. fn flatten_field_expression( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, expr: FieldElementExpression<'ast, T>, ) -> FlatExpression { match expr { @@ -2215,6 +2220,39 @@ impl<'ast, T: Field> Flattener<'ast, T> { FieldElementExpression::Conditional(e) => self .flatten_conditional_expression(statements_flattened, e) .get_field_unchecked(), + _ => unreachable!(), + } + } + + fn flatten_assembly_statement( + &mut self, + statements_flattened: &mut FlatStatements<'ast, T>, + stat: ZirAssemblyStatement<'ast, T>, + ) { + match stat { + ZirAssemblyStatement::Assignment(assignees, function) => { + let inputs: Vec> = function + .arguments + .iter() + .cloned() + .map(|p| self.layout.get(&p.id.id).cloned().unwrap().into()) + .collect(); + let outputs: Vec = assignees + .into_iter() + .map(|assignee| self.use_variable(&assignee)) + .collect(); + let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); + statements_flattened.push_back(FlatStatement::Directive(directive)); + } + ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { + let lhs = self.flatten_field_expression(statements_flattened, lhs); + let rhs = self.flatten_field_expression(statements_flattened, rhs); + statements_flattened.push_back(FlatStatement::Condition( + lhs, + rhs, + RuntimeError::SourceAssemblyConstraint(metadata), + )); + } } } @@ -2226,10 +2264,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * `stat` - `ZirStatement` that will be flattened. fn flatten_statement( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, stat: ZirStatement<'ast, T>, ) { match stat { + ZirStatement::Assembly(statements) => { + let mut block_statements = VecDeque::new(); + for s in statements { + self.flatten_assembly_statement(&mut block_statements, s); + } + statements_flattened.push_back(FlatStatement::Block(block_statements.into())); + } ZirStatement::Return(exprs) => { #[allow(clippy::needless_collect)] // clippy suggests to not collect here, but `statements_flattened` is borrowed in the iterator, @@ -2630,12 +2675,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// # Arguments /// - /// * `statements_flattened` - `FlatStatements` Vector where new flattened statements can be added. + /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added. /// * `lhs` - `FlatExpression` Left-hand side of the equality expression. /// * `rhs` - `FlatExpression` Right-hand side of the equality expression. fn flatten_equality_assertion( &mut self, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, lhs: FlatExpression, rhs: FlatExpression, error: RuntimeError, @@ -2664,11 +2709,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// # Arguments /// /// * `e` - `FlatExpression` Expression to be assigned to an identifier. - /// * `statements_flattened` - `FlatStatements` Vector where new flattened statements can be added. + /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added. fn identify_expression( &mut self, e: FlatExpression, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> FlatExpression { match e.is_linear() { true => e, @@ -2707,7 +2752,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn use_parameter( &mut self, parameter: &ZirParameter<'ast>, - statements_flattened: &mut FlatStatements, + statements_flattened: &mut FlatStatements<'ast, T>, ) -> Parameter { let variable = self.use_variable(¶meter.id); diff --git a/zokrates_core/src/flatten/utils.rs b/zokrates_codegen/src/utils.rs similarity index 100% rename from zokrates_core/src/flatten/utils.rs rename to zokrates_codegen/src/utils.rs diff --git a/zokrates_common/Cargo.toml b/zokrates_common/Cargo.toml index ef70f90d6..40ee04f0b 100644 --- a/zokrates_common/Cargo.toml +++ b/zokrates_common/Cargo.toml @@ -12,4 +12,5 @@ bellman = [] ark = [] -[dependencies] \ No newline at end of file +[dependencies] +serde = { version = "1.0", features = ["derive"] } \ No newline at end of file diff --git a/zokrates_common/src/lib.rs b/zokrates_common/src/lib.rs index 9bfc0c79b..a37d9112c 100644 --- a/zokrates_common/src/lib.rs +++ b/zokrates_common/src/lib.rs @@ -1,6 +1,7 @@ pub mod constants; pub mod helpers; +use serde::{Deserialize, Serialize}; use std::path::PathBuf; pub trait Resolver { @@ -10,3 +11,23 @@ pub trait Resolver { import_location: PathBuf, ) -> Result<(String, PathBuf), E>; } + +#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy)] +pub struct CompileConfig { + #[serde(default)] + pub isolate_branches: bool, + #[serde(default)] + pub debug: bool, +} + +impl CompileConfig { + pub fn isolate_branches(mut self, flag: bool) -> Self { + self.isolate_branches = flag; + self + } + + pub fn debug(mut self, debug: bool) -> Self { + self.debug = debug; + self + } +} diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index 10fa24a1a..44da27eab 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -8,8 +8,8 @@ readme = "README.md" [features] default = ["ark", "bellman"] -ark = ["zokrates_ast/ark", "zokrates_embed/ark", "zokrates_common/ark", "zokrates_interpreter/ark"] -bellman = ["zokrates_ast/bellman", "zokrates_embed/bellman", "zokrates_common/bellman", "zokrates_interpreter/bellman"] +ark = ["zokrates_ast/ark", "zokrates_embed/ark", "zokrates_common/ark", "zokrates_interpreter/ark", "zokrates_codegen/ark", "zokrates_analysis/ark"] +bellman = ["zokrates_ast/bellman", "zokrates_embed/bellman", "zokrates_common/bellman", "zokrates_interpreter/bellman", "zokrates_codegen/bellman", "zokrates_analysis/bellman"] [dependencies] log = "0.4" @@ -26,6 +26,8 @@ zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } zokrates_common = { version = "0.1", path = "../zokrates_common", default-features = false } zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false } zokrates_interpreter = { version = "0.1", path = "../zokrates_interpreter", default-features = false } +zokrates_codegen = { version = "0.1", path = "../zokrates_codegen", default-features = false } +zokrates_analysis = { version = "0.1", path = "../zokrates_analysis", default-features = false } zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } csv = "1" diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 68f35560d..0c7583fa2 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -3,35 +3,34 @@ //! @file compile.rs //! @author Thibaut Schaeffer //! @date 2018 -use crate::flatten::from_function_and_config; use crate::imports::{self, Importer}; use crate::macros; use crate::optimizer::optimize; use crate::semantics::{self, Checker}; -use crate::static_analysis::{self, analyse}; use macros::process_macros; -use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt; use std::io; use std::path::{Path, PathBuf}; use typed_arena::Arena; +use zokrates_analysis::{self, analyse}; use zokrates_ast::ir::{self, from_flat::from_flat}; use zokrates_ast::typed::abi::Abi; use zokrates_ast::untyped::{Module, OwnedModuleId, Program}; use zokrates_ast::zir::ZirProgram; -use zokrates_common::Resolver; +use zokrates_codegen::from_function_and_config; +use zokrates_common::{CompileConfig, Resolver}; use zokrates_field::Field; use zokrates_pest_ast as pest; #[derive(Debug)] -pub struct CompilationArtifacts>> { - prog: ir::ProgIterator, +pub struct CompilationArtifacts<'ast, T, I: IntoIterator>> { + prog: ir::ProgIterator<'ast, T, I>, abi: Abi, } -impl>> CompilationArtifacts { - pub fn prog(self) -> ir::ProgIterator { +impl<'ast, T, I: IntoIterator>> CompilationArtifacts<'ast, T, I> { + pub fn prog(self) -> ir::ProgIterator<'ast, T, I> { self.prog } @@ -39,11 +38,11 @@ impl>> CompilationArtifacts { &self.abi } - pub fn into_inner(self) -> (ir::ProgIterator, Abi) { + pub fn into_inner(self) -> (ir::ProgIterator<'ast, T, I>, Abi) { (self.prog, self.abi) } - pub fn collect(self) -> CompilationArtifacts>> { + pub fn collect(self) -> CompilationArtifacts<'ast, T, Vec>> { CompilationArtifacts { prog: self.prog.collect(), abi: self.abi, @@ -67,7 +66,7 @@ pub enum CompileErrorInner { MacroError(macros::Error), SemanticError(semantics::ErrorInner), ReadError(io::Error), - AnalysisError(static_analysis::Error), + AnalysisError(zokrates_analysis::Error), } impl CompileErrorInner { @@ -142,8 +141,8 @@ impl From for CompileError { } } -impl From for CompileErrorInner { - fn from(error: static_analysis::Error) -> Self { +impl From for CompileErrorInner { + fn from(error: zokrates_analysis::Error) -> Self { CompileErrorInner::AnalysisError(error) } } @@ -173,26 +172,6 @@ impl fmt::Display for CompileErrorInner { } } -#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy)] -pub struct CompileConfig { - #[serde(default)] - pub isolate_branches: bool, - #[serde(default)] - pub debug: bool, -} - -impl CompileConfig { - pub fn isolate_branches(mut self, flag: bool) -> Self { - self.isolate_branches = flag; - self - } - - pub fn debug(mut self, debug: bool) -> Self { - self.debug = debug; - self - } -} - type FilePath = PathBuf; pub fn compile<'ast, T: Field, E: Into>( @@ -201,8 +180,10 @@ pub fn compile<'ast, T: Field, E: Into>( resolver: Option<&dyn Resolver>, config: CompileConfig, arena: &'ast Arena, -) -> Result> + 'ast>, CompileErrors> -{ +) -> Result< + CompilationArtifacts<'ast, T, impl IntoIterator> + 'ast>, + CompileErrors, +> { let (typed_ast, abi): (zokrates_ast::zir::ZirProgram<'_, T>, _) = check_with_arena(source, location, resolver, &config, arena)?; @@ -218,8 +199,11 @@ pub fn compile<'ast, T: Field, E: Into>( log::debug!("Optimise IR"); let optimized_ir_prog = optimize(ir_prog); + // clean (remove blocks) + let clean_ir_prog = optimized_ir_prog.clean(); + Ok(CompilationArtifacts { - prog: optimized_ir_prog, + prog: clean_ir_prog, abi, }) } diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index 9abd9c1ba..f7fa95f13 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -147,6 +147,10 @@ impl Importer { id: symbol.get_alias(), symbol: Symbol::Flat(FlatEmbed::Unpack), }, + "field_to_bool_unsafe" => SymbolDeclaration { + id: symbol.get_alias(), + symbol: Symbol::Flat(FlatEmbed::FieldToBoolUnsafe), + }, "bit_array_le" => SymbolDeclaration { id: symbol.get_alias(), symbol: Symbol::Flat(FlatEmbed::BitArrayLe), diff --git a/zokrates_core/src/lib.rs b/zokrates_core/src/lib.rs index b6cebfcb6..51de21db8 100644 --- a/zokrates_core/src/lib.rs +++ b/zokrates_core/src/lib.rs @@ -1,9 +1,7 @@ #![feature(box_patterns, box_syntax)] pub mod compile; -mod flatten; pub mod imports; mod macros; mod optimizer; mod semantics; -mod static_analysis; diff --git a/zokrates_core/src/optimizer/canonicalizer.rs b/zokrates_core/src/optimizer/canonicalizer.rs index 4a65bc85f..57810aedd 100644 --- a/zokrates_core/src/optimizer/canonicalizer.rs +++ b/zokrates_core/src/optimizer/canonicalizer.rs @@ -4,7 +4,7 @@ use zokrates_field::Field; #[derive(Default)] pub struct Canonicalizer; -impl Folder for Canonicalizer { +impl<'ast, T: Field> Folder<'ast, T> for Canonicalizer { fn fold_linear_combination(&mut self, l: LinComb) -> LinComb { l.into_canonical().into() } diff --git a/zokrates_core/src/optimizer/directive.rs b/zokrates_core/src/optimizer/directive.rs index afabc87bd..4d140637a 100644 --- a/zokrates_core/src/optimizer/directive.rs +++ b/zokrates_core/src/optimizer/directive.rs @@ -14,19 +14,21 @@ use zokrates_ast::ir::folder::*; use zokrates_ast::ir::*; use zokrates_field::Field; +type SolverCall<'ast, T> = (Solver<'ast, T>, Vec>); + #[derive(Debug, Default)] -pub struct DirectiveOptimizer { - calls: HashMap<(Solver, Vec>), Vec>, +pub struct DirectiveOptimizer<'ast, T> { + calls: HashMap, Vec>, /// Map of renamings for reassigned variables while processing the program. substitution: HashMap, } -impl Folder for DirectiveOptimizer { +impl<'ast, T: Field> Folder<'ast, T> for DirectiveOptimizer<'ast, T> { fn fold_variable(&mut self, v: Variable) -> Variable { *self.substitution.get(&v).unwrap_or(&v) } - fn fold_statement(&mut self, s: Statement) -> Vec> { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { match s { Statement::Directive(d) => { let d = self.fold_directive(d); diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index 68c954096..664cfc2db 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -21,8 +21,8 @@ pub struct DuplicateOptimizer { seen: HashSet, } -impl Folder for DuplicateOptimizer { - fn fold_program(&mut self, p: Prog) -> Prog { +impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer { + fn fold_program(&mut self, p: Prog<'ast, T>) -> Prog<'ast, T> { // in order to correctly identify duplicates, we need to first canonicalize the statements let mut canonicalizer = Canonicalizer; @@ -38,7 +38,7 @@ impl Folder for DuplicateOptimizer { fold_program(self, p) } - fn fold_statement(&mut self, s: Statement) -> Vec> { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { let hashed = hash(&s); let result = match self.seen.get(&hashed) { Some(_) => vec![], diff --git a/zokrates_core/src/optimizer/mod.rs b/zokrates_core/src/optimizer/mod.rs index cecee2e3f..1f94740a9 100644 --- a/zokrates_core/src/optimizer/mod.rs +++ b/zokrates_core/src/optimizer/mod.rs @@ -19,9 +19,9 @@ use self::tautology::TautologyOptimizer; use zokrates_ast::ir::{ProgIterator, Statement}; use zokrates_field::Field; -pub fn optimize>>( - p: ProgIterator, -) -> ProgIterator>> { +pub fn optimize<'ast, T: Field, I: IntoIterator>>( + p: ProgIterator<'ast, T, I>, +) -> ProgIterator<'ast, T, impl IntoIterator>> { // remove redefinitions log::debug!("Optimizer: Remove redefinitions and tautologies and directives and duplicates"); diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index e853be115..b0877fd02 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -52,8 +52,10 @@ pub struct RedefinitionOptimizer { pub ignore: HashSet, } -impl RedefinitionOptimizer { - pub fn init>>(p: &ProgIterator) -> Self { +impl RedefinitionOptimizer { + pub fn init<'ast, I: IntoIterator>>( + p: &ProgIterator<'ast, T, I>, + ) -> Self { RedefinitionOptimizer { substitution: HashMap::new(), ignore: vec![Variable::one()] @@ -64,10 +66,12 @@ impl RedefinitionOptimizer { .collect(), } } -} -impl Folder for RedefinitionOptimizer { - fn fold_statement(&mut self, s: Statement) -> Vec> { + fn fold_statement<'ast>( + &mut self, + s: Statement<'ast, T>, + aggressive: bool, + ) -> Vec> { match s { Statement::Constraint(quad, lin, message) => { let quad = self.fold_quadratic_combination(quad); @@ -161,9 +165,11 @@ impl Folder for RedefinitionOptimizer { .unwrap_or_else(|q| q) }) .collect(); - // to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set - for o in d.outputs.iter().cloned() { - self.ignore.insert(o); + if !aggressive { + // to prevent the optimiser from replacing variables introduced by directives, add them to the ignored set + for o in d.outputs.iter().cloned() { + self.ignore.insert(o); + } } vec![Statement::Directive(Directive { inputs, ..d })] } @@ -172,6 +178,36 @@ impl Folder for RedefinitionOptimizer { s => fold_statement(self, s), } } +} + +impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { + match s { + Statement::Block(statements) => { + #[allow(clippy::needless_collect)] + // optimize aggressively and clean up in a second pass (we need to collect here) + let statements: Vec<_> = statements + .into_iter() + .flat_map(|s| self.fold_statement(s, true)) + .collect(); + + // clean up + let statements = statements + .into_iter() + .filter(|s| match s { + // we remove a directive iff it has a single output and this output is in the substitution map, meaning it was propagated + Statement::Directive(d) => { + d.outputs.len() > 1 || !self.substitution.contains_key(&d.outputs[0]) + } + _ => true, + }) + .collect(); + + vec![Statement::Block(statements)] + } + s => self.fold_statement(s, false), + } + } fn fold_linear_combination(&mut self, lc: LinComb) -> LinComb { match lc diff --git a/zokrates_core/src/optimizer/tautology.rs b/zokrates_core/src/optimizer/tautology.rs index 4a9ce8472..855efa11d 100644 --- a/zokrates_core/src/optimizer/tautology.rs +++ b/zokrates_core/src/optimizer/tautology.rs @@ -13,8 +13,8 @@ use zokrates_field::Field; #[derive(Default)] pub struct TautologyOptimizer; -impl Folder for TautologyOptimizer { - fn fold_statement(&mut self, s: Statement) -> Vec> { +impl<'ast, T: Field> Folder<'ast, T> for TautologyOptimizer { + fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { match s { Statement::Constraint(quad, lin, message) => match quad.try_linear() { Ok(l) => { diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index e3aeb8688..ad584e717 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -8,7 +8,7 @@ use num_bigint::BigUint; use std::collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt; use std::path::PathBuf; -use zokrates_ast::common::FormatString; +use zokrates_ast::common::{FormatString, SourceMetadata}; use zokrates_ast::typed::types::{GGenericsAssignment, GTupleType, GenericsAssignment}; use zokrates_ast::typed::SourceIdentifier; use zokrates_ast::typed::*; @@ -283,11 +283,12 @@ struct Scope<'ast, T> { impl<'ast, T: Field> Scope<'ast, T> { // insert into the scope and return whether we are shadowing an existing variable - fn insert( + fn insert>>( &mut self, - id: SourceIdentifier<'ast>, + id: I, info: IdentifierInfo<'ast, T, CoreIdentifier<'ast>>, ) -> bool { + let id = id.into(); let existed = self .map .get(&id) @@ -299,12 +300,12 @@ impl<'ast, T: Field> Scope<'ast, T> { } /// get the current version of this variable - fn get( + fn get>>( &self, - id: &SourceIdentifier<'ast>, + id: I, ) -> Option>> { self.map - .get(id) + .get(&id.into()) .and_then(|versions| versions.values().next_back().cloned()) } @@ -1084,14 +1085,14 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(var) } - fn id_in_this_scope(&self, id: SourceIdentifier<'ast>) -> CoreIdentifier<'ast> { + fn id_in_this_scope>>(&self, id: I) -> CoreIdentifier<'ast> { // in the semantic checker, 0 is top level, 1 is function level. For shadowing, we start with 0 at function level // hence the offset of 1 assert!( self.scope.level > 0, "CoreIdentifier cannot be declared in the global scope" ); - CoreIdentifier::from(ShadowedIdentifier::shadow(id, self.scope.level - 1)) + CoreIdentifier::from(ShadowedIdentifier::shadow(id.into(), self.scope.level - 1)) } fn check_function( @@ -1782,6 +1783,83 @@ impl<'ast, T: Field> Checker<'ast, T> { } } + fn check_assembly_statement( + &mut self, + stat: AssemblyStatementNode<'ast>, + module_id: &ModuleId, + types: &TypeMap<'ast, T>, + ) -> Result>, ErrorInner> { + let pos = stat.pos(); + + match stat.value { + AssemblyStatement::Assignment(assignee, expression, constrained) => { + let assignee = self.check_assignee(assignee, module_id, types)?; + let e = self.check_expression(expression, module_id, types)?; + + let e = FieldElementExpression::try_from_typed(e).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected right hand side of an assembly assignment to be of type field, found {}", + e.get_type(), + ), + })?; + + match constrained { + true => { + let e = FieldElementExpression::block(vec![], e); + match assignee.get_type() { + Type::FieldElement => Ok(vec![ + TypedAssemblyStatement::Assignment( + assignee.clone(), + e.clone().into(), + ), + TypedAssemblyStatement::Constraint( + assignee.into(), + e, + SourceMetadata::new(module_id.display().to_string(), pos.0), + ), + ]), + ty => Err(ErrorInner { + pos: Some(pos), + message: format!("Assignee must be of type field, found {}", ty), + }), + } + } + false => { + let e = FieldElementExpression::block(vec![], e); + Ok(vec![TypedAssemblyStatement::Assignment(assignee, e.into())]) + } + } + } + AssemblyStatement::Constraint(lhs, rhs) => { + let lhs = self.check_expression(lhs, module_id, types)?; + let rhs = self.check_expression(rhs, module_id, types)?; + + let lhs = FieldElementExpression::try_from_typed(lhs).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected left hand side of a constraint to be of type field, found {}", + e.get_type(), + ), + })?; + + let rhs = FieldElementExpression::try_from_typed(rhs).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected right hand side of a constraint to be of type field, found {}", + e.get_type(), + ), + })?; + + Ok(vec![TypedAssemblyStatement::Constraint( + lhs, + rhs, + SourceMetadata::new(module_id.display().to_string(), pos.0), + )]) + } + } + } + fn check_statement( &mut self, stat: StatementNode<'ast>, @@ -1791,6 +1869,16 @@ impl<'ast, T: Field> Checker<'ast, T> { let pos = stat.pos(); match stat.value { + Statement::Assembly(statements) => { + let mut checked_statements = vec![]; + for s in statements { + checked_statements.extend( + self.check_assembly_statement(s, module_id, types) + .map_err(|e| vec![e])?, + ); + } + Ok(TypedStatement::Assembly(checked_statements)) + } Statement::Log(l, expressions) => { let l = FormatString::from(l); @@ -2011,11 +2099,10 @@ impl<'ast, T: Field> Checker<'ast, T> { match e { TypedExpression::Boolean(e) => Ok(TypedStatement::Assertion( e, - RuntimeError::SourceAssertion(AssertionMetadata { - file: module_id.display().to_string(), - position: pos.0, - message, - }), + RuntimeError::SourceAssertion( + SourceMetadata::new(module_id.display().to_string(), pos.0) + .message(message), + ), )), e => Err(ErrorInner { pos: Some(pos), @@ -2049,7 +2136,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let pos = assignee.pos(); // check that the assignee is declared match assignee.value { - Assignee::Identifier(variable_name) => match self.scope.get(&variable_name) { + Assignee::Identifier(variable_name) => match self.scope.get(variable_name) { Some(info) => match info.is_mutable { false => Err(ErrorInner { pos: Some(assignee.pos()), @@ -2358,7 +2445,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Expression::BooleanConstant(b) => Ok(BooleanExpression::Value(b).into()), Expression::Identifier(name) => { // check that `id` is defined in the scope - match self.scope.get(&name) { + match self.scope.get(name) { Some(info) => { let id = info.id; match info.ty.clone() { @@ -3463,9 +3550,11 @@ impl<'ast, T: Field> Checker<'ast, T> { match e1 { TypedExpression::Int(e1) => Ok(IntExpression::LeftShift(box e1, box e2).into()), TypedExpression::Uint(e1) => Ok(UExpression::left_shift(e1, e2).into()), + TypedExpression::FieldElement(e1) => { + Ok(FieldElementExpression::LeftShift(box e1, box e2).into()) + } e1 => Err(ErrorInner { pos: Some(pos), - message: format!( "Cannot left-shift {} by {}", e1.get_type(), @@ -3492,9 +3581,11 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(IntExpression::RightShift(box e1, box e2).into()) } TypedExpression::Uint(e1) => Ok(UExpression::right_shift(e1, e2).into()), + TypedExpression::FieldElement(e1) => { + Ok(FieldElementExpression::RightShift(box e1, box e2).into()) + } e1 => Err(ErrorInner { pos: Some(pos), - message: format!( "Cannot right-shift {} by {}", e1.get_type(), @@ -3519,6 +3610,9 @@ impl<'ast, T: Field> Checker<'ast, T> { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { Ok(IntExpression::Or(box e1, box e2).into()) } + (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { + Ok(FieldElementExpression::Or(box e1, box e2).into()) + } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => { @@ -3526,7 +3620,6 @@ impl<'ast, T: Field> Checker<'ast, T> { } (e1, e2) => Err(ErrorInner { pos: Some(pos), - message: format!( "Cannot apply `|` to {}, {}", e1.get_type(), @@ -3551,6 +3644,9 @@ impl<'ast, T: Field> Checker<'ast, T> { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { Ok(IntExpression::And(box e1, box e2).into()) } + (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { + Ok(FieldElementExpression::And(box e1, box e2).into()) + } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => { @@ -3558,7 +3654,6 @@ impl<'ast, T: Field> Checker<'ast, T> { } (e1, e2) => Err(ErrorInner { pos: Some(pos), - message: format!( "Cannot apply `&` to {}, {}", e1.get_type(), @@ -3583,6 +3678,9 @@ impl<'ast, T: Field> Checker<'ast, T> { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { Ok(IntExpression::Xor(box e1, box e2).into()) } + (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { + Ok(FieldElementExpression::Xor(box e1, box e2).into()) + } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => { @@ -3590,7 +3688,6 @@ impl<'ast, T: Field> Checker<'ast, T> { } (e1, e2) => Err(ErrorInner { pos: Some(pos), - message: format!( "Cannot apply `^` to {}, {}", e1.get_type(), @@ -3614,14 +3711,14 @@ impl<'ast, T: Field> Checker<'ast, T> { } } - fn insert_into_scope( + fn insert_into_scope>>( &mut self, - id: SourceIdentifier<'ast>, + id: I, ty: Type<'ast, T>, is_mutable: bool, ) -> bool { let info = IdentifierInfo { - id: self.id_in_this_scope(id), + id: self.id_in_this_scope(id.clone()), ty, is_mutable, }; @@ -4742,12 +4839,12 @@ mod tests { let for_statements_checked = vec![TypedStatement::definition( typed::Variable::uint( - CoreIdentifier::Source(ShadowedIdentifier::shadow("a", 1)), + CoreIdentifier::Source(ShadowedIdentifier::shadow("a".into(), 1)), UBitwidth::B32, ) .into(), UExpression::identifier( - CoreIdentifier::Source(ShadowedIdentifier::shadow("i", 1)).into(), + CoreIdentifier::Source(ShadowedIdentifier::shadow("i".into(), 1)).into(), ) .annotate(UBitwidth::B32) .into(), @@ -4756,7 +4853,7 @@ mod tests { let foo_statements_checked = vec![ TypedStatement::For( typed::Variable::uint( - CoreIdentifier::Source(ShadowedIdentifier::shadow("i", 1)), + CoreIdentifier::Source(ShadowedIdentifier::shadow("i".into(), 1)), UBitwidth::B32, ), 0u32.into(), @@ -5302,10 +5399,7 @@ mod tests { &TypeMap::new(), ); assert!(s2_checked.is_ok()); - assert_eq!( - checker.scope.get(&"a").unwrap().ty, - DeclarationType::Boolean - ); + assert_eq!(checker.scope.get("a").unwrap().ty, DeclarationType::Boolean); } #[test] @@ -5363,16 +5457,16 @@ mod tests { let expected = vec![ TypedStatement::definition( typed::Variable::new( - CoreIdentifier::from(ShadowedIdentifier::shadow("a", 0)), + CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 0)), Type::FieldElement, true, ) .into(), - FieldElementExpression::Number(2.into()).into(), + FieldElementExpression::Number(2u32.into()).into(), ), TypedStatement::For( typed::Variable::new( - CoreIdentifier::from(ShadowedIdentifier::shadow("i", 1)), + CoreIdentifier::from(ShadowedIdentifier::shadow("i".into(), 1)), Type::Uint(UBitwidth::B32), false, ), @@ -5381,32 +5475,32 @@ mod tests { vec![ TypedStatement::definition( typed::Variable::new( - CoreIdentifier::from(ShadowedIdentifier::shadow("a", 0)), + CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 0)), Type::FieldElement, true, ) .into(), - FieldElementExpression::Number(3.into()).into(), + FieldElementExpression::Number(3u32.into()).into(), ), TypedStatement::definition( typed::Variable::new( - CoreIdentifier::from(ShadowedIdentifier::shadow("a", 1)), + CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 1)), Type::FieldElement, false, ) .into(), - FieldElementExpression::Number(4.into()).into(), + FieldElementExpression::Number(4u32.into()).into(), ), ], ), TypedStatement::definition( typed::Variable::new( - CoreIdentifier::from(ShadowedIdentifier::shadow("a", 0)), + CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 0)), Type::FieldElement, true, ) .into(), - FieldElementExpression::Number(5.into()).into(), + FieldElementExpression::Number(5u32.into()).into(), ), ]; diff --git a/zokrates_core_test/tests/tests/assembly/binary_check.json b/zokrates_core_test/tests/tests/assembly/binary_check.json new file mode 100644 index 000000000..01df2d425 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/binary_check.json @@ -0,0 +1,46 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "value": [] + } + } + }, + { + "input": { + "values": ["1"] + }, + "output": { + "Ok": { + "value": [] + } + } + }, + { + "input": { + "values": ["2"] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "error": { + "SourceAssemblyConstraint": { + "file": "tests/tests/assembly/binary_check.zok", + "position": { + "line": 3, + "col": 9 + } + } + } + } + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/binary_check.zok b/zokrates_core_test/tests/tests/assembly/binary_check.zok new file mode 100644 index 000000000..60b08d7cf --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/binary_check.zok @@ -0,0 +1,5 @@ +def main(field x) { + asm { + x * (x - 1) === 0; + } +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/binary_sub.json b/zokrates_core_test/tests/tests/assembly/binary_sub.json new file mode 100644 index 000000000..f74256be0 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/binary_sub.json @@ -0,0 +1,71 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 10, + "tests": [ + { + "input": { + "values": [ + ["0", "0", "1", "1"], + ["0", "0", "0", "1"] + ] + }, + "output": { + "Ok": { + "value": ["0", "0", "1", "0"] + } + } + }, + { + "input": { + "values": [ + ["1", "1", "1", "1"], + ["0", "1", "1", "0"] + ] + }, + "output": { + "Ok": { + "value": ["1", "0", "0", "1"] + } + } + }, + { + "input": { + "values": [ + ["0", "1", "1", "0"], + ["1", "1", "1", "1"] + ] + }, + "output": { + "Ok": { + "value": ["1", "1", "1", "0"] + } + } + }, + { + "input": { + "values": [ + ["0", "0", "0", "0"], + ["1", "1", "1", "1"] + ] + }, + "output": { + "Ok": { + "value": ["1", "0", "0", "0"] + } + } + }, + { + "input": { + "values": [ + ["1", "1", "1", "1"], + ["1", "1", "1", "1"] + ] + }, + "output": { + "Ok": { + "value": ["0", "0", "0", "0"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/binary_sub.zok b/zokrates_core_test/tests/tests/assembly/binary_sub.zok new file mode 100644 index 000000000..88d60780d --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/binary_sub.zok @@ -0,0 +1,36 @@ +// Subtraction of two binary numbers (comprising only two digits, 0 and 1) +// Precondition: caller has to ensure `a` and `b` are constrained by a binary check +def bin_sub(field[N] a, field[N] b) -> field[N] { + field mut lin = 2**N; + field mut lout = 0; + + for u32 i in 0..N { + lin = lin + a[i] * (2**i); + lin = lin - b[i] * (2**i); + } + + field[N] mut out = [0; N]; + for u32 i in 0..N { + asm { + out[i] <-- (lin >> i) & 1; + out[i] * (out[i] - 1) === 0; + } + lout = lout + out[i] * (2**i); + } + + field mut aux = 0; + asm { + aux <-- (lin >> N) & 1; + aux * (aux - 1) === 0; + } + + lout = lout + aux * (2**N); + asm { + lin === lout; + } + return out; +} + +def main(field[4] a, field[4] b) -> field[4] { + return bin_sub(a, b); +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/bitify.json b/zokrates_core_test/tests/tests/assembly/bitify.json new file mode 100644 index 000000000..516aed2d9 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/bitify.json @@ -0,0 +1,65 @@ +{ + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": ["1"] + }, + "output": { + "Ok": { + "value": ["0", "0", "0", "1"] + } + } + }, + { + "input": { + "values": ["2"] + }, + "output": { + "Ok": { + "value": ["0", "0", "1", "0"] + } + } + }, + { + "input": { + "values": ["3"] + }, + "output": { + "Ok": { + "value": ["0", "0", "1", "1"] + } + } + }, + { + "input": { + "values": ["15"] + }, + "output": { + "Ok": { + "value": ["1", "1", "1", "1"] + } + } + }, + { + "input": { + "values": ["16"] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "error": { + "SourceAssemblyConstraint": { + "file": "tests/tests/assembly/bitify.zok", + "position": { + "line": 13, + "col": 9 + } + } + } + } + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/bitify.zok b/zokrates_core_test/tests/tests/assembly/bitify.zok new file mode 100644 index 000000000..ae1f01a5e --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/bitify.zok @@ -0,0 +1,20 @@ +def bitify(field num) -> field[N] { + field[N] mut out = [0; N]; + field mut aux = 0; + for u32 i in 0..N { + u32 j = N - i - 1; + asm { + out[i] <-- (num >> j) & 1; + out[i] * (out[i] - 1) === 0; + } + aux = aux + out[i] * (2**j); + } + asm { + aux === num; + } + return out; +} + +def main(field input) -> field[4] { + return bitify(input); +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/condition.json b/zokrates_core_test/tests/tests/assembly/condition.json new file mode 100644 index 000000000..83d535f24 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/condition.json @@ -0,0 +1,26 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": ["1", "1", "0"] + }, + "output": { + "Ok": { + "value": "1" + } + } + }, + { + "input": { + "values": ["0", "1", "0"] + }, + "output": { + "Ok": { + "value": "0" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/condition.zok b/zokrates_core_test/tests/tests/assembly/condition.zok new file mode 100644 index 000000000..05aff5e3e --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/condition.zok @@ -0,0 +1,7 @@ +def main(field c, field l, field r) -> field { + field mut out = 0; + asm { + out <== c * l + (1 - c) * r; + } + return out; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/constraint.json b/zokrates_core_test/tests/tests/assembly/constraint.json new file mode 100644 index 000000000..89dddc24f --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/constraint.json @@ -0,0 +1,36 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": ["1", "1"] + }, + "output": { + "Ok": { + "value": [] + } + } + }, + { + "input": { + "values": ["0", "1"] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "error": { + "SourceAssemblyConstraint": { + "file": "tests/tests/assembly/constraint.zok", + "position": { + "line": 3, + "col": 9 + } + } + } + } + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/constraint.zok b/zokrates_core_test/tests/tests/assembly/constraint.zok new file mode 100644 index 000000000..2be7b61b1 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/constraint.zok @@ -0,0 +1,5 @@ +def main(field x, field y) { + asm { + x === y; + } +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/division.json b/zokrates_core_test/tests/tests/assembly/division.json new file mode 100644 index 000000000..5041b3a0a --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/division.json @@ -0,0 +1,36 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 3, + "tests": [ + { + "input": { + "values": ["4", "2"] + }, + "output": { + "Ok": { + "value": "2" + } + } + }, + { + "input": { + "values": ["0", "0"] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "error": { + "SourceAssemblyConstraint": { + "file": "tests/tests/assembly/division.zok", + "position": { + "line": 6, + "col": 9 + } + } + } + } + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/division.zok b/zokrates_core_test/tests/tests/assembly/division.zok new file mode 100644 index 000000000..b9f65af22 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/division.zok @@ -0,0 +1,11 @@ +def main(field a, field b) -> field { + field mut c = 0; + field mut invb = 0; + asm { + invb <-- b == 0 ? 0 : 1 / b; + invb * b === 1; + c <-- invb * a; + a === b * c; + } + return c; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/gates/and.json b/zokrates_core_test/tests/tests/assembly/gates/and.json new file mode 100644 index 000000000..9002c6c8a --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/gates/and.json @@ -0,0 +1,46 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": ["0", "0"] + }, + "output": { + "Ok": { + "value": "0" + } + } + }, + { + "input": { + "values": ["1", "0"] + }, + "output": { + "Ok": { + "value": "0" + } + } + }, + { + "input": { + "values": ["0", "1"] + }, + "output": { + "Ok": { + "value": "0" + } + } + }, + { + "input": { + "values": ["1", "1"] + }, + "output": { + "Ok": { + "value": "1" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/gates/and.zok b/zokrates_core_test/tests/tests/assembly/gates/and.zok new file mode 100644 index 000000000..708e95451 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/gates/and.zok @@ -0,0 +1,7 @@ +def main(field a, field b) -> field { + field mut c = 0; + asm { + c <== a * b; + } + return c; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/gates/not.json b/zokrates_core_test/tests/tests/assembly/gates/not.json new file mode 100644 index 000000000..92ab00e11 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/gates/not.json @@ -0,0 +1,26 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "value": "1" + } + } + }, + { + "input": { + "values": ["1"] + }, + "output": { + "Ok": { + "value": "0" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/gates/not.zok b/zokrates_core_test/tests/tests/assembly/gates/not.zok new file mode 100644 index 000000000..fa7cdd76e --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/gates/not.zok @@ -0,0 +1,7 @@ +def main(field inp) -> field { + field mut out = 0; + asm { + out <== 1 - inp; + } + return out; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/gates/or.json b/zokrates_core_test/tests/tests/assembly/gates/or.json new file mode 100644 index 000000000..72cb84223 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/gates/or.json @@ -0,0 +1,46 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": ["0", "0"] + }, + "output": { + "Ok": { + "value": "0" + } + } + }, + { + "input": { + "values": ["1", "0"] + }, + "output": { + "Ok": { + "value": "1" + } + } + }, + { + "input": { + "values": ["0", "1"] + }, + "output": { + "Ok": { + "value": "1" + } + } + }, + { + "input": { + "values": ["1", "1"] + }, + "output": { + "Ok": { + "value": "1" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/gates/or.zok b/zokrates_core_test/tests/tests/assembly/gates/or.zok new file mode 100644 index 000000000..f5969b3fd --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/gates/or.zok @@ -0,0 +1,7 @@ +def main(field a, field b) -> field { + field mut c = 0; + asm { + c <== a + b - a*b; + } + return c; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/gates/xor.json b/zokrates_core_test/tests/tests/assembly/gates/xor.json new file mode 100644 index 000000000..14908d6a0 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/gates/xor.json @@ -0,0 +1,46 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": ["0", "0"] + }, + "output": { + "Ok": { + "value": "0" + } + } + }, + { + "input": { + "values": ["1", "0"] + }, + "output": { + "Ok": { + "value": "1" + } + } + }, + { + "input": { + "values": ["0", "1"] + }, + "output": { + "Ok": { + "value": "1" + } + } + }, + { + "input": { + "values": ["1", "1"] + }, + "output": { + "Ok": { + "value": "0" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/gates/xor.zok b/zokrates_core_test/tests/tests/assembly/gates/xor.zok new file mode 100644 index 000000000..ef49721e2 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/gates/xor.zok @@ -0,0 +1,7 @@ +def main(field a, field b) -> field { + field mut c = 0; + asm { + c <== a + b - 2*a*b; + } + return c; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/is_equal.json b/zokrates_core_test/tests/tests/assembly/is_equal.json new file mode 100644 index 000000000..597a5ece3 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/is_equal.json @@ -0,0 +1,36 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 3, + "tests": [ + { + "input": { + "values": ["1", "1"] + }, + "output": { + "Ok": { + "value": "1" + } + } + }, + { + "input": { + "values": ["2", "4"] + }, + "output": { + "Ok": { + "value": "0" + } + } + }, + { + "input": { + "values": ["4", "2"] + }, + "output": { + "Ok": { + "value": "0" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/is_equal.zok b/zokrates_core_test/tests/tests/assembly/is_equal.zok new file mode 100644 index 000000000..2992008aa --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/is_equal.zok @@ -0,0 +1,5 @@ +import "./is_zero.zok"; + +def main(field a, field b) -> field { + return is_zero(b - a); +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/is_zero.json b/zokrates_core_test/tests/tests/assembly/is_zero.json new file mode 100644 index 000000000..ad409b2cc --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/is_zero.json @@ -0,0 +1,26 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 3, + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "value": "1" + } + } + }, + { + "input": { + "values": ["1"] + }, + "output": { + "Ok": { + "value": "0" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/is_zero.zok b/zokrates_core_test/tests/tests/assembly/is_zero.zok new file mode 100644 index 000000000..b366620df --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/is_zero.zok @@ -0,0 +1,10 @@ +def main(field inp) -> field { + field mut out = 0; + field mut inv = 0; + asm { + inv <-- inp != 0 ? 1 / inp : 0; + out <== -inp * inv + 1; + inp * out === 0; + } + return out; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/less_than.json b/zokrates_core_test/tests/tests/assembly/less_than.json new file mode 100644 index 000000000..91eaaa073 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/less_than.json @@ -0,0 +1,36 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 7, + "tests": [ + { + "input": { + "values": ["2", "2"] + }, + "output": { + "Ok": { + "value": false + } + } + }, + { + "input": { + "values": ["4", "2"] + }, + "output": { + "Ok": { + "value": false + } + } + }, + { + "input": { + "values": ["2", "4"] + }, + "output": { + "Ok": { + "value": true + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/less_than.zok b/zokrates_core_test/tests/tests/assembly/less_than.zok new file mode 100644 index 000000000..6d307aeed --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/less_than.zok @@ -0,0 +1,14 @@ +from "EMBED" import field_to_bool_unsafe; +from "field" import FIELD_SIZE_IN_BITS; +from "./bitify" import bitify; + +def less_than(field a, field b) -> bool { + assert(N < FIELD_SIZE_IN_BITS - 1); + field[N + 1] bits = bitify(a + 2**N - b); + bool out = field_to_bool_unsafe(1 - bits[0]); + return out; +} + +def main(field a, field b) -> bool { + return less_than::<4>(a, b); +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/propagation/array_rewrite.json b/zokrates_core_test/tests/tests/assembly/propagation/array_rewrite.json new file mode 100644 index 000000000..307badbc6 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/array_rewrite.json @@ -0,0 +1,16 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": ["2"] + }, + "output": { + "Ok": { + "value": ["4"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/propagation/array_rewrite.zok b/zokrates_core_test/tests/tests/assembly/propagation/array_rewrite.zok new file mode 100644 index 000000000..d6b9fec58 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/array_rewrite.zok @@ -0,0 +1,8 @@ +def main(field x) -> field[1] { + field[1] mut a = [x]; + asm { + a[0] <-- a[0] + a[0]; + a[0] === x + x; + } + return a; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/propagation/array_write_constant.json b/zokrates_core_test/tests/tests/assembly/propagation/array_write_constant.json new file mode 100644 index 000000000..bfd4a7277 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/array_write_constant.json @@ -0,0 +1,16 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "value": ["0", "2"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/propagation/array_write_constant.zok b/zokrates_core_test/tests/tests/assembly/propagation/array_write_constant.zok new file mode 100644 index 000000000..6ced6eabb --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/array_write_constant.zok @@ -0,0 +1,9 @@ +def main() -> field[2] { + field[2] mut a = [1, 2]; + u32 i = 0; + asm { + a[i] <-- 0; + a[i] === 0; + } + return a; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/propagation/array_write_variable.json b/zokrates_core_test/tests/tests/assembly/propagation/array_write_variable.json new file mode 100644 index 000000000..854325640 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/array_write_variable.json @@ -0,0 +1,16 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": ["42"] + }, + "output": { + "Ok": { + "value": ["42", "2"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/propagation/array_write_variable.zok b/zokrates_core_test/tests/tests/assembly/propagation/array_write_variable.zok new file mode 100644 index 000000000..e114e85be --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/array_write_variable.zok @@ -0,0 +1,9 @@ +def main(field v) -> field[2] { + field[2] mut a = [1, 2]; + u32 i = 0; + asm { + a[i] <-- v; + a[i] === v; + } + return a; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/propagation/definition.json b/zokrates_core_test/tests/tests/assembly/propagation/definition.json new file mode 100644 index 000000000..40a30f3ca --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/definition.json @@ -0,0 +1,5 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 0, + "tests": [] +} diff --git a/zokrates_core_test/tests/tests/assembly/propagation/definition.zok b/zokrates_core_test/tests/tests/assembly/propagation/definition.zok new file mode 100644 index 000000000..f63743a83 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/definition.zok @@ -0,0 +1,8 @@ +def main() { + field mut a = 0; + asm { + a <-- 1; + a === 1; + } + return; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/propagation/empty.json b/zokrates_core_test/tests/tests/assembly/propagation/empty.json new file mode 100644 index 000000000..40a30f3ca --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/empty.json @@ -0,0 +1,5 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 0, + "tests": [] +} diff --git a/zokrates_core_test/tests/tests/assembly/propagation/empty.zok b/zokrates_core_test/tests/tests/assembly/propagation/empty.zok new file mode 100644 index 000000000..33152ccc4 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/empty.zok @@ -0,0 +1,5 @@ +def main() { + asm { + } + return; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/propagation/redefinition.json b/zokrates_core_test/tests/tests/assembly/propagation/redefinition.json new file mode 100644 index 000000000..40a30f3ca --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/redefinition.json @@ -0,0 +1,5 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 0, + "tests": [] +} diff --git a/zokrates_core_test/tests/tests/assembly/propagation/redefinition.zok b/zokrates_core_test/tests/tests/assembly/propagation/redefinition.zok new file mode 100644 index 000000000..5643e5274 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/redefinition.zok @@ -0,0 +1,10 @@ +def main() { + field mut a = 0; + field mut b = 0; + asm { + a <-- 1; + b <-- a; + b === 1; + } + return; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/propagation/struct_write_variable.json b/zokrates_core_test/tests/tests/assembly/propagation/struct_write_variable.json new file mode 100644 index 000000000..95caa1b2d --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/struct_write_variable.json @@ -0,0 +1,23 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 4, + "tests": [ + { + "input": { + "values": [ + { + "a": ["0", "0"] + }, + "42" + ] + }, + "output": { + "Ok": { + "value": { + "a": ["42", "84"] + } + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/propagation/struct_write_variable.zok b/zokrates_core_test/tests/tests/assembly/propagation/struct_write_variable.zok new file mode 100644 index 000000000..f1fa62f53 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/struct_write_variable.zok @@ -0,0 +1,12 @@ +struct Foo { + field[2] a; +} + +def main(Foo mut foo, field v) -> Foo { + u32 i = 0; + asm { + foo.a[i] <== v; + foo.a[i + 1] <== foo.a[i] * 2; + } + return foo; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/propagation/variable_mutation.json b/zokrates_core_test/tests/tests/assembly/propagation/variable_mutation.json new file mode 100644 index 000000000..40a30f3ca --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/variable_mutation.json @@ -0,0 +1,5 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 0, + "tests": [] +} diff --git a/zokrates_core_test/tests/tests/assembly/propagation/variable_mutation.zok b/zokrates_core_test/tests/tests/assembly/propagation/variable_mutation.zok new file mode 100644 index 000000000..fde2e15ef --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/variable_mutation.zok @@ -0,0 +1,6 @@ +def main(field mut x, field y) { + asm { + x <-- y; // `x` is mutated here and gets a new ssa version + x === y; // so this constraint is going to be removed by the redefinition optimizer + } +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/propagation/write_variable.json b/zokrates_core_test/tests/tests/assembly/propagation/write_variable.json new file mode 100644 index 000000000..f827f1db2 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/write_variable.json @@ -0,0 +1,16 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": ["42"] + }, + "output": { + "Ok": { + "value": "42" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/propagation/write_variable.zok b/zokrates_core_test/tests/tests/assembly/propagation/write_variable.zok new file mode 100644 index 000000000..9585140ab --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/propagation/write_variable.zok @@ -0,0 +1,8 @@ +def main(field v) -> field { + field mut a = 0; + asm { + a <-- v; + a === v; + } + return a; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assembly/sha256_maj.json b/zokrates_core_test/tests/tests/assembly/sha256_maj.json new file mode 100644 index 000000000..f67b210a6 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/sha256_maj.json @@ -0,0 +1,62 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 6, + "tests": [ + { + "input": { + "values": [ + ["0", "1"], + ["1", "0"], + ["1", "1"] + ] + }, + "output": { + "Ok": { + "value": ["1", "1"] + } + } + }, + { + "input": { + "values": [ + ["1", "0"], + ["1", "0"], + ["1", "0"] + ] + }, + "output": { + "Ok": { + "value": ["1", "0"] + } + } + }, + { + "input": { + "values": [ + ["1", "0"], + ["0", "1"], + ["0", "0"] + ] + }, + "output": { + "Ok": { + "value": ["0", "0"] + } + } + }, + { + "input": { + "values": [ + ["1", "0"], + ["1", "0"], + ["1", "0"] + ] + }, + "output": { + "Ok": { + "value": ["1", "0"] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/assembly/sha256_maj.zok b/zokrates_core_test/tests/tests/assembly/sha256_maj.zok new file mode 100644 index 000000000..9d3f7d4e1 --- /dev/null +++ b/zokrates_core_test/tests/tests/assembly/sha256_maj.zok @@ -0,0 +1,17 @@ +// Maj function for sha256 => (a & b) ^ (a & c) ^ (b & c) +// Precondition: caller has to ensure inputs are constrained by a binary check +def maj(field[N] a, field[N] b, field[N] c) -> field[N] { + field[N] mut out = [0; N]; + field[N] mut bc = [0; N]; + for u32 i in 0..N { + asm { + bc[i] <== b[i] * c[i]; + out[i] <== a[i] * (b[i] + c[i] - 2*bc[i]) + bc[i]; + } + } + return out; +} + +def main(field[2] a, field[2] b, field[2] c) -> field[2] { + return maj(a, b, c); +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/assert_array_equality.json b/zokrates_core_test/tests/tests/assert_array_equality.json index 783f433b0..0abafed8d 100644 --- a/zokrates_core_test/tests/tests/assert_array_equality.json +++ b/zokrates_core_test/tests/tests/assert_array_equality.json @@ -22,7 +22,13 @@ "left": "0", "right": "1", "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/assert_array_equality.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/assert_array_equality.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/assert_one.json b/zokrates_core_test/tests/tests/assert_one.json index ff1b80946..d11e4e1b7 100644 --- a/zokrates_core_test/tests/tests/assert_one.json +++ b/zokrates_core_test/tests/tests/assert_one.json @@ -12,7 +12,13 @@ "left": "0", "right": "1", "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/assert_one.zok:2:2" + "SourceAssertion": { + "file": "./tests/tests/assert_one.zok", + "position": { + "line": 2, + "col": 2 + } + } } } } diff --git a/zokrates_core_test/tests/tests/embed/field_to_bool.json b/zokrates_core_test/tests/tests/embed/field_to_bool.json new file mode 100644 index 000000000..d0f8980d8 --- /dev/null +++ b/zokrates_core_test/tests/tests/embed/field_to_bool.json @@ -0,0 +1,46 @@ +{ + "curves": ["Bn128"], + "max_constraint_count": 2, + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "value": false + } + } + }, + { + "input": { + "values": ["1"] + }, + "output": { + "Ok": { + "value": true + } + } + }, + { + "input": { + "values": ["2"] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "error": { + "SourceAssemblyConstraint": { + "file": "tests/tests/embed/field_to_bool.zok", + "position": { + "line": 5, + "col": 9 + } + } + } + } + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/embed/field_to_bool.zok b/zokrates_core_test/tests/tests/embed/field_to_bool.zok new file mode 100644 index 000000000..ebf77e9b7 --- /dev/null +++ b/zokrates_core_test/tests/tests/embed/field_to_bool.zok @@ -0,0 +1,9 @@ +from "EMBED" import field_to_bool_unsafe; + +def main(field x) -> bool { + asm { + x * (x - 1) === 0; + } + bool out = field_to_bool_unsafe(x); + return out; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/field_bitwise_op.json b/zokrates_core_test/tests/tests/field_bitwise_op.json new file mode 100644 index 000000000..2960fe4c8 --- /dev/null +++ b/zokrates_core_test/tests/tests/field_bitwise_op.json @@ -0,0 +1,15 @@ +{ + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "value": "4" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/field_bitwise_op.zok b/zokrates_core_test/tests/tests/field_bitwise_op.zok new file mode 100644 index 000000000..cecae79cf --- /dev/null +++ b/zokrates_core_test/tests/tests/field_bitwise_op.zok @@ -0,0 +1,4 @@ +def main() -> field { + field a = 1 << 2; // constant bitwise operations are allowed as they get propagated away + return a; +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/panics/conditional_bound_throw_no_isolation.json b/zokrates_core_test/tests/tests/panics/conditional_bound_throw_no_isolation.json index 820ffbeb1..a9ebff5d6 100644 --- a/zokrates_core_test/tests/tests/panics/conditional_bound_throw_no_isolation.json +++ b/zokrates_core_test/tests/tests/panics/conditional_bound_throw_no_isolation.json @@ -12,7 +12,13 @@ "left": "0", "right": "1", "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/panics/conditional_bound_throw.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/panics/conditional_bound_throw.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } @@ -28,7 +34,13 @@ "left": "1", "right": "0", "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/panics/conditional_bound_throw.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/panics/conditional_bound_throw.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } @@ -44,7 +56,13 @@ "left": "2", "right": "0", "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/panics/conditional_bound_throw.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/panics/conditional_bound_throw.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/panics/deep_branch_no_isolation.json b/zokrates_core_test/tests/tests/panics/deep_branch_no_isolation.json index 1a10b3853..430a2c180 100644 --- a/zokrates_core_test/tests/tests/panics/deep_branch_no_isolation.json +++ b/zokrates_core_test/tests/tests/panics/deep_branch_no_isolation.json @@ -12,7 +12,13 @@ "left": "0", "right": "1", "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/panics/deep_branch.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/panics/deep_branch.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/panics/loop_bound.json b/zokrates_core_test/tests/tests/panics/loop_bound.json index 5caf89e0a..b70feec4b 100644 --- a/zokrates_core_test/tests/tests/panics/loop_bound.json +++ b/zokrates_core_test/tests/tests/panics/loop_bound.json @@ -12,7 +12,13 @@ "left": "0", "right": "1", "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/panics/loop_bound.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/panics/loop_bound.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/panics/panic_isolation.json b/zokrates_core_test/tests/tests/panics/panic_isolation.json index 434b45823..b928bcd17 100644 --- a/zokrates_core_test/tests/tests/panics/panic_isolation.json +++ b/zokrates_core_test/tests/tests/panics/panic_isolation.json @@ -15,7 +15,13 @@ "left": "1", "right": "21888242871839275222246405745257275088548364400416034343698204186575808495577", "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/panics/panic_isolation.zok:22:5" + "SourceAssertion": { + "file": "./tests/tests/panics/panic_isolation.zok", + "position": { + "line": 22, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/panics/panic_no_isolation.json b/zokrates_core_test/tests/tests/panics/panic_no_isolation.json index 349358227..107f79cd4 100644 --- a/zokrates_core_test/tests/tests/panics/panic_no_isolation.json +++ b/zokrates_core_test/tests/tests/panics/panic_no_isolation.json @@ -15,7 +15,13 @@ "left": "1", "right": "0", "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/panics/panic_isolation.zok:17:5" + "SourceAssertion": { + "file": "./tests/tests/panics/panic_isolation.zok", + "position": { + "line": 17, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/range_check/assert_ge.json b/zokrates_core_test/tests/tests/range_check/assert_ge.json index 9cd01522f..21f8219b0 100644 --- a/zokrates_core_test/tests/tests/range_check/assert_ge.json +++ b/zokrates_core_test/tests/tests/range_check/assert_ge.json @@ -11,7 +11,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_ge.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_ge.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } @@ -25,7 +31,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_ge.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_ge.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/range_check/assert_gt.json b/zokrates_core_test/tests/tests/range_check/assert_gt.json index ea3314b06..7e800f8bb 100644 --- a/zokrates_core_test/tests/tests/range_check/assert_gt.json +++ b/zokrates_core_test/tests/tests/range_check/assert_gt.json @@ -11,7 +11,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_gt.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_gt.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } @@ -25,7 +31,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_gt.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_gt.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } @@ -39,7 +51,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_gt.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_gt.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/range_check/assert_gt_big_constant.json b/zokrates_core_test/tests/tests/range_check/assert_gt_big_constant.json index 0d92e7b45..37df775ca 100644 --- a/zokrates_core_test/tests/tests/range_check/assert_gt_big_constant.json +++ b/zokrates_core_test/tests/tests/range_check/assert_gt_big_constant.json @@ -11,7 +11,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_gt_big_constant.zok:4:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_gt_big_constant.zok", + "position": { + "line": 4, + "col": 5 + } + } } } } @@ -27,7 +33,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_gt_big_constant.zok:4:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_gt_big_constant.zok", + "position": { + "line": 4, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/range_check/assert_le.json b/zokrates_core_test/tests/tests/range_check/assert_le.json index 27502fdc9..694379525 100644 --- a/zokrates_core_test/tests/tests/range_check/assert_le.json +++ b/zokrates_core_test/tests/tests/range_check/assert_le.json @@ -31,7 +31,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_le.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_le.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } @@ -45,7 +51,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_le.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_le.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/range_check/assert_lt.json b/zokrates_core_test/tests/tests/range_check/assert_lt.json index bdaee03b4..05c0e90eb 100644 --- a/zokrates_core_test/tests/tests/range_check/assert_lt.json +++ b/zokrates_core_test/tests/tests/range_check/assert_lt.json @@ -31,7 +31,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_lt.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_lt.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } @@ -45,7 +51,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_lt.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_lt.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/range_check/assert_lt_big_constant.json b/zokrates_core_test/tests/tests/range_check/assert_lt_big_constant.json index 9baeb5aa3..2d3b1b107 100644 --- a/zokrates_core_test/tests/tests/range_check/assert_lt_big_constant.json +++ b/zokrates_core_test/tests/tests/range_check/assert_lt_big_constant.json @@ -35,7 +35,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_lt_big_constant.zok:4:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_lt_big_constant.zok", + "position": { + "line": 4, + "col": 5 + } + } } } } diff --git a/zokrates_core_test/tests/tests/range_check/assert_lt_u8.json b/zokrates_core_test/tests/tests/range_check/assert_lt_u8.json index 3c2107c4c..cad68703b 100644 --- a/zokrates_core_test/tests/tests/range_check/assert_lt_u8.json +++ b/zokrates_core_test/tests/tests/range_check/assert_lt_u8.json @@ -31,7 +31,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_lt_u8.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_lt_u8.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } @@ -45,7 +51,13 @@ "Err": { "UnsatisfiedConstraint": { "error": { - "SourceAssertion": "Assertion failed at ./tests/tests/range_check/assert_lt_u8.zok:2:5" + "SourceAssertion": { + "file": "./tests/tests/range_check/assert_lt_u8.zok", + "position": { + "line": 2, + "col": 5 + } + } } } } diff --git a/zokrates_interpreter/Cargo.toml b/zokrates_interpreter/Cargo.toml index 290962557..7e36f1f9f 100644 --- a/zokrates_interpreter/Cargo.toml +++ b/zokrates_interpreter/Cargo.toml @@ -13,6 +13,7 @@ zokrates_field = { version = "0.5", path = "../zokrates_field", default-features zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false } zokrates_abi = { version = "0.1", path = "../zokrates_abi", default-features = false } +zokrates_analysis = { version = "0.1", path = "../zokrates_analysis", default-features = false } num = { version = "0.1.36", default-features = false } num-bigint = { version = "0.2", default-features = false } diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 3776d84fa..f33ab3de4 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -1,9 +1,11 @@ use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::fmt; use zokrates_abi::{Decode, Value}; use zokrates_ast::ir::{ LinComb, ProgIterator, QuadComb, RuntimeError, Solver, Statement, Variable, Witness, }; +use zokrates_ast::zir; use zokrates_field::Field; pub type ExecutionResult = Result, Error>; @@ -24,21 +26,22 @@ impl Interpreter { } impl Interpreter { - pub fn execute>>( + pub fn execute<'ast, T: Field, I: IntoIterator>>( &self, - program: ProgIterator, + program: ProgIterator<'ast, T, I>, inputs: &[T], ) -> ExecutionResult { self.execute_with_log_stream(program, inputs, &mut std::io::sink()) } pub fn execute_with_log_stream< + 'ast, W: std::io::Write, T: Field, - I: IntoIterator>, + I: IntoIterator>, >( &self, - program: ProgIterator, + program: ProgIterator<'ast, T, I>, inputs: &[T], log_stream: &mut W, ) -> ExecutionResult { @@ -52,6 +55,7 @@ impl Interpreter { for statement in program.statements.into_iter() { match statement { + Statement::Block(..) => unreachable!(), Statement::Constraint(quad, lin, error) => match lin.is_assignee(&witness) { true => { let val = evaluate_quad(&witness, &quad).unwrap(); @@ -81,7 +85,7 @@ impl Interpreter { } _ => Self::execute_solver(&d.solver, &inputs), } - .map_err(|_| Error::Solver)?; + .map_err(Error::Solver)?; for (i, o) in d.outputs.iter().enumerate() { witness.insert(*o, res[i].clone()); @@ -142,9 +146,9 @@ impl Interpreter { .collect() } - fn check_inputs>, U>( + fn check_inputs<'ast, T: Field, I: IntoIterator>, U>( &self, - program: &ProgIterator, + program: &ProgIterator<'ast, T, I>, inputs: &[U], ) -> Result<(), Error> { if program.arguments.len() == inputs.len() { @@ -157,11 +161,79 @@ impl Interpreter { } } - pub fn execute_solver(solver: &Solver, inputs: &[T]) -> Result, String> { + pub fn execute_solver<'ast, T: Field>( + solver: &Solver<'ast, T>, + inputs: &[T], + ) -> Result, String> { let (expected_input_count, expected_output_count) = solver.get_signature(); assert_eq!(inputs.len(), expected_input_count); let res = match solver { + Solver::Zir(func) => { + use zokrates_ast::zir::result_folder::ResultFolder; + assert_eq!(func.arguments.len(), inputs.len()); + + let constants = func + .arguments + .iter() + .zip(inputs) + .map(|(a, v)| match &a.id._type { + zir::Type::FieldElement => Ok(( + a.id.id.clone(), + zokrates_ast::zir::FieldElementExpression::Number(v.clone()).into(), + )), + zir::Type::Boolean => match v { + v if *v == T::from(0) => Ok(( + a.id.id.clone(), + zokrates_ast::zir::BooleanExpression::Value(false).into(), + )), + v if *v == T::from(1) => Ok(( + a.id.id.clone(), + zokrates_ast::zir::BooleanExpression::Value(true).into(), + )), + v => Err(format!("`{}` has unexpected value `{}`", a.id, v)), + }, + zir::Type::Uint(bitwidth) => match v.bits() <= bitwidth.to_usize() as u32 { + true => Ok(( + a.id.id.clone(), + zokrates_ast::zir::UExpressionInner::Value( + v.to_dec_string().parse::().unwrap(), + ) + .annotate(*bitwidth) + .into(), + )), + false => Err(format!( + "`{}` has unexpected bitwidth (got {} but expected {})", + a.id, + v.bits(), + bitwidth + )), + }, + }) + .collect::, _>>()?; + + let mut propagator = zokrates_analysis::ZirPropagator::with_constants(constants); + + let folded_function = propagator + .fold_function(func.clone()) + .map_err(|e| e.to_string())?; + + assert_eq!(folded_function.statements.len(), 1); + if let zokrates_ast::zir::ZirStatement::Return(v) = + folded_function.statements[0].clone() + { + v.into_iter() + .map(|v| match v { + zokrates_ast::zir::ZirExpression::FieldElement( + zokrates_ast::zir::FieldElementExpression::Number(n), + ) => n, + _ => unreachable!(), + }) + .collect() + } else { + unreachable!() + } + } Solver::ConditionEq => match inputs[0].is_zero() { true => vec![T::zero(), T::one()], false => vec![ @@ -276,7 +348,7 @@ pub struct EvaluationError; #[derive(PartialEq, Eq, Clone, Serialize, Deserialize)] pub enum Error { UnsatisfiedConstraint { error: Option }, - Solver, + Solver(String), WrongInputCount { expected: usize, received: usize }, LogStream, } @@ -319,7 +391,7 @@ impl fmt::Display for Error { _ => write!(f, ""), } } - Error::Solver => write!(f, ""), + Error::Solver(ref e) => write!(f, "Solver error: {}", e), Error::WrongInputCount { expected, received } => write!( f, "Program takes {} input{} but was passed {} value{}", diff --git a/zokrates_js/index.d.ts b/zokrates_js/index.d.ts index 979603e86..7c7bc8e05 100644 --- a/zokrates_js/index.d.ts +++ b/zokrates_js/index.d.ts @@ -60,6 +60,7 @@ declare module "zokrates-js" { snarkjs?: { program: Uint8Array; }; + constraintCount?: number; } export interface SetupKeypair { diff --git a/zokrates_js/lib.js b/zokrates_js/lib.js index 717541231..b3e1ff4ef 100644 --- a/zokrates_js/lib.js +++ b/zokrates_js/lib.js @@ -16,6 +16,7 @@ module.exports = (pkg) => { { program: ptr.program(), abi: ptr.abi(), + constraintCount: ptr.constraint_count(), }, snarkjs ? { snarkjs: { program: ptr.snarkjs_program() } } : {} ); diff --git a/zokrates_js/package-lock.json b/zokrates_js/package-lock.json index 1e3c42f27..b15cb706f 100644 --- a/zokrates_js/package-lock.json +++ b/zokrates_js/package-lock.json @@ -6,13 +6,13 @@ "packages": { "": { "name": "zokrates-js", - "version": "1.1.1", + "version": "1.1.2", "license": "GPLv3", "devDependencies": { "dree": "^2.6.1", "mocha": "^9.2.0", "rimraf": "^3.0.2", - "snarkjs": "^0.4.24", + "snarkjs": "^0.4.25", "wasm-pack": "^0.10.2" } }, diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index cf24252e0..07082c3e5 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -20,10 +20,8 @@ use zokrates_ast::typed::types::{ConcreteSignature, ConcreteType, GTupleType}; use zokrates_bellman::Bellman; use zokrates_circom::{write_r1cs, write_witness}; use zokrates_common::helpers::{BackendParameter, CurveParameter, SchemeParameter}; -use zokrates_common::Resolver; -use zokrates_core::compile::{ - compile as core_compile, CompilationArtifacts, CompileConfig, CompileError, -}; +use zokrates_common::{CompileConfig, Resolver}; +use zokrates_core::compile::{compile as core_compile, CompilationArtifacts, CompileError}; use zokrates_core::imports::Error; use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field}; use zokrates_proof_systems::groth16::G16; @@ -43,6 +41,7 @@ pub struct CompilationResult { program: Vec, abi: Abi, snarkjs_program: Option>, + constraint_count: u32, } #[wasm_bindgen] @@ -63,6 +62,10 @@ impl CompilationResult { arr }) } + + pub fn constraint_count(&self) -> JsValue { + JsValue::from_serde(&self.constraint_count).unwrap() + } } #[derive(Serialize, Deserialize)] @@ -255,6 +258,7 @@ mod internal { let abi = artifacts.abi().clone(); let program = artifacts.prog().collect(); + let constraint_count = program.constraint_count() as u32; let snarkjs_program = with_snarkjs_program.then(|| { let mut buffer = Cursor::new(vec![]); write_r1cs(&mut buffer, program.clone()).unwrap(); @@ -267,6 +271,7 @@ mod internal { abi, program: buffer.into_inner(), snarkjs_program, + constraint_count, }) } @@ -349,13 +354,14 @@ mod internal { } pub fn setup_universal< + 'a, T: Field, - I: IntoIterator>, + I: IntoIterator>, S: UniversalScheme + Serialize, B: UniversalBackend, >( srs: &[u8], - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, ) -> Result { let keypair = B::setup(srs.to_vec(), program).map_err(|e| JsValue::from_str(&e))?; Ok(JsValue::from_serde(&TaggedKeypair::::new(keypair)).unwrap()) diff --git a/zokrates_js/tests/tests.js b/zokrates_js/tests/tests.js index 86ed3bcd0..8202baa12 100644 --- a/zokrates_js/tests/tests.js +++ b/zokrates_js/tests/tests.js @@ -44,6 +44,7 @@ describe("tests", () => { ); assert.ok(artifacts); assert.ok(artifacts.snarkjs === undefined); + assert.equal(artifacts.constraintCount, 1); }); }); diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index b3de98d53..af4998e09 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -52,7 +52,7 @@ _mut = {"mut"} // Statements -statement = { (iteration_statement // does not require semicolon +statement = { (iteration_statement | asm_statement // does not require semicolon | ((log_statement |return_statement | definition_statement @@ -66,6 +66,16 @@ return_statement = { "return" ~ expression? } definition_statement = { typed_identifier_or_assignee ~ "=" ~ expression } assertion_statement = {"assert" ~ "(" ~ expression ~ ("," ~ quoted_string)? ~ ")"} +op_asm_assign = @{"<--"} +op_asm_assign_constrain = @{"<=="} +op_asm = { op_asm_assign | op_asm_assign_constrain } + +asm_assignment = { assignee ~ op_asm ~ expression } +asm_constraint = { expression ~ "===" ~ expression } + +asm_statement_inner = { (asm_assignment | asm_constraint) ~ semicolon } +asm_statement = { "asm" ~ "{" ~ NEWLINE* ~ asm_statement_inner* ~ NEWLINE* ~ "}" } + typed_identifier_or_assignee = { typed_identifier | assignee } // Expressions diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index f528cb2bd..65268b304 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -8,11 +8,12 @@ use zokrates_parser::Rule; extern crate lazy_static; pub use ast::{ - Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, - Assignee, AssigneeAccess, BasicOrStructOrTupleType, BasicType, BinaryExpression, - BinaryOperator, CallAccess, ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, - DecimalNumber, DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, - File, FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression, + Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssemblyStatement, + AssemblyStatementInner, AssertionStatement, Assignee, AssigneeAccess, AssignmentOperator, + BasicOrStructOrTupleType, BasicType, BinaryExpression, BinaryOperator, CallAccess, + ConstantDefinition, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber, + DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldType, File, + FromExpression, FunctionDefinition, HexLiteralExpression, HexNumberExpression, IdentifierExpression, IdentifierOrDecimal, IfElseExpression, ImportDirective, ImportSymbol, InlineArrayExpression, InlineStructExpression, InlineStructMember, InlineTupleExpression, IterationStatement, LiteralExpression, LogStatement, Parameter, PostfixExpression, Range, @@ -366,6 +367,7 @@ mod ast { Assertion(AssertionStatement<'ast>), Iteration(IterationStatement<'ast>), Log(LogStatement<'ast>), + Assembly(AssemblyStatement<'ast>), } #[derive(Debug, FromPest, PartialEq, Clone)] @@ -431,6 +433,55 @@ mod ast { pub span: Span<'ast>, } + #[derive(Debug, FromPest, PartialEq, Eq, Clone)] + #[pest_ast(rule(Rule::op_asm))] + pub enum AssignmentOperator { + Assign(AssignOperator), + AssignConstrain(AssignConstrainOperator), + } + + #[derive(Debug, FromPest, PartialEq, Eq, Clone)] + #[pest_ast(rule(Rule::op_asm_assign))] + pub struct AssignOperator; + + #[derive(Debug, FromPest, PartialEq, Eq, Clone)] + #[pest_ast(rule(Rule::op_asm_assign_constrain))] + pub struct AssignConstrainOperator; + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::asm_assignment))] + pub struct AssemblyAssignment<'ast> { + pub assignee: Assignee<'ast>, + pub operator: AssignmentOperator, + pub expression: Expression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::asm_constraint))] + pub struct AssemblyConstraint<'ast> { + pub lhs: Expression<'ast>, + pub rhs: Expression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::asm_statement_inner))] + pub enum AssemblyStatementInner<'ast> { + Assignment(AssemblyAssignment<'ast>), + Constraint(AssemblyConstraint<'ast>), + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::asm_statement))] + pub struct AssemblyStatement<'ast> { + pub inner: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, PartialEq, Eq, Clone)] pub enum BinaryOperator { BitXor, diff --git a/zokrates_proof_systems/src/lib.rs b/zokrates_proof_systems/src/lib.rs index 231fbeee7..7076cde06 100644 --- a/zokrates_proof_systems/src/lib.rs +++ b/zokrates_proof_systems/src/lib.rs @@ -96,8 +96,8 @@ impl ToString for G2AffineFq2 { } pub trait Backend> { - fn generate_proof>>( - program: ir::ProgIterator, + fn generate_proof<'a, I: IntoIterator>>( + program: ir::ProgIterator<'a, T, I>, witness: ir::Witness, proving_key: Vec, ) -> Proof; @@ -105,23 +105,23 @@ pub trait Backend> { fn verify(vk: S::VerificationKey, proof: Proof) -> bool; } pub trait NonUniversalBackend>: Backend { - fn setup>>( - program: ir::ProgIterator, + fn setup<'a, I: IntoIterator>>( + program: ir::ProgIterator<'a, T, I>, ) -> SetupKeypair; } pub trait UniversalBackend>: Backend { fn universal_setup(size: u32) -> Vec; - fn setup>>( + fn setup<'a, I: IntoIterator>>( srs: Vec, - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, ) -> Result, String>; } pub trait MpcBackend> { - fn initialize>>( - program: ir::ProgIterator, + fn initialize<'a, R: Read, W: Write, I: IntoIterator>>( + program: ir::ProgIterator<'a, T, I>, phase1_radix: &mut R, output: &mut W, ) -> Result<(), String>; @@ -132,9 +132,9 @@ pub trait MpcBackend> { output: &mut W, ) -> Result<[u8; 64], String>; - fn verify>>( + fn verify<'a, P: Read, R: Read, I: IntoIterator>>( params: &mut P, - program: ir::ProgIterator, + program: ir::ProgIterator<'a, T, I>, phase1_radix: &mut R, ) -> Result, String>; diff --git a/zokrates_test/src/lib.rs b/zokrates_test/src/lib.rs index fd4b0f7a4..0aa85f907 100644 --- a/zokrates_test/src/lib.rs +++ b/zokrates_test/src/lib.rs @@ -8,7 +8,8 @@ use std::path::{Path, PathBuf}; use zokrates_ast::typed::types::GTupleType; use zokrates_ast::typed::ConcreteSignature; use zokrates_ast::typed::ConcreteType; -use zokrates_core::compile::{compile, CompileConfig}; +use zokrates_common::CompileConfig; +use zokrates_core::compile::compile; use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field}; use zokrates_fs_resolver::FileSystemResolver; diff --git a/zokrates_test/tests/out_of_range.rs b/zokrates_test/tests/out_of_range.rs index 9304757d8..ea2800252 100644 --- a/zokrates_test/tests/out_of_range.rs +++ b/zokrates_test/tests/out_of_range.rs @@ -4,8 +4,8 @@ extern crate zokrates_field; use std::io; use typed_arena::Arena; +use zokrates_common::CompileConfig; use zokrates_common::Resolver; -use zokrates_core::compile::CompileConfig; use zokrates_core::compile::{compile, CompilationArtifacts}; use zokrates_field::Bn128Field; use zokrates_fs_resolver::FileSystemResolver;