Skip to content

Commit

Permalink
#13967: Support for clip_bw (#14093)
Browse files Browse the repository at this point in the history
* #13967: Support for clip_bw

* #13967: Update

* #13967: Update code
  • Loading branch information
VirdhatchaniKN authored Nov 2, 2024
1 parent a2c5123 commit 7321dd7
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@ def test_unary_bw_clamp(input_shapes, min, max, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)
if min is None and max is None:
with pytest.raises(RuntimeError, match="Only one of 'min' or 'max' can be None. Please provide one value"):
ttnn.clamp_bw(grad_tensor, input_tensor, min=min, max=max)
assert True
pytest.xfail("Only one of 'min' or 'max' can be None. Please provide one value")
else:
tt_output_tensor_on_device = ttnn.clamp_bw(grad_tensor, input_tensor, min=min, max=max)
tt_output_tensor_on_device = ttnn.clamp_bw(grad_tensor, input_tensor, min, max)
golden_function = ttnn.get_golden_function(ttnn.clamp_bw)
golden_tensor = golden_function(grad_data, in_data, min, max)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"min, max",
[
(-10.0, 10.0),
(10.0, -10.0),
(1, -1),
(0, 0),
(-1.0, None),
(None, 1.0),
(None, None),
(-0.5, None),
(None, -0.5),
(1.0, 0.0),
(0.0, 1.0),
],
)
def test_unary_bw_clip(input_shapes, min, max, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)
if min is None and max is None:
pytest.xfail("Only one of 'min' or 'max' can be None. Please provide one value")
else:
tt_output_tensor_on_device = ttnn.clip_bw(grad_tensor, input_tensor, min, max)
golden_function = ttnn.get_golden_function(ttnn.clip_bw)
golden_tensor = golden_function(grad_data, in_data, min, max)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ std::vector<Tensor> ExecuteUnaryBackwardClamp::invoke(
return grad_tensor;
}

std::vector<Tensor> ExecuteUnaryBackwardClip::invoke(
const Tensor& grad, const Tensor& input, std::optional<float> min, std::optional<float> max, const std::optional<MemoryConfig>& output_mem_config) {
return ExecuteUnaryBackwardClamp::invoke(grad, input, min, max, output_mem_config);
}

// Hardtanh
// result: torch.where((input <= min) | (input >= max), 0.0, grad)
std::vector<Tensor> ExecuteUnaryBackwardHardtanh::invoke(
Expand Down
10 changes: 10 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,15 @@ struct ExecuteUnaryBackwardClamp {
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

struct ExecuteUnaryBackwardClip {
static std::vector<Tensor> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
std::optional<float> min = std::nullopt,
std::optional<float> max = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

struct ExecuteUnaryBackwardRdiv {
static std::vector<Tensor> invoke(
const Tensor &grad_tensor_arg,
Expand Down Expand Up @@ -669,6 +678,7 @@ constexpr auto hardsigmoid_bw = ttnn::register_operation<"ttnn::hardsigmoid_bw",
constexpr auto cos_bw = ttnn::register_operation<"ttnn::cos_bw", operations::unary_backward::ExecuteUnaryBackwardCos>();
constexpr auto acosh_bw = ttnn::register_operation<"ttnn::acosh_bw", operations::unary_backward::ExecuteUnaryBackwardAcosh>();
constexpr auto clamp_bw = ttnn::register_operation<"ttnn::clamp_bw", operations::unary_backward::ExecuteUnaryBackwardClamp>();
constexpr auto clip_bw = ttnn::register_operation<"ttnn::clip_bw", operations::unary_backward::ExecuteUnaryBackwardClip>();
constexpr auto rdiv_bw = ttnn::register_operation<"ttnn::rdiv_bw", operations::unary_backward::ExecuteUnaryBackwardRdiv>();
constexpr auto gelu_bw = ttnn::register_operation<"ttnn::gelu_bw", operations::unary_backward::ExecuteUnaryBackwardGelu>();
constexpr auto repeat_bw = ttnn::register_operation<"ttnn::repeat_bw", operations::unary_backward::ExecuteUnaryBackwardRepeat>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,9 @@ void bind_unary_backward_optional_float_params_with_default(
},
py::arg("grad_tensor"),
py::arg("input_tensor"),
py::kw_only(),
py::arg(parameter_name_a.c_str()) = parameter_a_value,
py::arg(parameter_name_b.c_str()) = parameter_b_value,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

Expand Down Expand Up @@ -1096,6 +1096,17 @@ void py_module(py::module& module) {
std::nullopt,
R"doc(Performs backward operations for clamp value on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`. Only one of 'min' or 'max' value can be None.)doc");

detail::bind_unary_backward_optional_float_params_with_default(
module,
ttnn::clip_bw,
"min",
"Minimum value",
std::nullopt,
"max",
"Maximum value",
std::nullopt,
R"doc(Performs backward operations for clip on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`. Only one of 'min' or 'max' value can be None.)doc");

detail::bind_unary_backward_two_float_with_default(
module,
ttnn::hardtanh_bw,
Expand Down
7 changes: 7 additions & 0 deletions ttnn/ttnn/operations/unary_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ def _golden_function_backward_with_reverse_string(
),
)

ttnn.attach_golden_function(
ttnn.clip_bw,
golden_function=lambda grad, input, a, b, *args, **kwargs: _golden_function_unary_backward_with_two_float(
torch.clamp, grad, input, a, b, *args, **kwargs
),
)


def _golden_function_abs_cmplx(grad_tensor, input_tensor, *args, **kwargs):
import torch
Expand Down

0 comments on commit 7321dd7

Please sign in to comment.