From 80fc3f83e7436d36ba5b3bb41824279709dd28a3 Mon Sep 17 00:00:00 2001 From: KalaivaniMCW Date: Mon, 4 Nov 2024 18:01:12 +0000 Subject: [PATCH] #14466: cleanup unary composite --- .../operations/eltwise/test_ternary.py | 14 ++++++------- .../eltwise/ternary/ternary_composite_op.cpp | 8 ++----- .../unary/device/unary_composite_op.cpp | 21 +++++++------------ 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_ternary.py b/tests/ttnn/unit_tests/operations/eltwise/test_ternary.py index b81eed40ded..10397e95c2f 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_ternary.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_ternary.py @@ -11,10 +11,6 @@ from tests.ttnn.utils_for_testing import assert_with_pcc -def torch_mac(input, tensor1, tensor2): - return torch.add(torch.mul(input, tensor1), tensor2) - - @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) def test_mac_all_tensors(device, h, w): @@ -23,7 +19,9 @@ def test_mac_all_tensors(device, h, w): torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) torch_input_tensor1 = torch.rand((h, w), dtype=torch.bfloat16) torch_input_tensor2 = torch.rand((h, w), dtype=torch.bfloat16) - torch_output_tensor = torch_mac(torch_input_tensor, torch_input_tensor1, torch_input_tensor2) + + golden_fn = ttnn.get_golden_function(ttnn.mac) + torch_output_tensor = golden_fn(torch_input_tensor, torch_input_tensor1, torch_input_tensor2) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = ttnn.to_device(input_tensor, device) @@ -49,9 +47,9 @@ def test_mac_tensor_with_2_scalaras(device, h, w, scalar1, scalar2): torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16) torch_input_tensor1 = scalar1 torch_input_tensor2 = scalar2 - torch_output_tensor = torch.unsqueeze( - torch.unsqueeze(torch_mac(torch_input_tensor, torch_input_tensor1, torch_input_tensor2), 0), 0 - ) + + golden_fn = ttnn.get_golden_function(ttnn.mac) + torch_output_tensor = golden_fn(torch_input_tensor, torch_input_tensor1, torch_input_tensor2) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) input_tensor = ttnn.to_device(input_tensor, device) diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp index cf41c7f0e80..65887f3cd80 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp @@ -106,13 +106,9 @@ Tensor _mac(const Tensor& a, const Tensor& b, const Tensor& c, const std::option return ttnn::add(ttnn::multiply(a, b), c); } +// y = a * b + c Tensor _mac_overload(const Tensor& a, float b, float c, const std::optional& output_mem_config) { - Tensor t_b = ttnn::operations::creation::create_scalar(b, a.get_dtype(), Layout::TILE, a.device()); - Tensor t_c = ttnn::operations::creation::create_scalar(c, a.get_dtype(), Layout::TILE, a.device()); - Tensor return_tensor = _mac(a, t_b, t_c, output_mem_config); - t_b.deallocate(); - t_c.deallocate(); - return return_tensor; + return ttnn::add(ttnn::multiply(a, b, std::nullopt, output_mem_config), c, std::nullopt, output_mem_config); } } // namespace ttnn::operations::ternary diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index 03b6435ee12..bf0c1b2ece8 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -173,8 +173,7 @@ Tensor _cosh(const Tensor& input_a, const std::optional& output_me Tensor nr_term = ttnn::add(e_pos_x, e_neg_x, std::nullopt, output_mem_config); e_pos_x.deallocate(); e_neg_x.deallocate(); - Tensor scalar = ttnn::full_like(input_a, 0.5f); - return ttnn::multiply(nr_term, scalar, std::nullopt, output_mem_config); + return ttnn::multiply(nr_term, 0.5f, std::nullopt, output_mem_config); } // TODO: In future will uplift the op once the floor and tan has supported. @@ -294,12 +293,10 @@ Tensor _lgamma(const Tensor& x, const std::optional& output_mem_c result = ttnn::subtract(result, t, std::nullopt, output_mem_config); { { - Tensor t_one = ttnn::full_like(x, 1.0f); - result = ttnn::where(ttnn::eq(x, t_one, std::nullopt, output_mem_config), 0.0f, result); + result = ttnn::where(ttnn::eq(x, 1.0f, std::nullopt, output_mem_config), 0.0f, result); } { - Tensor t_two = ttnn::full_like(x, 2.0f); - result = ttnn::where(ttnn::eq(x, t_two, std::nullopt, output_mem_config), 0.0f, result); + result = ttnn::where(ttnn::eq(x, 2.0f, std::nullopt, output_mem_config), 0.0f, result); } } } @@ -309,8 +306,7 @@ Tensor _lgamma(const Tensor& x, const std::optional& output_mem_c // log1p 1 // use transformation y = log(1.0 + x) by broadcast Tensor _log1p(const Tensor& x, const std::optional& output_mem_config) { - Tensor t_one = ttnn::full_like(x, 1.0f); - Tensor x_1 = ttnn::add(t_one, x, std::nullopt, output_mem_config); + Tensor x_1 = ttnn::add(x, 1.0f, std::nullopt, output_mem_config); Tensor result_log1p = ttnn::log(x_1, output_mem_config); return result_log1p; } @@ -350,8 +346,7 @@ Tensor _sinh(const Tensor& input_a, const std::optional& output_me Tensor nr_term = ttnn::subtract(e_pos_x, e_neg_x, std::nullopt, output_mem_config); e_pos_x.deallocate(); e_neg_x.deallocate(); - Tensor scalar = ttnn::full_like(input_a, 0.5f); - return ttnn::multiply(nr_term, scalar, std::nullopt, output_mem_config); + return ttnn::multiply(nr_term, 0.5f, std::nullopt, output_mem_config); } // Function: softsign @@ -372,8 +367,8 @@ Tensor _swish(const Tensor& a, const std::optional& output_mem_con Tensor ExecuteTrunc::invoke(uint8_t queue_id, const Tensor& input, const std::optional& output_mem_config, std::optional output_tensor) { auto arch = input.device()->arch(); - output_tensor = output_tensor.value_or(ttnn::empty_like(input)); TT_FATAL(arch != tt::ARCH::GRAYSKULL, "Op is not supported on Grayskull"); + output_tensor = output_tensor.value_or(ttnn::empty_like(input)); Tensor floor_res = ttnn::floor(queue_id, input, output_mem_config); ttnn::where(queue_id, ttnn::ne(queue_id, input, floor_res), ttnn::add(queue_id, floor_res, 1.0f, std::nullopt, output_mem_config), floor_res, output_mem_config, output_tensor); ttnn::where(queue_id, ttnn::gtz(queue_id, input, output_mem_config), floor_res, output_tensor.value(), output_mem_config, output_tensor); @@ -449,9 +444,7 @@ Tensor _normalize(const Tensor& y, const std::optional& output_mem // PyTorch version: // hard sigmoid(x) = { x <= -3: 0, x >= +3: +3, x/6 + 0.5 otherwise} Tensor _hardsigmoid(const Tensor& a, float value_1, float value_2, const std::optional& output_mem_config) { - Tensor a_t = ttnn::full_like(a,value_1); - Tensor b_t = ttnn::full_like(a,value_2); - Tensor a_mac = ttnn::mac(a, a_t, b_t); // multiply and add. + Tensor a_mac = ttnn::mac(a, value_1, value_2); // multiply and add. Tensor a_clip = relu_max(a_mac, 1.0f); return a_clip; }