Skip to content

Commit

Permalink
Added support for stablehlo.remainder op. (#1126)
Browse files Browse the repository at this point in the history
Added tests. Tested with ttrt.
  • Loading branch information
kmitrovicTT authored Nov 4, 2024
1 parent e6c60fd commit 51f6356
Show file tree
Hide file tree
Showing 13 changed files with 96 additions and 2 deletions.
15 changes: 15 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,21 @@ def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
}];
}

def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder"> {
let summary = "Eltwise remainder.";
let description = [{
Performs element-wise remainder of dividend lhs and divisor rhs tensors and produces a
result tensor.

Example:

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "ttir.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
}];
}

class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {

Expand Down
15 changes: 15 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,21 @@ def TTNN_SubtractOp : TTNN_ElementwiseBinaryOp<"subtract"> {
}];
}

def TTNN_RemainderOp : TTNN_ElementwiseBinaryOp<"remainder"> {
let summary = "Eltwise remainder.";
let description = [{
Performs element-wise remainder of dividend lhs and divisor rhs tensors and produces a
result tensor.

Example:

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "ttnn.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
}];
}

class TTNN_ReductionOp<string mnemonic, list<Trait> traits = []> : TTNN_Op<mnemonic, traits> {
let summary = "Reduction op.";
let description = [{
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ enum EltwiseOpType: uint32 {
Log = 28,
Log1p = 29,
Expm1 = 30,
Sign = 31
Sign = 31,
Remainder = 32
}

union EltwiseOpParams {
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,8 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SubtractOp, mlir::tt::ttir::SubtractOp>>(typeConverter,
ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::RemOp, mlir::tt::ttir::RemainderOp>>(typeConverter, ctx);
}

void addReduceOpsConversionPatterns(MLIRContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::SinOp, ttnn::SinOp>,
ElementwiseOpConversionPattern<ttir::CosOp, ttnn::CosOp>,
ElementwiseOpConversionPattern<ttir::Expm1Op, ttnn::Expm1Op>,
ElementwiseOpConversionPattern<ttir::RemainderOp, ttnn::RemainderOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
Expand Down
4 changes: 3 additions & 1 deletion lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,9 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::LessThanOp>,
DefaultOpConversionPattern<ttnn::MaximumOp>,
DefaultOpConversionPattern<ttnn::MinimumOp>,
DefaultOpConversionPattern<ttnn::DivOp>>(typeConverter, ctx);
DefaultOpConversionPattern<ttnn::DivOp>,
DefaultOpConversionPattern<ttnn::RemainderOp>>(typeConverter,
ctx);

// Tensor manipulation ops
//
Expand Down
6 changes: 6 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Log;
} else if constexpr (std::is_same_v<EltwiseOp, Expm1Op>) {
type = ::tt::target::ttnn::EltwiseOpType::Expm1;
} else if constexpr (std::is_same_v<EltwiseOp, RemainderOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Remainder;
} else {
llvm_unreachable("unhandled EltwiseOp");
}
Expand Down Expand Up @@ -628,6 +630,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto divOp = dyn_cast<DivOp>(op); divOp) {
return createOperation(cache, createEltwiseOp(cache, divOp), debugString);
}
if (auto remainderOp = dyn_cast<RemainderOp>(op); remainderOp) {
return createOperation(cache, createEltwiseOp(cache, remainderOp),
debugString);
}
if (auto matmulOp = dyn_cast<MatmulOp>(op); matmulOp) {
return createOperation(cache, createOp(cache, matmulOp), debugString);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Remainder: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::remainder);
break;
}
default:
throw std::invalid_argument(
"Unsupported Eltwise Binary Composite operation");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Maximum:
case ::tt::target::ttnn::EltwiseOpType::Minimum:
case ::tt::target::ttnn::EltwiseOpType::Remainder:
return true;
default:
return false;
Expand Down
12 changes: 12 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/binary/remainder_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_remainder attributes {} {
func.func public @test_remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = stablehlo.remainder %arg0, %arg1 : tensor<32x32xf32>
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : tensor<32x32xf32>
// CHECK: %[[REM:[0-9]+]] = "ttir.remainder"(%arg0, %arg1, %[[EMPTY]]){{.*}} -> tensor<32x32xf32>
return %0 : tensor<32x32xf32>
// CHECK: return %[[REM]] : tensor<32x32xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = tensor.empty() : tensor<32x32xf32>
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}}
%1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}}
return %1 : tensor<32x32xf32>
// CHECK: return {{.*}} : tensor<32x32xf32, {{.*}}
}
}
14 changes: 14 additions & 0 deletions test/ttmlir/Silicon/TTNN/perf_unit/test_perf_remainder.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>

func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = tensor.empty() : tensor<32x32xf32>
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}}
%1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}}
return %1 : tensor<32x32xf32>
// CHECK: return {{.*}} : tensor<32x32xf32, {{.*}}
}
9 changes: 9 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,12 @@ func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
return %1 : tensor<64x128xf32>
// CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}>
}

func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = tensor.empty() : tensor<32x32xf32>
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}}
%1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}}
return %1 : tensor<32x32xf32>
// CHECK: return {{.*}} : tensor<32x32xf32, {{.*}}
}

0 comments on commit 51f6356

Please sign in to comment.