forked from researchmm/STTN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
92 lines (77 loc) · 3.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import json
import argparse
import datetime
import numpy as np
from shutil import copyfile
import torch
import torch.multiprocessing as mp
from core.trainer import Trainer
from core.dist import (
get_world_size,
get_local_rank,
get_global_rank,
get_master_ip,
)
parser = argparse.ArgumentParser(description='STTN')
parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str)
parser.add_argument('-m', '--model', default='sttn', type=str)
parser.add_argument('-p', '--port', default='23455', type=str)
parser.add_argument('-e', '--exam', action='store_true')
args = parser.parse_args()
def main_worker(rank, config):
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
print("rank = ", rank)
if config['distributed']:
print("DISTRIBUTED!!!!!")
torch.cuda.set_device(int(config['local_rank']))
torch.distributed.init_process_group(backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch'
)
print('using GPU {}-{} for training'.format(
int(config['global_rank']), int(config['local_rank'])))
config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'],
os.path.basename(args.config).split('.')[0]))
# if torch.cuda.is_available():
# config['device'] = torch.device("cuda:{}".format(config['local_rank']))
# else:
# config['device'] = 'cpu'
config["device"] = torch.device("cuda:{}".format(config["device_int"]))
print("Device: ", config["device"])
if (not config['distributed']) or config['global_rank'] == 0:
os.makedirs(config['save_dir'], exist_ok=True)
config_path = os.path.join(
config['save_dir'], config['config'].split('/')[-1])
if not os.path.isfile(config_path):
copyfile(config['config'], config_path)
print('[**] create folder {}'.format(config['save_dir']))
trainer = Trainer(config, debug=args.exam)
trainer.train()
if __name__ == "__main__":
# loading configs
config = json.load(open(args.config))
config['model'] = args.model
config['config'] = args.config
# setting distributed configurations
# config['world_size'] = get_world_size()
# config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
# config['distributed'] = True if config['world_size'] > 1 else False
config["world_size"] = 1
config["distributed"] = False
import json
print(json.dumps(config, indent=4))
# setup distributed parallel training environments
if get_master_ip() == "127.0.0.1":
# if False:
# manually launch distributed processes
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
else:
# multiple processes have been launched by openmpi
# config['local_rank'] = get_local_rank()
# config['global_rank'] = get_global_rank()
# main_worker(-1, config)
main_worker(1, config)