Skip to content

Commit

Permalink
add instancenorm
Browse files Browse the repository at this point in the history
  • Loading branch information
curioyang committed May 10, 2023
1 parent b57f35b commit 81563a6
Show file tree
Hide file tree
Showing 31 changed files with 686 additions and 16 deletions.
16 changes: 15 additions & 1 deletion include/nncase/codegen/stackvm/op_writer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/2/6 下午2:28:54 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/5/9 下午5:18:43 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1543,6 +1543,19 @@ struct op_writer<nncase::runtime::stackvm::tensor_gather_elements_op_t>
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_instance_normalization_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_instance_normalization_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.input_shape);
writer.write(op.epsilon);
}
};

class NNCASE_API op_builder
{
public:
Expand Down Expand Up @@ -1684,6 +1697,7 @@ class NNCASE_API op_builder
void tensor_layer_normalization_(datatype_t datatype, uint8_t input_shape, int32_t axis, float epsilon);
void tensor_compress_(uint8_t input_shape_src, uint8_t condition_shape_src, float axis);
void tensor_gather_elements_(uint8_t input_shape_src, uint8_t indices_shape_src, int32_t axis);
void tensor_instance_normalization_(datatype_t datatype, uint8_t input_shape, float epsilon);

private:
section_writer &writer_;
Expand Down
3 changes: 2 additions & 1 deletion include/nncase/ir/opcode.def
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ DEFINE_NEUTRAL_OPCODE(roi_align, RoiAlign, 0x126)
DEFINE_NEUTRAL_OPCODE(compare, Compare, 0x127)
DEFINE_NEUTRAL_OPCODE(softmax, Softmax, 0x128)
DEFINE_NEUTRAL_OPCODE(gru, GRU, 0x129)
DEFINE_NEUTRAL_OPCODE(tflite_detection_postprocess, TfliteDetectionPostprocess, 0x12A)
DEFINE_NEUTRAL_OPCODE(tflite_detection_postprocess, TfliteDetectionPostprocess, 0x12A)
DEFINE_NEUTRAL_OPCODE(layernorm, LayerNormalization, 0x12B)
DEFINE_NEUTRAL_OPCODE(compress, Compress, 0x12C)
DEFINE_NEUTRAL_OPCODE(gather_elements, GatherElements, 0x12D)
DEFINE_NEUTRAL_OPCODE(instancenorm, InstanceNormliaztion, 0x12E)
39 changes: 39 additions & 0 deletions include/nncase/ir/ops/instancenorm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include "nncase/ir/connectors.h"

namespace nncase::ir
{
class NNCASE_API instancenorm : public node
{
public:
DEFINE_NODE_OPCODE(op_instancenorm);

input_connector &input() { return input_at(0); }
input_connector &scale() { return input_at(1); }
input_connector &bias() { return input_at(2); }
output_connector &output() { return output_at(0); }
float epsilon() const noexcept { return epsilon_; }
instancenorm(datatype_t input_type, shape_t input_shape, float epsilon);

protected:
bool properties_equal(node &other) const override;

private:
float epsilon_;
};
}
3 changes: 3 additions & 0 deletions include/nncase/kernels/cpu/optimized/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ template <typename T>
NNCASE_API result<void> sigmoid(const T *input, T *output, const runtime_shape_t &in_shape, const runtime_shape_t &in_strides,
const runtime_shape_t &out_strides) noexcept;

template <typename T>
NNCASE_API result<void> instancenorm(const T *input, T *output, T *scale, T *bias, const runtime_shape_t &in_shape, float epsilon) noexcept;

template <typename T>
NNCASE_API result<void> layernorm(const T *input, T *output, T *scale, T *bias, const runtime_shape_t &in_shape, int32_t axis, float epsilon) noexcept;

Expand Down
5 changes: 5 additions & 0 deletions include/nncase/kernels/cpu/reference/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ NNCASE_API result<void>
gather_elements(const TI *input, const TK *indices, TI *output, const runtime_shape_t &in_shape,
const runtime_shape_t &indices_shape, const int axis) noexcept;

template <typename T>
NNCASE_API result<void>
instancenorm(const T *input, T *output, T *scale, T *bias, const runtime_shape_t &in_shape,
float epsilon) noexcept;

template <typename T>
NNCASE_API result<void>
layernorm(const T *input, T *output, T *scale, T *bias, const runtime_shape_t &in_shape, int32_t axis,
Expand Down
3 changes: 3 additions & 0 deletions include/nncase/kernels/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ template <typename TI, typename TK>
NNCASE_API result<void> gather_elements(const TI *input, const TK *indices, TI *output, const runtime_shape_t &in_shape,
const runtime_shape_t &indices_shape, const int axis) noexcept;

template <typename T>
NNCASE_API result<void> instancenorm(const T *input, T *output, T *scale, T *bias, const runtime_shape_t &in_shape, float epsilon) noexcept;

template <typename T>
NNCASE_API result<void> layernorm(const T *input, T *output, T *scale, T *bias, const runtime_shape_t &in_shape, int32_t axis, float epsilon) noexcept;

Expand Down
18 changes: 17 additions & 1 deletion include/nncase/runtime/stackvm/op_reader.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/2/6 下午2:28:53 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/5/9 下午5:18:43 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1813,6 +1813,21 @@ struct op_reader<tensor_gather_elements_op_t>
}
};

template <>
struct op_reader<tensor_instance_normalization_op_t>
{
tensor_instance_normalization_op_t operator()(span_reader &reader) const
{
tensor_instance_normalization_op_t op(default_init);
op.opcode = static_cast<opcode_t>(reader.read_unaligned<uint8_t>());
op.funct = static_cast<tensor_function_t>(reader.read_unaligned<uint16_t>());
op.datatype = static_cast<datatype_t>(reader.read_unaligned<uint8_t>());
op.input_shape = reader.read_unaligned<uint8_t>();
op.epsilon = reader.read_unaligned<float>();
return op;
}
};

class NNCASE_API op_visitor
{
public:
Expand Down Expand Up @@ -1959,6 +1974,7 @@ class NNCASE_API op_visitor
virtual result<void> visit(NNCASE_UNUSED const tensor_layer_normalization_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_compress_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_gather_elements_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_instance_normalization_op_t &op) noexcept { return ok(); }

protected:
bool interrupted_;
Expand Down
18 changes: 17 additions & 1 deletion include/nncase/runtime/stackvm/opcode.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/2/6 下午2:28:53 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/5/9 下午5:18:43 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -166,6 +166,7 @@ enum class tensor_function_t
LAYER_NORMALIZATION = 0x0029,
COMPRESS = 0x002A,
GATHER_ELEMENTS = 0x002B,
INSTANCE_NORMALIZATION = 0x002C,
};

// Instructions
Expand Down Expand Up @@ -1958,4 +1959,19 @@ struct tensor_gather_elements_op_t
}
};

struct tensor_instance_normalization_op_t
{
opcode_t opcode;
tensor_function_t funct;
datatype_t datatype;
uint8_t input_shape;
float epsilon;

tensor_instance_normalization_op_t(default_init_t) noexcept { }
explicit tensor_instance_normalization_op_t(datatype_t datatype, uint8_t input_shape, float epsilon) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::INSTANCE_NORMALIZATION), datatype(datatype), input_shape(input_shape), epsilon(epsilon)
{
}
};

END_NS_NNCASE_RT_MODULE
29 changes: 29 additions & 0 deletions include/nncase/transforms/neutral/fold_instancenorm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../transform.h"

namespace nncase::ir::transforms
{
class NNCASE_API fold_instancenorm_transform : public transform
{
public:
void process(transform_context &context) override;

protected:
bool skip_self_contained_check() const noexcept override { return true; }
bool on_try_match(ir::node &node, transform_context &context) override;
};
}
3 changes: 2 additions & 1 deletion src/codegen/stackvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ set(SRCS module_builder.cpp
ops/trilu.cpp
ops/tflite_detection_postprocess.cpp
ops/unary.cpp
ops/layernorm.cpp)
ops/layernorm.cpp
ops/instancenorm.cpp)

add_library(codegen_stackvm OBJECT ${SRCS})
target_link_libraries(codegen_stackvm PUBLIC ir schedule)
Expand Down
1 change: 1 addition & 0 deletions src/codegen/stackvm/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <nncase/ir/ops/gather_nd.h>
#include <nncase/ir/ops/gru.h>
#include <nncase/ir/ops/hardmax.h>
#include <nncase/ir/ops/instancenorm.h>
#include <nncase/ir/ops/layernorm.h>
#include <nncase/ir/ops/matmul.h>
#include <nncase/ir/ops/onehot.h>
Expand Down
7 changes: 6 additions & 1 deletion src/codegen/stackvm/op_writer.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/2/6 下午2:28:54 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/5/9 下午5:18:43 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -697,3 +697,8 @@ void op_builder::tensor_gather_elements_(uint8_t input_shape_src, uint8_t indice
{
op_writer<tensor_gather_elements_op_t>()(tensor_gather_elements_op_t(input_shape_src, indices_shape_src, axis), writer_);
}

void op_builder::tensor_instance_normalization_(datatype_t datatype, uint8_t input_shape, float epsilon)
{
op_writer<tensor_instance_normalization_op_t>()(tensor_instance_normalization_op_t(datatype, input_shape, epsilon), writer_);
}
3 changes: 2 additions & 1 deletion src/codegen/stackvm/ops.def
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ DEFINE_OP(transpose)
DEFINE_OP(trilu)
DEFINE_OP(tflite_detection_postprocess)
DEFINE_OP(unary)
DEFINE_OP(layernorm)
DEFINE_OP(layernorm)
DEFINE_OP(instancenorm)
37 changes: 37 additions & 0 deletions src/codegen/stackvm/ops/instancenorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../module_builder.h"

using namespace nncase;
using namespace nncase::codegen;
using namespace nncase::codegen::stackvm;
using namespace nncase::ir;

void stackvm_module_builder::emit(instancenorm &node, stackvm_op_builder &builder)
{
auto &input = allocation(node.input());
auto &scale = allocation(node.scale());
auto &bias = allocation(node.bias());
auto &output = allocation(node.output());

builder.lea_buffer(input);
builder.lea_buffer(scale);
builder.lea_buffer(bias);
builder.lea_buffer(output);

builder.stshape(0, input.shape);

builder.tensor_instance_normalization_(node.output().type(), 0, node.epsilon());
}
22 changes: 22 additions & 0 deletions src/evaluator/ops/neutral/neutral_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <nncase/ir/ops/gather_nd.h>
#include <nncase/ir/ops/gru.h>
#include <nncase/ir/ops/hardmax.h>
#include <nncase/ir/ops/instancenorm.h>
#include <nncase/ir/ops/layernorm.h>
#include <nncase/ir/ops/matmul.h>
#include <nncase/ir/ops/onehot.h>
Expand Down Expand Up @@ -886,6 +887,27 @@ void register_neutral_evaluators()
}
});

register_evaluator(op_instancenorm, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<instancenorm &>(node);

auto input = context.memory_at(rnode.input());
auto scale = context.memory_at(rnode.scale());
auto bias = context.memory_at(rnode.bias());
auto output = context.memory_at(rnode.output());

auto output_type = rnode.output().type();
switch (output_type)
{
case dt_float32:
kernels::instancenorm(input.buffer().as_span<float>().data(), output.buffer().as_span<float>().data(),
scale.buffer().as_span<float>().data(), bias.buffer().as_span<float>().data(), input.shape(),
rnode.epsilon())
.unwrap_or_throw();
break;
default:
std::cerr << "unsupported dtype for layernorm: " + std::string(datatype_names(output_type));
} });

register_evaluator(op_layernorm, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<layernorm &>(node);

Expand Down
5 changes: 3 additions & 2 deletions src/ir/ops/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cmake_minimum_required (VERSION 3.13)
cmake_minimum_required(VERSION 3.13)

target_sources(ir PRIVATE
call.cpp
Expand Down Expand Up @@ -51,4 +51,5 @@ target_sources(ir PRIVATE
gather_elements.cpp
layernorm.cpp
compress.cpp
)
instancenorm.cpp
)
35 changes: 35 additions & 0 deletions src/ir/ops/instancenorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <nncase/ir/op_utils.h>
#include <nncase/ir/ops/instancenorm.h>
#include <xtensor/xarray.hpp>

using namespace nncase;
using namespace nncase::ir;

instancenorm::instancenorm(datatype_t input_type, shape_t input_shape, float epsilon)
: epsilon_(epsilon)
{
add_input("input", input_type, input_shape);
add_input("scale", input_type, shape_t { input_shape[1], 1, 1 });
add_input("bias", input_type, shape_t { input_shape[1], 1, 1 });
add_output("output", input_type, input_shape);
}

bool instancenorm::properties_equal(node &other) const
{
auto &r = static_cast<instancenorm &>(other);
return epsilon() == r.epsilon();
}
3 changes: 2 additions & 1 deletion src/kernels/cpu/optimized/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ set(SRCS convolution.cpp
${ARCH}/softmax.cpp
${ARCH}/layernorm.cpp
${ARCH}/ternary.cpp
${ARCH}/reduce.cpp)
${ARCH}/reduce.cpp
${ARCH}/instancenorm.cpp)
target_sources(kernels PRIVATE ${SRCS})
Loading

0 comments on commit 81563a6

Please sign in to comment.