Skip to content

Commit

Permalink
#0: Handle 0 volume case in layout change code (#14578)
Browse files Browse the repository at this point in the history
#0: handle 0 volume case in layout change code
  • Loading branch information
ayerofieiev-tt authored Nov 1, 2024
1 parent 5c31e3e commit c42a534
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions tt_metal/common/test_tiles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ std::vector<T> convert_to_tile_layout(
std::optional<tt::stl::Span<const uint32_t>> face_shape = std::nullopt) {
ZoneScoped;
std::vector<T> result;
if(data.size() == 0) {
return result;
}

result.reserve(data.size());
auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT;
auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH;
Expand Down Expand Up @@ -86,6 +90,9 @@ std::vector<T> convert_to_flat_layout(
std::optional<tt::stl::Span<const uint32_t>> face_shape = std::nullopt) {
ZoneScoped;
std::vector<T> result;
if(data.size() == 0) {
return result;
}
result.reserve(data.size());
auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT;
auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH;
Expand Down Expand Up @@ -123,9 +130,13 @@ inline std::vector<T> untilize_nchw(const BufferType<T>& in, tt::stl::Span<const
auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT;
auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH;

std::vector<T> result;
if(in.size() == 0) {
return result;
}

TT_ASSERT(shape[shape.size() - 2] % tile_H == 0 && shape[shape.size() - 1] % tile_W == 0);

std::vector<T> result;
// Untilize into row major
int H = shape[shape.size() - 2], W = shape[shape.size() - 1];
auto batch_size = 1;
Expand Down Expand Up @@ -164,6 +175,11 @@ inline std::uint32_t round_up_to_tile(int val, int tile_val) { return (val + til
template <typename T, template <typename...> typename BufferType>
inline std::vector<T> tilize_nchw(const BufferType<T>& in_rowmajor, tt::stl::Span<const uint32_t> shape, std::optional<tt::stl::Span<const uint32_t>> tile_shape = std::nullopt) {
ZoneScoped;
std::vector<T> tilized_result;
if(in_rowmajor.size() == 0) {
return tilized_result;
}

int H = shape[shape.size() - 2], W = shape[shape.size() - 1];
auto batch_size = 1;
for (int i = 0; i < shape.size() - 2; i++) {
Expand All @@ -174,7 +190,6 @@ inline std::vector<T> tilize_nchw(const BufferType<T>& in_rowmajor, tt::stl::Spa
auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH;
int OH = round_up_to_tile(H, tile_H);
int OW = round_up_to_tile(W, tile_W);
std::vector<T> tilized_result;
tilized_result.resize(batch_size * OH * OW);
std::fill(tilized_result.begin(), tilized_result.end(), 0);
int out_index = 0;
Expand Down Expand Up @@ -230,6 +245,10 @@ inline std::vector<T> convert_layout(
std::optional<tt::stl::Span<const uint32_t>> tile_shape = std::nullopt,
std::optional<const tt::stl::Span<const uint32_t>> face_shape = std::nullopt) {
ZoneScoped;
if(inp.size() == 0) {
return std::vector<T>();
}

switch (inL) {
case tests::utils::TensorLayoutType::TILED_SWIZZLED:
if (outL == tests::utils::TensorLayoutType::TILED_NFACES) {
Expand Down

0 comments on commit c42a534

Please sign in to comment.