diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index e2b97cfb8..ebd8aef5e 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -713,6 +713,60 @@ def TTIR_UnsqueezeOp : TTIR_DPSOp<"unsqueeze"> { let hasVerifier = 1; } +def TTIR_ClampOp : TTIR_DPSOp<"clamp"> { + let summary = "Clamp op."; + let description = [{ + Clamp tensor values to a specified range. + + Example: + min: 2.000000+00 + input: [[0, 1, 2, 3, 4, 5, 6, 7]] + max: 5.000000+00 + + "ttir.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}> + -> %out = [[2, 2, 2, 3, 4, 5, 5, 5]] + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + F32Attr:$min, + F32Attr:$max, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let results = (outs AnyRankedTensor:$result); +} + +def TTIR_ClipOp : TTIR_DPSOp<"clip"> { + let summary = "Clip op."; + let description = [{ + Clip tensor values to a specified range. + + Example: + min: 2.000000+00 + input: [[0, 1, 2, 3, 4, 5, 6, 7]] + max: 5.000000+00 + + "ttir.clip"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}> + -> %out = [[2, 2, 2, 3, 4, 5, 5, 5]] + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + F32Attr:$min, + F32Attr:$max, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let results = (outs AnyRankedTensor:$result); +} + def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike, AllShapesMatch<["value", "result"]>]> { let summary = "Constant op."; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 42518c6e7..2642e497b 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -576,6 +576,48 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> { let hasVerifier = 1; } +def TTNN_ClampOp : TTNN_Op<"clamp"> { + let summary = "Clamp op."; + let description = [{ + Clamp tensor values to a specified range. + + Example: + min: 2.000000+00 + input: [[0, 1, 2, 3, 4, 5, 6, 7]] + max: 5.000000+00 + + "ttnn.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}> + -> %out = [[2, 2, 2, 3, 4, 5, 5, 5]] + }]; + + let arguments = (ins F32Attr:$min, + AnyRankedTensor:$input, + F32Attr:$max); + + let results = (outs AnyRankedTensor:$result); +} + +def TTNN_ClipOp : TTNN_Op<"clip"> { + let summary = "Clip op."; + let description = [{ + Clip tensor values to a specified range. + + Example: + min: 2.000000+00 + input: [[0, 1, 2, 3, 4, 5, 6, 7]] + max: 5.000000+00 + + "ttnn.clip"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}> + -> %out = [[2, 2, 2, 3, 4, 5, 5, 5]] + }]; + + let arguments = (ins F32Attr:$min, + AnyRankedTensor:$input, + F32Attr:$max); + + let results = (outs AnyRankedTensor:$result); +} + // Note: NoMemoryEffect is used to indicate that operation can be removed if it is not used. // Removal of this operation is done by the dead code elimination pass (RemoveDeadValuesPass). def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> { diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 56c80410d..4166ed329 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -155,6 +155,20 @@ table MatmulOp { } // ANCHOR_END: adding_an_op_matmul_fbs +table ClampOp { + in: tt.target.TensorRef; + out: tt.target.TensorRef; + min: float; + max: float; +} + +table ClipOp { + in: tt.target.TensorRef; + out: tt.target.TensorRef; + min: float; + max: float; +} + table Conv2dOp { input: tt.target.TensorRef; weight: tt.target.TensorRef; @@ -222,6 +236,8 @@ union OpType { EmbeddingOp, SoftmaxOp, TransposeOp, + ClipOp, + ClampOp, Conv2dOp, ConcatOp, ReshapeOp, diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index d41650933..ba61b0c2c 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -852,6 +852,74 @@ class StableHLOToTTIRSliceOpConversionPattern } }; +class StableHLOToTTIROpClampOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::ClampOp srcOp, + mlir::stablehlo::ClampOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResult().getType())); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + Value min = adaptor.getMin(); + Value max = adaptor.getMax(); + Operation *minDefiningOp = min.getDefiningOp(); + Operation *maxDefiningOp = max.getDefiningOp(); + if (minDefiningOp && maxDefiningOp && + isa(minDefiningOp) && + isa(maxDefiningOp)) { + mlir::ElementsAttr minValAttr = + mlir::cast(minDefiningOp).getValueAttr(); + mlir::ElementsAttr maxValAttr = + mlir::cast(maxDefiningOp).getValueAttr(); + if (minValAttr.isSplat() && maxValAttr.isSplat()) { + float minValue = + minValAttr.getElementType().isInteger() + ? static_cast(minValAttr.getSplatValue()) + : minValAttr.getSplatValue(); + float maxValue = + maxValAttr.getElementType().isInteger() + ? static_cast(maxValAttr.getSplatValue()) + : maxValAttr.getSplatValue(); + rewriter.replaceOpWithNewOp( + srcOp, + this->getTypeConverter()->convertType(outputTensor.getType()), + Value(adaptor.getOperand()), Value(outputTensor), + rewriter.getF32FloatAttr(minValue), + rewriter.getF32FloatAttr(maxValue), + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + return success(); + } + } + + ttir::MaximumOp maximumOp = rewriter.create( + srcOp->getLoc(), min, adaptor.getOperand(), outputTensor, + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + tensor::EmptyOp finalOutputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + rewriter.replaceOpWithNewOp( + srcOp, maximumOp->getResult(0), max, finalOutputTensor, + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -981,6 +1049,11 @@ void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, patterns.add(typeConverter, ctx); } +void addClampOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -1002,6 +1075,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addReshapeOpConversionPattern(ctx, patterns, typeConverter); addLogicalOpConversionPattern(ctx, patterns, typeConverter); addSliceOpConversionPattern(ctx, patterns, typeConverter); + addClampOpConversionPattern(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 812d07b41..045b9e419 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -351,6 +351,22 @@ class TransposeOpConversionPattern } }; +template +class ClampOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TTIROpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getMin(), adaptor.getInput(), adaptor.getMax()); + return success(); + } +}; + class ConcatOpConversionPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -842,6 +858,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, SoftmaxOpConversionPattern, TransposeOpConversionPattern, TypecastOpConversionPattern, + ClampOpConversionPattern, + ClampOpConversionPattern, ConcatOpConversionPattern, ReshapeOpConversionPattern, SliceOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 4655bc0c3..d171e3d37 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -616,6 +616,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // patterns.add, DefaultOpConversionPattern, + DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 39352d1bb..ec238320f 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -256,6 +256,30 @@ createOp(FlatbufferObjectCache &cache, MatmulOp op) { } // ANCHOR_END: adding_an_op_matmul_serialize_to_binary +::flatbuffers::Offset<::tt::target::ttnn::ClampOp> +createOp(FlatbufferObjectCache &cache, ClampOp op) { + auto input = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + auto minValue = op.getMin().convertToFloat(); + auto maxValue = op.getMax().convertToFloat(); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + return ::tt::target::ttnn::CreateClampOp(*cache.fbb, input, output, minValue, + maxValue); +} + +::flatbuffers::Offset<::tt::target::ttnn::ClipOp> +createOp(FlatbufferObjectCache &cache, ClipOp op) { + auto input = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + float minValue = op.getMin().convertToFloat(); + float maxValue = op.getMax().convertToFloat(); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + return ::tt::target::ttnn::CreateClipOp(*cache.fbb, input, output, minValue, + maxValue); +} + ::flatbuffers::Offset<::tt::target::ttnn::Conv2dOp> createOp(FlatbufferObjectCache &cache, Conv2dOp op) { auto in0 = @@ -629,6 +653,12 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createTransposeOp(cache, transposeOp), debugString); } + if (auto clampOp = dyn_cast(op); clampOp) { + return createOperation(cache, createOp(cache, clampOp), debugString); + } + if (auto clipOp = dyn_cast(op); clipOp) { + return createOperation(cache, createOp(cache, clipOp), debugString); + } if (auto conv2dOp = dyn_cast(op); conv2dOp) { return createOperation(cache, createOp(cache, conv2dOp), debugString); } diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index f557d318b..c0afa828d 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -1,6 +1,8 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/clamp/clamp.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/clamp/clip.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp diff --git a/runtime/lib/ttnn/operations/clamp/clamp.cpp b/runtime/lib/ttnn/operations/clamp/clamp.cpp new file mode 100644 index 000000000..0a07eba51 --- /dev/null +++ b/runtime/lib/ttnn/operations/clamp/clamp.cpp @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "clamp.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "ttmlir/Target/TTNN/program_generated.h" +#include "ttnn/operations/eltwise/unary/unary_composite.hpp" + +namespace tt::runtime::ttnn::operations::clamp { +void run(const ::tt::target::ttnn::ClampOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id()); + const float minValue = op->min(); + const float maxValue = op->max(); + ::tt::tt_metal::MemoryConfig memoryConfig = + utils::createMemoryConfig(op->out()); + + ::ttnn::Tensor out = ::ttnn::clamp(input, minValue, maxValue, memoryConfig); + tensorPool.insert_or_assign(op->out()->global_id(), out); +} +} // namespace tt::runtime::ttnn::operations::clamp diff --git a/runtime/lib/ttnn/operations/clamp/clamp.h b/runtime/lib/ttnn/operations/clamp/clamp.h new file mode 100644 index 000000000..d19c71bd5 --- /dev/null +++ b/runtime/lib/ttnn/operations/clamp/clamp.h @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_CLAMP_H +#define TTNN_RUNTIME_CLAMP_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::clamp { +void run(const ::tt::target::ttnn::ClampOp *op, ProgramContext &context); + +} // namespace tt::runtime::ttnn::operations::clamp + +#endif // TTNN_RUNTIME_CLAMP_H diff --git a/runtime/lib/ttnn/operations/clamp/clip.cpp b/runtime/lib/ttnn/operations/clamp/clip.cpp new file mode 100644 index 000000000..235a28c30 --- /dev/null +++ b/runtime/lib/ttnn/operations/clamp/clip.cpp @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "clip.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "ttmlir/Target/TTNN/program_generated.h" +#include "ttnn/operations/eltwise/unary/unary_composite.hpp" + +namespace tt::runtime::ttnn::operations::clip { +void run(const ::tt::target::ttnn::ClipOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id()); + const float minValue = op->min(); + const float maxValue = op->max(); + ::tt::tt_metal::MemoryConfig memoryConfig = + utils::createMemoryConfig(op->out()); + + ::ttnn::Tensor out = ::ttnn::clip(input, minValue, maxValue, memoryConfig); + tensorPool.insert_or_assign(op->out()->global_id(), out); +} +} // namespace tt::runtime::ttnn::operations::clip diff --git a/runtime/lib/ttnn/operations/clamp/clip.h b/runtime/lib/ttnn/operations/clamp/clip.h new file mode 100644 index 000000000..30ad0fed9 --- /dev/null +++ b/runtime/lib/ttnn/operations/clamp/clip.h @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_CLIP_H +#define TTNN_RUNTIME_CLIP_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::clip { +void run(const ::tt::target::ttnn::ClipOp *op, ProgramContext &context); + +} // namespace tt::runtime::ttnn::operations::clip + +#endif // TTNN_RUNTIME_CLIP_H diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index ab5d651e9..57d87d4cc 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 #include "operations/ccl/all_gather.h" +#include "operations/clamp/clamp.h" +#include "operations/clamp/clip.h" #include "operations/context/get_device.h" #include "operations/conv/conv2d.h" #include "operations/creation/empty.h" @@ -104,6 +106,12 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::TransposeOp: { return operations::data_movement::run(op->type_as_TransposeOp(), context); } + case ::tt::target::ttnn::OpType::ClampOp: { + return operations::clamp::run(op->type_as_ClampOp(), context); + } + case ::tt::target::ttnn::OpType::ClipOp: { + return operations::clip::run(op->type_as_ClipOp(), context); + } case ::tt::target::ttnn::OpType::ConcatOp: { return operations::data_movement::run(op->type_as_ConcatOp(), context); } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir new file mode 100644 index 000000000..4bad0199f --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir @@ -0,0 +1,26 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_transpose attributes {} { + func.func public @test_clamp_constant(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32> + // CHECK: %[[EMPTY:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]] + // CHECK: "ttir.clamp"(%arg0, %[[EMPTY]]) + // CHECK-SAME: max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, + // CHECK-SMAE: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.clamp %cst, %arg0, %cst_0 : tensor<4xf32> + return %0 : tensor<4xf32> + } + + func.func public @test_clamp_tensor(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: %[[EMPTY0:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]] + // CHECK: %[[MAX:.*]] = "ttir.maximum"(%arg1, %arg0, %[[EMPTY0]]) + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + // CHECK: %[[EMPTY1:.*]] = tensor.empty() : [[TENSOR]] + // CHECK: %[[MIN:.*]] = "ttir.minimum"(%[[MAX]], %arg2, %[[EMPTY1]]) + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.clamp %arg1, %arg0, %arg2 : tensor<4xf32> + // CHECK: return %[[MIN]] : [[TENSOR]] + return %0 : tensor<4xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/clamp/simple_clamp.mlir b/test/ttmlir/Dialect/TTNN/clamp/simple_clamp.mlir new file mode 100644 index 000000000..2fccafb3d --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/clamp/simple_clamp.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, + // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) + // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]] + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/clamp/simple_clip.mlir b/test/ttmlir/Dialect/TTNN/clamp/simple_clip.mlir new file mode 100644 index 000000000..d5d60c077 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/clamp/simple_clip.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @clip(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, + // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) + // CHECK: = "ttnn.clip"(%[[LAYOUT]]) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]] + %1 = "ttir.clip"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir new file mode 100644 index 000000000..facaf3499 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir @@ -0,0 +1,17 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, + // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) + // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]] + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clip.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clip.mlir new file mode 100644 index 000000000..cdbd59fb5 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clip.mlir @@ -0,0 +1,17 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @clip(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, + // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) + // CHECK: = "ttnn.clip"(%[[LAYOUT]]) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]] + %1 = "ttir.clip"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_clamp.mlir b/test/ttmlir/Silicon/TTNN/simple_clamp.mlir new file mode 100644 index 000000000..8e08aa619 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/simple_clamp.mlir @@ -0,0 +1,28 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, + // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) + // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]] + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} + +func.func @clip(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, + // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) + // CHECK: = "ttnn.clip"(%[[LAYOUT]]) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #layout{{[0-9]+}}>) -> [[TENSOR]] + %1 = "ttir.clip"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +}