Skip to content

Commit

Permalink
#8117: Move global_avg_pool2d to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt committed May 18, 2024
1 parent 87a78d6 commit a60d17c
Show file tree
Hide file tree
Showing 12 changed files with 136 additions and 36 deletions.
4 changes: 4 additions & 0 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "transformer.hpp"
#include "normalization.hpp"
#include "kv_cache.hpp"
#include "pool.hpp"

namespace py = pybind11;

Expand Down Expand Up @@ -50,6 +51,9 @@ void py_module(py::module& module) {

auto m_kv_cache = module.def_submodule("kv_cache", "KV cache operations");
kv_cache::py_module(m_kv_cache);

auto m_pool = module.def_submodule("pool", "pool operations");
pool::py_module(m_pool);
}

} // namespace operations
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/binary.hpp"
#include "ttnn/types.hpp"

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/ccl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/ccl.hpp"
#include "ttnn/types.hpp"

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/core.hpp"

namespace py = pybind11;
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/kv_cache.hpp"
#include "ttnn/types.hpp"

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/normalization.hpp"

namespace py = pybind11;
Expand Down
68 changes: 68 additions & 0 deletions ttnn/cpp/pybind11/operations/pool.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/pool.hpp"
#include "ttnn/types.hpp"

namespace py = pybind11;

namespace ttnn {
namespace operations {
namespace pool {

namespace detail {

void bind_global_avg_pool2d(py::module& module) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` by performing a 2D adaptive average pooling over an input signal composed of several input planes. This operation computes the average of all elements in each channel across the entire spatial dimensions.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor` (ttnn.Tensor): The input tensor to be pooled. Typically of shape (batch_size, channels, height, width).
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
* :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor
Returns:
ttnn.Tensor: The tensor with the averaged values. The output tensor shape is (batch_size, channels, 1, 1).
Example::
>>> tensor = ttnn.from_torch(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
ttnn::operations::pool::global_avg_pool2d.name(),
ttnn::operations::pool::global_avg_pool2d.python_fully_qualified_name());

bind_registered_operation(
module,
ttnn::operations::pool::global_avg_pool2d,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("dtype") = std::nullopt});
}

} // namespace detail

void py_module(py::module& module) {
detail::bind_global_avg_pool2d(module);
}

} // namespace pool
} // namespace operations
} // namespace ttnn
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/transformer.hpp"

namespace py = pybind11;
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/unary.hpp"
#include "ttnn/types.hpp"

Expand Down
47 changes: 47 additions & 0 deletions ttnn/cpp/ttnn/operations/pool.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/decorators.hpp"
#include "ttnn/operations/core.hpp"
#include "tt_eager/tt_dnn/op_library/pool/average_pool.hpp"

namespace ttnn {
namespace operations {
namespace pool {

namespace detail {
inline const std::array<ttnn::TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{
4, // min rank
4, // max rank
{ttnn::bfloat16, ttnn::bfloat8_b, ttnn::uint16, ttnn::uint32},
{ttnn::TILE_LAYOUT},
true, // can_be_on_device
false, // can_be_on_cpu
false, // can_be_scalar
false // is_optional}
}};
}
} // namespace details

struct GlobalAveragePool2D {
static const std::array<TensorSchema, 1> input_tensor_schemas() { return detail::input_tensor_schemas(); }

template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return std::make_tuple(input_tensor);
}

static Tensor execute(const Tensor& input, const std::optional<MemoryConfig>& memory_config_arg = std::nullopt, const std::optional<DataType>& output_dtype = std::nullopt) {
auto memory_config = memory_config_arg.value_or(input.memory_config());
auto result = tt::tt_metal::average_pool_2d(input, memory_config, output_dtype);
return result;
}
};
constexpr auto global_avg_pool2d = ttnn::register_operation<ttnn::operations::pool::GlobalAveragePool2D>("ttnn::pool::global_avg_pool2d");
} // namespace pool
} // namespace operations
} // namespace ttnn
2 changes: 1 addition & 1 deletion ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def manage_config(name, value):
from ttnn.operations import transformer
from ttnn.operations import kv_cache
from ttnn.operations.conv2d import Conv2d
from ttnn.operations.maxpool2d import (
from ttnn.operations.pool import (
MaxPool2d,
global_avg_pool2d,
)
37 changes: 9 additions & 28 deletions ttnn/ttnn/operations/maxpool2d.py → ttnn/ttnn/operations/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@

import tt_lib as ttl

import sys
import ttnn

from tt_eager.tt_dnn.op_library.sliding_window_op_infra.tt_py_max_pool import (
TTPyMaxPool,
SlidingWindowOpParams,
)

THIS_MODULE = sys.modules[__name__]

__all__ = []


class MaxPool2d:
r"""
Expand Down Expand Up @@ -117,7 +122,7 @@ def copy_output_from_device(self, output: ttnn.Tensor):
## Average Pooling


def _torch_global_avg_pool2d(input_tensor: ttnn.Tensor):
def _golden_function(input_tensor: ttnn.Tensor):
import torch

input_tensor = ttnn.from_device(input_tensor)
Expand All @@ -128,32 +133,8 @@ def _torch_global_avg_pool2d(input_tensor: ttnn.Tensor):
return torch.nn.functional.global_avg_pool2d(input_tensor, output_size)


def _global_avg_pool2d_validate_input_tensors(operation_name, input_tensor, *args, **kwargs):
ttnn.validate_input_tensor(
operation_name,
input_tensor,
ranks=(4,),
dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32),
layouts=(ttnn.TILE_LAYOUT,),
can_be_on_device=True,
can_be_on_cpu=False,
)


@ttnn.register_operation(
name="ttnn.global_avg_pool2d",
validate_input_tensors=_global_avg_pool2d_validate_input_tensors,
golden_function=_torch_global_avg_pool2d,
global_avg_pool2d = ttnn.register_operation(golden_function=_golden_function)(
ttnn._ttnn.operations.pool.global_avg_pool2d
)
def global_avg_pool2d(input_tensor: ttnn.Tensor, memory_config: ttnn.MemoryConfig = None) -> ttnn.Tensor:
r"""
Applies a 2D adaptive average pooling over an input signal composed of several input planes.

Arguments:
* :attr: input_tensor: the input tensor
"""
if memory_config is None:
output = ttl.tensor.average_pool_2d(input_tensor)
else:
output = ttl.tensor.average_pool_2d(input_tensor, memory_config)
return output
__all__ = []

0 comments on commit a60d17c

Please sign in to comment.