diff --git a/lib/Dialect/ArithExt/Transforms/BUILD b/lib/Dialect/ArithExt/Transforms/BUILD new file mode 100644 index 000000000..d2596662d --- /dev/null +++ b/lib/Dialect/ArithExt/Transforms/BUILD @@ -0,0 +1,78 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Transforms", + hdrs = [ + "Passes.h", + ], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/ArithExt/IR:Dialect", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "RewriteModulo", + srcs = ["RewriteModulo.cpp"], + hdrs = [ + "RewriteModulo.h", + ], + deps = [ + ":rewrite_modulo_inc_gen", + ":pass_inc_gen", + "@heir//lib/Dialect/ArithExt/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", + ], +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ArithExt", + ], + "Passes.h.inc", + ), + ( + ["-gen-pass-doc"], + "ArithExtPasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "rewrite_modulo_inc_gen", + tbl_outs = [ + ( + ["-gen-rewriters"], + "RewriteModulo.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "RewriteModulo.td", + deps = [ + "@heir//lib/DRR", + "@heir//lib/Dialect/ArithExt/IR:ops_inc_gen", + "@heir//lib/Dialect/ArithExt/IR:td_files", + "@heir//lib/Dialect/Polynomial/IR:td_files", + "@llvm-project//mlir:ArithOpsTdFiles", + ], +) diff --git a/lib/Dialect/ArithExt/Transforms/Passes.h b/lib/Dialect/ArithExt/Transforms/Passes.h new file mode 100644 index 000000000..725513ba0 --- /dev/null +++ b/lib/Dialect/ArithExt/Transforms/Passes.h @@ -0,0 +1,18 @@ +#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_H_ +#define LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_H_ + +#include "lib/Dialect/ArithExt/IR/ArithExtDialect.h" +#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.h" + +namespace mlir { +namespace heir { +namespace arith_ext { + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/ArithExt/Transforms/Passes.h.inc" + +} // namespace arith_ext +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_H_ diff --git a/lib/Dialect/ArithExt/Transforms/Passes.td b/lib/Dialect/ArithExt/Transforms/Passes.td new file mode 100644 index 000000000..ec0b6e544 --- /dev/null +++ b/lib/Dialect/ArithExt/Transforms/Passes.td @@ -0,0 +1,15 @@ +#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_TD_ +#define LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def RewriteModulo : Pass<"convert-remui-to-arith-ext"> { + let summary = "Rewrites arith remui patterns to their arith_ext equivalents"; + let description = [{ + Applies a rewrite pattern to convert the following arith patterns to their + equivalent using the arith_ext operations. + }]; + let dependentDialects = ["mlir::heir::arith_ext::ArithExtDialect"]; +} + +#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_TD_ diff --git a/lib/Dialect/ArithExt/Transforms/RewriteModulo.cpp b/lib/Dialect/ArithExt/Transforms/RewriteModulo.cpp new file mode 100644 index 000000000..425bed658 --- /dev/null +++ b/lib/Dialect/ArithExt/Transforms/RewriteModulo.cpp @@ -0,0 +1,32 @@ +#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.h" + +#include "lib/Dialect/ArithExt/IR/ArithExtOps.h" +#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h" +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace arith_ext { + +#define GEN_PASS_DEF_REWRITEMODULO +#include "lib/Dialect/ArithExt/Transforms/Passes.h.inc" + +namespace rewrites { +// In an inner namespace to avoid conflicts with canonicalization patterns +#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.cpp.inc" +} // namespace rewrites + +struct RewriteModulo : impl::RewriteModuloBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + rewrites::populateWithGenerated(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace arith_ext +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/ArithExt/Transforms/RewriteModulo.h b/lib/Dialect/ArithExt/Transforms/RewriteModulo.h new file mode 100644 index 000000000..482924854 --- /dev/null +++ b/lib/Dialect/ArithExt/Transforms/RewriteModulo.h @@ -0,0 +1,18 @@ +#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_H_ +#define LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_H_ + +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace arith_ext { + +#define GEN_PASS_DECL_REWRITEMODULO +#include "lib/Dialect/ArithExt/Transforms/Passes.h.inc" + +} // namespace arith_ext +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_H_ diff --git a/lib/Dialect/ArithExt/Transforms/RewriteModulo.td b/lib/Dialect/ArithExt/Transforms/RewriteModulo.td new file mode 100644 index 000000000..855689ef6 --- /dev/null +++ b/lib/Dialect/ArithExt/Transforms/RewriteModulo.td @@ -0,0 +1,64 @@ +#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_TD_ +#define LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_TD_ + +include "lib/DRR/Utils.td" +include "lib/Dialect/Polynomial/IR/PolynomialAttributes.td" +include "lib/Dialect/ArithExt/IR/ArithExtOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "mlir/IR/PatternBase.td" + +def DefaultBitWidth : NativeCodeCall< + "IntegerAttr::get(" + " IntegerType::get($_builder.getContext(), 64), " + " (dyn_cast($0.getType())).getElementTypeBitWidth() / 2)">; + +def Normalised : Constraint< + CPred<"true">, + "ensure the values will be within the range [0, cmod).">; + +def EncodedRankedTensorType : Constraint< + CPred<"(dyn_cast($0.getType())).getEncoding()">, + "ensure the operation type is a ranked tensor type with a will be within the range [0, cmod).">; + +def GetEncodedRingModulo : NativeCodeCall< + "IntegerAttr::get(" + " IntegerType::get($_builder.getContext(), 64), " + " (dyn_cast<::mlir::heir::polynomial::RingAttr>((dyn_cast($0.getType())).getEncoding())).coefficientModulus())">; + +def RewriteAddRem : Pattern< + (Arith_RemUIOp:$remOp (Arith_AddIOp:$addOp $lhs, $rhs, $overflow), $cmod), + [ + (ArithExt_SubIfGEOp $addOp, $cmod) + ], + [ + (Normalised $lhs), + (Normalised $rhs) + ] +>; + +def RewriteSubRem : Pattern< + (Arith_RemUIOp:$remOp (Arith_SubIOp:$subOp $lhs, $rhs, $overflow), $cmod), + [ + (ArithExt_SubIfGEOp + (Arith_AddIOp $subOp, $cmod, $overflow), $cmod) + ], + [ + (Normalised $lhs), + (Normalised $rhs) + ] +>; + +def RewriteMulRem : Pattern< + (Arith_RemUIOp:$remOp (Arith_MulIOp:$mulOp $lhs, $rhs, $overflow), $cmod), + [ + (ArithExt_SubIfGEOp + (ArithExt_BarrettReduceOp $mulOp, (DefaultBitWidth $mulOp), (GetEncodedRingModulo $mulOp)), $cmod) + ], + [ + (Normalised $lhs), + (Normalised $rhs), + (EncodedRankedTensorType $mulOp) + ] +>; + +#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_TD_ diff --git a/tests/arith/rewrite-modulo.mlir b/tests/arith/rewrite-modulo.mlir new file mode 100644 index 000000000..43fcb20d6 --- /dev/null +++ b/tests/arith/rewrite-modulo.mlir @@ -0,0 +1,42 @@ +// RUN: heir-opt -convert-remui-to-arith-ext %s | FileCheck %s + +// CHECK-LABEL: @test_add_rewrite +func.func @test_add_rewrite(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { + %cmod_vec = arith.constant dense<17> : tensor<4xi8> + + // CHECK: arith.addi + // CHECK: arith_ext.subifge + %add = arith.addi %lhs, %rhs : tensor<4xi8> + %res = arith.remui %add, %cmod_vec: tensor<4xi8> + + return %res : tensor<4xi8> +} + +// CHECK-LABEL: @test_sub_rewrite +func.func @test_sub_rewrite(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { + %cmod_vec = arith.constant dense<17> : tensor<4xi8> + + // CHECK: arith.subi + // CHECK: arith.addi + // CHECK: arith_ext.subifge + %sub = arith.subi %lhs, %rhs : tensor<4xi8> + %res = arith.remui %sub, %cmod_vec : tensor<4xi8> + + return %res : tensor<4xi8> +} + +#ideal = #_polynomial.polynomial<1 + x**4> +#ring = #_polynomial.ring + +// CHECK-LABEL: @test_mul_rewrite +func.func @test_mul_rewrite(%lhs : tensor<4xi8, #ring>, %rhs : tensor<4xi8, #ring>) -> tensor<4xi8, #ring> { + %cmod_vec = arith.constant dense<17> : tensor<4xi8, #ring> + + // CHECK: arith.muli + // CHECK: arith_ext.barrett_reduce + // CHECK: arith_ext.subifge + %sub = arith.muli %lhs, %rhs : tensor<4xi8, #ring> + %res = arith.remui %sub, %cmod_vec : tensor<4xi8, #ring> + + return %res : tensor<4xi8, #ring> +} diff --git a/tools/BUILD b/tools/BUILD index 9b4d113f0..3b06eeccf 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -42,6 +42,8 @@ cc_binary( "@heir//lib/Conversion/PolynomialToStandard", "@heir//lib/Conversion/SecretToBGV", "@heir//lib/Dialect/ArithExt/IR:Dialect", + "@heir//lib/Dialect/ArithExt/Transforms", + "@heir//lib/Dialect/ArithExt/Transforms:RewriteModulo", "@heir//lib/Dialect/BGV/IR:Dialect", "@heir//lib/Dialect/BGV/Transforms", "@heir//lib/Dialect/BGV/Transforms:AddClientInterface", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index c6a335d64..e68fe5107 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -11,6 +11,8 @@ #include "lib/Conversion/PolynomialToStandard/PolynomialToStandard.h" #include "lib/Conversion/SecretToBGV/SecretToBGV.h" #include "lib/Dialect/ArithExt/IR/ArithExtDialect.h" +#include "lib/Dialect/ArithExt/Transforms/Passes.h" +#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.h" #include "lib/Dialect/BGV/IR/BGVDialect.h" #include "lib/Dialect/BGV/Transforms/AddClientInterface.h" #include "lib/Dialect/BGV/Transforms/Passes.h" @@ -498,6 +500,7 @@ int main(int argc, char **argv) { registerAllPasses(); // Custom passes in HEIR + arith_ext::registerArithExtPasses(); bgv::registerBGVPasses(); cggi::registerCGGIPasses(); lwe::registerLWEPasses();