-
Notifications
You must be signed in to change notification settings - Fork 1
/
convert_to_caffe.py
39 lines (33 loc) · 1.13 KB
/
convert_to_caffe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import sys
import torch
from torch.autograd import Variable
from torchvision.models.alexnet import alexnet
from pth2caffe import pytorch_to_caffe
import sys
import torch.onnx
from vision.ssd.config.fd_config import define_img_size
input_img_size = 320
define_img_size(input_img_size)
#from vision.ssd.mb_tiny_RFB_fd import create_Mb_Tiny_RFB_fd
from vision.ssd.mb_tiny_fd import create_mb_tiny_fd
name='ultra-ssd'
net_type = "slim"
def main():
if net_type == 'slim':
model_path = "models/pretrained/version-slim-320.pth"
net = create_mb_tiny_fd(2, is_test=True)
elif net_type == 'RFB':
model_path = "models/pretrained/version-RFB-320.pth"
net = create_Mb_Tiny_RFB_fd(2, is_test=True)
else:
print("unsupport network type.")
sys.exit(1)
net.load(model_path)
net.eval()
net.to("cpu")
dummy_input=torch.ones([1,3,240,320])
pytorch_to_caffe.trans_net(net,dummy_input,name)
pytorch_to_caffe.save_prototxt('pth2caffe/models/{}.prototxt'.format(name))
pytorch_to_caffe.save_caffemodel('pth2caffe/models/{}.caffemodel'.format(name))
if __name__=='__main__':
main()