Skip to content

Commit

Permalink
[CPU] fix QKVProjection fake-matching for accuracy with bert models (#…
Browse files Browse the repository at this point in the history
…26417)

### Details:
- prevent QKVProjFusion from mistakenly fusing patterns with integer
input (happens in bert-INT8 models with smooth-quant)


### Tickets:
 - *CVS-151213*
  • Loading branch information
usstq authored Sep 5, 2024
1 parent 0bfa578 commit ad3f51b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void QKVProjectionNode::validate_and_infer_types() {
const auto& ishape = get_input_partial_shape(0);
const auto& itype = get_input_element_type(0);
NODE_VALIDATION_CHECK(this, ishape.rank().is_static() && ishape.rank() == 3, "feature shape rank must be 3");
NODE_VALIDATION_CHECK(this, itype.is_real(), "feature data type must be real");

set_output_size(3);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ ov::intel_cpu::QKVProjFusion::QKVProjFusion() {
return false;
}

if (!src.get_element_type().is_real()) {
return false;
}

OutputVector args = {src};
OutputVector outputs;
size_t hidden_size = 0;
Expand Down

0 comments on commit ad3f51b

Please sign in to comment.