-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
56 lines (50 loc) · 2.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
from bunch import Bunch
from loguru import logger
# from ruamel.yaml import safe_load
from torch.utils.data import DataLoader
import models
# from dataset import vessel_dataset
from dataset_edm import vessel_dataset
# from trainer import Trainer
from trainer_dgt import Trainer
from utils import losses
from utils.helpers import get_instance, seed_torch
import yaml
def main(CFG, data_path, batch_size, with_val=False):
seed_torch()
if with_val:
train_dataset = vessel_dataset(data_path, mode="training", split=0.9)
val_dataset = vessel_dataset(
data_path, mode="training", split=0.9, is_val=True)
val_loader = DataLoader(
val_dataset, batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=False)
else:
train_dataset = vessel_dataset(data_path, mode="training")
train_loader = DataLoader(
train_dataset, batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)
logger.info('The patch number of train is %d' % len(train_dataset))
model = get_instance(models, 'model', CFG)
# logger.info(f'\n{model}\n')
loss = get_instance(losses, 'loss', CFG)
trainer = Trainer(
model=model,
loss=loss,
CFG=CFG,
train_loader=train_loader,
val_loader=val_loader if with_val else None
)
trainer.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-dp', '--dataset_path', default="/home/lwt/data_pro/vessel/DRIVE", type=str,
help='the path of dataset')
parser.add_argument('-bs', '--batch_size', default=512,
help='batch_size for trianing and validation')
parser.add_argument("--val", help="split training data for validation",
required=False, default=False, action="store_true")
args = parser.parse_args()
with open('config.yaml', encoding='utf-8') as file:
CFG = Bunch(yaml.safe_load(file))
# CFG = yaml.safe_load(file) # 为列表类型
main(CFG, args.dataset_path, args.batch_size, args.val)