Skip to content

Commit

Permalink
Merge pull request #204 from Oneflow-Inc/update_resnet_to_onnx
Browse files Browse the repository at this point in the history
Update resnet to onnx
  • Loading branch information
BBuf authored Jun 23, 2021
2 parents cec8f14 + 163d02d commit e64f95a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
4 changes: 2 additions & 2 deletions Classification/cnns/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,9 @@ onnx_model_dir = 'onnx/model'

**步骤三:调用 flow.onnx.export 方法**

接下来代码中会调用`oneflow_to_onnx()`方法,此方法包含了核心的模型转换方法: `flow.onnx.export()`
接下来代码中会调用`oneflow_to_onnx()`方法,此方法包含了核心的模型转换方法: `oneflow_onnx.oneflow2onnx.util.export_onnx_model()`,更多OneFlow和ONNX模型转换相关的问题请看: [oneflow_convert_tools介绍](https://docs.oneflow.org/extended_topics/oneflow_convert_tools.html)

**flow.onnx.export** 将从 OneFlow 网络得到 ONNX 模型,它的第一个参数是上文所说的专用于推理的 job function,第二个参数是OneFlow模型路径,第三个参数是(转换后)ONNX模型的存放路径
**oneflow_to_onnx** 将从 OneFlow 网络得到 ONNX 模型,它的第一个参数是上文所说的专用于推理的 job function,第二个参数是OneFlow模型路径,第三个参数是(转换后)ONNX模型的存放路径

```python
onnx_model = oneflow_to_onnx(InferenceNet, flow_weights_path, onnx_model_dir, external_data=False)
Expand Down
14 changes: 6 additions & 8 deletions Classification/cnns/resnet_to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -29,6 +26,7 @@
from resnet_model import resnet50
import config as configs
from imagenet1000_clsidx_to_labels import clsidx_2_labels
from oneflow_onnx.oneflow2onnx.util import export_onnx_model

parser = configs.get_parser()
args = parser.parse_args()
Expand Down Expand Up @@ -92,12 +90,12 @@ def oneflow_to_onnx(
assert os.path.exists(flow_weights_path) and os.path.isdir(onnx_model_dir)

onnx_model_path = os.path.join(
onnx_model_dir, os.path.basename(flow_weights_path) + ".onnx"
onnx_model_dir, "model.onnx"
)
flow.onnx.export(
export_onnx_model(
job_func,
flow_weights_path,
onnx_model_path,
flow_weight_dir=flow_weights_path,
onnx_model_path=onnx_model_dir,
opset=11,
external_data=external_data,
)
Expand Down Expand Up @@ -132,4 +130,4 @@ def check_equality(
are_equal, onnx_res = check_equality(InferenceNet, onnx_model, image_path)
clsidx_onnx = onnx_res.argmax()
print("Are the results equal? {}".format("Yes" if are_equal else "No"))
print("Class: {}; score: {}".format(clsidx_2_labels[clsidx_onnx], onnx_res.max()))
print("Class: {}; score: {}".format(clsidx_2_labels[clsidx_onnx], onnx_res.max()))

0 comments on commit e64f95a

Please sign in to comment.