Skip to content

Commit

Permalink
Fix pr comment
Browse files Browse the repository at this point in the history
  • Loading branch information
ddilbazTT committed Nov 7, 2024
1 parent b833e79 commit f406619
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 41 deletions.
30 changes: 24 additions & 6 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ struct GatherToEmbeddingConversionPattern

LogicalResult checkBasicLegality(ttir::GatherOp op,
PatternRewriter &rewriter) const {
// variables for embedding pattern matching checks
auto outputType = mlir::cast<RankedTensorType>(op.getResult().getType());
auto shape = outputType.getShape();
auto startIndices = op.getStartIndices(); // start indices of the gather op
Expand All @@ -419,42 +420,54 @@ struct GatherToEmbeddingConversionPattern
auto offsetDims = op.getOffsetDims();
auto collapsedSliceDims =
op.getCollapsedSliceDims(); // collapsed slice dims of the gather op

if (shape.size() > 1) {
auto hiddenDim = shape[shape.size() - 1];
assert(sliceSizes.size() > 1 &&
"sliceSizes should have at least 2 elements");
// check if sliceSizes is [1, hiddenDim]
if (sliceSizes[0] != 1 || sliceSizes[1] != hiddenDim) {
return rewriter.notifyMatchFailure(op, "Did not satisfy sliceSizes");
}
}
if (offsetDims.size() != 1 &&
std::vector<int64_t>(offsetDims.begin(), offsetDims.end()) !=
std::vector<int64_t>{2}) {

// check if offsetDims is [2]
if (std::vector<int64_t>(offsetDims.begin(), offsetDims.end()) !=
std::vector<int64_t>{2}) {
return rewriter.notifyMatchFailure(op, "Did not satisfy offsetDims");
}
if (collapsedSliceDims.size() != 1 ||
std::vector<int64_t>(collapsedSliceDims.begin(),

// check if collapsedSliceDims is [0]
if (std::vector<int64_t>(collapsedSliceDims.begin(),
collapsedSliceDims.end()) !=
std::vector<int64_t>{0}) {
std::vector<int64_t>{0}) {
return rewriter.notifyMatchFailure(op,
"Did not satisfy collapsedSliceDims");
}

// check if startIndices and output have same shape, if not, check if
// reshape is possible can reshape startIndices to remove the last dimension
// if it is 1
if (shape.size() == startIndicesType.getShape().size() &&
startIndicesType.getShape()[shape.size() - 1] != 1) {
return rewriter.notifyMatchFailure(op,
"Did not satisfy startIndicesType");
}

return success();
}
ttir::ReshapeOp createReshapeOp(PatternRewriter &rewriter, Location loc,
Value input,
::llvm::ArrayRef<int64_t> shapei64,
::mlir::ArrayAttr operandConstraints) const {

// reshape start indices (input) to remove the last dimension
auto ty = mlir::cast<RankedTensorType>(input.getType());
auto output = rewriter.create<tensor::EmptyOp>(
loc, llvm::ArrayRef<int64_t>(shapei64), ty.getElementType());
std::vector<int32_t> shapei32(shapei64.begin(), shapei64.end());
auto shape_attr = rewriter.getI32ArrayAttr(shapei32);

return rewriter.create<ttir::ReshapeOp>(
loc, mlir::RankedTensorType::get(shapei64, ty.getElementType()), input,
output, shape_attr, operandConstraints);
Expand All @@ -479,13 +492,16 @@ struct GatherToEmbeddingConversionPattern
// before gather/ embedding op
std::vector<int64_t> newShapeI64(startIndicesType.getShape().begin(),
startIndicesType.getShape().end() - 1);

ttir::ReshapeOp reshapeOp =
createReshapeOp(rewriter, op.getLoc(), startIndices, newShapeI64,
op.getOperandConstraints());

assert(reshapeOp && "Failed to create reshape op");
reshapeOp->moveBefore(op);
input = reshapeOp.getResult();
}

ttir::EmbeddingOp embeddingOp = rewriter.create<ttir::EmbeddingOp>(
op.getLoc(), op.getResult().getType(),
input, // input - start indices
Expand All @@ -495,8 +511,10 @@ struct GatherToEmbeddingConversionPattern
SmallVector<Attribute>(op.getNumOperands() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

assert(embeddingOp != nullptr && "Failed to create embedding op");
rewriter.replaceOp(op, embeddingOp);

return success();
}
};
Expand Down
18 changes: 0 additions & 18 deletions test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,16 @@ module attributes {} {
%0 = tensor.empty() : tensor<1x32x1024xf32>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.gather"(%operand, %start_indices, %0) {
// Specify which dimensions in the output shape correspond to the slice dimensions
offset_dims = array<i64: 2>,

// Specify which dimensions should be collapsed/removed from the slice shape
collapsed_slice_dims = array<i64: 0>,

// Specify which dimensions in operand represent batches
operand_batching_dims = array<i64: 0>,

// Specify which dimensions in start_indices represent batches
start_indices_batching_dims = array<i64: 0>,

// Map from index vector components to input dimensions
start_index_map = array<i64: 0>,

// Which dimension in start_indices contains the gather indices
index_vector_dim = 1 : si64,

// Size of the slice to gather for each dimension
slice_sizes = array<i64: 1, 1024>,

// Whether indices are guaranteed to be sorted
indices_are_sorted = false,

// Any constraints on the operands (implementation specific)
operand_constraints = [#any_device, #any_device, #any_device]
} : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32>

return %1 : tensor<1x32x1024xf32>
}
}
17 changes: 0 additions & 17 deletions test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,14 @@ module attributes {} {
%0 = tensor.empty() : tensor<1x32x1024xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"(%start_indices, %operand, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<32000x1024xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16>
%1 = "ttir.gather"(%operand, %start_indices, %0) {
// Specify which dimensions in the output shape correspond to the slice dimensions
offset_dims = array<i64: 2>,

// Specify which dimensions should be collapsed/removed from the slice shape
collapsed_slice_dims = array<i64: 0>,

// Specify which dimensions in operand represent batches
operand_batching_dims = array<i64: 0>,

// Specify which dimensions in start_indices represent batches
start_indices_batching_dims = array<i64: 0>,

// Map from index vector components to input dimensions
start_index_map = array<i64: 0>,

// Which dimension in start_indices contains the gather indices
index_vector_dim = 1 : si64,

// Size of the slice to gather for each dimension
slice_sizes = array<i64: 1, 1024>,

// Whether indices are guaranteed to be sorted
indices_are_sorted = false,

// Any constraints on the operands (implementation specific)
operand_constraints = [#any_device, #any_device, #any_device]
} : (tensor<32000x1024xbf16>, tensor<1x32xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16>
return %1 : tensor<1x32x1024xbf16>
Expand Down

0 comments on commit f406619

Please sign in to comment.