-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
arith_ext: add transforms for arith.remui patterns
- We can reduce runtime computation by avoiding computing remui and instead use a Barrett reduction or SubIfGE operation when possible
- Loading branch information
Showing
9 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "Transforms", | ||
hdrs = [ | ||
"Passes.h", | ||
], | ||
deps = [ | ||
":pass_inc_gen", | ||
"@heir//lib/Dialect/ArithExt/IR:Dialect", | ||
"@llvm-project//mlir:IR", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "RewriteModulo", | ||
srcs = ["RewriteModulo.cpp"], | ||
hdrs = [ | ||
"RewriteModulo.h", | ||
], | ||
deps = [ | ||
":rewrite_modulo_inc_gen", | ||
":pass_inc_gen", | ||
"@heir//lib/Dialect/ArithExt/IR:Dialect", | ||
"@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:TransformUtils", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=ArithExt", | ||
], | ||
"Passes.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"ArithExtPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "Passes.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "rewrite_modulo_inc_gen", | ||
tbl_outs = [ | ||
( | ||
["-gen-rewriters"], | ||
"RewriteModulo.cpp.inc", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "RewriteModulo.td", | ||
deps = [ | ||
"@heir//lib/DRR", | ||
"@heir//lib/Dialect/ArithExt/IR:ops_inc_gen", | ||
"@heir//lib/Dialect/ArithExt/IR:td_files", | ||
"@heir//lib/Dialect/Polynomial/IR:td_files", | ||
"@llvm-project//mlir:ArithOpsTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_H_ | ||
#define LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_H_ | ||
|
||
#include "lib/Dialect/ArithExt/IR/ArithExtDialect.h" | ||
#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.h" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace arith_ext { | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "lib/Dialect/ArithExt/Transforms/Passes.h.inc" | ||
|
||
} // namespace arith_ext | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_TD_ | ||
#define LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def RewriteModulo : Pass<"convert-remui-to-arith-ext"> { | ||
let summary = "Rewrites arith remui patterns to their arith_ext equivalents"; | ||
let description = [{ | ||
Applies a rewrite pattern to convert the following arith patterns to their | ||
equivalent using the arith_ext operations. | ||
}]; | ||
let dependentDialects = ["mlir::heir::arith_ext::ArithExtDialect"]; | ||
} | ||
|
||
#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_PASSES_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.h" | ||
|
||
#include "lib/Dialect/ArithExt/IR/ArithExtOps.h" | ||
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h" | ||
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace arith_ext { | ||
|
||
#define GEN_PASS_DEF_REWRITEMODULO | ||
#include "lib/Dialect/ArithExt/Transforms/Passes.h.inc" | ||
|
||
namespace rewrites { | ||
// In an inner namespace to avoid conflicts with canonicalization patterns | ||
#include "lib/Dialect/ArithExt/Transforms/RewriteModulo.cpp.inc" | ||
} // namespace rewrites | ||
|
||
struct RewriteModulo : impl::RewriteModuloBase<RewriteModulo> { | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet patterns(context); | ||
rewrites::populateWithGenerated(patterns); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
||
} // namespace arith_ext | ||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_H_ | ||
#define LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_H_ | ||
|
||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project | ||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace arith_ext { | ||
|
||
#define GEN_PASS_DECL_REWRITEMODULO | ||
#include "lib/Dialect/ArithExt/Transforms/Passes.h.inc" | ||
|
||
} // namespace arith_ext | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_TD_ | ||
#define LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_TD_ | ||
|
||
include "lib/DRR/Utils.td" | ||
include "lib/Dialect/Polynomial/IR/PolynomialAttributes.td" | ||
include "lib/Dialect/ArithExt/IR/ArithExtOps.td" | ||
include "mlir/Dialect/Arith/IR/ArithOps.td" | ||
include "mlir/IR/PatternBase.td" | ||
|
||
def DefaultBitWidth : NativeCodeCall< | ||
"IntegerAttr::get(" | ||
" IntegerType::get($_builder.getContext(), 64), " | ||
" (dyn_cast<RankedTensorType>($0.getType())).getElementTypeBitWidth() / 2)">; | ||
|
||
def Normalised : Constraint< | ||
CPred<"true">, | ||
"ensure the values will be within the range [0, cmod).">; | ||
|
||
def EncodedRankedTensorType : Constraint< | ||
CPred<"(dyn_cast<RankedTensorType>($0.getType())).getEncoding()">, | ||
"ensure the operation type is a ranked tensor type with a will be within the range [0, cmod).">; | ||
|
||
def GetEncodedRingModulo : NativeCodeCall< | ||
"IntegerAttr::get(" | ||
" IntegerType::get($_builder.getContext(), 64), " | ||
" (dyn_cast<::mlir::heir::polynomial::RingAttr>((dyn_cast<RankedTensorType>($0.getType())).getEncoding())).coefficientModulus())">; | ||
|
||
def RewriteAddRem : Pattern< | ||
(Arith_RemUIOp:$remOp (Arith_AddIOp:$addOp $lhs, $rhs, $overflow), $cmod), | ||
[ | ||
(ArithExt_SubIfGEOp $addOp, $cmod) | ||
], | ||
[ | ||
(Normalised $lhs), | ||
(Normalised $rhs) | ||
] | ||
>; | ||
|
||
def RewriteSubRem : Pattern< | ||
(Arith_RemUIOp:$remOp (Arith_SubIOp:$subOp $lhs, $rhs, $overflow), $cmod), | ||
[ | ||
(ArithExt_SubIfGEOp | ||
(Arith_AddIOp $subOp, $cmod, $overflow), $cmod) | ||
], | ||
[ | ||
(Normalised $lhs), | ||
(Normalised $rhs) | ||
] | ||
>; | ||
|
||
def RewriteMulRem : Pattern< | ||
(Arith_RemUIOp:$remOp (Arith_MulIOp:$mulOp $lhs, $rhs, $overflow), $cmod), | ||
[ | ||
(ArithExt_SubIfGEOp | ||
(ArithExt_BarrettReduceOp $mulOp, (DefaultBitWidth $mulOp), (GetEncodedRingModulo $mulOp)), $cmod) | ||
], | ||
[ | ||
(Normalised $lhs), | ||
(Normalised $rhs), | ||
(EncodedRankedTensorType $mulOp) | ||
] | ||
>; | ||
|
||
#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_REWRITEMODULO_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// RUN: heir-opt -convert-remui-to-arith-ext %s | FileCheck %s | ||
|
||
// CHECK-LABEL: @test_add_rewrite | ||
func.func @test_add_rewrite(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { | ||
%cmod_vec = arith.constant dense<17> : tensor<4xi8> | ||
|
||
// CHECK: arith.addi | ||
// CHECK: arith_ext.subifge | ||
%add = arith.addi %lhs, %rhs : tensor<4xi8> | ||
%res = arith.remui %add, %cmod_vec: tensor<4xi8> | ||
|
||
return %res : tensor<4xi8> | ||
} | ||
|
||
// CHECK-LABEL: @test_sub_rewrite | ||
func.func @test_sub_rewrite(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { | ||
%cmod_vec = arith.constant dense<17> : tensor<4xi8> | ||
|
||
// CHECK: arith.subi | ||
// CHECK: arith.addi | ||
// CHECK: arith_ext.subifge | ||
%sub = arith.subi %lhs, %rhs : tensor<4xi8> | ||
%res = arith.remui %sub, %cmod_vec : tensor<4xi8> | ||
|
||
return %res : tensor<4xi8> | ||
} | ||
|
||
#ideal = #_polynomial.polynomial<1 + x**4> | ||
#ring = #_polynomial.ring<cmod=17, ideal=#ideal, root=2> | ||
|
||
// CHECK-LABEL: @test_mul_rewrite | ||
func.func @test_mul_rewrite(%lhs : tensor<4xi8, #ring>, %rhs : tensor<4xi8, #ring>) -> tensor<4xi8, #ring> { | ||
%cmod_vec = arith.constant dense<17> : tensor<4xi8, #ring> | ||
|
||
// CHECK: arith.muli | ||
// CHECK: arith_ext.barrett_reduce | ||
// CHECK: arith_ext.subifge | ||
%sub = arith.muli %lhs, %rhs : tensor<4xi8, #ring> | ||
%res = arith.remui %sub, %cmod_vec : tensor<4xi8, #ring> | ||
|
||
return %res : tensor<4xi8, #ring> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters