Skip to content

Commit

Permalink
Merge branch 'master' into kernel_bug_fixv2
Browse files Browse the repository at this point in the history
  • Loading branch information
lerenhua authored Aug 28, 2023
2 parents 7e6e57d + aaf0565 commit 17b2039
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 8 deletions.
6 changes: 3 additions & 3 deletions tests/kernels/test_broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class BroadCastTest : public KernelTest,
void SetUp() override {
READY_SUBCASE()

auto typecode = GetDataType("lhs_type");
auto l_shape = GetShapeArray("lhs_shape");
auto r_shape = GetShapeArray("rhs_shape");
auto typecode = GetDataType("lhs_type");

input =
hrt::create(typecode, r_shape, host_runtime_tensor::pool_cpu_only)
Expand Down Expand Up @@ -219,12 +219,12 @@ TEST_P(BroadCastTest, BroadCast) {

int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_type, i)
FOR_LOOP(lhs_shape, j)
FOR_LOOP(rhs_shape, k)
SPLIT_ELEMENT(lhs_type, i)
FOR_LOOP(lhs_type, i)
SPLIT_ELEMENT(lhs_shape, j)
SPLIT_ELEMENT(rhs_shape, k)
SPLIT_ELEMENT(lhs_type, i)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CastTest : public KernelTest,

auto typecode_input = GetDataType("lhs_type");
auto typecode_output = GetDataType("rhs_type");
auto l_shape = GetShapeArray("i_shape");
auto l_shape = GetShapeArray("lhs_shape");

input = hrt::create(typecode_input, l_shape,
host_runtime_tensor::pool_cpu_only)
Expand Down
47 changes: 46 additions & 1 deletion tests/kernels/test_celu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,56 @@ class CeluTest : public KernelTest,

alpha = hrt::create(typecode, {1}, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
init_tensor(alpha);
init_tensor_alpha(alpha);
}

void TearDown() override {}

virtual void init_tensor_alpha(runtime::runtime_tensor &tensor) {
auto dtype = tensor.datatype();
switch (dtype) {
case dt_float16: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(-5.0f, 5.0f);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<half>(tensor, index) = static_cast<half>(dis(gen));
return ok();
});
break;
}
case dt_float32: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(-5.0f, 5.0f);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<float>(tensor, index) = static_cast<float>(dis(gen));
return ok();
});
break;
}
case dt_bfloat16: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(-5.0f, 5.0);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<bfloat16>(tensor, index) =
static_cast<bfloat16>(dis(gen));
return ok();
});
break;
}
default: {
}
}
}

protected:
runtime_tensor input;
runtime_tensor alpha;
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_celu.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"lhs_shape":[[1], [1, 2], [1, 3, 16, 16], [16, 16], [1, 16], [1, 3, 16, 1], []],
"lhs_shape":[[1], [1, 2], [1, 3, 16, 16], [16, 16], [1, 16], [1, 3, 16, 1], [1, 3, 16], []],
"lhs_type":["dt_float32"]
}
4 changes: 2 additions & 2 deletions tests/kernels/test_clamp.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"lhs_shape":[[1, 3, 16, 16], [1], [1, 3], [8, 8], [1, 3, 8], [16, 16], [16], []],
"lhs_type":["dt_float32"],
"min": [-1, -2, -3, -4, -5, -6],
"max": [0, 1, 2, 3, 4, 5, 6]
"min": [-1.0, -2.0, -3.0, -4.0, -5.0, -6.0],
"max": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
}

0 comments on commit 17b2039

Please sign in to comment.