Skip to content

Commit

Permalink
fast reduction for reducemean (#8976)
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang authored Sep 8, 2021
1 parent 1c872f9 commit b7b42e0
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ ApplicableMatrixReduction get_applicable_matrix_reduction(
const cudnnReduceTensorOp_t cudnn_reduce_op,
const std::vector<int64_t>& dims, const std::vector<int64_t>& original_axes,
int& m_out, int& n_out) {
if (cudnn_reduce_op != CUDNN_REDUCE_TENSOR_ADD) {
if (cudnn_reduce_op != CUDNN_REDUCE_TENSOR_ADD && cudnn_reduce_op != CUDNN_REDUCE_TENSOR_AVG) {
return ApplicableMatrixReduction::None;
}

Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/core/providers/cuda/reduction/reduction_functions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/providers/cuda/reduction/reduction_utils.cuh"
#include "core/providers/cuda/cu_inc/unary_elementwise_impl.cuh"

namespace onnxruntime {
namespace cuda {
Expand Down Expand Up @@ -458,6 +459,32 @@ Status call_reduce_matrix_rows(cudaStream_t stream, const TIn* input, TOut* outp
}
} // namespace detail

template <typename T>
struct OP_Div {
__device__ __inline__ T operator()(const T& a) const {
return a / v_;
}

OP_Div(T v) : v_(v) {}

T v_;
};

template <typename T>
void UnaryDiv(cudaStream_t stream, const T* input, T* output, T denominator, size_t count) {
UnaryElementWiseImpl(stream, input, output, OP_Div<T>(denominator), count);
}

#define INSTANTIATE_UNARY_DIV(T) \
template void UnaryDiv<T>(cudaStream_t stream, const T* input, T* output, T denominator, size_t count)
INSTANTIATE_UNARY_DIV(half);
INSTANTIATE_UNARY_DIV(float);
INSTANTIATE_UNARY_DIV(double);
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
INSTANTIATE_UNARY_DIV(nv_bfloat16);
#endif
#undef INSTANTIATE_UNARY_DIV

template <typename TIn, typename TOut>
Status reduce_matrix_rows(cudaStream_t stream, const TIn* input, TOut* output, int m, int n, bool reset_initial_output) {
using TBuf = AccumulationType_t<TIn>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,9 @@ Status reduce_matrix_rows(cudaStream_t stream, const TIn* input, TOut* output, i
template <typename TIn, typename TOut>
Status reduce_matrix_columns(cudaStream_t stream, const TIn* input, TOut* output, int m, int n, void* buffer, size_t buffer_size);

/** Apply unary elementwise division. */
template <typename T>
void UnaryDiv(cudaStream_t stream, const T* input, T* output, T denominator, size_t count);

} // namespace cuda
} // namespace onnxruntime
63 changes: 44 additions & 19 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,27 +455,52 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
// Block of fast matrix reduction.
if (fast_reduction) {
int m{}, n{};
const auto applicable_matrix_reduction = get_applicable_matrix_reduction(
cudnn_reduce_op, input_shape.GetDims(), axes, m, n);
switch (applicable_matrix_reduction) {
case ApplicableMatrixReduction::Rows: {
return reduce_matrix_rows(
stream,
reinterpret_cast<const CudaT*>(input.template Data<T>()),
reinterpret_cast<CudaT*>(output.template MutableData<T>()),
m, n);
const auto applicable_matrix_reduction =
get_applicable_matrix_reduction(cudnn_reduce_op, input_shape.GetDims(), axes, m, n);
if (applicable_matrix_reduction != ApplicableMatrixReduction::None) {
IAllocatorUniquePtr<T> input_data_buffer(nullptr, [](T*) {});
const CudaT* input_data = reinterpret_cast<const CudaT*>(input.template Data<T>());
if (calculate_sqt) {
input_data_buffer = cuda_ep.GetScratchBuffer<T>(input_count);
input_data = reinterpret_cast<CudaT*>(input_data_buffer.get());
fast_divmod tmp_div;
Impl_Mul<CudaT>(stream, static_cast<int32_t>(SimpleBroadcast::NoBroadcast), nullptr,
reinterpret_cast<const CudaT*>(input.template Data<T>()), nullptr,
reinterpret_cast<const CudaT*>(input.template Data<T>()), nullptr, tmp_div, tmp_div,
reinterpret_cast<CudaT*>(input_data_buffer.get()), input_count);
input_data = reinterpret_cast<const CudaT*>(input_data_buffer.get());
}
case ApplicableMatrixReduction::Columns: {
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<CudaT>(m, n);
auto buffer = cuda_ep.GetScratchBuffer<void>(buffer_size_bytes);
return reduce_matrix_columns(
stream,
reinterpret_cast<const CudaT*>(input.template Data<T>()),
reinterpret_cast<CudaT*>(output.template MutableData<T>()),
m, n, buffer.get(), buffer_size_bytes);

switch (applicable_matrix_reduction) {
case ApplicableMatrixReduction::Rows: {
ORT_RETURN_IF_ERROR(reduce_matrix_rows(
stream, input_data, reinterpret_cast<CudaT*>(output.template MutableData<T>()), m, n));
} break;
case ApplicableMatrixReduction::Columns: {
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<CudaT>(m, n);
auto buffer = cuda_ep.GetScratchBuffer<void>(buffer_size_bytes);
ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data,
reinterpret_cast<CudaT*>(output.template MutableData<T>()), m, n,
buffer.get(), buffer_size_bytes));
} break;
default: {
ORT_ENFORCE(false, "Invild matrix reduction type.");
}
}
default:
break;

if (calculate_log) {
Impl_Log<CudaT>(stream, reinterpret_cast<const CudaT*>(output.template Data<T>()),
reinterpret_cast<CudaT*>(output.template MutableData<T>()), output_count);
} else if (cudnn_reduce_op == CUDNN_REDUCE_TENSOR_AVG) {
float denominator_float = applicable_matrix_reduction == ApplicableMatrixReduction::Rows
? static_cast<float>(m)
: static_cast<float>(n);
CudaT denominator = ToCudaType<T>::FromFloat(denominator_float);
UnaryDiv(stream, reinterpret_cast<const CudaT*>(output.template Data<T>()),
reinterpret_cast<CudaT*>(output.template MutableData<T>()), denominator, output_count);
}

return Status::OK();
}
}

Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/cuda/reduction/reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ class ReduceMax final : public ReduceKernel<true> {
template <typename T>
class ReduceMean final : public ReduceKernel<true> {
public:
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info) {}
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info) {
fast_reduction_ = true;
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T>(ctx, CUDNN_REDUCE_TENSOR_AVG);
Expand Down Expand Up @@ -182,6 +184,7 @@ class ReduceLogSum final : public ReduceKernel<true> {
public:
ReduceLogSum(const OpKernelInfo& info) : ReduceKernel<true>(info) {
ReduceKernel<true>::calculate_log_ = true;
fast_reduction_ = true;
}

Status ComputeInternal(OpKernelContext* ctx) const override {
Expand All @@ -194,6 +197,7 @@ class ReduceSumSquare final : public ReduceKernel<true> {
public:
ReduceSumSquare(const OpKernelInfo& info) : ReduceKernel<true>(info) {
ReduceKernel<true>::calculate_sqt_ = true;
fast_reduction_ = true;
}

Status ComputeInternal(OpKernelContext* ctx) const override {
Expand Down
63 changes: 44 additions & 19 deletions onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,27 +445,52 @@ Status ReduceComputeCore(ROCMExecutionProvider& rocm_ep, const Tensor& input, Pr
// Block of fast matrix reduction.
if (fast_reduction) {
int m{}, n{};
const auto applicable_matrix_reduction = get_applicable_matrix_reduction(
miopen_reduce_op, input_shape.GetDims(), axes, m, n);
switch (applicable_matrix_reduction) {
case ApplicableMatrixReduction::Rows: {
return reduce_matrix_rows(
stream,
reinterpret_cast<const HipT*>(input.template Data<T>()),
reinterpret_cast<HipT*>(output.template MutableData<T>()),
m, n);
const auto applicable_matrix_reduction =
get_applicable_matrix_reduction(miopen_reduce_op, input_shape.GetDims(), axes, m, n);
if (applicable_matrix_reduction != ApplicableMatrixReduction::None) {
IAllocatorUniquePtr<T> input_data_buffer(nullptr, [](T*) {});
const HipT* input_data = reinterpret_cast<const HipT*>(input.template Data<T>());
if (calculate_sqt) {
input_data_buffer = rocm_ep.GetScratchBuffer<T>(input_count);
input_data = reinterpret_cast<HipT*>(input_data_buffer.get());
fast_divmod tmp_div;
Impl_Mul<HipT>(stream, static_cast<int32_t>(SimpleBroadcast::NoBroadcast), nullptr,
reinterpret_cast<const HipT*>(input.template Data<T>()), nullptr,
reinterpret_cast<const HipT*>(input.template Data<T>()), nullptr, tmp_div, tmp_div,
reinterpret_cast<HipT*>(input_data_buffer.get()), input_count);
input_data = reinterpret_cast<const HipT*>(input_data_buffer.get());
}

switch (applicable_matrix_reduction) {
case ApplicableMatrixReduction::Rows: {
ORT_RETURN_IF_ERROR(reduce_matrix_rows(
stream, input_data, reinterpret_cast<HipT*>(output.template MutableData<T>()), m, n));
} break;
case ApplicableMatrixReduction::Columns: {
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<HipT>(m, n);
auto buffer = rocm_ep.GetScratchBuffer<void>(buffer_size_bytes);
ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data,
reinterpret_cast<HipT*>(output.template MutableData<T>()), m, n,
buffer.get(), buffer_size_bytes));
} break;
default: {
ORT_ENFORCE(false, "Invild matrix reduction type.");
}
}
case ApplicableMatrixReduction::Columns: {
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<HipT>(m, n);
auto buffer = rocm_ep.GetScratchBuffer<void>(buffer_size_bytes);
return reduce_matrix_columns(
stream,
reinterpret_cast<const HipT*>(input.template Data<T>()),
reinterpret_cast<HipT*>(output.template MutableData<T>()),
m, n, buffer.get(), buffer_size_bytes);

if (calculate_log) {
Impl_Log<HipT>(stream, reinterpret_cast<const HipT*>(output.template Data<T>()),
reinterpret_cast<HipT*>(output.template MutableData<T>()), output_count);
} else if (miopen_reduce_op == MIOPEN_REDUCE_TENSOR_AVG) {
float denominator_float = applicable_matrix_reduction == ApplicableMatrixReduction::Rows
? static_cast<float>(m)
: static_cast<float>(n);
HipT denominator = ToHipType<T>::FromFloat(denominator_float);
UnaryDiv(stream, reinterpret_cast<const HipT*>(output.template Data<T>()),
reinterpret_cast<HipT*>(output.template MutableData<T>()), denominator, output_count);
}
default:
break;

return Status::OK();
}
}

Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/rocm/reduction/reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ class ReduceMax final : public ReduceKernel<true> {
template <typename T>
class ReduceMean final : public ReduceKernel<true> {
public:
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info) {}
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info) {
fast_reduction_ = true;
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T>(ctx, MIOPEN_REDUCE_TENSOR_AVG);
Expand Down Expand Up @@ -184,6 +186,7 @@ class ReduceLogSum final : public ReduceKernel<true> {
public:
ReduceLogSum(const OpKernelInfo& info) : ReduceKernel<true>(info) {
ReduceKernel<true>::calculate_log_ = true;
fast_reduction_ = true;
}

Status ComputeInternal(OpKernelContext* ctx) const override {
Expand All @@ -196,6 +199,7 @@ class ReduceSumSquare final : public ReduceKernel<true> {
public:
ReduceSumSquare(const OpKernelInfo& info) : ReduceKernel<true>(info) {
ReduceKernel<true>::calculate_sqt_ = true;
fast_reduction_ = true;
}

Status ComputeInternal(OpKernelContext* ctx) const override {
Expand Down

0 comments on commit b7b42e0

Please sign in to comment.