Skip to content

Commit

Permalink
Add Support for MatMulInteger (#2072)
Browse files Browse the repository at this point in the history
* Add Support for MatMulInteger

MatMulInteger was supported in ONNX opset v10 (not checked in proposed change, the error can be addressed on save), this specific type combination is support in TensorFlow, but the node type not identified and handled properly here.

Handles #2071

Signed-off-by: Gregory Morse <[email protected]>

* Update math.py

Signed-off-by: Gregory Morse <[email protected]>

* Update support_status.md

Signed-off-by: Gregory Morse <[email protected]>

* Update test_backend.py

Signed-off-by: Gregory Morse <[email protected]>

Signed-off-by: Gregory Morse <[email protected]>
Co-authored-by: Jay Zhang <[email protected]>
  • Loading branch information
GregoryMorse and fatcat-z authored Nov 4, 2022
1 parent 48e9015 commit 2c1db54
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
1 change: 1 addition & 0 deletions support_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
| AvgPool3D | 1 ~ 17 |
| BatchMatMul | 1 ~ 17 |
| BatchMatMulV2 | 1 ~ 17 |
| BatchMatMulV3 | 1 ~ 17 |
| BatchToSpaceND | 1 ~ 17 |
| BiasAdd | 1 ~ 17 |
| BiasAddV1 | 1 ~ 17 |
Expand Down
9 changes: 9 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,15 @@ def func(x, y):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_val}, rtol=1e-5)

@check_tf_min_version("2.6")
def test_matmulinteger(self):
x_val = np.array([1, 2, -3, -4], dtype=np.int8).reshape((2, 2))
y_val = np.array([1, 2, -3, -4], dtype=np.int8).reshape((2, 2))
def func(x, y):
x_ = tf.matmul(x, y, output_type=tf.int32)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})

@check_onnxruntime_incompatibility("Sub")
def test_sub(self):
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
Expand Down
18 changes: 15 additions & 3 deletions tf2onnx/onnx_opset/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,13 @@ def version_1(cls, ctx, node, **kwargs):
name=op_name, shapes=shapes, dtypes=dtypes)


@tf_op(["MatMul", "BatchMatMul", "BatchMatMulV2"])
@tf_op(["MatMul", "BatchMatMul", "BatchMatMulV2", "BatchMatMulV3"])
class MatMul:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# tensorflow allows transpose and conjugated. If found, insert the required transpose.
# We could use Gemm as well but tensorflow does not pass bias in matmul.
node.type = "MatMul"
if node.type != "MatMulInteger": node.type = "MatMul"

attrs = ["transpose_a", "transpose_b", "adjoint_a", "adjoint_b", "adj_x", "adj_y"]
attrs_val = [node.get_attr(attr) for attr in attrs]
Expand Down Expand Up @@ -408,7 +408,19 @@ def version_1(cls, ctx, node, **kwargs):
val = node.get_attr(i)
if val is not None and val.i != 0:
raise ValueError(node.type + " attribute " + i + " is not supported")

@classmethod
def version_10(cls, ctx, node, **kwargs):
if (ctx.get_dtype(node.input[0]) in [onnx_pb.TensorProto.INT8, onnx_pb.TensorProto.UINT8] and
ctx.get_dtype(node.input[1]) in [onnx_pb.TensorProto.INT8, onnx_pb.TensorProto.UINT8] and
ctx.get_dtype(node.output[0]) == onnx_pb.TensorProto.INT32):
node.type = "MatMulInteger"
zpdata_a = np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0])))
zero_point_node_a = ctx.make_const(utils.make_name("zero_point_a"), zpdata_a)
zpdata_b = np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1])))
zero_point_node_b = ctx.make_const(utils.make_name("zero_point_b"), zpdata_b)
ctx.replace_inputs(node, [node.input[0], node.input[1],
zero_point_node_a.output[0], zero_point_node_b.output[0]])
cls.version_1(ctx, node, **kwargs)

@tf_op("Erf")
class Erf:
Expand Down

0 comments on commit 2c1db54

Please sign in to comment.