Skip to content

Commit

Permalink
Update softmax
Browse files Browse the repository at this point in the history
1. remove max calc
2. use exp2
  • Loading branch information
liuzm6217-jianan committed Aug 29, 2023
1 parent 1dce12c commit 969e8d8
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions src/Native/src/kernels/stackvm/optimized/riscv64/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,17 @@ result<void> optimized_softmax_impl(const float *input, float *output,
float *ptr_output_vl = ptr_output;

// max
float max = std::numeric_limits<float>::lowest();
while (n) {
auto vl = vsetvl_e32m8(n);
auto v = vle32_v_f32m8(ptr_input_vl, vl);
auto s = vfmv_s_f_f32m1(vundefined_f32m1(), max, vl);

s = vfredmax_vs_f32m8_f32m1(s, v, s, vl);
max = vfmv_f_s_f32m1_f32(s);
ptr_input_vl += vl;
n -= vl;
}
// float max = std::numeric_limits<float>::lowest();
// while (n) {
// auto vl = vsetvl_e32m8(n);
// auto v = vle32_v_f32m8(ptr_input_vl, vl);
// auto s = vfmv_s_f_f32m1(vundefined_f32m1(), max, vl);

// s = vfredmax_vs_f32m8_f32m1(s, v, s, vl);
// max = vfmv_f_s_f32m1_f32(s);
// ptr_input_vl += vl;
// n -= vl;
// }

// exp((x - max) * beta) and sum(exp)
float sum = 0.f;
Expand All @@ -268,9 +268,13 @@ result<void> optimized_softmax_impl(const float *input, float *output,
auto v_in = vle32_v_f32m8(ptr_input_vl, vl);
auto s = vfmv_s_f_f32m1(vundefined_f32m1(), sum, vl);

auto v_out = exp_ps(
vfmul_vf_f32m8(vfsub_vf_f32m8(v_in, max, vl), beta, vl),
vl);
// auto v_out = exp_ps(
// vfmul_vf_f32m8(vfsub_vf_f32m8(v_in, max, vl), beta, vl),
// vl);
// auto v_out = exp_ps(v_in, vl);
// printf("---------call softmax rvv ------------\n");
auto v_out = exp_ps2(v_in, vl);

s = vfredosum_vs_f32m8_f32m1(s, v_out, s, vl);

vse32_v_f32m8(ptr_output_vl, v_out, vl);
Expand Down Expand Up @@ -406,7 +410,6 @@ result<void> optimized::softmax(const T *input, T *output,
#if __riscv_vector
return optimized_softmax_impl(input, output, in_shape, axis, beta);
#endif

return stackvm::reference::softmax(input, output, in_shape, in_strides,
out_strides, axis, beta);
}

0 comments on commit 969e8d8

Please sign in to comment.