Skip to content

Commit

Permalink
fix const shape
Browse files Browse the repository at this point in the history
  • Loading branch information
curioyang committed May 11, 2023
1 parent 3358565 commit 3046412
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
7 changes: 7 additions & 0 deletions include/nncase/ir/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ inline bool is_simple_slice(const axis_t &begin, const axis_t &end, const axis_t
return is_simple_slice;
}

inline shape_t get_instancenorm_const_shape(const shape_t &in_shape)
{
shape_t const_shape(in_shape.size() - 1, 1);
const_shape[0] = in_shape[1];
return const_shape;
}

inline bool is_axis0_squeeze_or_expand_dim_bitcast(const shape_t &in_shape, const shape_t &out_shape)
{
auto in_begin = std::find_if_not(in_shape.begin(), in_shape.end(), [](size_t dim) { return dim == 1; });
Expand Down
4 changes: 2 additions & 2 deletions src/ir/ops/instancenorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ instancenorm::instancenorm(datatype_t input_type, shape_t input_shape, float eps
: epsilon_(epsilon)
{
add_input("input", input_type, input_shape);
add_input("scale", input_type, shape_t { input_shape[1], 1, 1 });
add_input("bias", input_type, shape_t { input_shape[1], 1, 1 });
add_input("scale", input_type, get_instancenorm_const_shape(input_shape));
add_input("bias", input_type, get_instancenorm_const_shape(input_shape));
add_output("output", input_type, input_shape);
}

Expand Down

0 comments on commit 3046412

Please sign in to comment.