forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_onnxscript_no_runtime.py
163 lines (135 loc) · 5.94 KB
/
test_onnxscript_no_runtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Owner(s): ["module: onnx"]
"""Test the support on onnxscript in PyTorch-ONNX converter."""
import io
from typing import List
import onnx
import onnxscript
from onnxscript.onnx_types import FLOAT
import torch
from torch.onnx._internal import jit_utils
from torch.testing._internal import common_utils
class TestONNXScriptExport(common_utils.TestCase):
# opset version is
# 1. local function is supported after opset 15
# 2. onnx-script requires users to determine opset in local function
opset_version = 15
def test_onnxscript_registration_with_multiple_models(self):
from onnxscript.onnx_opset import opset15 as op
# 1. Register Selu onnxscript function as custom Op
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
@onnxscript.script(custom_opset)
def Selu(X):
# default value is not supported by onnxscript
alpha = 1.67326 # auto wrapped as Constants
gamma = 1.0507
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
def custom_selu(g: jit_utils.GraphContext, X):
return g.onnxscript_op(Selu, X).setType(X.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=self.opset_version,
)
# 2. Register layer_norm onnxscript function as custom Op
@onnxscript.script(custom_opset)
def layer_norm(
X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
):
mean = op.ReduceMean(X, axes=axes)
D = X - mean # op.Sub(X, mean)
DD = D * D # op.Mul(D, D)
var = op.ReduceMean(DD, axes=axes)
vareps = var + eps # op.Add(var, eps)
stddev = op.Sqrt(vareps)
invstddev = op.Reciprocal(stddev)
normalized = D * invstddev # op.Mul(D, invstddev)
normalizedw = op.CastLike(
normalized, weight
) # Type issue if missing this Op
normalizedscaled = normalizedw * weight # op.Mul(normalized, weight)
return normalizedscaled + bias
@torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def custom_layer_norm(
g, input, normalized_shape, weight, bias, eps, cudnn_enable
):
# comprehension is not supported by onnxscript
axes = [-i for i in range(len(normalized_shape), 0, -1)]
return g.onnxscript_op(
layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
).setType(input.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::layer_norm",
symbolic_fn=custom_layer_norm,
opset_version=self.opset_version,
)
# 3. export two models
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model_selu = torch.nn.SELU()
selu_onnx = io.BytesIO()
torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version)
N, C = 3, 4
y = torch.randn(N, C)
model_layer_norm = torch.nn.LayerNorm(C)
layer_norm_onnx = io.BytesIO()
torch.onnx.export(
model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version
)
# 4. test on models
selu_proto = onnx.load(io.BytesIO(selu_onnx.getvalue()))
layer_norm_proto = onnx.load(io.BytesIO(layer_norm_onnx.getvalue()))
self.assertEqual(len(selu_proto.functions), 1)
self.assertEqual(len(layer_norm_proto.functions), 1)
self.assertEqual(selu_proto.functions[0].name, "Selu")
self.assertEqual(layer_norm_proto.functions[0].name, "layer_norm")
def test_loop_registration(self):
# Control flow is tested for _find_onnxscript_op function in torch/onnx/utils.py,
# which has recursive logic to go through every nodes with subgraph in model proto
class NestedLoopsModel(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.selu = torch.nn.SELU()
@torch.jit.script_method
def forward(self, x):
y = x
for i in range(x.size(3)):
if i == 0:
y = self.selu(x)
else:
y += i
return y
model = NestedLoopsModel()
inputs = torch.zeros(1, 2, 3, 4)
from onnxscript.onnx_opset import opset15 as op
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=2)
@onnxscript.script(custom_opset)
def Selu(X):
alpha = 1.6732632423543772848170429916717
gamma = 1.0507009873554804934193349852946
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
def custom_selu(g, X):
# domain of the Op should be aligned with onnx-script
# setType API is required for custom Op to support
# torchscript shape type inference
print("custom_selu is used!")
return g.onnxscript_op(Selu, X).setType(X.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=15,
)
saved_model = io.BytesIO()
torch.onnx.export(
torch.jit.script(model), inputs, f=saved_model, opset_version=15
)
loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue()))
self.assertEqual(len(loop_selu_proto.functions), 1)