Skip to content

Commit

Permalink
[GPU] Fix lack of support for 3D (and smaller) shapes in TransposeFus…
Browse files Browse the repository at this point in the history
…ion pass (#26440)

### Details:
- Fix lack of support for 3D (and smaller) shapes in TransposeFusion
pass
- Update IncreasePositionIdsPrecision to support both MatMul and Gemm
operations

### Tickets:
 - [CVS-146889](https://jira.devtools.intel.com/browse/CVS-146889)
  • Loading branch information
sshlyapn authored Sep 5, 2024
1 parent 8c9d4be commit c3152d3
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ IncreasePositionIdsPrecision::IncreasePositionIdsPrecision() {
using namespace ov::pass::pattern;
using ov::pass::pattern::op::Or;

auto gemm = wrap_type<ov::intel_gpu::op::Gemm>();
auto concat = wrap_type<ov::op::v0::Concat>({gemm, gemm});
auto gemm_or_matmul = wrap_type<ov::intel_gpu::op::Gemm, ov::op::v0::MatMul>();
auto concat = wrap_type<ov::op::v0::Concat>({gemm_or_matmul, gemm_or_matmul});
auto sin = wrap_type<ov::op::v0::Sin>({concat});
auto cos = wrap_type<ov::op::v0::Cos>({concat});

Expand All @@ -50,15 +50,15 @@ IncreasePositionIdsPrecision::IncreasePositionIdsPrecision() {
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

auto gemm_node = std::dynamic_pointer_cast<ov::intel_gpu::op::Gemm>(pattern_map.at(gemm).get_node_shared_ptr());
auto matmul_node = std::dynamic_pointer_cast<ov::op::v0::MatMul>(pattern_map.at(gemm_or_matmul).get_node_shared_ptr());
auto cos_node = std::dynamic_pointer_cast<ov::op::v0::Cos>(pattern_map.at(cos).get_node_shared_ptr());
auto sin_node = std::dynamic_pointer_cast<ov::op::v0::Sin>(pattern_map.at(sin).get_node_shared_ptr());

if (!gemm_node || transformation_callback(gemm_node))
if (!matmul_node || transformation_callback(matmul_node))
return false;

const auto desired_et = ov::element::f32;
const auto original_et = gemm_node->get_output_element_type(0);
const auto original_et = matmul_node->get_output_element_type(0);
if (original_et == desired_et)
return false;

Expand Down Expand Up @@ -112,7 +112,7 @@ IncreasePositionIdsPrecision::IncreasePositionIdsPrecision() {
}
};

bool is_changed = insert_converts_before_if_needed(gemm_node);
bool is_changed = insert_converts_before_if_needed(matmul_node);

if (is_changed) {
insert_converts_after_if_needed(cos_node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ bool has_optimized_version(const ov::Output<ov::Node>& output, bool supports_imm
{3, 0, 1, 2},
};

const auto expected_dims_num = 4;
const auto original_dims_num = transpose_order.size();
if (original_dims_num < expected_dims_num) {
transpose_order.resize(expected_dims_num);
std::iota(transpose_order.begin() + original_dims_num, transpose_order.end(), original_dims_num);
}

if (!cldnn::one_of(transpose_order, allowed_orders))
return false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,49 @@ TEST_F(TransformationTestsF, IncreasePositionIdsPrecisionWithUnsqueeze) {
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}

TEST_F(TransformationTestsF, IncreasePositionIdsMatmulWithoutUnsqueeze) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, -1 });
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1}));
auto input_convert_fp = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f16);
auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 64, 1});

auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_fp, rotary_embd_const);
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{matmul, matmul}, 2);

auto cos = std::make_shared<ov::op::v0::Cos>(concat);
auto sin = std::make_shared<ov::op::v0::Sin>(concat);

auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos, sin}, ov::op::internal::RoPE::Config());

model = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input });
manager.register_pass<IncreasePositionIdsPrecision>();
}
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, -1 });
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1}));
auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 64, 1});

auto input_convert_f32 = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f32);
auto rotary_embd_const_convert_f32 = std::make_shared<ov::op::v0::Convert>(rotary_embd_const, ov::element::f32);

auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_f32, rotary_embd_const_convert_f32);
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{matmul, matmul}, 2);

auto cos = std::make_shared<ov::op::v0::Cos>(concat);
auto sin = std::make_shared<ov::op::v0::Sin>(concat);

auto cos_convert = std::make_shared<ov::op::v0::Convert>(cos, ov::element::f16);
auto sin_convert = std::make_shared<ov::op::v0::Convert>(sin, ov::element::f16);

auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos_convert, sin_convert}, ov::op::internal::RoPE::Config());

model_ref = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input });
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,64 @@ TEST_F(TransformationTestsF, TranposeMatmulFusion4) {
}
}

TEST_F(TransformationTestsF, TranposeMatmulFusion5) {
{
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(3));
auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(3));
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 2, 1});
auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);

model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });

const auto supports_immad = false;
manager.register_pass<TransposeFusion>(supports_immad);
}
{
std::vector<int64_t> order_a = {0, 1, 2};
std::vector<int64_t> order_b = {0, 1, 2};
std::vector<int64_t> order_c = {0, 2, 1};
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(3));
auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(3));
auto gemm = std::make_shared<ov::intel_gpu::op::Gemm>(input_a, input_b, order_a, order_b, order_c, ov::element::undefined);

model_ref = std::make_shared<ov::Model>(ov::NodeVector{ gemm }, ov::ParameterVector{ input_a, input_b });
comparator.enable(FunctionsComparator::ATTRIBUTES);
}
}

TEST_F(TransformationTestsF, TranposeMatmulFusion6) {
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(2));
auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(2));
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{2}, {1, 0});
auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);

model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });

const auto supports_immad = false;
manager.register_pass<TransposeFusion>(supports_immad);

model_ref = model->clone();
comparator.enable(FunctionsComparator::ATTRIBUTES);
}

TEST_F(TransformationTestsF, TranposeMatmulFusion7) {
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{2, 4});
auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{4, 2});
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{2}, {1, 0});
auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);

model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });

const auto supports_immad = false;
manager.register_pass<TransposeFusion>(supports_immad);

model_ref = model->clone();
comparator.enable(FunctionsComparator::ATTRIBUTES);
}

TEST_F(TransformationTestsF, TranposeMatmulFusion_Illegal_1) {
{
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{10, 20});
Expand Down

0 comments on commit c3152d3

Please sign in to comment.