Skip to content

Commit

Permalink
#12815: update mish op (#14270)
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw authored Oct 30, 2024
1 parent 312d97b commit dec1dc2
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 32 deletions.
3 changes: 1 addition & 2 deletions tests/ttnn/unit_tests/operations/eltwise/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,7 @@ def test_unary_composite_log1p_ttnn(input_shapes, device):
),
)
def test_unary_composite_mish_ttnn(input_shapes, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)

in_data1, input_tensor1 = data_gen_with_range(input_shapes, -20, 100, device)
output_tensor = ttnn.mish(input_tensor1)
golden_function = ttnn.get_golden_function(ttnn.mish)
golden_tensor = golden_function(in_data1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,26 +315,19 @@ Tensor _log1p(const Tensor& x, const std::optional<MemoryConfig>& output_mem_con
return result_log1p;
}

// log[exp[x] + 1] => softplus[x]
// mish[x] = x*tanh[softplus[x]]
// use transformation y = x*tanh[softplus[x]] by broadcast
// Ref: https://krutikabapat.github.io/Swish-Vs-Mish-Latest-Activation-Functions/
Tensor _mish(const Tensor& x, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({x}))};
operation::launch_op(
[output_mem_config](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
const auto& x = input_tensors.at(0);
Tensor sp_x = ttnn::softplus(x, 1.0f, 20.0f, output_mem_config);
Tensor tanh_x = ttnn::tanh(sp_x, output_mem_config);
sp_x.deallocate();
Tensor mish_x = ttnn::multiply(x, tanh_x, std::nullopt, output_mem_config);
return {mish_x};
},
{x},
output_tensors);
return output_tensors.at(0);
Tensor ExecuteMish::invoke(const Tensor& x, const std::optional<MemoryConfig>& output_mem_config) {
using ttnn::operations::unary::UnaryWithParam;
using ttnn::operations::unary::UnaryOpType;
std::vector<UnaryWithParam> ops_chain = {
UnaryWithParam{UnaryOpType::EXP, 1.0f},
UnaryWithParam{UnaryOpType::ADD_UNARY_SFPU, 1.0f},
UnaryWithParam{UnaryOpType::LOG},
UnaryWithParam{UnaryOpType::TANH} };
Tensor result = ttnn::unary_chain(x, ops_chain, output_mem_config);
Tensor mish_x = ttnn::multiply(x, result, std::nullopt, output_mem_config);
return mish_x;
}

// multivariate log-gamma function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ enum class UnaryCompositeOpType {
DIGAMMA,
LGAMMA,
LOG1P,
MISH,
MULTIGAMMALN,
SINH,
SOFTSIGN,
Expand Down Expand Up @@ -66,7 +65,6 @@ Tensor _cosh(const Tensor&, const std::optional<MemoryConfig>&);
Tensor _digamma(const Tensor&, const std::optional<MemoryConfig>&);
Tensor _lgamma(const Tensor&, const std::optional<MemoryConfig>&);
Tensor _log1p(const Tensor&, const std::optional<MemoryConfig>&);
Tensor _mish(const Tensor&, const std::optional<MemoryConfig>&);
Tensor _multigammaln(const Tensor&, const std::optional<MemoryConfig>&);
Tensor _sinh(const Tensor&, const std::optional<MemoryConfig>&);
Tensor _softsign(const Tensor&, const std::optional<MemoryConfig>&);
Expand Down Expand Up @@ -192,13 +190,6 @@ struct OpHandler<UnaryCompositeOpType::LOG1P> {
}
};

template <>
struct OpHandler<UnaryCompositeOpType::MISH> {
static Tensor handle(const Tensor& t1, const std::optional<MemoryConfig>& mem_cfg ) {
return _mish(t1, mem_cfg);
}
};

template <>
struct OpHandler<UnaryCompositeOpType::MULTIGAMMALN> {
static Tensor handle(const Tensor& t1, const std::optional<MemoryConfig>& mem_cfg ) {
Expand Down
8 changes: 7 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ struct ExecuteRdiv {
std::optional<Tensor> optional_output_tensor = std::nullopt);
};

struct ExecuteMish {
static Tensor invoke(
const Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config = std::nullopt);
};

} // namespace unary
} // namespace operations

Expand Down Expand Up @@ -240,7 +246,7 @@ constexpr auto log1p = ttnn::register_operation_with_auto_launch_op<
operations::unary::ExecuteUnaryCompositeOp<operations::unary::UnaryCompositeOpType::LOG1P>>();
constexpr auto mish = ttnn::register_operation_with_auto_launch_op<
"ttnn::mish",
operations::unary::ExecuteUnaryCompositeOp<operations::unary::UnaryCompositeOpType::MISH>>();
operations::unary::ExecuteMish>();
constexpr auto multigammaln = ttnn::register_operation_with_auto_launch_op<
"ttnn::multigammaln",
operations::unary::ExecuteUnaryCompositeOp<operations::unary::UnaryCompositeOpType::MULTIGAMMALN>>();
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,7 @@ void py_module(py::module& module) {
+----------------------------+---------------------------------+-------------------+
)doc");
detail::bind_unary_composite(module, ttnn::mish, R"doc(Performs mish function on :attr:`input_tensor`, not supported for grayskull.)doc");
detail::bind_unary_composite(module, ttnn::mish, R"doc(Performs mish function on :attr:`input_tensor`, not supported for grayskull.)doc", "[supported range -20 to inf]");
detail::bind_unary_composite(module, ttnn::multigammaln, R"doc(Performs multigammaln function on :attr:`input_tensor`.)doc", "[supported range 1.6 to inf]",
R"doc(Supported dtypes, layouts, and ranks:
Expand Down

0 comments on commit dec1dc2

Please sign in to comment.