Skip to content

Commit

Permalink
Gather op implementation [#1015]
Browse files Browse the repository at this point in the history
Gather op is lowered into embedding. Used TTIR pass from 38a4a46. Used
embedding fixes from e798a17. Blocked by tt-metal issue 14584.
  • Loading branch information
ddilbazTT committed Nov 7, 2024
1 parent a145ead commit 7dcf489
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 1 deletion.
24 changes: 23 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -704,11 +704,33 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
let hasVerifier = 1;

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
}

def TTIR_GatherOp: TTIR_DPSOp<"gather"> {
let summary = "Gather operation.";
let description = [{
Gather operation.
}];
let arguments = (ins AnyRankedTensor:$input, // operand
AnyRankedTensor:$start_indices, // start_indices
AnyRankedTensor:$output, // result
DenseI64ArrayAttr:$offset_dims, // offset_dims
DenseI64ArrayAttr:$collapsed_slice_dims, // collapsed_slice_dims
DenseI64ArrayAttr:$operand_batching_dims, // operand_batching_dims
DenseI64ArrayAttr:$start_indices_batching_dims, // start_indices_batching_dims
DenseI64ArrayAttr:$start_index_map, // start_index_map
SI64Attr:$index_vector_dim, // index_vector_dim
DenseI64ArrayAttr:$slice_sizes, // slice_sizes
BoolAttr:$indices_are_sorted, // indices_are_sorted (bool)
TT_OperandConstraintArrayAttr:$operand_constraints);
let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
}


def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
let description = [{
Expand Down
48 changes: 48 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,48 @@ class StableHLOToTTIRSliceOpConversionPattern
}
};

class StableHLOToTTIRGatherOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::GatherOp> {
using OpConversionPattern<mlir::stablehlo::GatherOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::GatherOp srcOp,
mlir::stablehlo::GatherOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Create the output tensor type based on inputs
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
// Create an empty output tensor with the computed shape
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
auto dimensionNumbers = srcOp.getDimensionNumbers();
rewriter.replaceOpWithNewOp<mlir::tt::ttir::GatherOp>(
srcOp, // The original operation to replace
outputType, // Result type
srcOp.getOperands()[0], // Input tensor
srcOp.getOperands()[1], // Start indices
Value(outputTensor), // Output tensor
dimensionNumbers.getOffsetDims(), // offset_dims attribute
dimensionNumbers
.getCollapsedSliceDims(), // collapsed_slice_dims attribute
dimensionNumbers
.getOperandBatchingDims(), // operand_batching_dims attribute
dimensionNumbers
.getStartIndicesBatchingDims(), // start_indices_batching_dims
// attribute
dimensionNumbers.getStartIndexMap(), // start_index_map attribute
dimensionNumbers.getIndexVectorDim(), // index_vector_dim attribute
srcOp.getSliceSizesAttr(), // slice_sizes attribute
false, // indices_are_sorted attribute
rewriter.getArrayAttr( // operand constraints
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1036,6 +1078,11 @@ void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
patterns.add<StableHLOToTTIRSliceOpConversionPattern>(typeConverter, ctx);
}

void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRGatherOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -1057,6 +1104,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
addLogicalOpConversionPattern(ctx, patterns, typeConverter);
addSliceOpConversionPattern(ctx, patterns, typeConverter);
addGatherOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
118 changes: 118 additions & 0 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,123 @@ struct ConvolutionToConv2dPattern
adaptor.getOperandConstraints());

rewriter.replaceOp(op, output);
return success();
}
};
//===----------------------------------------------------------------------===//
// Gather Pattern Matching
//===----------------------------------------------------------------------===//

struct GatherToEmbeddingConversionPattern
: public OpConversionPattern<ttir::GatherOp> {
using OpConversionPattern<ttir::GatherOp>::OpConversionPattern;

LogicalResult checkBasicLegality(ttir::GatherOp op,
PatternRewriter &rewriter) const {
// variables for embedding pattern matching checks
auto outputType = mlir::cast<RankedTensorType>(op.getResult().getType());
auto shape = outputType.getShape();
auto startIndices = op.getStartIndices(); // start indices of the gather op
auto startIndicesType =
mlir::cast<RankedTensorType>(startIndices.getType());
auto sliceSizes = op.getSliceSizes(); // slice sizes of the gather op
auto offsetDims = op.getOffsetDims();
auto collapsedSliceDims =
op.getCollapsedSliceDims(); // collapsed slice dims of the gather op

if (shape.size() > 1) {
auto hiddenDim = shape[shape.size() - 1];
assert(sliceSizes.size() > 1 &&
"sliceSizes should have at least 2 elements");
// check if sliceSizes is [1, hiddenDim]
if (sliceSizes[0] != 1 || sliceSizes[1] != hiddenDim) {
return rewriter.notifyMatchFailure(op, "Did not satisfy sliceSizes");
}
}

// check if offsetDims is [2]
if (std::vector<int64_t>(offsetDims.begin(), offsetDims.end()) !=
std::vector<int64_t>{2}) {
return rewriter.notifyMatchFailure(op, "Did not satisfy offsetDims");
}

// check if collapsedSliceDims is [0]
if (std::vector<int64_t>(collapsedSliceDims.begin(),
collapsedSliceDims.end()) !=
std::vector<int64_t>{0}) {
return rewriter.notifyMatchFailure(op,
"Did not satisfy collapsedSliceDims");
}

// check if startIndices and output have same shape, if not, check if
// reshape is possible can reshape startIndices to remove the last dimension
// if it is 1
if (shape.size() == startIndicesType.getShape().size() &&
startIndicesType.getShape()[shape.size() - 1] != 1) {
return rewriter.notifyMatchFailure(op,
"Did not satisfy startIndicesType");
}

return success();
}
ttir::ReshapeOp createReshapeOp(PatternRewriter &rewriter, Location loc,
Value input,
::llvm::ArrayRef<int64_t> shapei64,
::mlir::ArrayAttr operandConstraints) const {

// reshape start indices (input) to remove the last dimension
auto ty = mlir::cast<RankedTensorType>(input.getType());
auto output = rewriter.create<tensor::EmptyOp>(
loc, llvm::ArrayRef<int64_t>(shapei64), ty.getElementType());
std::vector<int32_t> shapei32(shapei64.begin(), shapei64.end());
auto shape_attr = rewriter.getI32ArrayAttr(shapei32);

return rewriter.create<ttir::ReshapeOp>(
loc, mlir::RankedTensorType::get(shapei64, ty.getElementType()), input,
output, shape_attr, operandConstraints);
}
LogicalResult
matchAndRewrite(ttir::GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult err = checkBasicLegality(op, rewriter);
if (not err.succeeded()) {
return err;
}
auto outputType = mlir::cast<RankedTensorType>(op.getResult().getType());
auto shape = outputType.getShape();
auto startIndices = op.getStartIndices(); // start indices of the gather op
auto startIndicesType =
mlir::cast<RankedTensorType>(startIndices.getType());
::mlir::Value input = op.getStartIndices();
if (shape.size() == startIndicesType.getShape().size() &&
startIndicesType.getShape()[shape.size() - 1] == 1) {
// reduce weight tensor dimension
// insert reshape op to remove the last dimension of start indices
// before gather/ embedding op
std::vector<int64_t> newShapeI64(startIndicesType.getShape().begin(),
startIndicesType.getShape().end() - 1);

ttir::ReshapeOp reshapeOp =
createReshapeOp(rewriter, op.getLoc(), startIndices, newShapeI64,
op.getOperandConstraints());

assert(reshapeOp && "Failed to create reshape op");
reshapeOp->moveBefore(op);
input = reshapeOp.getResult();
}

ttir::EmbeddingOp embeddingOp = rewriter.create<ttir::EmbeddingOp>(
op.getLoc(), op.getResult().getType(),
input, // input - start indices
op.getOperands()[0], // weight - input tensor
op.getOutput(),
rewriter.getArrayAttr( // operand constraints
SmallVector<Attribute>(op.getNumOperands() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

assert(embeddingOp != nullptr && "Failed to create embedding op");
rewriter.replaceOp(op, embeddingOp);

return success();
}
Expand All @@ -411,6 +528,7 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
TypeConverter &typeConverter) {
patterns.add<IndexToSliceConversionPattern>(typeConverter, ctx);
patterns.add<ConvolutionToConv2dPattern>(typeConverter, ctx);
patterns.add<GatherToEmbeddingConversionPattern>(typeConverter, ctx);
}

} // namespace mlir::tt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct TTIRToTTIRDecompositionPass
// These are the ops we intend to remove entirely with this pass
target.addIllegalOp<ttir::IndexOp>();
target.addIllegalOp<ttir::ConvolutionOp>();
target.addIllegalOp<ttir::GatherOp>();

TypeConverter typeConverter;
// All types map 1:1.
Expand Down
25 changes: 25 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_gather attributes {} {
func.func public @test_gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1024>}> : (tensor<32000x1024xf32>, tensor<1x32xi32>) -> tensor<1x32x1024xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]]
return %0 : tensor<1x32x1024xf32>
}
func.func public @test_gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<448x384xf32>, tensor<1x2x1xi32>) -> tensor<1x2x384xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]]
return %0 : tensor<1x2x384xf32>
}

func.func public @test_gather_2(%operand: tensor<51864x384xf32>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<51864x384xf32>, tensor<1x2xi32>) -> tensor<1x2x384xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]]
return %0 : tensor<1x2x384xf32>
}

}
21 changes: 21 additions & 0 deletions test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x32x1024xf32>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.gather"(%operand, %start_indices, %0) {
offset_dims = array<i64: 2>,
collapsed_slice_dims = array<i64: 0>,
operand_batching_dims = array<i64: 0>,
start_indices_batching_dims = array<i64: 0>,
start_index_map = array<i64: 0>,
index_vector_dim = 1 : si64,
slice_sizes = array<i64: 1, 1024>,
indices_are_sorted = false,
operand_constraints = [#any_device, #any_device, #any_device]
} : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32>
return %1 : tensor<1x32x1024xf32>
}
}
24 changes: 24 additions & 0 deletions test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// XFAIL: *
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xbf16>) -> tensor<1x32x1024xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x32x1024xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"(%start_indices, %operand, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<32000x1024xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16>
%1 = "ttir.gather"(%operand, %start_indices, %0) {
offset_dims = array<i64: 2>,
collapsed_slice_dims = array<i64: 0>,
operand_batching_dims = array<i64: 0>,
start_indices_batching_dims = array<i64: 0>,
start_index_map = array<i64: 0>,
index_vector_dim = 1 : si64,
slice_sizes = array<i64: 1, 1024>,
indices_are_sorted = false,
operand_constraints = [#any_device, #any_device, #any_device]
} : (tensor<32000x1024xbf16>, tensor<1x32xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16>
return %1 : tensor<1x32x1024xbf16>
}
}

0 comments on commit 7dcf489

Please sign in to comment.