Skip to content

Commit

Permalink
Enable RunMatMulTest all test cases support FP16 (#22440)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->


### Motivation and Context
increase FP16 test coverage for all related EPs
  • Loading branch information
mszhanyi authored Oct 16, 2024
1 parent af00a20 commit 2b8fc55
Showing 1 changed file with 36 additions and 52 deletions.
88 changes: 36 additions & 52 deletions onnxruntime/test/providers/cpu/math/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,128 +38,125 @@ template <typename T>
std::vector<MatMulTestData<T>> GenerateTestCases() {
std::vector<MatMulTestData<T>> test_cases;

auto real_expected_vals = [](const std::vector<int32_t>& expected_vals) {
if constexpr (std::is_same_v<T, int32_t>) {
return expected_vals;
} else if constexpr (std::is_same_v<T, MLFloat16>) {
std::vector<MLFloat16> expected_vals_fp16(expected_vals.size());
std::transform(expected_vals.begin(), expected_vals.end(), expected_vals_fp16.begin(),
[](int32_t num) { return MLFloat16(float(num)); });
return expected_vals_fp16;
} else {
std::vector<T> real_expected_vals(expected_vals.size());
std::transform(expected_vals.begin(), expected_vals.end(), real_expected_vals.begin(),
[](int32_t num) { return static_cast<T>(num); });
return real_expected_vals;
}
};

test_cases.push_back(
{"test padding and broadcast A > B",
{3, 1, 1, 2},
{2, 2, 2},
{3, 2, 1, 2},
{2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}});
real_expected_vals({2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55})});

test_cases.push_back(
{"test padding and broadcast B > A",
{2, 3, 2},
{3, 2, 2, 1},
{3, 2, 3, 1},
{1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}});
real_expected_vals({1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221})});

test_cases.push_back(
{"test left 1D",
{2},
{3, 2, 1},
{3, 1},
{1, 3, 5}});
real_expected_vals({1, 3, 5})});

test_cases.push_back(
{"test right 1D",
{3, 1, 2},
{2},
{3, 1},
{1, 3, 5}});
real_expected_vals({1, 3, 5})});

test_cases.push_back(
{"test left 1D right 2D",
{2},
{2, 3},
{3},
{3, 4, 5}});
real_expected_vals({3, 4, 5})});

test_cases.push_back(
{"test scalar output",
{3},
{3},
{},
{5}});
real_expected_vals({5})});

test_cases.push_back(
{"test 2D",
{3, 4},
{4, 3},
{3, 3},
{42, 48, 54, 114, 136, 158, 186, 224, 262}});
real_expected_vals({42, 48, 54, 114, 136, 158, 186, 224, 262})});

test_cases.push_back(
{"test 2D special",
{2, 2, 3},
{3, 4},
{2, 2, 4},
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}});
real_expected_vals({20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218})});

test_cases.push_back(
{"test 2D special 2",
{2, 2, 3},
{1, 3, 4},
{2, 2, 4},
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}});
real_expected_vals({20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218})});

test_cases.push_back(
{"test 2D special 3",
{2, 6},
{1, 1, 6, 1},
{1, 1, 2, 1},
{55, 145}});
real_expected_vals({55, 145})});

test_cases.push_back(
{"test 2D empty input",
{3, 4},
{4, 0},
{3, 0},
{}});
real_expected_vals({})});

test_cases.push_back(
{"test 3D batch",
{3, 1, 3},
{3, 3, 2},
{3, 1, 2},
{
real_expected_vals({
// clang-format off
10, 13,
100, 112,
298, 319,
// clang-format on
}});
})});

test_cases.push_back(
{"test 4D batch",
{2, 2, 1, 3},
{2, 2, 3, 2},
{2, 2, 1, 2},
{
real_expected_vals({
// clang-format off
10, 13,
100, 112,
298, 319,
604, 634,
// clang-format on
}});

return test_cases;
}

template <>
std::vector<MatMulTestData<MLFloat16>> GenerateTestCases() {
std::vector<MatMulTestData<MLFloat16>> test_cases;

// test 2D expected_vals
std::vector<int64_t> expected_vals = {42, 48, 54, 114, 136, 158, 186, 224, 262};
std::vector<MLFloat16> expected_vals_fp16(expected_vals.size());
std::transform(expected_vals.begin(), expected_vals.end(), expected_vals_fp16.begin(),
[](int64_t num) { return MLFloat16(float(num)); });
test_cases.push_back(
{"test 2D MLfloat16",
{3, 4},
{4, 3},
{3, 3},
expected_vals_fp16});
})});

return test_cases;
}
Expand Down Expand Up @@ -209,19 +206,12 @@ TEST(MathOpTest, MatMulFloatType) {
GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)";
}
RunMatMulTest<float>(7, false, false);
}

// To Test XNNPACK, Matrix B must be constant
TEST(MathOpTest, MatMulFloatType_ConstantB) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)";
}
// Note. Xnnpack only supports matmul when Matrix B is constant
RunMatMulTest<float>(7, false, true);
}

#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK)
TEST(MathOpTest, MatMulFloat16_ConstantB) {
TEST(MathOpTest, MatMulFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
Expand All @@ -233,22 +223,16 @@ TEST(MathOpTest, MatMulFloat16_ConstantB) {
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)";
}
RunMatMulTest<MLFloat16>(7, false, true);
RunMatMulTest<MLFloat16>(14, false, false);
// Note. Xnnpack only supports matmul when Matrix B is constant
RunMatMulTest<MLFloat16>(14, false, true);
}
#endif

TEST(MathOpTest, MatMulDoubleType) {
RunMatMulTest<double>(7);
}

TEST(MathOpTest, MatMulFloatTypeInitializer) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)";
}
RunMatMulTest<float>(7, false, true);
}

TEST(MathOpTest, MatMulInt32Type) {
RunMatMulTest<int32_t>(9);
}
Expand Down

0 comments on commit 2b8fc55

Please sign in to comment.