diff --git a/modules/cpu/src/runtime/cmodel/include/clamp.h b/modules/cpu/src/runtime/cmodel/include/clamp.h index de3a5a1ad9..bebf734c06 100644 --- a/modules/cpu/src/runtime/cmodel/include/clamp.h +++ b/modules/cpu/src/runtime/cmodel/include/clamp.h @@ -16,15 +16,53 @@ #include "../../gsl-lite.hpp" #include #include +#ifdef __riscv_vector +#include +#endif using namespace nncase::runtime::cpu; namespace kernels { namespace { +#ifdef __riscv_vector template -void clamp_impl(const T *input, T min, T max, T *output, - gsl::span in_shape, - gsl::span in_strides, - gsl::span out_strides) { +void clamp_rvv_impl(const T *input, T min, T max, T *output, + gsl::span in_shape, + gsl::span in_strides, + gsl::span out_strides) { + auto [new_in_shape, new_in_stride] = to_nd(in_shape, in_strides, 5); + auto [new_out_shape, new_out_stride] = to_nd(in_shape, out_strides, 5); + for (size_t n = 0; n < new_in_shape[0]; ++n) { + for (size_t c = 0; c < new_in_shape[1]; ++c) { + for (size_t h = 0; h < new_in_shape[2]; ++h) { + for (size_t w = 0; w < new_in_shape[3]; ++w) { + const T *in_ptr = input + n * new_in_stride[0] + + c * new_in_stride[1] + h * new_in_stride[2] + + w * new_in_stride[3]; + T *out_ptr = output + n * new_out_stride[0] + + c * new_out_stride[1] + h * new_out_stride[2] + + w * new_out_stride[3]; + size_t vl; + for (size_t i = new_in_shape[4]; i > 0; i -= vl) { + vl = vsetvl_e32m8(i); + vfloat32m8_t vx = vle32_v_f32m8(in_ptr, vl); + vx = vfmax_vf_f32m8(vx, min, vl); + vx = vfmin_vf_f32m8(vx, max, vl); + vse32_v_f32m8(out_ptr, vx, vl); + in_ptr += vl; + out_ptr += vl; + } + } + } + } + } + return; +} +#else +template +void clamp_native_impl(const T *input, T min, T max, T *output, + gsl::span in_shape, + gsl::span in_strides, + gsl::span out_strides) { return apply(in_shape, [&](gsl::span index) -> void { const auto v = input[offset(index, in_strides)]; output[offset(index, out_strides)] = static_cast( @@ -33,12 +71,18 @@ void clamp_impl(const T *input, T min, T max, T *output, return; }); } +#endif } // namespace template void clamp(const T *input, T *output, T min, T max, gsl::span in_shape, gsl::span in_strides, gsl::span out_strides) { - clamp_impl(input, min, max, output, in_shape, in_strides, out_strides); +#ifdef __riscv_vector + clamp_rvv_impl(input, min, max, output, in_shape, in_strides, out_strides); +#else + clamp_native_impl(input, min, max, output, in_shape, in_strides, + out_strides); +#endif } } // namespace kernels \ No newline at end of file