Skip to content

Commit

Permalink
#14316: Refactoring moreh_helper function (#14317)
Browse files Browse the repository at this point in the history
* #14316: refactoring moreh_helper
  • Loading branch information
hschoi4448 authored Nov 1, 2024
1 parent 3a24131 commit 5de2817
Show file tree
Hide file tree
Showing 81 changed files with 478 additions and 501 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ void MorehClipGradNormStep1::validate(
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors) const {
for (const auto &input : input_tensors) {
check_tensor(input, "moreh_clip_grad_norm_step1", "input");
ttnn::operations::check_tensor(input, "moreh_clip_grad_norm_step1", "input");
}

const auto &tmp_pow_sum = optional_input_tensors.at(0).value();
check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step1", "tmp_pow_sum");
ttnn::operations::check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step1", "tmp_pow_sum");
};

std::vector<ttnn::SimpleShape> MorehClipGradNormStep1::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }
Expand Down Expand Up @@ -99,10 +99,10 @@ void moreh_clip_grad_norm_step1(const std::vector<Tensor> &inputs, float norm_ty

void MorehClipGradNormStep2::validate(const std::vector<Tensor> &input_tensors) const {
const auto &tmp_pow_sum = input_tensors.at(0);
check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step2", "tmp_pow_sum");
ttnn::operations::check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step2", "tmp_pow_sum");

const auto &total_norm = input_tensors.at(1);
check_tensor(total_norm, "moreh_clip_grad_norm_step2", "total_norm");
ttnn::operations::check_tensor(total_norm, "moreh_clip_grad_norm_step2", "total_norm");
}

std::vector<ttnn::SimpleShape> MorehClipGradNormStep2::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }
Expand Down Expand Up @@ -139,11 +139,11 @@ void MorehClipGradNormStep3::validate(
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors) const {
for (const auto &input : input_tensors) {
check_tensor(input, "moreh_clip_grad_norm_step3", "input");
ttnn::operations::check_tensor(input, "moreh_clip_grad_norm_step3", "input");
}

const auto &clip_coef_clamped = optional_input_tensors.at(0).value();
check_tensor(clip_coef_clamped, "moreh_clip_grad_norm_step3", "clip_coef_clamped");
ttnn::operations::check_tensor(clip_coef_clamped, "moreh_clip_grad_norm_step3", "clip_coef_clamped");
}

std::vector<ttnn::SimpleShape> MorehClipGradNormStep3::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(

const auto cb_data_format = tt_metal::datatype_to_dataformat_converter(tmp_pow_sum.get_dtype());

CreateCircularBuffer(
ttnn::operations::CreateCircularBuffer(
program,
core_group_1,
cb_data_format,
Expand Down Expand Up @@ -112,8 +112,8 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/"
"writer_moreh_clip_grad_norm_step1.cpp";

const auto reader_kernels_id = CreateReadKernel(program, reader_kernel_file, core_group_1);
const auto writer_kernels_id = CreateWriteKernel(program, writer_kernel_file, core_group_1);
const auto reader_kernels_id = ttnn::operations::CreateReadKernel(program, reader_kernel_file, core_group_1);
const auto writer_kernels_id = ttnn::operations::CreateWriteKernel(program, writer_kernel_file, core_group_1);

////////////////////////////////////////////////////////////////////////////
// ComputeKernel SetUp
Expand All @@ -127,7 +127,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
"moreh_clip_grad_norm_step1_kernel.cpp";

const auto compute_kernels_id =
CreateComputeKernel(program, compute_kernel_file, {core_group_1, num_inputs_per_core_group_1}, compute_defines);
ttnn::operations::CreateComputeKernel(program, compute_kernel_file, {core_group_1, num_inputs_per_core_group_1}, compute_defines);

////////////////////////////////////////////////////////////////////////////
// RuntimeArgs SetUp
Expand All @@ -146,7 +146,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
// reader
const std::array reader_runtime_args{
input_addr,
static_cast<uint32_t>(is_dram(input)),
static_cast<uint32_t>(ttnn::operations::is_dram(input)),
num_tiles,
*reinterpret_cast<uint32_t*>(&decimal),
origin_h,
Expand All @@ -155,7 +155,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(

// writer
const std::array writer_runtime_args{
output_addr, static_cast<uint32_t>(is_dram(tmp_pow_sum)), tile_offset};
output_addr, static_cast<uint32_t>(ttnn::operations::is_dram(tmp_pow_sum)), tile_offset};
SetRuntimeArgs(program, writer_kernels_id, core, writer_runtime_args);

// compute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(

const auto cb_data_format = tt_metal::datatype_to_dataformat_converter(total_norm.get_dtype());

CreateCircularBuffer(
ttnn::operations::CreateCircularBuffer(
program,
single_core,
cb_data_format,
Expand All @@ -82,8 +82,8 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/"
"writer_moreh_clip_grad_norm_step2.cpp";

const auto reader_kernels_id = CreateReadKernel(program, reader_kernel_file, single_core);
const auto writer_kernels_id = CreateWriteKernel(program, writer_kernel_file, single_core);
const auto reader_kernels_id = ttnn::operations::CreateReadKernel(program, reader_kernel_file, single_core);
const auto writer_kernels_id = ttnn::operations::CreateWriteKernel(program, writer_kernel_file, single_core);

////////////////////////////////////////////////////////////////////////////
// ComputeKernel SetUp
Expand All @@ -92,7 +92,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/"
"moreh_clip_grad_norm_step2_kernel.cpp";

const auto compute_kernels_id = CreateComputeKernel(program, compute_kernel_file, {single_core, num_tiles});
const auto compute_kernels_id = ttnn::operations::CreateComputeKernel(program, compute_kernel_file, {single_core, num_tiles});

////////////////////////////////////////////////////////////////////////////
// RuntimeArgs SetUp
Expand All @@ -102,11 +102,11 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(

// reader
const std::array reader_runtime_args{
input_addr, static_cast<uint32_t>(is_dram(tmp_pow_sum)), num_tiles, *reinterpret_cast<uint32_t*>(&decimal)};
input_addr, static_cast<uint32_t>(ttnn::operations::is_dram(tmp_pow_sum)), num_tiles, *reinterpret_cast<uint32_t*>(&decimal)};
SetRuntimeArgs(program, reader_kernels_id, single_core, reader_runtime_args);

// writer
const std::array writer_runtime_args{output_addr, static_cast<uint32_t>(is_dram(total_norm))};
const std::array writer_runtime_args{output_addr, static_cast<uint32_t>(ttnn::operations::is_dram(total_norm))};
SetRuntimeArgs(program, writer_kernels_id, single_core, writer_runtime_args);

// compute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(

const auto cb_data_format = tt_metal::datatype_to_dataformat_converter(inputs.at(0).get_dtype());

CreateCircularBuffer(
ttnn::operations::CreateCircularBuffer(
program,
core_group_1,
cb_data_format,
Expand All @@ -82,8 +82,8 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/"
"writer_moreh_clip_grad_norm_step3.cpp";

const auto reader_kernels_id = CreateReadKernel(program, reader_kernel_file, core_group_1);
const auto writer_kernels_id = CreateWriteKernel(program, writer_kernel_file, core_group_1);
const auto reader_kernels_id = ttnn::operations::CreateReadKernel(program, reader_kernel_file, core_group_1);
const auto writer_kernels_id = ttnn::operations::CreateWriteKernel(program, writer_kernel_file, core_group_1);

////////////////////////////////////////////////////////////////////////////
// ComputeKernel SetUp
Expand All @@ -93,7 +93,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
"moreh_clip_grad_norm_step3_kernel.cpp";

const auto compute_kernels_id =
CreateComputeKernel(program, compute_kernel_file, {core_group_1, num_inputs_per_core_group_1});
ttnn::operations::CreateComputeKernel(program, compute_kernel_file, {core_group_1, num_inputs_per_core_group_1});

////////////////////////////////////////////////////////////////////////////
// RuntimeArgs SetUp
Expand All @@ -109,14 +109,14 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
// reader
const std::array reader_runtime_args{
input_addr,
static_cast<uint32_t>(is_dram(input)),
static_cast<uint32_t>(ttnn::operations::is_dram(input)),
clip_coef_clamped_addr,
static_cast<uint32_t>(is_dram(clip_coef_clamped)),
static_cast<uint32_t>(ttnn::operations::is_dram(clip_coef_clamped)),
num_tiles};
SetRuntimeArgs(program, reader_kernels_id, core, reader_runtime_args);

// writer
const std::array writer_runtime_args{input_addr, static_cast<uint32_t>(is_dram(input)), num_tiles};
const std::array writer_runtime_args{input_addr, static_cast<uint32_t>(ttnn::operations::is_dram(input)), num_tiles};
SetRuntimeArgs(program, writer_kernels_id, core, writer_runtime_args);

// compute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ void FastReduceNCDeviceOperation::validate_with_output_tensors(
auto& output = output_tensors.at(0);

// validate tensor
tt::operations::primary::check_tensor(input, "FastReduceNC", "input", {DataType::BFLOAT16, DataType::BFLOAT8_B});
tt::operations::primary::check_tensor(output, "FastReduceNC", "output", {DataType::BFLOAT16, DataType::BFLOAT8_B});
check_tensor(input, "FastReduceNC", "input", {DataType::BFLOAT16, DataType::BFLOAT8_B});
check_tensor(output, "FastReduceNC", "output", {DataType::BFLOAT16, DataType::BFLOAT8_B});

// validate input dim
const auto input_rank = input.get_logical_shape().rank();
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/operations/full/device/full_program_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ FullOperation::ProgramFactory::cached_program_t FullOperation::ProgramFactory::c
auto grid = tensor_args.any.device()->compute_with_storage_grid_size();
auto num_tiles = output.volume() / TILE_HW;
auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
split_work_to_cores(grid, num_tiles);
tt::tt_metal::split_work_to_cores(grid, num_tiles);

tt::DataFormat data_format = tt::tt_metal::datatype_to_dataformat_converter(dtype);
uint32_t single_tile_size = tt::tt_metal::detail::TileSize(data_format);
Expand All @@ -40,7 +40,7 @@ FullOperation::ProgramFactory::cached_program_t FullOperation::ProgramFactory::c

// Create circular buffer
auto cb_index = tt::CB::c_intermed0;
tt::operations::primary::CreateCircularBuffer(
CreateCircularBuffer(
program,
all_cores,
data_format,
Expand All @@ -57,7 +57,7 @@ FullOperation::ProgramFactory::cached_program_t FullOperation::ProgramFactory::c
default: break;
}

auto writer_id = tt::operations::primary::CreateWriteKernel(
auto writer_id = CreateWriteKernel(
program,
"ttnn/cpp/ttnn/operations/full/device/kernels/writer_full.cpp",
all_cores,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,31 @@ void MorehAdamOperation::validate_inputs(
auto& exp_avg_in = tensor_args.exp_avg_in;
auto& exp_avg_sq_in = tensor_args.exp_avg_sq_in;

tt::operations::primary::check_tensor(params_in, "moreh_adam", "params_in");
tt::operations::primary::check_tensor(grad, "moreh_adam", "grad");
tt::operations::primary::check_tensor(exp_avg_in, "moreh_adam", "exp_avg_in");
tt::operations::primary::check_tensor(exp_avg_sq_in, "moreh_adam", "exp_avg_sq_in");
check_tensor(params_in, "moreh_adam", "params_in");
check_tensor(grad, "moreh_adam", "grad");
check_tensor(exp_avg_in, "moreh_adam", "exp_avg_in");
check_tensor(exp_avg_sq_in, "moreh_adam", "exp_avg_sq_in");

if (tensor_args.max_exp_avg_sq_in) {
tt::operations::primary::check_tensor(*tensor_args.max_exp_avg_sq_in, "moreh_adam", "max_exp_avg_sq_in");
check_tensor(*tensor_args.max_exp_avg_sq_in, "moreh_adam", "max_exp_avg_sq_in");
}

const auto& params_out = tensor_args.output_tensors.at(0);

if (params_out.has_value()) {
tt::operations::primary::check_tensor(params_out.value(), "moreh_adam", "params_out");
check_tensor(params_out.value(), "moreh_adam", "params_out");
}

if (tensor_args.output_tensors.at(1).has_value()) {
tt::operations::primary::check_tensor(tensor_args.output_tensors.at(1).value(), "moreh_adam", "exp_avg_out");
check_tensor(tensor_args.output_tensors.at(1).value(), "moreh_adam", "exp_avg_out");
}

if (tensor_args.output_tensors.at(2).has_value()) {
tt::operations::primary::check_tensor(tensor_args.output_tensors.at(2).value(), "moreh_adam", "exp_avg_sq_out");
check_tensor(tensor_args.output_tensors.at(2).value(), "moreh_adam", "exp_avg_sq_out");
}

if (tensor_args.output_tensors.at(3).has_value()) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.output_tensors.at(3).value(), "moreh_adam", "max_exp_avg_sq_out");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
////////////////////////////////////////////////////////////////////////////
auto data_format = tt::tt_metal::datatype_to_dataformat_converter(param_in.get_dtype());
auto intermed_cb_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format;
tt::operations::primary::CreateCircularBuffer(
CreateCircularBuffer(
program,
all_cores,
data_format,
Expand Down Expand Up @@ -94,17 +94,17 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
////////////////////////////////////////////////////////////////////////////

const std::vector<uint32_t> reader_compile_time_args{
static_cast<uint32_t>(tt::operations::primary::is_dram(param_in)),
static_cast<uint32_t>(tt::operations::primary::is_dram(grad)),
static_cast<uint32_t>(tt::operations::primary::is_dram(exp_avg_in)),
static_cast<uint32_t>(tt::operations::primary::is_dram(exp_avg_sq_in)),
static_cast<uint32_t>(tt::operations::primary::is_dram(max_exp_avg_sq_in))};
static_cast<uint32_t>(is_dram(param_in)),
static_cast<uint32_t>(is_dram(grad)),
static_cast<uint32_t>(is_dram(exp_avg_in)),
static_cast<uint32_t>(is_dram(exp_avg_sq_in)),
static_cast<uint32_t>(is_dram(max_exp_avg_sq_in))};

const std::vector<uint32_t> writer_compile_time_args{
static_cast<uint32_t>(tt::operations::primary::is_dram(param_out)),
static_cast<uint32_t>(tt::operations::primary::is_dram(exp_avg_out)),
static_cast<uint32_t>(tt::operations::primary::is_dram(exp_avg_sq_out)),
static_cast<uint32_t>(tt::operations::primary::is_dram(max_exp_avg_sq_out.value()))};
static_cast<uint32_t>(is_dram(param_out)),
static_cast<uint32_t>(is_dram(exp_avg_out)),
static_cast<uint32_t>(is_dram(exp_avg_sq_out)),
static_cast<uint32_t>(is_dram(max_exp_avg_sq_out.value()))};

const auto reader_kernel_file =
"ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/kernels/"
Expand All @@ -120,9 +120,9 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
if (fp32_dest_acc_en) {
data_movement_defines["FP32_DEST_ACC_EN"] = "1";
}
const auto reader_kernel_id = tt::operations::primary::CreateReadKernel(
const auto reader_kernel_id = CreateReadKernel(
program, reader_kernel_file, all_cores, reader_compile_time_args, data_movement_defines);
const auto writer_kernel_id = tt::operations::primary::CreateWriteKernel(
const auto writer_kernel_id = CreateWriteKernel(
program, writer_kernel_file, all_cores, writer_compile_time_args, data_movement_defines);

////////////////////////////////////////////////////////////////////////////
Expand All @@ -143,7 +143,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
"ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/kernels/"
"moreh_adam.cpp";

auto compute_kernel_1_id = tt ::operations::primary::CreateComputeKernel(
auto compute_kernel_1_id = CreateComputeKernel(
program,
compute_kernel_file,
{core_group_1, num_tiles_per_core_group_1, compute_args_group_1},
Expand All @@ -155,7 +155,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
if (!core_group_2.ranges().empty()) {
const std::vector<uint32_t> compute_args_group_2{num_tiles_per_core_group_2};

compute_kernel_2_id = tt::operations::primary::CreateComputeKernel(
compute_kernel_2_id = CreateComputeKernel(
program,
compute_kernel_file,
{core_group_2, num_tiles_per_core_group_2, compute_args_group_2},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,40 @@ MorehAdamWDeviceOperation::program_factory_t MorehAdamWDeviceOperation::select_p

void MorehAdamWDeviceOperation::validate_inputs(
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.param_in, "moreh_adamw", "param_in", {DataType::BFLOAT16, DataType::BFLOAT8_B});
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.grad, "moreh_adamw", "grad", {DataType::BFLOAT16, DataType::BFLOAT8_B});
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.exp_avg_in, "moreh_adamw", "exp_avg_in", {DataType::BFLOAT16, DataType::BFLOAT8_B});
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.exp_avg_sq_in, "moreh_adamw", "exp_avg_sq_in", {DataType::BFLOAT16, DataType::BFLOAT8_B});

if (tensor_args.max_exp_avg_sq_in.has_value()) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.max_exp_avg_sq_in.value(),
"moreh_adamw",
"max_exp_avg_sq_in",
{DataType::BFLOAT16, DataType::BFLOAT8_B});
}

if (tensor_args.param_out.has_value()) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.param_out.value(), "moreh_adamw", "param_out", {DataType::BFLOAT16, DataType::BFLOAT8_B});
}
if (tensor_args.exp_avg_out.has_value()) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.exp_avg_out.value(), "moreh_adamw", "exp_avg_out", {DataType::BFLOAT16, DataType::BFLOAT8_B});
}
if (tensor_args.exp_avg_sq_out.has_value()) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.exp_avg_sq_out.value(),
"moreh_adamw",
"exp_avg_sq_out",
{DataType::BFLOAT16, DataType::BFLOAT8_B});
}
if (tensor_args.max_exp_avg_sq_out.has_value()) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.max_exp_avg_sq_out.value(),
"moreh_adamw",
"max_exp_avg_sq_out",
Expand Down
Loading

0 comments on commit 5de2817

Please sign in to comment.