Skip to content

Commit

Permalink
Merge branch 'master' into fix/cos-nan
Browse files Browse the repository at this point in the history
  • Loading branch information
HeJunchao100813 authored Oct 24, 2023
2 parents 82d1ce6 + 3a62f98 commit b44c03d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
10 changes: 8 additions & 2 deletions src/Nncase.Importer/Onnx/Reduce.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ namespace Nncase.Importer
{
public partial class OnnxImporter
{
private Expr VisitReduce(in NodeProto op, ReduceOp reduceOp, float initValue)
private Expr VisitReduce(in NodeProto op, ReduceOp reduceOp, Expr initValue)
{
return ReduceCore(op, reduceOp, initValue, expr => expr);
}

private Expr ReduceCore(in NodeProto op, ReduceOp reduceOp, float initValue, Func<Expr, Expr> f)
private Expr ReduceCore(in NodeProto op, ReduceOp reduceOp, Expr initValue, Func<Expr, Expr> f)
{
var input = GetInputExpr(op, 0);
Expr axis;
Expand Down Expand Up @@ -51,6 +51,12 @@ private Expr ReduceCore(in NodeProto op, ReduceOp reduceOp, float initValue, Fun
var x when x == ReduceOp.Max && input.CheckedDataType == DataTypes.Int32 => F.Tensors.Reduce(reduceOp, f(input), axis, int.MinValue, keepDims),
var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.Int64 => F.Tensors.Reduce(reduceOp, f(input), axis, long.MaxValue, keepDims),
var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.Int32 => F.Tensors.Reduce(reduceOp, f(input), axis, int.MaxValue, keepDims),
var x when x == ReduceOp.Max && input.CheckedDataType == DataTypes.Float32 => F.Tensors.Reduce(reduceOp, f(input), axis, float.MinValue, keepDims),
var x when x == ReduceOp.Max && input.CheckedDataType == DataTypes.Float16 => F.Tensors.Reduce(reduceOp, f(input), axis, Half.MinValue, keepDims),
var x when x == ReduceOp.Max && input.CheckedDataType == DataTypes.BFloat16 => F.Tensors.Reduce(reduceOp, f(input), axis, BFloat16.RoundToBFloat16(float.MinValue), keepDims),
var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.Float32 => F.Tensors.Reduce(reduceOp, f(input), axis, float.MaxValue, keepDims),
var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.Float16 => F.Tensors.Reduce(reduceOp, f(input), axis, Half.MaxValue, keepDims),
var x when x == ReduceOp.Min && input.CheckedDataType == DataTypes.BFloat16 => F.Tensors.Reduce(reduceOp, f(input), axis, BFloat16.RoundToBFloat16(float.MaxValue), keepDims),
_ => F.Tensors.Reduce(reduceOp, f(input), axis, F.Tensors.Cast(initValue, input.CheckedDataType), keepDims),
};
}
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Importer/Onnx/ReduceWindow2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Nncase.Importer
public partial class OnnxImporter
{
// isGlobal used for GlobalXXXPool
private Expr VisitReduceWindow2D(in NodeProto op, ReduceOp reduceOp, float initValue, bool isGlobal = false)
private Expr VisitReduceWindow2D(in NodeProto op, ReduceOp reduceOp, Expr initValue, bool isGlobal = false)
{
// auto_pad had been DEPRECATED
var input = GetInputExpr(op, 0);
Expand Down
16 changes: 11 additions & 5 deletions tests/importer/onnx_/basic/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@
import numpy as np


def _make_module(in_shape, reduce_op, axes, keepdims, op_version):
def _make_module(in_shape, in_datatype, reduce_op, axes, keepdims, op_version):
inputs = []
outputs = []
initializers = []
attributes_dict = {}
nodes = []

# input
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, in_shape)
input = helper.make_tensor_value_info('input', in_datatype, in_shape)
inputs.append('input')

# output
kd = 1 if keepdims is None else keepdims
data = np.ones(in_shape)
out_shape = np.prod(data, axis=tuple(axes), keepdims=kd).shape
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, out_shape)
output = helper.make_tensor_value_info('output', in_datatype, out_shape)
outputs.append('output')

# axes
Expand Down Expand Up @@ -73,6 +73,11 @@ def _make_module(in_shape, reduce_op, axes, keepdims, op_version):
[1, 3, 16, 16]
]

in_datatypes = [
TensorProto.FLOAT,
TensorProto.FLOAT16
]

reduce_ops = [
'ReduceMax',
'ReduceMean',
Expand Down Expand Up @@ -108,13 +113,14 @@ def _make_module(in_shape, reduce_op, axes, keepdims, op_version):


@pytest.mark.parametrize('in_shape', in_shapes)
@pytest.mark.parametrize('in_datatype', in_datatypes)
@pytest.mark.parametrize('reduce_op', reduce_ops)
@pytest.mark.parametrize('axes', axes_list)
@pytest.mark.parametrize('keepdims', keepdims_lists)
@pytest.mark.parametrize('op_version', op_version_lists)
def test_reduce(in_shape, reduce_op, axes, keepdims, request, op_version):
def test_reduce(in_shape, in_datatype, reduce_op, axes, keepdims, request, op_version):
if len(axes) <= len(in_shape):
model_def = _make_module(in_shape, reduce_op, axes, keepdims, op_version)
model_def = _make_module(in_shape, in_datatype, reduce_op, axes, keepdims, op_version)

runner = OnnxTestRunner(request.node.name)
model_file = runner.from_onnx_helper(model_def)
Expand Down

0 comments on commit b44c03d

Please sign in to comment.