diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index d26a3f6c0..3d3281c7b 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -704,11 +704,33 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> { let hasVerifier = 1; let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; +} + +def TTIR_GatherOp: TTIR_DPSOp<"gather"> { + let summary = "Gather operation."; + let description = [{ + Gather operation. + }]; + let arguments = (ins AnyRankedTensor:$input, // operand + AnyRankedTensor:$start_indices, // start_indices + AnyRankedTensor:$output, // result + DenseI64ArrayAttr:$offset_dims, // offset_dims + DenseI64ArrayAttr:$collapsed_slice_dims, // collapsed_slice_dims + DenseI64ArrayAttr:$operand_batching_dims, // operand_batching_dims + DenseI64ArrayAttr:$start_indices_batching_dims, // start_indices_batching_dims + DenseI64ArrayAttr:$start_index_map, // start_index_map + SI64Attr:$index_vector_dim, // index_vector_dim + DenseI64ArrayAttr:$slice_sizes, // slice_sizes + BoolAttr:$indices_are_sorted, // indices_are_sorted (bool) + TT_OperandConstraintArrayAttr:$operand_constraints); + let results = (outs AnyRankedTensor:$result); + let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } }]; } - def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> { let summary = "Applies a 2D max pooling over an input signal composed of several input planes."; let description = [{ diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 4672e144b..ad6ff293d 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -894,6 +894,48 @@ class StableHLOToTTIRSliceOpConversionPattern } }; +class StableHLOToTTIRGatherOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::GatherOp srcOp, + mlir::stablehlo::GatherOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Create the output tensor type based on inputs + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getResult().getType())); + // Create an empty output tensor with the computed shape + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + auto dimensionNumbers = srcOp.getDimensionNumbers(); + rewriter.replaceOpWithNewOp( + srcOp, // The original operation to replace + outputType, // Result type + srcOp.getOperands()[0], // Input tensor + srcOp.getOperands()[1], // Start indices + Value(outputTensor), // Output tensor + dimensionNumbers.getOffsetDims(), // offset_dims attribute + dimensionNumbers + .getCollapsedSliceDims(), // collapsed_slice_dims attribute + dimensionNumbers + .getOperandBatchingDims(), // operand_batching_dims attribute + dimensionNumbers + .getStartIndicesBatchingDims(), // start_indices_batching_dims + // attribute + dimensionNumbers.getStartIndexMap(), // start_index_map attribute + dimensionNumbers.getIndexVectorDim(), // index_vector_dim attribute + srcOp.getSliceSizesAttr(), // slice_sizes attribute + false, // indices_are_sorted attribute + rewriter.getArrayAttr( // operand constraints + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1036,6 +1078,11 @@ void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, patterns.add(typeConverter, ctx); } +void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -1057,6 +1104,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addReshapeOpConversionPattern(ctx, patterns, typeConverter); addLogicalOpConversionPattern(ctx, patterns, typeConverter); addSliceOpConversionPattern(ctx, patterns, typeConverter); + addGatherOpConversionPattern(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index f91664616..697e6ed13 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -401,6 +401,123 @@ struct ConvolutionToConv2dPattern adaptor.getOperandConstraints()); rewriter.replaceOp(op, output); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// Gather Pattern Matching +//===----------------------------------------------------------------------===// + +struct GatherToEmbeddingConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult checkBasicLegality(ttir::GatherOp op, + PatternRewriter &rewriter) const { + // variables for embedding pattern matching checks + auto outputType = mlir::cast(op.getResult().getType()); + auto shape = outputType.getShape(); + auto startIndices = op.getStartIndices(); // start indices of the gather op + auto startIndicesType = + mlir::cast(startIndices.getType()); + auto sliceSizes = op.getSliceSizes(); // slice sizes of the gather op + auto offsetDims = op.getOffsetDims(); + auto collapsedSliceDims = + op.getCollapsedSliceDims(); // collapsed slice dims of the gather op + + if (shape.size() > 1) { + auto hiddenDim = shape[shape.size() - 1]; + assert(sliceSizes.size() > 1 && + "sliceSizes should have at least 2 elements"); + // check if sliceSizes is [1, hiddenDim] + if (sliceSizes[0] != 1 || sliceSizes[1] != hiddenDim) { + return rewriter.notifyMatchFailure(op, "Did not satisfy sliceSizes"); + } + } + + // check if offsetDims is [2] + if (std::vector(offsetDims.begin(), offsetDims.end()) != + std::vector{2}) { + return rewriter.notifyMatchFailure(op, "Did not satisfy offsetDims"); + } + + // check if collapsedSliceDims is [0] + if (std::vector(collapsedSliceDims.begin(), + collapsedSliceDims.end()) != + std::vector{0}) { + return rewriter.notifyMatchFailure(op, + "Did not satisfy collapsedSliceDims"); + } + + // check if startIndices and output have same shape, if not, check if + // reshape is possible can reshape startIndices to remove the last dimension + // if it is 1 + if (shape.size() == startIndicesType.getShape().size() && + startIndicesType.getShape()[shape.size() - 1] != 1) { + return rewriter.notifyMatchFailure(op, + "Did not satisfy startIndicesType"); + } + + return success(); + } + ttir::ReshapeOp createReshapeOp(PatternRewriter &rewriter, Location loc, + Value input, + ::llvm::ArrayRef shapei64, + ::mlir::ArrayAttr operandConstraints) const { + + // reshape start indices (input) to remove the last dimension + auto ty = mlir::cast(input.getType()); + auto output = rewriter.create( + loc, llvm::ArrayRef(shapei64), ty.getElementType()); + std::vector shapei32(shapei64.begin(), shapei64.end()); + auto shape_attr = rewriter.getI32ArrayAttr(shapei32); + + return rewriter.create( + loc, mlir::RankedTensorType::get(shapei64, ty.getElementType()), input, + output, shape_attr, operandConstraints); + } + LogicalResult + matchAndRewrite(ttir::GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LogicalResult err = checkBasicLegality(op, rewriter); + if (not err.succeeded()) { + return err; + } + auto outputType = mlir::cast(op.getResult().getType()); + auto shape = outputType.getShape(); + auto startIndices = op.getStartIndices(); // start indices of the gather op + auto startIndicesType = + mlir::cast(startIndices.getType()); + ::mlir::Value input = op.getStartIndices(); + if (shape.size() == startIndicesType.getShape().size() && + startIndicesType.getShape()[shape.size() - 1] == 1) { + // reduce weight tensor dimension + // insert reshape op to remove the last dimension of start indices + // before gather/ embedding op + std::vector newShapeI64(startIndicesType.getShape().begin(), + startIndicesType.getShape().end() - 1); + + ttir::ReshapeOp reshapeOp = + createReshapeOp(rewriter, op.getLoc(), startIndices, newShapeI64, + op.getOperandConstraints()); + + assert(reshapeOp && "Failed to create reshape op"); + reshapeOp->moveBefore(op); + input = reshapeOp.getResult(); + } + + ttir::EmbeddingOp embeddingOp = rewriter.create( + op.getLoc(), op.getResult().getType(), + input, // input - start indices + op.getOperands()[0], // weight - input tensor + op.getOutput(), + rewriter.getArrayAttr( // operand constraints + SmallVector(op.getNumOperands() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + assert(embeddingOp != nullptr && "Failed to create embedding op"); + rewriter.replaceOp(op, embeddingOp); return success(); } @@ -411,6 +528,7 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, TypeConverter &typeConverter) { patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp index e621e6b28..997b6f629 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp @@ -48,6 +48,7 @@ struct TTIRToTTIRDecompositionPass // These are the ops we intend to remove entirely with this pass target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); TypeConverter typeConverter; // All types map 1:1. diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir new file mode 100644 index 000000000..ba29d123e --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir @@ -0,0 +1,25 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_gather attributes {} { + func.func public @test_gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<32000x1024xf32>, tensor<1x32xi32>) -> tensor<1x32x1024xf32> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]] + return %0 : tensor<1x32x1024xf32> + } + func.func public @test_gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> { + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<448x384xf32>, tensor<1x2x1xi32>) -> tensor<1x2x384xf32> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]] + return %0 : tensor<1x2x384xf32> + } + + func.func public @test_gather_2(%operand: tensor<51864x384xf32>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xf32> { + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<51864x384xf32>, tensor<1x2xi32>) -> tensor<1x2x384xf32> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]] + return %0 : tensor<1x2x384xf32> + } + +} diff --git a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir new file mode 100644 index 000000000..238b8f3e8 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir @@ -0,0 +1,21 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<1x32x1024xf32> + // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] + %1 = "ttir.gather"(%operand, %start_indices, %0) { + offset_dims = array, + collapsed_slice_dims = array, + operand_batching_dims = array, + start_indices_batching_dims = array, + start_index_map = array, + index_vector_dim = 1 : si64, + slice_sizes = array, + indices_are_sorted = false, + operand_constraints = [#any_device, #any_device, #any_device] + } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> + return %1 : tensor<1x32x1024xf32> + } +} diff --git a/test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir b/test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir new file mode 100644 index 000000000..52f417e3c --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir @@ -0,0 +1,24 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// XFAIL: * +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xbf16>) -> tensor<1x32x1024xbf16> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<1x32x1024xbf16> + // CHECK: %[[C:.*]] = "ttnn.embedding"(%start_indices, %operand, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<32000x1024xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> + %1 = "ttir.gather"(%operand, %start_indices, %0) { + offset_dims = array, + collapsed_slice_dims = array, + operand_batching_dims = array, + start_indices_batching_dims = array, + start_index_map = array, + index_vector_dim = 1 : si64, + slice_sizes = array, + indices_are_sorted = false, + operand_constraints = [#any_device, #any_device, #any_device] + } : (tensor<32000x1024xbf16>, tensor<1x32xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> + return %1 : tensor<1x32x1024xbf16> + } +}