-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathserialize_net.lua
48 lines (44 loc) · 1.15 KB
/
serialize_net.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
require 'torch'
require 'nn'
require 'optim'
require 'paths'
assert(pcall(function () mat = require('fb.mattorch') end) or pcall(function() mat = require('matio') end), 'no mat IO interface available')
opt = {
gpu=1,
name = 'shapenet101',
data_dir='/data/jjliu/checkpoints/',
checkpointf='checkpoints_64chair100o_vaen2',
epoch=1450,
ext = 'net_G',
genEpoch=1450,
}
for k,v in pairs(opt) do
opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k]
print(k .. ': ' .. opt[k])
end
if opt.gpu > 0 then
require 'cunn'
require 'cudnn'
require 'cutorch'
cutorch.setDevice(opt.gpu)
nn.DataParallelTable.deserializeNGPUs = 1
end
if opt.ext == 'net_P' then
fname = opt.name .. '_' .. opt.genEpoch .. '_' .. opt.epoch .. '_' .. opt.ext .. '.t7'
else
fname = opt.name .. '_' .. opt.epoch .. '_' .. opt.ext .. '.t7'
end
local fpath = paths.concat(opt.data_dir .. opt.checkpointf, fname)
local net = torch.load(fpath)
print(net)
local name = torch.type(net)
if name:find('DataParallelTable') then
net = net:get(1)
end
if opt.gpu > 0 then
net = cudnn.convert(net, nn)
net = net:float()
end
print(net)
torch.save(fpath, net)
--torch.save(opath, net)