From fca5ed37084224ab88014ed6a8d59f1a3094a2b8 Mon Sep 17 00:00:00 2001 From: TFLM-bot Date: Thu, 10 Oct 2024 14:02:36 +0000 Subject: [PATCH] Sync from upstream TF. --- .../lite/kernels/internal/reference/batch_matmul.h | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h index 767ad6ab0af..3e3929e08d8 100644 --- a/tensorflow/lite/kernels/internal/reference/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h @@ -111,7 +111,8 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, const float* scaling_factors, const int32_t* input_offset, int32_t* row_sums, const RuntimeShape& output_shape, float* output_data, - bool* compute_row_sums) { + bool* compute_row_sums, + const float* per_channel_scales) { const RuntimeShape extended_lhs_shape = RuntimeShape::ExtendedShape(5, lhs_shape); const RuntimeShape extended_rhs_shape = @@ -188,7 +189,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, int32_t row_sum = woff_ptr2[i]; total -= row_sum * batch_offset; int idx = lhs_rows * j + i; - out_ptr[idx] += batch_scaling_factor * total; + float ss = batch_scaling_factor; + if (per_channel_scales) { + ss *= per_channel_scales[j]; + } + out_ptr[idx] += ss * total; } } }