From 8fe226eab5f95bcff3b30ee2a93d8c36aec6be26 Mon Sep 17 00:00:00 2001 From: Lei Mao Date: Mon, 30 Jan 2023 20:28:26 -0800 Subject: [PATCH] Rebased and Updated the MR Signed-off-by: Lei Mao --- .../examples/04_modifying_a_model/modify.py | 5 +++-- .../examples/06_removing_nodes/generate.py | 4 +++- .../onnx-graphsurgeon/examples/06_removing_nodes/remove.py | 4 +++- .../07_creating_a_model_with_the_layer_api/generate.py | 3 ++- .../09_shape_operations_with_the_layer_api/generate.py | 6 +++++- .../examples/10_dynamic_batch_size/generate.py | 5 ++++- .../examples/10_dynamic_batch_size/modify.py | 2 +- tools/onnx-graphsurgeon/tests/test_examples.py | 2 ++ 8 files changed, 23 insertions(+), 8 deletions(-) diff --git a/tools/onnx-graphsurgeon/examples/04_modifying_a_model/modify.py b/tools/onnx-graphsurgeon/examples/04_modifying_a_model/modify.py index 5dba5271..4e1fbae3 100644 --- a/tools/onnx-graphsurgeon/examples/04_modifying_a_model/modify.py +++ b/tools/onnx-graphsurgeon/examples/04_modifying_a_model/modify.py @@ -43,6 +43,7 @@ # Therefore, you should only need to sort the graph when you have added new nodes out-of-order. # In this case, the identity node is already in the correct spot (it is the last node, # and was appended to the end of the list), but to be on the safer side, we can sort anyway. -graph.cleanup().toposort() +graph.cleanup(remove_unused_graph_inputs=True).toposort() -onnx.save(gs.export_onnx(graph), "modified.onnx") +model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph)) +onnx.save(model, "modified.onnx") diff --git a/tools/onnx-graphsurgeon/examples/06_removing_nodes/generate.py b/tools/onnx-graphsurgeon/examples/06_removing_nodes/generate.py index c1bd733a..116a2ffb 100644 --- a/tools/onnx-graphsurgeon/examples/06_removing_nodes/generate.py +++ b/tools/onnx-graphsurgeon/examples/06_removing_nodes/generate.py @@ -37,4 +37,6 @@ ] graph = gs.Graph(nodes=nodes, inputs=[x], outputs=[y]) -onnx.save(gs.export_onnx(graph), "model.onnx") + +model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph)) +onnx.save(model, "model.onnx") diff --git a/tools/onnx-graphsurgeon/examples/06_removing_nodes/remove.py b/tools/onnx-graphsurgeon/examples/06_removing_nodes/remove.py index c4823fe0..fde66ea0 100644 --- a/tools/onnx-graphsurgeon/examples/06_removing_nodes/remove.py +++ b/tools/onnx-graphsurgeon/examples/06_removing_nodes/remove.py @@ -37,4 +37,6 @@ # Remove the fake node from the graph completely graph.cleanup() -onnx.save(gs.export_onnx(graph), "removed.onnx") + +model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph)) +onnx.save(model, "removed.onnx") diff --git a/tools/onnx-graphsurgeon/examples/07_creating_a_model_with_the_layer_api/generate.py b/tools/onnx-graphsurgeon/examples/07_creating_a_model_with_the_layer_api/generate.py index 627d6682..04f2f65c 100644 --- a/tools/onnx-graphsurgeon/examples/07_creating_a_model_with_the_layer_api/generate.py +++ b/tools/onnx-graphsurgeon/examples/07_creating_a_model_with_the_layer_api/generate.py @@ -95,4 +95,5 @@ def relu(self, a): for out in graph.outputs: out.dtype = np.float32 -onnx.save(gs.export_onnx(graph), "model.onnx") +model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph)) +onnx.save(model, "model.onnx") diff --git a/tools/onnx-graphsurgeon/examples/09_shape_operations_with_the_layer_api/generate.py b/tools/onnx-graphsurgeon/examples/09_shape_operations_with_the_layer_api/generate.py index 33cbba4c..b8247045 100644 --- a/tools/onnx-graphsurgeon/examples/09_shape_operations_with_the_layer_api/generate.py +++ b/tools/onnx-graphsurgeon/examples/09_shape_operations_with_the_layer_api/generate.py @@ -74,8 +74,12 @@ def concat(self, inputs, axis=0): # Finally, set up the outputs and export. flattened.name = "flattened" # Rename output tensor to make it easy to find. flattened.dtype = np.float32 # NOTE: We must include dtype information for graph outputs +flattened.shape = (gs.Tensor.DYNAMIC,) partially_flattened.name = "partially_flattened" partially_flattened.dtype = np.float32 +partially_flattened.shape = (gs.Tensor.DYNAMIC, 3, gs.Tensor.DYNAMIC) graph.outputs = [flattened, partially_flattened] -onnx.save(gs.export_onnx(graph), "model.onnx") + +model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph)) +onnx.save(model, "model.onnx") diff --git a/tools/onnx-graphsurgeon/examples/10_dynamic_batch_size/generate.py b/tools/onnx-graphsurgeon/examples/10_dynamic_batch_size/generate.py index 679af4a2..718884f4 100644 --- a/tools/onnx-graphsurgeon/examples/10_dynamic_batch_size/generate.py +++ b/tools/onnx-graphsurgeon/examples/10_dynamic_batch_size/generate.py @@ -24,6 +24,7 @@ ########################################################################################################## # Register functions to simplify the graph building process later on. + @gs.Graph.register() def conv(self, inp, weights, dilations, group, strides): out = self.layer( @@ -49,6 +50,7 @@ def matmul(self, lhs, rhs): out.dtype = lhs.dtype return out + ########################################################################################################## @@ -67,4 +69,5 @@ def matmul(self, lhs, rhs): graph.outputs = [matmul_out] # Save graph -onnx.save(gs.export_onnx(graph), "model.onnx") +model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph)) +onnx.save(model, "model.onnx") diff --git a/tools/onnx-graphsurgeon/examples/10_dynamic_batch_size/modify.py b/tools/onnx-graphsurgeon/examples/10_dynamic_batch_size/modify.py index b6bb64b5..5f73d64a 100644 --- a/tools/onnx-graphsurgeon/examples/10_dynamic_batch_size/modify.py +++ b/tools/onnx-graphsurgeon/examples/10_dynamic_batch_size/modify.py @@ -24,7 +24,7 @@ # Update input shape for input in graph.inputs: - input.shape[0] = 'N' + input.shape[0] = "N" # Update 'Reshape' nodes (if they exist) reshape_nodes = [node for node in graph.nodes if node.op == "Reshape"] diff --git a/tools/onnx-graphsurgeon/tests/test_examples.py b/tools/onnx-graphsurgeon/tests/test_examples.py index a89266a8..9270b35d 100644 --- a/tools/onnx-graphsurgeon/tests/test_examples.py +++ b/tools/onnx-graphsurgeon/tests/test_examples.py @@ -72,6 +72,8 @@ def ignore_command(cmd): def infer_model(path): model = onnx.load(path) + onnx.checker.check_model(model) + graph = gs.import_onnx(model) feed_dict = {}