Skip to content

Commit

Permalink
Add support for clamp and clip op.
Browse files Browse the repository at this point in the history
* Add end-to-end implementation of the ops.
* Add stablehlo to ttir conversion for clamp op.
  • Loading branch information
mmanzoorTT committed Oct 30, 2024
1 parent 22a06f2 commit d74020f
Show file tree
Hide file tree
Showing 19 changed files with 442 additions and 0 deletions.
54 changes: 54 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,60 @@ def TTIR_UnsqueezeOp : TTIR_DPSOp<"unsqueeze"> {
let hasVerifier = 1;
}

def TTIR_ClampOp : TTIR_DPSOp<"clamp"> {
let summary = "Clamp op.";
let description = [{
Clamp tensor values to a specified range.

Example:
min: 2.000000+00
input: [[0, 1, 2, 3, 4, 5, 6, 7]]
max: 5.000000+00

"ttir.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}>
-> %out = [[2, 2, 2, 3, 4, 5, 5, 5]]
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
F32Attr:$min,
F32Attr:$max,
TT_OperandConstraintArrayAttr:$operand_constraints);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let results = (outs AnyRankedTensor:$result);
}

def TTIR_ClipOp : TTIR_DPSOp<"clip"> {
let summary = "Clip op.";
let description = [{
Clip tensor values to a specified range.

Example:
min: 2.000000+00
input: [[0, 1, 2, 3, 4, 5, 6, 7]]
max: 5.000000+00

"ttir.clip"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}>
-> %out = [[2, 2, 2, 3, 4, 5, 5, 5]]
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
F32Attr:$min,
F32Attr:$max,
TT_OperandConstraintArrayAttr:$operand_constraints);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let results = (outs AnyRankedTensor:$result);
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
Expand Down
42 changes: 42 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,48 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
let hasVerifier = 1;
}

def TTNN_ClampOp : TTNN_Op<"clamp"> {
let summary = "Clamp op.";
let description = [{
Clamp tensor values to a specified range.

Example:
min: 2.000000+00
input: [[0, 1, 2, 3, 4, 5, 6, 7]]
max: 5.000000+00

"ttnn.clamp"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}>
-> %out = [[2, 2, 2, 3, 4, 5, 5, 5]]
}];

let arguments = (ins F32Attr:$min,
AnyRankedTensor:$input,
F32Attr:$max);

let results = (outs AnyRankedTensor:$result);
}

def TTNN_ClipOp : TTNN_Op<"clip"> {
let summary = "Clip op.";
let description = [{
Clip tensor values to a specified range.

Example:
min: 2.000000+00
input: [[0, 1, 2, 3, 4, 5, 6, 7]]
max: 5.000000+00

"ttnn.clip"(%arg0) <{max = 2.000000e+00 : f32, min = 5.000000e+00 : f32}>
-> %out = [[2, 2, 2, 3, 4, 5, 5, 5]]
}];

let arguments = (ins F32Attr:$min,
AnyRankedTensor:$input,
F32Attr:$max);

let results = (outs AnyRankedTensor:$result);
}

// Note: NoMemoryEffect is used to indicate that operation can be removed if it is not used.
// Removal of this operation is done by the dead code elimination pass (RemoveDeadValuesPass).
def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> {
Expand Down
16 changes: 16 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,20 @@ table MatmulOp {
}
// ANCHOR_END: adding_an_op_matmul_fbs

table ClampOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
min: float;
max: float;
}

table ClipOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
min: float;
max: float;
}

table Conv2dOp {
input: tt.target.TensorRef;
weight: tt.target.TensorRef;
Expand Down Expand Up @@ -222,6 +236,8 @@ union OpType {
EmbeddingOp,
SoftmaxOp,
TransposeOp,
ClipOp,
ClampOp,
Conv2dOp,
ConcatOp,
ReshapeOp,
Expand Down
74 changes: 74 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,74 @@ class StableHLOToTTIRSliceOpConversionPattern
}
};

class StableHLOToTTIROpClampOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ClampOp> {

using OpConversionPattern<mlir::stablehlo::ClampOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ClampOp srcOp,
mlir::stablehlo::ClampOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
Value min = adaptor.getMin();
Value max = adaptor.getMax();
Operation *minDefiningOp = min.getDefiningOp();
Operation *maxDefiningOp = max.getDefiningOp();
if (minDefiningOp && maxDefiningOp &&
isa<mlir::tt::ttir::ConstantOp>(minDefiningOp) &&
isa<mlir::tt::ttir::ConstantOp>(maxDefiningOp)) {
mlir::ElementsAttr minValAttr =
mlir::cast<mlir::tt::ttir::ConstantOp>(minDefiningOp).getValueAttr();
mlir::ElementsAttr maxValAttr =
mlir::cast<mlir::tt::ttir::ConstantOp>(maxDefiningOp).getValueAttr();
if (minValAttr.isSplat() && maxValAttr.isSplat()) {
float minValue =
minValAttr.getElementType().isInteger()
? static_cast<float>(minValAttr.getSplatValue<int>())
: minValAttr.getSplatValue<float>();
float maxValue =
maxValAttr.getElementType().isInteger()
? static_cast<float>(maxValAttr.getSplatValue<int>())
: maxValAttr.getSplatValue<float>();
rewriter.replaceOpWithNewOp<mlir::tt::ttir::ClampOp>(
srcOp,
this->getTypeConverter()->convertType(outputTensor.getType()),
Value(adaptor.getOperand()), Value(outputTensor),
rewriter.getF32FloatAttr(minValue),
rewriter.getF32FloatAttr(maxValue),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

return success();
}
}

ttir::MaximumOp maximumOp = rewriter.create<mlir::tt::ttir::MaximumOp>(
srcOp->getLoc(), min, adaptor.getOperand(), outputTensor,
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

tensor::EmptyOp finalOutputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
rewriter.replaceOpWithNewOp<mlir::tt::ttir::MinimumOp>(
srcOp, maximumOp->getResult(0), max, finalOutputTensor,
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -981,6 +1049,11 @@ void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
patterns.add<StableHLOToTTIRSliceOpConversionPattern>(typeConverter, ctx);
}

void addClampOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIROpClampOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -1002,6 +1075,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
addLogicalOpConversionPattern(ctx, patterns, typeConverter);
addSliceOpConversionPattern(ctx, patterns, typeConverter);
addClampOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
18 changes: 18 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,22 @@ class TransposeOpConversionPattern
}
};

template <typename TTIROpTy, typename TTNNOpTy,
typename OpAdaptor = typename TTIROpTy::Adaptor>
class ClampOpConversionPattern : public OpConversionPattern<TTIROpTy> {
public:
using OpConversionPattern<TTIROpTy>::OpConversionPattern;

LogicalResult
matchAndRewrite(TTIROpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TTNNOpTy>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getMin(), adaptor.getInput(), adaptor.getMax());
return success();
}
};

class ConcatOpConversionPattern : public OpConversionPattern<ttir::ConcatOp> {
public:
using OpConversionPattern<ttir::ConcatOp>::OpConversionPattern;
Expand Down Expand Up @@ -842,6 +858,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
TypecastOpConversionPattern,
ClampOpConversionPattern<ttir::ClampOp, ttnn::ClampOp>,
ClampOpConversionPattern<ttir::ClipOp, ttnn::ClipOp>,
ConcatOpConversionPattern,
ReshapeOpConversionPattern,
SliceOpConversionPattern,
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::AbsOp>,
DefaultOpConversionPattern<ttnn::CbrtOp>,
DefaultOpConversionPattern<ttnn::ClampOp>,
DefaultOpConversionPattern<ttnn::ClipOp>,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
DefaultOpConversionPattern<ttnn::NegOp>,
DefaultOpConversionPattern<ttnn::ReluOp>,
Expand Down
30 changes: 30 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ createOp(FlatbufferObjectCache &cache, MatmulOp op) {
}
// ANCHOR_END: adding_an_op_matmul_serialize_to_binary

::flatbuffers::Offset<::tt::target::ttnn::ClampOp>
createOp(FlatbufferObjectCache &cache, ClampOp op) {
auto input =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto minValue = op.getMin().convertToFloat();
auto maxValue = op.getMax().convertToFloat();
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
return ::tt::target::ttnn::CreateClampOp(*cache.fbb, input, output, minValue,
maxValue);
}

::flatbuffers::Offset<::tt::target::ttnn::ClipOp>
createOp(FlatbufferObjectCache &cache, ClipOp op) {
auto input =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
float minValue = op.getMin().convertToFloat();
float maxValue = op.getMax().convertToFloat();
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
return ::tt::target::ttnn::CreateClipOp(*cache.fbb, input, output, minValue,
maxValue);
}

::flatbuffers::Offset<::tt::target::ttnn::Conv2dOp>
createOp(FlatbufferObjectCache &cache, Conv2dOp op) {
auto in0 =
Expand Down Expand Up @@ -629,6 +653,12 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createTransposeOp(cache, transposeOp),
debugString);
}
if (auto clampOp = dyn_cast<ClampOp>(op); clampOp) {
return createOperation(cache, createOp(cache, clampOp), debugString);
}
if (auto clipOp = dyn_cast<ClipOp>(op); clipOp) {
return createOperation(cache, createOp(cache, clipOp), debugString);
}
if (auto conv2dOp = dyn_cast<Conv2dOp>(op); conv2dOp) {
return createOperation(cache, createOp(cache, conv2dOp), debugString);
}
Expand Down
2 changes: 2 additions & 0 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/clamp/clamp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/clamp/clip.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp
Expand Down
24 changes: 24 additions & 0 deletions runtime/lib/ttnn/operations/clamp/clamp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "clamp.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttmlir/Target/TTNN/program_generated.h"
#include "ttnn/operations/eltwise/unary/unary_composite.hpp"

namespace tt::runtime::ttnn::operations::clamp {
void run(const ::tt::target::ttnn::ClampOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
const float minValue = op->min();
const float maxValue = op->max();
::tt::tt_metal::MemoryConfig memoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ::ttnn::clamp(input, minValue, maxValue, memoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::clamp
16 changes: 16 additions & 0 deletions runtime/lib/ttnn/operations/clamp/clamp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_CLAMP_H
#define TTNN_RUNTIME_CLAMP_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::clamp {
void run(const ::tt::target::ttnn::ClampOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::clamp

#endif // TTNN_RUNTIME_CLAMP_H
24 changes: 24 additions & 0 deletions runtime/lib/ttnn/operations/clamp/clip.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "clip.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttmlir/Target/TTNN/program_generated.h"
#include "ttnn/operations/eltwise/unary/unary_composite.hpp"

namespace tt::runtime::ttnn::operations::clip {
void run(const ::tt::target::ttnn::ClipOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
const float minValue = op->min();
const float maxValue = op->max();
::tt::tt_metal::MemoryConfig memoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ::ttnn::clip(input, minValue, maxValue, memoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::clip
Loading

0 comments on commit d74020f

Please sign in to comment.