Skip to content

Commit

Permalink
#3662: Use power_tile_to_cb
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwan100 authored and seunghwan100 committed Jan 22, 2024
1 parent 7477d8f commit 7b9aae9
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 194 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>

#include "dataflow_api.h"

void fill_cb_with_value(uint32_t cb_id, uint32_t value) {
cb_reserve_back(cb_id, 1);
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id));
for (int j = 0; j < 1024; j++) {
ptr[j] = uint16_t(value >> 16);
}
cb_push_back(cb_id, 1);
}

void generate_mask_h_w(uint32_t cb_mask_h_w, uint32_t mask_h, uint32_t mask_w, uint32_t single_tile_size = 2048) {
union {
float f;
uint32_t u;
} one;
one.f = 1.0f;
union {
float f;
uint32_t u;
} zero;
zero.f = 0.0f;

const auto u16_one = uint16_t(one.u >> 16);
const auto u16_zero = uint16_t(zero.u >> 16);

cb_reserve_back(cb_mask_h_w, 2);

// mask_h
// first tile ptr
auto mask_h_ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_mask_h_w));
for (uint32_t w = 0; w < 16; w++) {
// sub tile 0
{
uint32_t mask_h_0 = mask_h;
if (mask_h_0 >= 16) {
mask_h_0 = 16;
}
uint32_t h = 0;
for (; h < mask_h_0; h++) {
mask_h_ptr[h * 16 + w] = u16_one;
}
for (; h < 16; h++) {
mask_h_ptr[h * 16 + w] = u16_zero;
}
}

// sub tile 1
{
uint32_t mask_h_0 = mask_h;
if (mask_h_0 >= 16) {
mask_h_0 = 16;
}
uint32_t h = 0;
for (; h < mask_h_0; h++) {
mask_h_ptr[h * 16 + w + 256] = u16_one;
}
for (; h < 16; h++) {
mask_h_ptr[h * 16 + w + 256] = u16_zero;
}
}

// sub tile 2
{
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16;
uint32_t h = 0;
for (; h < mask_h_1; h++) {
mask_h_ptr[h * 16 + w + 512] = u16_one;
}
for (; h < 16; h++) {
mask_h_ptr[h * 16 + w + 512] = u16_zero;
}
}

// sub tile 3
{
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16;
uint32_t h = 0;
for (; h < mask_h_1; h++) {
mask_h_ptr[h * 16 + w + 768] = u16_one;
}
for (; h < 16; h++) {
mask_h_ptr[h * 16 + w + 768] = u16_zero;
}
}
}

// mask_w
// second tile ptr
auto mask_w_ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_mask_h_w) + single_tile_size);
for (uint32_t h = 0; h < 16; h++) {
// sub tile 0
{
uint32_t mask_w_0 = mask_w;
if (mask_w_0 >= 16) {
mask_w_0 = 16;
}
uint32_t w = 0;
for (; w < mask_w_0; w++) {
mask_w_ptr[h * 16 + w] = u16_one;
}
for (; w < 16; w++) {
mask_w_ptr[h * 16 + w] = u16_zero;
}
}

// sub tile 1
{
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16;
uint32_t w = 0;
for (; w < mask_w_1; w++) {
mask_w_ptr[h * 16 + w + 256] = u16_one;
}
for (; w < 16; w++) {
mask_w_ptr[h * 16 + w + 256] = u16_zero;
}
}

// sub tile 2
{
uint32_t mask_w_0 = mask_w;
if (mask_w_0 >= 16) {
mask_w_0 = 16;
}
uint32_t w = 0;
for (; w < mask_w_0; w++) {
mask_w_ptr[h * 16 + w + 512] = u16_one;
}
for (; w < 16; w++) {
mask_w_ptr[h * 16 + w + 512] = u16_zero;
}
}

// sub tile 3
{
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16;
uint32_t w = 0;
for (; w < mask_w_1; w++) {
mask_w_ptr[h * 16 + w + 768] = u16_one;
}
for (; w < 16; w++) {
mask_w_ptr[h * 16 + w + 768] = u16_zero;
}
}
}

cb_push_back(cb_mask_h_w, 2);
}

void generate_mask_h_w_if_needed(uint32_t cb_mask_h_w, uint32_t origin_h, uint32_t origin_w) {
constexpr uint32_t TILE_H = 32;
constexpr uint32_t TILE_W = 32;

const bool do_mask_h = (origin_h % TILE_H) != 0;
const uint32_t mask_h = do_mask_h ? (origin_h % TILE_H) : TILE_H;

const bool do_mask_w = (origin_w % TILE_W) != 0;
const uint32_t mask_w = do_mask_w ? (origin_w % TILE_W) : TILE_W;

if (do_mask_h || do_mask_w) {
const uint32_t mask_tile_bytes = get_tile_size(cb_mask_h_w);
generate_mask_h_w(cb_mask_h_w, mask_h, mask_w, mask_tile_bytes);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
*
* SPDX-License-Identifier: Apache-2.0
*/

#include <cstdint>

#include "compute_kernel_api.h"
#include "compute_kernel_api/bcast.h"
#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/eltwise_unary/exp.h"
#include "compute_kernel_api/eltwise_unary/recip.h"
#include "compute_kernel_api/mask.h"
#include "compute_kernel_api/reduce.h"
#include "compute_kernel_api/tile_move_copy.h"

ALWI void ACQ() { acquire_dst(tt::DstMode::Half); }
ALWI void REL() { release_dst(tt::DstMode::Half); }

namespace ckernel {

ALWI void power_tile_to_cb(
std::uint8_t cb_x,
std::uint8_t cb_xpow,
std::uint8_t cb_logx,
std::uint8_t cb_decimal,
std::uint8_t cb_exp_lxmd,
std::uint8_t cb_correct_xpow,
uint32_t p,
bool p_is_negative) {
constexpr uint32_t onetile = 1;
constexpr uint32_t dst0 = 0;

// x^p
ACQ();
cb_wait_front(cb_x, onetile);
cb_reserve_back(cb_xpow, onetile);

copy_tile_init();
copy_tile(cb_x, 0, dst0);

power_tile_init();
power_tile(dst0, p);

if (p_is_negative) {
recip_tile_init();
recip_tile(dst0);
}

pack_tile(dst0, cb_xpow);

cb_push_back(cb_xpow, onetile);
REL();
// We don't pop cb_x here.

// log(x)
ACQ();
cb_reserve_back(cb_logx, onetile);

copy_tile_init();
copy_tile(cb_x, 0, dst0);

log_tile_init();
log_tile(dst0);

pack_tile(dst0, cb_logx);

cb_pop_front(cb_x, onetile);
cb_push_back(cb_logx, onetile);
REL();

// exp(log(x) * decimal)
ACQ();
cb_wait_front(cb_logx, onetile);
cb_reserve_back(cb_exp_lxmd, onetile);

mul_tiles_init();
mul_tiles(cb_logx, cb_decimal, 0, 0, dst0);

exp_tile_init();
exp_tile(dst0);

pack_tile(dst0, cb_exp_lxmd);

cb_pop_front(cb_logx, onetile);
cb_push_back(cb_exp_lxmd, onetile);
REL();

// x^p * exp(log(x) * decimal)(==(x + decimal)^p)
ACQ();
cb_wait_front(cb_xpow, onetile);
cb_wait_front(cb_exp_lxmd, onetile);
cb_reserve_back(cb_correct_xpow, onetile);

mul_tiles_init();
mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0);

pack_tile(dst0, cb_correct_xpow);

cb_pop_front(cb_xpow, onetile);
cb_pop_front(cb_exp_lxmd, onetile);
cb_push_back(cb_correct_xpow, onetile);
REL();
}

} // namespace ckernel
Loading

0 comments on commit 7b9aae9

Please sign in to comment.