From ad3f51bcd5188ed27e544a14e342574d465bdd58 Mon Sep 17 00:00:00 2001 From: Tingqian Li Date: Thu, 5 Sep 2024 09:24:49 +0800 Subject: [PATCH] [CPU] fix QKVProjection fake-matching for accuracy with bert models (#26417) ### Details: - prevent QKVProjFusion from mistakenly fusing patterns with integer input (happens in bert-INT8 models with smooth-quant) ### Tickets: - *CVS-151213* --- .../src/transformations/cpu_opset/x64/op/qkv_proj.cpp | 1 + .../transformations/cpu_opset/x64/pass/qkv_proj_fusion.cpp | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/qkv_proj.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/qkv_proj.cpp index 11d122b3815d8d..cb9cc543c40a1e 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/qkv_proj.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/qkv_proj.cpp @@ -16,6 +16,7 @@ void QKVProjectionNode::validate_and_infer_types() { const auto& ishape = get_input_partial_shape(0); const auto& itype = get_input_element_type(0); NODE_VALIDATION_CHECK(this, ishape.rank().is_static() && ishape.rank() == 3, "feature shape rank must be 3"); + NODE_VALIDATION_CHECK(this, itype.is_real(), "feature data type must be real"); set_output_size(3); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/qkv_proj_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/qkv_proj_fusion.cpp index 4b9540b06ccafd..11ef6ca09ff4a0 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/qkv_proj_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/qkv_proj_fusion.cpp @@ -52,6 +52,10 @@ ov::intel_cpu::QKVProjFusion::QKVProjFusion() { return false; } + if (!src.get_element_type().is_real()) { + return false; + } + OutputVector args = {src}; OutputVector outputs; size_t hidden_size = 0;