diff --git a/thop/onnx_profile.py b/thop/onnx_profile.py index 10da68c..4109046 100644 --- a/thop/onnx_profile.py +++ b/thop/onnx_profile.py @@ -66,6 +66,8 @@ def calculate_macs(self, model: onnx.ModelProto) -> torch.DoubleTensor: input = model.graph.input output = model.graph.output name2dims = self.create_dict(weight, input, output) + for v in model.graph.value_info: + name2dims[v.name] = np.array([i.dim_value for i in v.type.tensor_type.shape.dim]) macs = 0 for n in nodes: macs_adding, out_size, outname = self.nodes_counter(name2dims, n)