-
Notifications
You must be signed in to change notification settings - Fork 63
/
train_2d.py
124 lines (89 loc) · 4.06 KB
/
train_2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# train.py
#!/usr/bin/env python3
""" train network using pytorch
Jiayuan Zhu
"""
import os
import time
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
#from dataset import *
from torch.utils.data import DataLoader
import cfg
import func_2d.function as function
from conf import settings
#from models.discriminatorlayer import discriminator
from func_2d.dataset import *
from func_2d.utils import *
def main():
# use bfloat16 for the entire work
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
args = cfg.parse_args()
GPUdevice = torch.device('cuda', args.gpu_device)
net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
# optimisation
optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
'''load pretrained model'''
args.path_helper = set_log_dir('logs', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
logger.info(args)
'''segmentation data'''
transform_train = transforms.Compose([
transforms.Resize((args.image_size,args.image_size)),
transforms.ToTensor(),
])
transform_test = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
])
# example of REFUGE dataset
if args.dataset == 'REFUGE':
'''REFUGE data'''
refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_train, mode = 'Training')
refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_test, mode = 'Test')
nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=2, pin_memory=True)
nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=2, pin_memory=True)
'''end'''
'''checkpoint path and tensorboard'''
checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)
#use tensorboard
if not os.path.exists(settings.LOG_DIR):
os.mkdir(settings.LOG_DIR)
writer = SummaryWriter(log_dir=os.path.join(
settings.LOG_DIR, args.net, settings.TIME_NOW))
#create checkpoint folder to save model
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')
'''begain training'''
best_tol = 1e4
best_dice = 0.0
for epoch in range(settings.EPOCH):
if epoch == 0:
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
# training
net.train()
time_start = time.time()
loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer)
logger.info(f'Train loss: {loss} || @ epoch {epoch}.')
time_end = time.time()
print('time_for_training ', time_end - time_start)
# validation
net.eval()
if epoch % args.val_freq == 0 or epoch == settings.EPOCH-1:
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
if edice > best_dice:
best_dice = edice
torch.save({'model': net.state_dict(), 'parameter': net._parameters}, os.path.join(args.path_helper['ckpt_path'], 'latest_epoch.pth'))
writer.close()
if __name__ == '__main__':
main()