Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #119 fix #114 fix #112 #122

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ OR
## How to use
* Basic usage
```python
import torch
from torchvision.models import resnet50
from thop import profile
model = resnet50()
Expand All @@ -24,7 +25,7 @@ OR
# your definition
def count_your_model(model, x, y):
# your rule here

# Note that your rule only calculate the ops and params except its submodule's ops and params
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ),
custom_ops={YourModule: count_your_model})
Expand Down Expand Up @@ -92,4 +93,4 @@ inception_v3 | 27.16 | 5.75

</td>
</tr>
</p>
</p>
13 changes: 11 additions & 2 deletions thop/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,16 @@ def add_hooks(m: nn.Module):
model(*inputs)

def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
total_ops, total_params = 0, 0
"""
calculate the ops and params through dfs
For each module's ops and params,it contains two part:
1) the ops and params of its submodule
2) the ops and params except 1)
:param module: the module
:param prefix: the prefix
:return: total_ops, total_params
"""
total_ops, total_params = module.total_ops, module.total_params
for m in module.children():
# if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
# m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
Expand All @@ -204,7 +213,7 @@ def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
total_ops += m_ops
total_params += m_params
# print(prefix, module._get_name(), (total_ops.item(), total_params.item()))
return total_ops, total_params
return total_ops.item(), total_params.item()

total_ops, total_params = dfs_count(model)

Expand Down
5 changes: 3 additions & 2 deletions thop/vision/basic_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

def count_parameters(m, x, y):
total_params = 0
for p in m.parameters():
for p in m.parameters(recurse = False):
total_params += torch.DoubleTensor([p.numel()])
m.total_params[0] = total_params
if type(total_params) != int:
m.total_params += total_params


def zero_ops(m, x, y):
Expand Down