Skip to content

Commit

Permalink
#14462: Conv2d account for split reader in determine_per_core_conv_bl…
Browse files Browse the repository at this point in the history
…ock_config
  • Loading branch information
Pavle Josipovic authored and pavlejosipovic committed Oct 31, 2024
1 parent 25150e6 commit 0add329
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,23 @@ OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_o
}

std::pair<uint32_t, uint32_t> determine_largest_subblock_size(
uint32_t block_height, uint32_t block_width, bool fp32_accum) {
std::vector<std::pair<uint32_t, uint32_t>> subblocks = {
uint32_t block_height, uint32_t block_width, bool fp32_accum, bool split_reader_enabled) {
constexpr std::array<std::pair<uint32_t, uint32_t>, 20> subblocks = {{
{2, 4}, {4, 2}, {1, 8}, {8, 1}, {1, 7}, {7, 1}, {2, 3}, {3, 2}, {1, 6}, {6, 1},
{1, 5}, {5, 1}, {2, 2}, {1, 4}, {4, 1}, {1, 3}, {3, 1}, {1, 2}, {2, 1}, {1, 1},
};
}};

uint32_t subblock_h = 0;
uint32_t subblock_w = 0;
for (auto [subblock_height, subblock_width] : subblocks) {
if (fp32_accum && (subblock_height * subblock_width > 4)) {
continue;
}

if (split_reader_enabled && (block_height / subblock_height) < 2) {
continue;
}

if ((block_height % subblock_height == 0) && (block_width % subblock_width == 0)) {
if (subblock_width != block_width && subblock_height != 1) {
continue;
Expand All @@ -250,7 +256,8 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config(
uint32_t act_block_w_div,
uint32_t window_h,
uint32_t window_w,
bool fp32_accum) {
bool fp32_accum,
bool split_reader_enabled) {

if (act_block_h_override > 0) {
TT_ASSERT(
Expand All @@ -276,8 +283,9 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config(
: grid_size.x;
uint32_t out_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles;
uint32_t weight_block_w_ntiles = conv_op_parallel_config.per_core_out_matrix_width_ntiles;
//act_block_h_ntiles / block_config.out_subblock_h_ntiles) >= 2
auto [out_subblock_h_ntiles, out_subblock_w_ntiles] =
determine_largest_subblock_size(act_block_h_ntiles, weight_block_w_ntiles, fp32_accum);
determine_largest_subblock_size(act_block_h_ntiles, weight_block_w_ntiles, fp32_accum, split_reader_enabled);
return {
.act_block_h_ntiles = act_block_h_ntiles,
.act_block_w_ntiles = act_block_w_ntiles,
Expand Down Expand Up @@ -857,7 +865,8 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
conv_config.act_block_w_div,
kernel_size[0],
kernel_size[1],
conv_config.fp32_dest_acc_enabled);
conv_config.fp32_dest_acc_enabled,
conv_config.enable_split_reader);
bool weight_is_on_device = ttnn::is_tensor_on_device_or_multidevice(weight_tensor);
ttnn::Tensor weight_tensor_on_device = weight_tensor;
std::optional<ttnn::Tensor> bias_tensor_on_device = bias_tensor;
Expand Down

0 comments on commit 0add329

Please sign in to comment.