diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 381e3750f..9effabdf5 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -850,6 +850,33 @@ def TTIR_UnsqueezeOp : TTIR_DPSOp<"unsqueeze"> { let hasVerifier = 1; } +def TTIR_ClampOp : TTIR_DPSOp<"clamp"> { + let summary = "Clamp op."; + let description = [{ + Clamp tensor values to a specified range. + + Example: + min: 2.000000+00 + input: [[0, 1, 2, 3, 4, 5, 6, 7]] + max: 5.000000+00 + + "ttir.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}> + -> %out = [[2, 2, 2, 3, 4, 5, 5, 5]] + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + F32Attr:$min, + F32Attr:$max, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let results = (outs AnyRankedTensor:$result); +} + def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike, AllShapesMatch<["value", "result"]>]> { let summary = "Constant op."; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 91cb51cca..39c235642 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -638,6 +638,27 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> { let hasVerifier = 1; } +def TTNN_ClampOp : TTNN_Op<"clamp"> { + let summary = "Clamp op."; + let description = [{ + Clamp tensor values to a specified range. + + Example: + min: 2.000000+00 + input: [[0, 1, 2, 3, 4, 5, 6, 7]] + max: 5.000000+00 + + "ttnn.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}> + -> %out = [[2, 2, 2, 3, 4, 5, 5, 5]] + }]; + + let arguments = (ins Variadic:$inputs, + F32Attr:$min, + F32Attr:$max); + + let results = (outs Variadic:$result); +} + // Note: NoMemoryEffect is used to indicate that operation can be removed if it is not used. // Removal of this operation is done by the dead code elimination pass (RemoveDeadValuesPass). def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> { diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 1918fa035..390a5d366 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -3,6 +3,11 @@ include "Common/debug_info.fbs"; namespace tt.target.ttnn; +table ClampOpParams { + min: float; + max: float; +} + table GetDeviceOp { mesh: Dim2d; chip_ids: [uint32]; @@ -93,10 +98,11 @@ enum EltwiseOpType: uint32 { Remainder = 32, IsFinite = 33, Floor = 34, + Clamp = 35, } union EltwiseOpParams { - + ClampOpParams, } table EltwiseOp { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 4672e144b..5dc4b50ed 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -894,6 +894,74 @@ class StableHLOToTTIRSliceOpConversionPattern } }; +class StableHLOToTTIROpClampOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::ClampOp srcOp, + mlir::stablehlo::ClampOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResult().getType())); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + Value min = adaptor.getMin(); + Value max = adaptor.getMax(); + Operation *minDefiningOp = min.getDefiningOp(); + Operation *maxDefiningOp = max.getDefiningOp(); + if (minDefiningOp && maxDefiningOp && + isa(minDefiningOp) && + isa(maxDefiningOp)) { + mlir::ElementsAttr minValAttr = + mlir::cast(minDefiningOp).getValueAttr(); + mlir::ElementsAttr maxValAttr = + mlir::cast(maxDefiningOp).getValueAttr(); + if (minValAttr.isSplat() && maxValAttr.isSplat()) { + float minValue = + minValAttr.getElementType().isInteger() + ? static_cast(minValAttr.getSplatValue()) + : minValAttr.getSplatValue(); + float maxValue = + maxValAttr.getElementType().isInteger() + ? static_cast(maxValAttr.getSplatValue()) + : maxValAttr.getSplatValue(); + rewriter.replaceOpWithNewOp( + srcOp, + this->getTypeConverter()->convertType(outputTensor.getType()), + Value(adaptor.getOperand()), Value(outputTensor), + rewriter.getF32FloatAttr(minValue), + rewriter.getF32FloatAttr(maxValue), + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + return success(); + } + } + + ttir::MaximumOp maximumOp = rewriter.create( + srcOp->getLoc(), min, adaptor.getOperand(), outputTensor, + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + tensor::EmptyOp finalOutputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + rewriter.replaceOpWithNewOp( + srcOp, maximumOp->getResult(0), max, finalOutputTensor, + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1036,6 +1104,11 @@ void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, patterns.add(typeConverter, ctx); } +void addClampOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -1057,6 +1130,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addReshapeOpConversionPattern(ctx, patterns, typeConverter); addLogicalOpConversionPattern(ctx, patterns, typeConverter); addSliceOpConversionPattern(ctx, patterns, typeConverter); + addClampOpConversionPattern(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index f18d4034b..73afaacda 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -352,6 +352,20 @@ class TransposeOpConversionPattern } }; +class ClampOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ClampOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), adaptor.getMin(), adaptor.getMax()); + return success(); + } +}; + class ConcatOpConversionPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -898,6 +912,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, SoftmaxOpConversionPattern, TransposeOpConversionPattern, TypecastOpConversionPattern, + ClampOpConversionPattern, ConcatOpConversionPattern, ReshapeOpConversionPattern, SliceOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 0a04c53a8..3a5ac51a2 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -616,6 +616,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // patterns.add, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index f5f178f41..9876f913a 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -290,6 +290,40 @@ createOp(FlatbufferObjectCache &cache, AllGatherOp op) { op.getDim(), op.getNumLinks()); } +::flatbuffers::Offset<::tt::target::ttnn::ClampOpParams> +createEltwiseOpParams(FlatbufferObjectCache &cache, ClampOp op) { + auto min = op.getMin().convertToFloat(); + auto max = op.getMax().convertToFloat(); + return ::tt::target::ttnn::CreateClampOpParams(*cache.fbb, min, max); +} + +template +::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp> +createNonDPSEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { + ::tt::target::ttnn::EltwiseOpType type; + ::tt::target::ttnn::EltwiseOpParams paramsType = + ::tt::target::ttnn::EltwiseOpParams::NONE; + ::flatbuffers::Offset params = 0; + if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Clamp; + paramsType = ::tt::target::ttnn::EltwiseOpParams::ClampOpParams; + params = createEltwiseOpParams(cache, op).Union(); + } else { + llvm_unreachable("unhandled non-DPS EltwiseOp"); + } + + std::vector<::flatbuffers::Offset<::tt::target::TensorRef>> ins; + for (auto input : op.getInputs()) { + ins.push_back( + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(input))); + } + assert(op.getResults().size() == 1); + auto out = cache.getOrCreate(op.getResults().front(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + return ::tt::target::ttnn::CreateEltwiseOpDirect(*cache.fbb, type, &ins, out, + paramsType, params); +} + template ::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp> createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { @@ -670,6 +704,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createTransposeOp(cache, transposeOp), debugString); } + if (auto clampOp = dyn_cast(op); clampOp) { + return createOperation(cache, createNonDPSEltwiseOp(cache, clampOp), + debugString); + } if (auto conv2dOp = dyn_cast(op); conv2dOp) { return createOperation(cache, createOp(cache, conv2dOp), debugString); } diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp index 78b23ce0e..6152ed482 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp @@ -26,6 +26,23 @@ static void runEltwiseUnaryCompositeOP( tensorPool.insert_or_assign(op->out()->global_id(), out); } +static void runEltwiseUnaryCompositeClampOP( + const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, + std::function<::ttnn::Tensor(const ::ttnn::Tensor &, float, float, + const ::tt::tt_metal::MemoryConfig &)> + ttnnOp) { + ::ttnn::Tensor *in = nullptr; + getEltwiseUnaryOPInputTensor(op, tensorPool, &in); + + float min = op->params_as_ClampOpParams()->min(); + float max = op->params_as_ClampOpParams()->max(); + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + utils::createMemoryConfig(op->out()); + ::ttnn::Tensor out = ttnnOp(*in, min, max, outputMemoryConfig); + tensorPool.insert_or_assign(op->out()->global_id(), out); + return; +} + void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); switch (op->type()) { @@ -33,6 +50,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt); break; } + case ::tt::target::ttnn::EltwiseOpType::Clamp: { + runEltwiseUnaryCompositeClampOP(op, tensorPool, ::ttnn::clamp); + break; + } case ::tt::target::ttnn::EltwiseOpType::Log1p: { runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::log1p); break; diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h index d40f32ffe..cd11ea191 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h @@ -14,6 +14,8 @@ inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { switch (op->type()) { case ::tt::target::ttnn::EltwiseOpType::Cbrt: return true; + case ::tt::target::ttnn::EltwiseOpType::Clamp: + return true; case ::tt::target::ttnn::EltwiseOpType::Log1p: return true; default: diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir new file mode 100644 index 000000000..4bad0199f --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir @@ -0,0 +1,26 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_transpose attributes {} { + func.func public @test_clamp_constant(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32> + // CHECK: %[[EMPTY:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]] + // CHECK: "ttir.clamp"(%arg0, %[[EMPTY]]) + // CHECK-SAME: max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, + // CHECK-SMAE: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.clamp %cst, %arg0, %cst_0 : tensor<4xf32> + return %0 : tensor<4xf32> + } + + func.func public @test_clamp_tensor(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: %[[EMPTY0:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]] + // CHECK: %[[MAX:.*]] = "ttir.maximum"(%arg1, %arg0, %[[EMPTY0]]) + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + // CHECK: %[[EMPTY1:.*]] = tensor.empty() : [[TENSOR]] + // CHECK: %[[MIN:.*]] = "ttir.minimum"(%[[MAX]], %arg2, %[[EMPTY1]]) + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.clamp %arg1, %arg0, %arg2 : tensor<4xf32> + // CHECK: return %[[MIN]] : [[TENSOR]] + return %0 : tensor<4xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_clamp.mlir b/test/ttmlir/Dialect/TTNN/simple_clamp.mlir new file mode 100644 index 000000000..2fccafb3d --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_clamp.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, + // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) + // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]] + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir new file mode 100644 index 000000000..facaf3499 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir @@ -0,0 +1,17 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, + // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) + // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]] + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index cdf0a0374..5772c94c3 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -19,6 +19,17 @@ func.func @ceil(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { return %1 : tensor<32x32xf32> } +func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, + // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) + // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]] + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} + func.func @concat(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<32x96xf32>