Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce ONNX Checker and Shape Inference Whenever Needed for Examples #2619

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
# Therefore, you should only need to sort the graph when you have added new nodes out-of-order.
# In this case, the identity node is already in the correct spot (it is the last node,
# and was appended to the end of the list), but to be on the safer side, we can sort anyway.
graph.cleanup().toposort()
graph.cleanup(remove_unused_graph_inputs=True).toposort()

onnx.save(gs.export_onnx(graph), "modified.onnx")
model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
onnx.save(model, "modified.onnx")
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@
]

graph = gs.Graph(nodes=nodes, inputs=[x], outputs=[y])
onnx.save(gs.export_onnx(graph), "model.onnx")

model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
onnx.save(model, "model.onnx")
4 changes: 3 additions & 1 deletion tools/onnx-graphsurgeon/examples/06_removing_nodes/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@

# Remove the fake node from the graph completely
graph.cleanup()
onnx.save(gs.export_onnx(graph), "removed.onnx")

model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
onnx.save(model, "removed.onnx")
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,5 @@ def relu(self, a):
for out in graph.outputs:
out.dtype = np.float32

onnx.save(gs.export_onnx(graph), "model.onnx")
model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
onnx.save(model, "model.onnx")
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ def concat(self, inputs, axis=0):
# Finally, set up the outputs and export.
flattened.name = "flattened" # Rename output tensor to make it easy to find.
flattened.dtype = np.float32 # NOTE: We must include dtype information for graph outputs
flattened.shape = (gs.Tensor.DYNAMIC,)
partially_flattened.name = "partially_flattened"
partially_flattened.dtype = np.float32
partially_flattened.shape = (gs.Tensor.DYNAMIC, 3, gs.Tensor.DYNAMIC)

graph.outputs = [flattened, partially_flattened]
onnx.save(gs.export_onnx(graph), "model.onnx")

model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
onnx.save(model, "model.onnx")
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
##########################################################################################################
# Register functions to simplify the graph building process later on.


@gs.Graph.register()
def conv(self, inp, weights, dilations, group, strides):
out = self.layer(
Expand All @@ -49,6 +50,7 @@ def matmul(self, lhs, rhs):
out.dtype = lhs.dtype
return out


##########################################################################################################


Expand All @@ -67,4 +69,5 @@ def matmul(self, lhs, rhs):
graph.outputs = [matmul_out]

# Save graph
onnx.save(gs.export_onnx(graph), "model.onnx")
model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
onnx.save(model, "model.onnx")
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# Update input shape
for input in graph.inputs:
input.shape[0] = 'N'
input.shape[0] = "N"

# Update 'Reshape' nodes (if they exist)
reshape_nodes = [node for node in graph.nodes if node.op == "Reshape"]
Expand Down
2 changes: 2 additions & 0 deletions tools/onnx-graphsurgeon/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def ignore_command(cmd):

def infer_model(path):
model = onnx.load(path)
onnx.checker.check_model(model)

graph = gs.import_onnx(model)

feed_dict = {}
Expand Down