Skip to content

Commit

Permalink
Add support for clamp op.
Browse files Browse the repository at this point in the history
* Add end-to-end implementation of the ops.
* Add stablehlo to ttir conversion for clamp op.
  • Loading branch information
mmanzoorTT committed Nov 5, 2024
1 parent 6bad515 commit 669e8a0
Show file tree
Hide file tree
Showing 13 changed files with 274 additions and 1 deletion.
27 changes: 27 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,33 @@ def TTIR_UnsqueezeOp : TTIR_DPSOp<"unsqueeze"> {
let hasVerifier = 1;
}

def TTIR_ClampOp : TTIR_DPSOp<"clamp"> {
let summary = "Clamp op.";
let description = [{
Clamp tensor values to a specified range.

Example:
min: 2.000000+00
input: [[0, 1, 2, 3, 4, 5, 6, 7]]
max: 5.000000+00

"ttir.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}>
-> %out = [[2, 2, 2, 3, 4, 5, 5, 5]]
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
F32Attr:$min,
F32Attr:$max,
TT_OperandConstraintArrayAttr:$operand_constraints);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let results = (outs AnyRankedTensor:$result);
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,27 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
let hasVerifier = 1;
}

def TTNN_ClampOp : TTNN_Op<"clamp"> {
let summary = "Clamp op.";
let description = [{
Clamp tensor values to a specified range.

Example:
min: 2.000000+00
input: [[0, 1, 2, 3, 4, 5, 6, 7]]
max: 5.000000+00

"ttnn.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}>
-> %out = [[2, 2, 2, 3, 4, 5, 5, 5]]
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
F32Attr:$min,
F32Attr:$max);

let results = (outs Variadic<AnyRankedTensor>:$result);
}

// Note: NoMemoryEffect is used to indicate that operation can be removed if it is not used.
// Removal of this operation is done by the dead code elimination pass (RemoveDeadValuesPass).
def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> {
Expand Down
8 changes: 7 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ include "Common/debug_info.fbs";

namespace tt.target.ttnn;

table ClampOpParams {
min: float;
max: float;
}

table GetDeviceOp {
mesh: Dim2d;
chip_ids: [uint32];
Expand Down Expand Up @@ -93,10 +98,11 @@ enum EltwiseOpType: uint32 {
Remainder = 32,
IsFinite = 33,
Floor = 34,
Clamp = 35,
}

union EltwiseOpParams {

ClampOpParams,
}

table EltwiseOp {
Expand Down
74 changes: 74 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,74 @@ class StableHLOToTTIRSliceOpConversionPattern
}
};

class StableHLOToTTIROpClampOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ClampOp> {

using OpConversionPattern<mlir::stablehlo::ClampOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ClampOp srcOp,
mlir::stablehlo::ClampOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
Value min = adaptor.getMin();
Value max = adaptor.getMax();
Operation *minDefiningOp = min.getDefiningOp();
Operation *maxDefiningOp = max.getDefiningOp();
if (minDefiningOp && maxDefiningOp &&
isa<mlir::tt::ttir::ConstantOp>(minDefiningOp) &&
isa<mlir::tt::ttir::ConstantOp>(maxDefiningOp)) {
mlir::ElementsAttr minValAttr =
mlir::cast<mlir::tt::ttir::ConstantOp>(minDefiningOp).getValueAttr();
mlir::ElementsAttr maxValAttr =
mlir::cast<mlir::tt::ttir::ConstantOp>(maxDefiningOp).getValueAttr();
if (minValAttr.isSplat() && maxValAttr.isSplat()) {
float minValue =
minValAttr.getElementType().isInteger()
? static_cast<float>(minValAttr.getSplatValue<int>())
: minValAttr.getSplatValue<float>();
float maxValue =
maxValAttr.getElementType().isInteger()
? static_cast<float>(maxValAttr.getSplatValue<int>())
: maxValAttr.getSplatValue<float>();
rewriter.replaceOpWithNewOp<mlir::tt::ttir::ClampOp>(
srcOp,
this->getTypeConverter()->convertType(outputTensor.getType()),
Value(adaptor.getOperand()), Value(outputTensor),
rewriter.getF32FloatAttr(minValue),
rewriter.getF32FloatAttr(maxValue),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

return success();
}
}

ttir::MaximumOp maximumOp = rewriter.create<mlir::tt::ttir::MaximumOp>(
srcOp->getLoc(), min, adaptor.getOperand(), outputTensor,
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

tensor::EmptyOp finalOutputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
rewriter.replaceOpWithNewOp<mlir::tt::ttir::MinimumOp>(
srcOp, maximumOp->getResult(0), max, finalOutputTensor,
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1036,6 +1104,11 @@ void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
patterns.add<StableHLOToTTIRSliceOpConversionPattern>(typeConverter, ctx);
}

void addClampOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIROpClampOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -1057,6 +1130,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
addLogicalOpConversionPattern(ctx, patterns, typeConverter);
addSliceOpConversionPattern(ctx, patterns, typeConverter);
addClampOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
15 changes: 15 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,20 @@ class TransposeOpConversionPattern
}
};

class ClampOpConversionPattern : public OpConversionPattern<ttir::ClampOp> {
public:
using OpConversionPattern<ttir::ClampOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ClampOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::ClampOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getMin(), adaptor.getMax());
return success();
}
};

class ConcatOpConversionPattern : public OpConversionPattern<ttir::ConcatOp> {
public:
using OpConversionPattern<ttir::ConcatOp>::OpConversionPattern;
Expand Down Expand Up @@ -898,6 +912,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
TypecastOpConversionPattern,
ClampOpConversionPattern,
ConcatOpConversionPattern,
ReshapeOpConversionPattern,
SliceOpConversionPattern,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::AbsOp>,
DefaultOpConversionPattern<ttnn::CbrtOp>,
DefaultOpConversionPattern<ttnn::ClampOp>,
DefaultOpConversionPattern<ttnn::FloorOp>,
DefaultOpConversionPattern<ttnn::IsFiniteOp>,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
Expand Down
38 changes: 38 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,40 @@ createOp(FlatbufferObjectCache &cache, AllGatherOp op) {
op.getDim(), op.getNumLinks());
}

::flatbuffers::Offset<::tt::target::ttnn::ClampOpParams>
createEltwiseOpParams(FlatbufferObjectCache &cache, ClampOp op) {
auto min = op.getMin().convertToFloat();
auto max = op.getMax().convertToFloat();
return ::tt::target::ttnn::CreateClampOpParams(*cache.fbb, min, max);
}

template <typename EltwiseOp>
::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp>
createNonDPSEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
::tt::target::ttnn::EltwiseOpType type;
::tt::target::ttnn::EltwiseOpParams paramsType =
::tt::target::ttnn::EltwiseOpParams::NONE;
::flatbuffers::Offset<void> params = 0;
if constexpr (std::is_same_v<EltwiseOp, ClampOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Clamp;
paramsType = ::tt::target::ttnn::EltwiseOpParams::ClampOpParams;
params = createEltwiseOpParams(cache, op).Union();
} else {
llvm_unreachable("unhandled non-DPS EltwiseOp");
}

std::vector<::flatbuffers::Offset<::tt::target::TensorRef>> ins;
for (auto input : op.getInputs()) {
ins.push_back(
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(input)));
}
assert(op.getResults().size() == 1);
auto out = cache.getOrCreate(op.getResults().front(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
return ::tt::target::ttnn::CreateEltwiseOpDirect(*cache.fbb, type, &ins, out,
paramsType, params);
}

template <typename EltwiseOp>
::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp>
createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
Expand Down Expand Up @@ -670,6 +704,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createTransposeOp(cache, transposeOp),
debugString);
}
if (auto clampOp = dyn_cast<ClampOp>(op); clampOp) {
return createOperation(cache, createNonDPSEltwiseOp(cache, clampOp),
debugString);
}
if (auto conv2dOp = dyn_cast<Conv2dOp>(op); conv2dOp) {
return createOperation(cache, createOp(cache, conv2dOp), debugString);
}
Expand Down
21 changes: 21 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,34 @@ static void runEltwiseUnaryCompositeOP(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryCompositeClampOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(const ::ttnn::Tensor &, float, float,
const ::tt::tt_metal::MemoryConfig &)>
ttnnOp) {
::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOPInputTensor(op, tensorPool, &in);

float min = op->params_as_ClampOpParams()->min();
float max = op->params_as_ClampOpParams()->max();
::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());
::ttnn::Tensor out = ttnnOp(*in, min, max, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
return;
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt: {
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Clamp: {
runEltwiseUnaryCompositeClampOP(op, tensorPool, ::ttnn::clamp);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Log1p: {
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::log1p);
break;
Expand Down
2 changes: 2 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt:
return true;
case ::tt::target::ttnn::EltwiseOpType::Clamp:
return true;
case ::tt::target::ttnn::EltwiseOpType::Log1p:
return true;
default:
Expand Down
26 changes: 26 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_transpose attributes {} {
func.func public @test_clamp_constant(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%cst = stablehlo.constant dense<2.000000e+00> : tensor<4xf32>
%cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]]
// CHECK: "ttir.clamp"(%arg0, %[[EMPTY]])
// CHECK-SAME: max = 3.000000e+00 : f32, min = 2.000000e+00 : f32,
// CHECK-SMAE: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%0 = stablehlo.clamp %cst, %arg0, %cst_0 : tensor<4xf32>
return %0 : tensor<4xf32>
}

func.func public @test_clamp_tensor(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: %[[EMPTY0:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]]
// CHECK: %[[MAX:.*]] = "ttir.maximum"(%arg1, %arg0, %[[EMPTY0]])
// CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
// CHECK: %[[EMPTY1:.*]] = tensor.empty() : [[TENSOR]]
// CHECK: %[[MIN:.*]] = "ttir.minimum"(%[[MAX]], %arg2, %[[EMPTY1]])
// CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%0 = stablehlo.clamp %arg1, %arg0, %arg2 : tensor<4xf32>
// CHECK: return %[[MIN]] : [[TENSOR]]
return %0 : tensor<4xf32>
}
}
14 changes: 14 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_clamp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> {
%0 = tensor.empty() : tensor<64x128xbf16>
// CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0,
// CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]])
// CHECK: = "ttnn.clamp"(%[[LAYOUT]])
// CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}
// CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]]
%1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
return %1 : tensor<64x128xbf16>
}
}
17 changes: 17 additions & 0 deletions test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>

func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> {
%0 = tensor.empty() : tensor<64x128xbf16>
// CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0,
// CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]])
// CHECK: = "ttnn.clamp"(%[[LAYOUT]])
// CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}
// CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]]
%1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
return %1 : tensor<64x128xbf16>
}
11 changes: 11 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ func.func @ceil(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> {
return %1 : tensor<32x32xf32>
}

func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> {
%0 = tensor.empty() : tensor<64x128xbf16>
// CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0,
// CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]])
// CHECK: = "ttnn.clamp"(%[[LAYOUT]])
// CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}
// CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]]
%1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
return %1 : tensor<64x128xbf16>
}

func.func @concat(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x96xf32>
Expand Down

0 comments on commit 669e8a0

Please sign in to comment.