forked from val-iisc/EffSSL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_swav.py
366 lines (318 loc) · 14.1 KB
/
main_swav.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import argparse
import math
import os
import shutil
import time
from logging import getLogger
import numpy as np
import torch
from torch.cuda.amp import GradScaler
from torch.cuda.amp.autocast_mode import autocast
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
from lars import LARSWrapper
from src.utils import (
bool_flag,
initialize_exp,
restart_from_checkpoint,
fix_random_seeds,
AverageMeter,
init_distributed_mode,
)
from src.multicropdataset import MultiCropDataset
import src.resnet50 as resnet_models
logger = getLogger()
parser = argparse.ArgumentParser(description="Implementation of SwAV")
#########################
#### data parameters ####
#########################
parser.add_argument("--data_path", type=str, default="/path/to/imagenet",
help="path to dataset repository")
parser.add_argument("--nmb_crops", type=int, default=[2], nargs="+",
help="list of number of crops (example: [2, 6])")
parser.add_argument("--size_crops", type=int, default=[224], nargs="+",
help="crops resolutions (example: [224, 96])")
parser.add_argument("--min_scale_crops", type=float, default=[0.14], nargs="+",
help="argument in RandomResizedCrop (example: [0.14, 0.05])")
parser.add_argument("--max_scale_crops", type=float, default=[1], nargs="+",
help="argument in RandomResizedCrop (example: [1., 0.14])")
#########################
## swav specific params #
#########################
parser.add_argument("--crops_for_assign", type=int, nargs="+", default=[0, 1],
help="list of crops id used for computing assignments")
parser.add_argument("--temperature", default=0.1, type=float,
help="temperature parameter in training loss")
parser.add_argument("--epsilon", default=0.05, type=float,
help="regularization parameter for Sinkhorn-Knopp algorithm")
parser.add_argument("--sinkhorn_iterations", default=3, type=int,
help="number of iterations in Sinkhorn-Knopp algorithm")
parser.add_argument("--feat_dim", default=128, type=int,
help="feature dimension")
parser.add_argument("--nmb_prototypes", default=3000, type=int,
help="number of prototypes")
parser.add_argument("--queue_length", type=int, default=0,
help="length of the queue (0 for no queue)")
parser.add_argument("--epoch_queue_starts", type=int, default=15,
help="from this epoch, we start using a queue")
#########################
#### optim parameters ###
#########################
parser.add_argument("--epochs", default=100, type=int,
help="number of total epochs to run")
parser.add_argument("--batch_size", default=64, type=int,
help="batch size per gpu, i.e. how many unique instances per gpu")
parser.add_argument("--base_lr", default=4.8, type=float, help="base learning rate")
parser.add_argument("--final_lr", type=float, default=0, help="final learning rate")
parser.add_argument("--freeze_prototypes_niters", default=313, type=int,
help="freeze the prototypes during this many iterations from the start")
parser.add_argument("--wd", default=1e-6, type=float, help="weight decay")
parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs")
parser.add_argument("--start_warmup", default=0, type=float,
help="initial warmup learning rate")
#########################
#### dist parameters ###
#########################
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed
training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--world_size", default=-1, type=int, help="""
number of processes: it is set automatically and
should not be passed as argument""")
parser.add_argument("--rank", default=0, type=int, help="""rank of this process:
it is set automatically and should not be passed as argument""")
parser.add_argument("--local_rank", default=0, type=int,
help="this argument is not used and should be ignored")
#########################
#### other parameters ###
#########################
parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture")
parser.add_argument("--hidden_mlp", default=2048, type=int,
help="hidden layer dimension in projection head")
parser.add_argument("--workers", default=10, type=int,
help="number of data loading workers")
parser.add_argument("--checkpoint_freq", type=int, default=25,
help="Save the model periodically")
parser.add_argument("--use_fp16", type=bool_flag, default=True,
help="whether to train with mixed precision or not")
parser.add_argument("--sync_bn", type=str, default="pytorch", help="synchronize bn")
# parser.add_argument("--syncbn_process_group_size", type=int, default=8, help=""" see
# https://github.com/NVIDIA/apex/blob/master/apex/parallel/__init__.py#L58-L67""")
parser.add_argument("--dump_path", type=str, default=".",
help="experiment dump path for checkpoints and log")
parser.add_argument("--seed", type=int, default=31, help="seed")
def main():
global args
args = parser.parse_args()
init_distributed_mode(args)
fix_random_seeds(args.seed)
logger, training_stats = initialize_exp(args, "epoch", "loss")
# build data
train_dataset = MultiCropDataset(
args.data_path,
args.size_crops,
args.nmb_crops,
args.min_scale_crops,
args.max_scale_crops,
)
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
drop_last=True
)
logger.info("Building data done with {} images loaded.".format(len(train_dataset)))
# build model
model = resnet_models.__dict__[args.arch](
normalize=True,
hidden_mlp=args.hidden_mlp,
output_dim=args.feat_dim,
nmb_prototypes=args.nmb_prototypes,
)
# synchronize batch norm layers
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
# copy model to GPU
model = model.cuda()
if args.rank == 0:
logger.info(model)
logger.info("Building model done.")
# build optimizer
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.base_lr,
momentum=0.9,
weight_decay=args.wd,
)
optimizer = LARSWrapper(optimizer, eta=0.001, clip=False, exclude_bias_n_norm=True)
# optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs)
iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
if args.rank == 0:
logger.info("Building optimizer done.")
scaler = GradScaler(enabled=args.use_fp16)
# wrap model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[args.gpu_to_work_on]
)
# optionally resume from a checkpoint
to_restore = {"epoch": 0}
restart_from_checkpoint(
os.path.join(args.dump_path, "checkpoint.pth.tar"),
run_variables=to_restore,
state_dict=model,
optimizer=optimizer
)
start_epoch = to_restore["epoch"]
# build the queue
queue = None
queue_path = os.path.join(args.dump_path, "queue" + str(args.rank) + ".pth")
if os.path.isfile(queue_path):
queue = torch.load(queue_path)["queue"]
# the queue needs to be divisible by the batch size
args.queue_length -= args.queue_length % (args.batch_size * args.world_size)
cudnn.benchmark = True
for epoch in range(start_epoch, args.epochs):
# train the network for one epoch
if args.rank == 0:
logger.info("============ Starting epoch %i ... ============" % epoch)
# set sampler
train_loader.sampler.set_epoch(epoch)
# optionally starts a queue
if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None:
queue = torch.zeros(
len(args.crops_for_assign),
args.queue_length // args.world_size,
args.feat_dim,
).cuda()
# train the network
scores, queue = train(train_loader, model, optimizer, scaler, epoch, lr_schedule, queue)
training_stats.update(scores)
# save checkpoints
if args.rank == 0:
save_dict = {
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(
save_dict,
os.path.join(args.dump_path, "checkpoint.pth.tar"),
)
if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
shutil.copyfile(
os.path.join(args.dump_path, "checkpoint.pth.tar"),
os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"),
)
if queue is not None:
torch.save({"queue": queue}, queue_path)
def train(train_loader, model, optimizer, scaler, epoch, lr_schedule, queue):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
model.train()
use_the_queue = False
end = time.time()
for it, inputs in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
# update learning rate
iteration = epoch * len(train_loader) + it
for param_group in optimizer.param_groups:
param_group["lr"] = lr_schedule[iteration]
# normalize the prototypes
with torch.no_grad():
w = model.module.prototypes.weight.data.clone()
w = nn.functional.normalize(w, dim=1, p=2)
model.module.prototypes.weight.copy_(w)
# ============ multi-res forward passes ... ============
with autocast(enabled=args.use_fp16):
embedding, output = model(inputs)
embedding = embedding.detach()
bs = inputs[0].size(0)
# ============ swav loss ... ============
loss = 0
for i, crop_id in enumerate(args.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id: bs * (crop_id + 1)].detach()
# time to use the queue
if queue is not None:
if use_the_queue or not torch.all(queue[i, -1, :] == 0):
use_the_queue = True
out = torch.cat((torch.mm(
queue[i],
model.module.prototypes.weight.t()
), out))
# fill the queue
queue[i, bs:] = queue[i, :-bs].clone()
queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs]
# get assignments
q = distributed_sinkhorn(out)[-bs:]
# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(args.nmb_crops)), crop_id):
x = output[bs * v: bs * (v + 1)] / args.temperature
subloss -= torch.mean(torch.sum(q * F.log_softmax(x, dim=1), dim=1))
loss += subloss / (np.sum(args.nmb_crops) - 1)
loss /= len(args.crops_for_assign)
# ============ backward and optim step ... ============
optimizer.zero_grad()
scaler.scale(loss).backward()
# cancel gradients for the prototypes
if iteration < args.freeze_prototypes_niters:
for name, p in model.named_parameters():
if "prototypes" in name:
p.grad = None
scaler.step(optimizer)
scaler.update()
# ============ misc ... ============
losses.update(loss.item(), inputs[0].size(0))
batch_time.update(time.time() - end)
end = time.time()
if args.rank ==0 and it % 50 == 0:
logger.info(
"Epoch: [{0}][{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
"Lr: {lr:.4f}".format(
epoch,
it,
batch_time=batch_time,
data_time=data_time,
loss=losses,
lr=optimizer.optim.param_groups[0]["lr"],
)
)
return (epoch, losses.avg), queue
@torch.no_grad()
def distributed_sinkhorn(out):
Q = torch.exp(out / args.epsilon).t() # Q is K-by-B for consistency with notations from our paper
B = Q.shape[1] * args.world_size # number of samples to assign
K = Q.shape[0] # how many prototypes
# make the matrix sums to 1
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q
for it in range(args.sinkhorn_iterations):
# normalize each row: total weight per prototype must be 1/K
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
dist.all_reduce(sum_of_rows)
Q /= sum_of_rows
Q /= K
# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
Q *= B # the colomns must sum to 1 so that Q is an assignment
return Q.t()
if __name__ == "__main__":
main()