Skip to content

Commit

Permalink
arith_ext: add transforms for arith.remui patterns
Browse files Browse the repository at this point in the history
  - We can reduce runtime computation by avoiding computing remui and
    instead use a Barrett reduction or SubIfGE operation when possible
  • Loading branch information
inbelic committed May 31, 2024
1 parent 266d8fd commit 4ff9835
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 0 deletions.
78 changes: 78 additions & 0 deletions lib/Dialect/ArithExt/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Transforms",
hdrs = [
"Passes.h",
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ArithExt/IR:Dialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "RewriteModulo",
srcs = ["RewriteModulo.cpp"],
hdrs = [
"RewriteModulo.h",
],
deps = [
":rewrite_modulo_inc_gen",
":pass_inc_gen",
"@heir//lib/Dialect/ArithExt/IR:Dialect",
"@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TransformUtils",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ArithExt",
],
"Passes.h.inc",
),
(
["-gen-pass-doc"],
"ArithExtPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

gentbl_cc_library(
name = "rewrite_modulo_inc_gen",
tbl_outs = [
(
["-gen-rewriters"],
"RewriteModulo.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "RewriteModulo.td",
deps = [
"@heir//lib/DRR",
"@heir//lib/Dialect/ArithExt/IR:ops_inc_gen",
"@heir//lib/Dialect/ArithExt/IR:td_files",
"@heir//lib/Dialect/Polynomial/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
],
)
18 changes: 18 additions & 0 deletions lib/Dialect/ArithExt/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_H_
#define LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_H_

#include "lib/Dialect/ArithExt/IR/ArithExtDialect.h"
#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.h"

namespace mlir {
namespace heir {
namespace arith_ext {

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/ArithExt/Transforms/Passes.h.inc"

} // namespace arith_ext
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_H_
15 changes: 15 additions & 0 deletions lib/Dialect/ArithExt/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_TD_
#define LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_TD_

include "mlir/Pass/PassBase.td"

def RewriteModulo : Pass<"convert-remui-to-arith-ext"> {
let summary = "Rewrites arith remui patterns to their arith_ext equivalents";
let description = [{
Applies a rewrite pattern to convert the following arith patterns to their
equivalent using the arith_ext operations.
}];
let dependentDialects = ["mlir::heir::arith_ext::ArithExtDialect"];
}

#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_TD_
32 changes: 32 additions & 0 deletions lib/Dialect/ArithExt/Transforms/RewriteModulo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.h"

#include "lib/Dialect/ArithExt/IR/ArithExtOps.h"
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace arith_ext {

#define GEN_PASS_DEF_REWRITEMODULO
#include "lib/Dialect/ArithExt/Transforms/Passes.h.inc"

namespace rewrites {
// In an inner namespace to avoid conflicts with canonicalization patterns
#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.cpp.inc"
} // namespace rewrites

struct RewriteModulo : impl::RewriteModuloBase<RewriteModulo> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
rewrites::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace arith_ext
} // namespace heir
} // namespace mlir
18 changes: 18 additions & 0 deletions lib/Dialect/ArithExt/Transforms/RewriteModulo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_H_
#define LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_H_

#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace arith_ext {

#define GEN_PASS_DECL_REWRITEMODULO
#include "lib/Dialect/ArithExt/Transforms/Passes.h.inc"

} // namespace arith_ext
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_H_
64 changes: 64 additions & 0 deletions lib/Dialect/ArithExt/Transforms/RewriteModulo.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_TD_
#define LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_TD_

include "lib/DRR/Utils.td"
include "lib/Dialect/Polynomial/IR/PolynomialAttributes.td"
include "lib/Dialect/ArithExt/IR/ArithExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/IR/PatternBase.td"

def DefaultBitWidth : NativeCodeCall<
"IntegerAttr::get("
" IntegerType::get($_builder.getContext(), 64), "
" (dyn_cast<RankedTensorType>($0.getType())).getElementTypeBitWidth() / 2)">;

def Normalised : Constraint<
CPred<"true">,
"ensure the values will be within the range [0, cmod).">;

def EncodedRankedTensorType : Constraint<
CPred<"(dyn_cast<RankedTensorType>($0.getType())).getEncoding()">,
"ensure the operation type is a ranked tensor type with a will be within the range [0, cmod).">;

def GetEncodedRingModulo : NativeCodeCall<
"IntegerAttr::get("
" IntegerType::get($_builder.getContext(), 64), "
" (dyn_cast<::mlir::heir::polynomial::RingAttr>((dyn_cast<RankedTensorType>($0.getType())).getEncoding())).coefficientModulus())">;

def RewriteAddRem : Pattern<
(Arith_RemUIOp:$remOp (Arith_AddIOp:$addOp $lhs, $rhs, $overflow), $cmod),
[
(ArithExt_SubIfGEOp $addOp, $cmod)
],
[
(Normalised $lhs),
(Normalised $rhs)
]
>;

def RewriteSubRem : Pattern<
(Arith_RemUIOp:$remOp (Arith_SubIOp:$subOp $lhs, $rhs, $overflow), $cmod),
[
(ArithExt_SubIfGEOp
(Arith_AddIOp $subOp, $cmod, $overflow), $cmod)
],
[
(Normalised $lhs),
(Normalised $rhs)
]
>;

def RewriteMulRem : Pattern<
(Arith_RemUIOp:$remOp (Arith_MulIOp:$mulOp $lhs, $rhs, $overflow), $cmod),
[
(ArithExt_SubIfGEOp
(ArithExt_BarrettReduceOp $mulOp, (DefaultBitWidth $mulOp), (GetEncodedRingModulo $mulOp)), $cmod)
],
[
(Normalised $lhs),
(Normalised $rhs),
(EncodedRankedTensorType $mulOp)
]
>;

#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_TD_
42 changes: 42 additions & 0 deletions tests/arith/rewrite-modulo.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: heir-opt -convert-remui-to-arith-ext %s | FileCheck %s

// CHECK-LABEL: @test_add_rewrite
func.func @test_add_rewrite(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> {
%cmod_vec = arith.constant dense<17> : tensor<4xi8>

// CHECK: arith.addi
// CHECK: arith_ext.subifge
%add = arith.addi %lhs, %rhs : tensor<4xi8>
%res = arith.remui %add, %cmod_vec: tensor<4xi8>

return %res : tensor<4xi8>
}

// CHECK-LABEL: @test_sub_rewrite
func.func @test_sub_rewrite(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> {
%cmod_vec = arith.constant dense<17> : tensor<4xi8>

// CHECK: arith.subi
// CHECK: arith.addi
// CHECK: arith_ext.subifge
%sub = arith.subi %lhs, %rhs : tensor<4xi8>
%res = arith.remui %sub, %cmod_vec : tensor<4xi8>

return %res : tensor<4xi8>
}

#ideal = #_polynomial.polynomial<1 + x**4>
#ring = #_polynomial.ring<cmod=17, ideal=#ideal, root=2>

// CHECK-LABEL: @test_mul_rewrite
func.func @test_mul_rewrite(%lhs : tensor<4xi8, #ring>, %rhs : tensor<4xi8, #ring>) -> tensor<4xi8, #ring> {
%cmod_vec = arith.constant dense<17> : tensor<4xi8, #ring>

// CHECK: arith.muli
// CHECK: arith_ext.barrett_reduce
// CHECK: arith_ext.subifge
%sub = arith.muli %lhs, %rhs : tensor<4xi8, #ring>
%res = arith.remui %sub, %cmod_vec : tensor<4xi8, #ring>

return %res : tensor<4xi8, #ring>
}
2 changes: 2 additions & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ cc_binary(
"@heir//lib/Conversion/PolynomialToStandard",
"@heir//lib/Conversion/SecretToBGV",
"@heir//lib/Dialect/ArithExt/IR:Dialect",
"@heir//lib/Dialect/ArithExt/Transforms",
"@heir//lib/Dialect/ArithExt/Transforms:RewriteModulo",
"@heir//lib/Dialect/BGV/IR:Dialect",
"@heir//lib/Dialect/BGV/Transforms",
"@heir//lib/Dialect/BGV/Transforms:AddClientInterface",
Expand Down
3 changes: 3 additions & 0 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "lib/Conversion/PolynomialToStandard/PolynomialToStandard.h"
#include "lib/Conversion/SecretToBGV/SecretToBGV.h"
#include "lib/Dialect/ArithExt/IR/ArithExtDialect.h"
#include "lib/Dialect/ArithExt/Transforms/Passes.h"
#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.h"
#include "lib/Dialect/BGV/IR/BGVDialect.h"
#include "lib/Dialect/BGV/Transforms/AddClientInterface.h"
#include "lib/Dialect/BGV/Transforms/Passes.h"
Expand Down Expand Up @@ -498,6 +500,7 @@ int main(int argc, char **argv) {
registerAllPasses();

// Custom passes in HEIR
arith_ext::registerArithExtPasses();
bgv::registerBGVPasses();
cggi::registerCGGIPasses();
lwe::registerLWEPasses();
Expand Down

0 comments on commit 4ff9835

Please sign in to comment.