diff --git a/include/nncase/ir/op_utils.h b/include/nncase/ir/op_utils.h index d4eaf65d7f..0f379661d5 100644 --- a/include/nncase/ir/op_utils.h +++ b/include/nncase/ir/op_utils.h @@ -385,6 +385,13 @@ inline bool is_simple_slice(const axis_t &begin, const axis_t &end, const axis_t return is_simple_slice; } +inline shape_t get_instancenorm_const_shape(const shape_t &in_shape) +{ + shape_t const_shape(in_shape.size() - 1, 1); + const_shape[0] = in_shape[1]; + return const_shape; +} + inline bool is_axis0_squeeze_or_expand_dim_bitcast(const shape_t &in_shape, const shape_t &out_shape) { auto in_begin = std::find_if_not(in_shape.begin(), in_shape.end(), [](size_t dim) { return dim == 1; }); diff --git a/src/ir/ops/instancenorm.cpp b/src/ir/ops/instancenorm.cpp index a6803fd246..7d20bb0add 100644 --- a/src/ir/ops/instancenorm.cpp +++ b/src/ir/ops/instancenorm.cpp @@ -23,8 +23,8 @@ instancenorm::instancenorm(datatype_t input_type, shape_t input_shape, float eps : epsilon_(epsilon) { add_input("input", input_type, input_shape); - add_input("scale", input_type, shape_t { input_shape[1], 1, 1 }); - add_input("bias", input_type, shape_t { input_shape[1], 1, 1 }); + add_input("scale", input_type, get_instancenorm_const_shape(input_shape)); + add_input("bias", input_type, get_instancenorm_const_shape(input_shape)); add_output("output", input_type, input_shape); }