-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Gather op is lowered into embedding. Used TTIR pass from 38a4a46. Used embedding fixes from e798a17. Blocked by tt-metal issue 14584.
- Loading branch information
Showing
7 changed files
with
260 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// REQUIRES: stablehlo | ||
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s | ||
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile> | ||
module @jit_gather attributes {} { | ||
func.func public @test_gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { | ||
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1024>}> : (tensor<32000x1024xf32>, tensor<1x32xi32>) -> tensor<1x32x1024xf32> | ||
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]] | ||
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]] | ||
return %0 : tensor<1x32x1024xf32> | ||
} | ||
func.func public @test_gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> { | ||
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<448x384xf32>, tensor<1x2x1xi32>) -> tensor<1x2x384xf32> | ||
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]] | ||
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]] | ||
return %0 : tensor<1x2x384xf32> | ||
} | ||
|
||
func.func public @test_gather_2(%operand: tensor<51864x384xf32>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xf32> { | ||
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<51864x384xf32>, tensor<1x2xi32>) -> tensor<1x2x384xf32> | ||
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]] | ||
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]] | ||
return %0 : tensor<1x2x384xf32> | ||
} | ||
|
||
} |
21 changes: 21 additions & 0 deletions
21
test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s | ||
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile> | ||
module attributes {} { | ||
func.func @forward(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { | ||
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] | ||
%0 = tensor.empty() : tensor<1x32x1024xf32> | ||
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] | ||
%1 = "ttir.gather"(%operand, %start_indices, %0) { | ||
offset_dims = array<i64: 2>, | ||
collapsed_slice_dims = array<i64: 0>, | ||
operand_batching_dims = array<i64: 0>, | ||
start_indices_batching_dims = array<i64: 0>, | ||
start_index_map = array<i64: 0>, | ||
index_vector_dim = 1 : si64, | ||
slice_sizes = array<i64: 1, 1024>, | ||
indices_are_sorted = false, | ||
operand_constraints = [#any_device, #any_device, #any_device] | ||
} : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> | ||
return %1 : tensor<1x32x1024xf32> | ||
} | ||
} |
24 changes: 24 additions & 0 deletions
24
test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir | ||
// RUN: FileCheck %s --input-file=%t.mlir | ||
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn | ||
// XFAIL: * | ||
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile> | ||
module attributes {} { | ||
func.func @forward(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xbf16>) -> tensor<1x32x1024xbf16> { | ||
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] | ||
%0 = tensor.empty() : tensor<1x32x1024xbf16> | ||
// CHECK: %[[C:.*]] = "ttnn.embedding"(%start_indices, %operand, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<32000x1024xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> | ||
%1 = "ttir.gather"(%operand, %start_indices, %0) { | ||
offset_dims = array<i64: 2>, | ||
collapsed_slice_dims = array<i64: 0>, | ||
operand_batching_dims = array<i64: 0>, | ||
start_indices_batching_dims = array<i64: 0>, | ||
start_index_map = array<i64: 0>, | ||
index_vector_dim = 1 : si64, | ||
slice_sizes = array<i64: 1, 1024>, | ||
indices_are_sorted = false, | ||
operand_constraints = [#any_device, #any_device, #any_device] | ||
} : (tensor<32000x1024xbf16>, tensor<1x32xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> | ||
return %1 : tensor<1x32x1024xbf16> | ||
} | ||
} |