forked from noahcao/Pixel2Mesh
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modelsize_estimate.py
38 lines (30 loc) · 1.35 KB
/
modelsize_estimate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import numpy as np
def modelsize(model, input, type_size=4):
para = sum([np.prod(list(p.size())) for p in model.parameters()])
# print('Model {} : Number of params: {}'.format(model._get_name(), para))
print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))
input_ = input.clone()
input_.requires_grad_(requires_grad=False)
mods = list(model.modules())
out_sizes = []
for i in range(1, len(mods)):
m = mods[i]
if isinstance(m, nn.ReLU):
if m.inplace:
continue
out = m(input_)
out_sizes.append(np.array(out.size()))
input_ = out
total_nums = 0
for i in range(len(out_sizes)):
s = out_sizes[i]
nums = np.prod(np.array(s))
total_nums += nums
# print('Model {} : Number of intermedite variables without backward: {}'.format(model._get_name(), total_nums))
# print('Model {} : Number of intermedite variables with backward: {}'.format(model._get_name(), total_nums*2))
print('Model {} : intermedite variables: {:3f} M (without backward)'
.format(model._get_name(), total_nums * type_size / 1000 / 1000))
print('Model {} : intermedite variables: {:3f} M (with backward)'
.format(model._get_name(), total_nums * type_size*2 / 1000 / 1000))