diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common.hpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common.hpp new file mode 100644 index 00000000000..c811dfccdc3 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common.hpp @@ -0,0 +1,170 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void fill_cb_with_value(uint32_t cb_id, uint32_t value) { + cb_reserve_back(cb_id, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_id)); + for (int j = 0; j < 1024; j++) { + ptr[j] = uint16_t(value >> 16); + } + cb_push_back(cb_id, 1); +} + +void generate_mask_h_w(uint32_t cb_mask_h_w, uint32_t mask_h, uint32_t mask_w, uint32_t single_tile_size = 2048) { + union { + float f; + uint32_t u; + } one; + one.f = 1.0f; + union { + float f; + uint32_t u; + } zero; + zero.f = 0.0f; + + const auto u16_one = uint16_t(one.u >> 16); + const auto u16_zero = uint16_t(zero.u >> 16); + + cb_reserve_back(cb_mask_h_w, 2); + + // mask_h + // first tile ptr + auto mask_h_ptr = reinterpret_cast(get_write_ptr(cb_mask_h_w)); + for (uint32_t w = 0; w < 16; w++) { + // sub tile 0 + { + uint32_t mask_h_0 = mask_h; + if (mask_h_0 >= 16) { + mask_h_0 = 16; + } + uint32_t h = 0; + for (; h < mask_h_0; h++) { + mask_h_ptr[h * 16 + w] = u16_one; + } + for (; h < 16; h++) { + mask_h_ptr[h * 16 + w] = u16_zero; + } + } + + // sub tile 1 + { + uint32_t mask_h_0 = mask_h; + if (mask_h_0 >= 16) { + mask_h_0 = 16; + } + uint32_t h = 0; + for (; h < mask_h_0; h++) { + mask_h_ptr[h * 16 + w + 256] = u16_one; + } + for (; h < 16; h++) { + mask_h_ptr[h * 16 + w + 256] = u16_zero; + } + } + + // sub tile 2 + { + uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; + uint32_t h = 0; + for (; h < mask_h_1; h++) { + mask_h_ptr[h * 16 + w + 512] = u16_one; + } + for (; h < 16; h++) { + mask_h_ptr[h * 16 + w + 512] = u16_zero; + } + } + + // sub tile 3 + { + uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; + uint32_t h = 0; + for (; h < mask_h_1; h++) { + mask_h_ptr[h * 16 + w + 768] = u16_one; + } + for (; h < 16; h++) { + mask_h_ptr[h * 16 + w + 768] = u16_zero; + } + } + } + + // mask_w + // second tile ptr + auto mask_w_ptr = reinterpret_cast(get_write_ptr(cb_mask_h_w) + single_tile_size); + for (uint32_t h = 0; h < 16; h++) { + // sub tile 0 + { + uint32_t mask_w_0 = mask_w; + if (mask_w_0 >= 16) { + mask_w_0 = 16; + } + uint32_t w = 0; + for (; w < mask_w_0; w++) { + mask_w_ptr[h * 16 + w] = u16_one; + } + for (; w < 16; w++) { + mask_w_ptr[h * 16 + w] = u16_zero; + } + } + + // sub tile 1 + { + uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; + uint32_t w = 0; + for (; w < mask_w_1; w++) { + mask_w_ptr[h * 16 + w + 256] = u16_one; + } + for (; w < 16; w++) { + mask_w_ptr[h * 16 + w + 256] = u16_zero; + } + } + + // sub tile 2 + { + uint32_t mask_w_0 = mask_w; + if (mask_w_0 >= 16) { + mask_w_0 = 16; + } + uint32_t w = 0; + for (; w < mask_w_0; w++) { + mask_w_ptr[h * 16 + w + 512] = u16_one; + } + for (; w < 16; w++) { + mask_w_ptr[h * 16 + w + 512] = u16_zero; + } + } + + // sub tile 3 + { + uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; + uint32_t w = 0; + for (; w < mask_w_1; w++) { + mask_w_ptr[h * 16 + w + 768] = u16_one; + } + for (; w < 16; w++) { + mask_w_ptr[h * 16 + w + 768] = u16_zero; + } + } + } + + cb_push_back(cb_mask_h_w, 2); +} + +void generate_mask_h_w_if_needed(uint32_t cb_mask_h_w, uint32_t origin_h, uint32_t origin_w) { + constexpr uint32_t TILE_H = 32; + constexpr uint32_t TILE_W = 32; + + const bool do_mask_h = (origin_h % TILE_H) != 0; + const uint32_t mask_h = do_mask_h ? (origin_h % TILE_H) : TILE_H; + + const bool do_mask_w = (origin_w % TILE_W) != 0; + const uint32_t mask_w = do_mask_w ? (origin_w % TILE_W) : TILE_W; + + if (do_mask_h || do_mask_w) { + const uint32_t mask_tile_bytes = get_tile_size(cb_mask_h_w); + generate_mask_h_w(cb_mask_h_w, mask_h, mask_w, mask_tile_bytes); + } +} diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common_ckernels.hpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common_ckernels.hpp new file mode 100644 index 00000000000..b76b17ba272 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common_ckernels.hpp @@ -0,0 +1,107 @@ +/* + * SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "compute_kernel_api.h" +#include "compute_kernel_api/bcast.h" +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/eltwise_unary/exp.h" +#include "compute_kernel_api/eltwise_unary/recip.h" +#include "compute_kernel_api/mask.h" +#include "compute_kernel_api/reduce.h" +#include "compute_kernel_api/tile_move_copy.h" + +ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } +ALWI void REL() { release_dst(tt::DstMode::Half); } + +namespace ckernel { + +ALWI void power_tile_to_cb( + std::uint8_t cb_x, + std::uint8_t cb_xpow, + std::uint8_t cb_logx, + std::uint8_t cb_decimal, + std::uint8_t cb_exp_lxmd, + std::uint8_t cb_correct_xpow, + uint32_t p, + bool p_is_negative) { + constexpr uint32_t onetile = 1; + constexpr uint32_t dst0 = 0; + + // x^p + ACQ(); + cb_wait_front(cb_x, onetile); + cb_reserve_back(cb_xpow, onetile); + + copy_tile_init(); + copy_tile(cb_x, 0, dst0); + + power_tile_init(); + power_tile(dst0, p); + + if (p_is_negative) { + recip_tile_init(); + recip_tile(dst0); + } + + pack_tile(dst0, cb_xpow); + + cb_push_back(cb_xpow, onetile); + REL(); + // We don't pop cb_x here. + + // log(x) + ACQ(); + cb_reserve_back(cb_logx, onetile); + + copy_tile_init(); + copy_tile(cb_x, 0, dst0); + + log_tile_init(); + log_tile(dst0); + + pack_tile(dst0, cb_logx); + + cb_pop_front(cb_x, onetile); + cb_push_back(cb_logx, onetile); + REL(); + + // exp(log(x) * decimal) + ACQ(); + cb_wait_front(cb_logx, onetile); + cb_reserve_back(cb_exp_lxmd, onetile); + + mul_tiles_init(); + mul_tiles(cb_logx, cb_decimal, 0, 0, dst0); + + exp_tile_init(); + exp_tile(dst0); + + pack_tile(dst0, cb_exp_lxmd); + + cb_pop_front(cb_logx, onetile); + cb_push_back(cb_exp_lxmd, onetile); + REL(); + + // x^p * exp(log(x) * decimal)(==(x + decimal)^p) + ACQ(); + cb_wait_front(cb_xpow, onetile); + cb_wait_front(cb_exp_lxmd, onetile); + cb_reserve_back(cb_correct_xpow, onetile); + + mul_tiles_init(); + mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0); + + pack_tile(dst0, cb_correct_xpow); + + cb_pop_front(cb_xpow, onetile); + cb_pop_front(cb_exp_lxmd, onetile); + cb_push_back(cb_correct_xpow, onetile); + REL(); +} + +} // namespace ckernel diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/moreh_clip_grad_norm_step1_kernel.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/moreh_clip_grad_norm_step1_kernel.cpp index cd757a9083b..0f60d1b6454 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/moreh_clip_grad_norm_step1_kernel.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/moreh_clip_grad_norm_step1_kernel.cpp @@ -2,20 +2,9 @@ // // SPDX-License-Identifier: Apache-2.0 -#include +#include "tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common_ckernels.hpp" -#include "compute_kernel_api.h" -#include "compute_kernel_api/eltwise_binary.h" -#include "compute_kernel_api/eltwise_unary/exp.h" -#include "compute_kernel_api/eltwise_unary/recip.h" -#include "compute_kernel_api/mask.h" -#include "compute_kernel_api/reduce.h" -#include "compute_kernel_api/tile_move_copy.h" - -ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } -ALWI void REL() { release_dst(tt::DstMode::Half); } - -inline bool need_to_do_mask_h(uint32_t tile_idx, uint32_t ht, uint32_t wt) { return (((tile_idx / wt) + 1) % ht) == 0; } +ALWI bool need_to_do_mask_h(uint32_t tile_idx, uint32_t ht, uint32_t wt) { return (((tile_idx / wt) + 1) % ht) == 0; } namespace NAMESPACE { void MAIN { @@ -101,80 +90,8 @@ void MAIN { cb_push_back(cb_xabs, onetile); REL(); - // Compute cb_logx - // log(|x|) - ACQ(); - cb_wait_front(cb_xabs, onetile); - cb_reserve_back(cb_logx, onetile); - - copy_tile_init(); - copy_tile(cb_xabs, 0, dst0); - - log_tile_init(); - log_tile(dst0); - - pack_tile(dst0, cb_logx); - - cb_push_back(cb_logx, onetile); - REL(); - // We don't pop cb_xabs here. - - // Compute cb_exp_lxmd - // exp(log(|x|) * decimal) - ACQ(); - cb_wait_front(cb_logx, onetile); - cb_reserve_back(cb_exp_lxmd, onetile); - - mul_tiles_init(); - mul_tiles(cb_logx, cb_decimal, 0, 0, dst0); - - exp_tile_init(); - exp_tile(dst0); - - pack_tile(dst0, cb_exp_lxmd); - - cb_pop_front(cb_logx, onetile); - cb_push_back(cb_exp_lxmd, onetile); - REL(); - - // Compute cb_xpow - // |x|^p - ACQ(); - cb_reserve_back(cb_xpow, onetile); - - copy_tile_init(); - copy_tile(cb_xabs, 0, dst0); - - power_tile_init(); - power_tile(dst0, p); - - if (p_is_negative) { - recip_tile_init(); - recip_tile(dst0); - } - - pack_tile(dst0, cb_xpow); - - cb_pop_front(cb_xabs, onetile); - cb_push_back(cb_xpow, onetile); - REL(); - - // Compute cb_correct_xpow - // |x|^p * exp(log(|x|) * decimal) - ACQ(); - cb_wait_front(cb_xpow, onetile); - cb_wait_front(cb_exp_lxmd, onetile); - cb_reserve_back(cb_correct_xpow, onetile); - - mul_tiles_init(); - mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0); - - pack_tile(dst0, cb_correct_xpow); - - cb_pop_front(cb_xpow, onetile); - cb_pop_front(cb_exp_lxmd, onetile); - cb_push_back(cb_correct_xpow, onetile); - REL(); + // |x + decimal|^p + power_tile_to_cb(cb_xabs, cb_xpow, cb_logx, cb_decimal, cb_exp_lxmd, cb_correct_xpow, p, p_is_negative); if (tile_idx == 0) { ACQ(); @@ -206,11 +123,6 @@ void MAIN { REL(); } } - cb_pop_front(cb_decimal, onetile); - cb_pop_front(cb_one, onetile); - if (do_mask_h || do_mask_w) { - cb_pop_front(cb_mask_h_w, 2); - } // Compute cb_y ACQ(); @@ -226,5 +138,12 @@ void MAIN { cb_pop_front(cb_xpowadd, onetile); cb_push_back(cb_y, onetile); REL(); + + cb_pop_front(cb_decimal, onetile); + cb_pop_front(cb_one, onetile); + if (do_mask_h || do_mask_w) { + cb_pop_front(cb_mask_h_w, 2); + } + } // void MAIN } // namespace NAMESPACE diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/reader_moreh_clip_grad_norm_step1.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/reader_moreh_clip_grad_norm_step1.cpp index 370fb3f5ebf..59bfe304d92 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/reader_moreh_clip_grad_norm_step1.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/reader_moreh_clip_grad_norm_step1.cpp @@ -2,10 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include - -#include "dataflow_api.h" -#include "tt_eager/tt_dnn/op_library/moreh_layernorm_backward/kernels/utils.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common.hpp" void kernel_main() { int i{0}; diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/moreh_clip_grad_norm_step2_kernel.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/moreh_clip_grad_norm_step2_kernel.cpp index fde93033908..bc09992705c 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/moreh_clip_grad_norm_step2_kernel.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/moreh_clip_grad_norm_step2_kernel.cpp @@ -2,16 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include - -#include "compute_kernel_api.h" -#include "compute_kernel_api/eltwise_binary.h" -#include "compute_kernel_api/eltwise_unary/exp.h" -#include "compute_kernel_api/eltwise_unary/recip.h" -#include "compute_kernel_api/tile_move_copy.h" - -ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } -ALWI void REL() { release_dst(tt::DstMode::Half); } +#include "tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common_ckernels.hpp" namespace NAMESPACE { void MAIN { @@ -79,80 +70,8 @@ void MAIN { } } - // Compute cb_xpow // x^p - ACQ(); - cb_wait_front(cb_x, onetile); - cb_reserve_back(cb_xpow, onetile); - - copy_tile_init(); - copy_tile(cb_x, 0, dst0); - - power_tile_init(); - power_tile(dst0, p); - - if (p_is_negative) { - recip_tile_init(); - recip_tile(dst0); - } - - pack_tile(dst0, cb_xpow); - - cb_push_back(cb_xpow, onetile); - REL(); - // We don't pop cb_x here. - - // Compute cb_logx - // log(x) - ACQ(); - cb_reserve_back(cb_logx, onetile); - - copy_tile_init(); - copy_tile(cb_x, 0, dst0); - - log_tile_init(); - log_tile(dst0); - - pack_tile(dst0, cb_logx); - - cb_pop_front(cb_x, onetile); - cb_push_back(cb_logx, onetile); - REL(); - - // Compute cb_exp_lxmd - // exp(log(x) * decimal) - ACQ(); - cb_wait_front(cb_logx, onetile); - cb_reserve_back(cb_exp_lxmd, onetile); - - mul_tiles_init(); - mul_tiles(cb_logx, cb_decimal, 0, 0, dst0); - - exp_tile_init(); - exp_tile(dst0); - - pack_tile(dst0, cb_exp_lxmd); - - cb_pop_front(cb_logx, onetile); - cb_pop_front(cb_decimal, onetile); - cb_push_back(cb_exp_lxmd, onetile); - REL(); - - // Compute cb_y - // x^p * exp(log(x) * decimal) - ACQ(); - cb_wait_front(cb_xpow, onetile); - cb_wait_front(cb_exp_lxmd, onetile); - cb_reserve_back(cb_y, onetile); - - mul_tiles_init(); - mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0); - - pack_tile(dst0, cb_y); + power_tile_to_cb(cb_x, cb_xpow, cb_logx, cb_decimal, cb_exp_lxmd, cb_y, p, p_is_negative); - cb_pop_front(cb_xpow, onetile); - cb_pop_front(cb_exp_lxmd, onetile); - cb_push_back(cb_y, onetile); - REL(); } // void MAIN } // namespace NAMESPACE diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/reader_moreh_clip_grad_norm_step2.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/reader_moreh_clip_grad_norm_step2.cpp index 1cb8a63c1bf..7475b8a22c0 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/reader_moreh_clip_grad_norm_step2.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/reader_moreh_clip_grad_norm_step2.cpp @@ -2,10 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include - -#include "dataflow_api.h" -#include "tt_eager/tt_dnn/op_library/moreh_layernorm_backward/kernels/utils.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common.hpp" void kernel_main() { int i{0}; diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/moreh_clip_grad_norm_step3_kernel.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/moreh_clip_grad_norm_step3_kernel.cpp index b4d01bbd86f..7ac1e659369 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/moreh_clip_grad_norm_step3_kernel.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/moreh_clip_grad_norm_step3_kernel.cpp @@ -2,14 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include - -#include "compute_kernel_api.h" -#include "compute_kernel_api/bcast.h" -#include "compute_kernel_api/eltwise_binary.h" - -ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } -ALWI void REL() { release_dst(tt::DstMode::Half); } +#include "tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common_ckernels.hpp" namespace NAMESPACE { void MAIN { diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/reader_moreh_clip_grad_norm_step3.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/reader_moreh_clip_grad_norm_step3.cpp index b34a4e8b665..584f9135102 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/reader_moreh_clip_grad_norm_step3.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/reader_moreh_clip_grad_norm_step3.cpp @@ -2,9 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include - -#include "dataflow_api.h" +#include "tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common.hpp" void kernel_main() { int i{0};