-
Notifications
You must be signed in to change notification settings - Fork 90
/
function.py
431 lines (371 loc) · 16.5 KB
/
function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
import argparse
import os
import shutil
import sys
import tempfile
import time
from collections import OrderedDict
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from einops import rearrange
from monai.inferers import sliding_window_inference
from monai.losses import DiceCELoss
from monai.transforms import AsDiscrete
from PIL import Image
from skimage import io
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score
from tensorboardX import SummaryWriter
#from dataset import *
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import cfg
import models.sam.utils.transforms as samtrans
import pytorch_ssim
#from models.discriminatorlayer import discriminator
from conf import settings
from utils import *
# from lucent.modelzoo.util import get_model_layers
# from lucent.optvis import render, param, transform, objectives
# from lucent.modelzoo import inceptionv1
args = cfg.parse_args()
GPUdevice = torch.device('cuda', args.gpu_device)
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
seed = torch.randint(1,11,(args.b,7))
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
scaler = torch.cuda.amp.GradScaler()
max_iterations = settings.EPOCH
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
def train_sam(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis = 50):
hard = 0
epoch_loss = 0
ind = 0
# train mode
net.train()
optimizer.zero_grad()
epoch_loss = 0
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
# torch.cuda.empty_cache()
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
# for k,v in pack['image_meta_dict'].items():
# print(k)
if 'pt' not in pack:
imgs, pt, masks = generate_click_prompt(imgs, masks)
else:
pt = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
if args.thd:
imgs, pt, masks = generate_click_prompt(imgs, masks)
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1,3,1,1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
if point_labels.clone().flatten()[0] != -1:
# point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
if(len(point_labels.shape)==1): # only one point prompt
coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :]
pt = (coords_torch, labels_torch)
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
#true_mask_ave = cons_tensor(true_mask_ave)
# imgs = imgs.to(dtype = mask_type,device = GPUdevice)
'''Train'''
if args.mod == 'sam_adpt':
for n, value in net.image_encoder.named_parameters():
if "Adapter" not in n:
value.requires_grad = False
else:
value.requires_grad = True
elif args.mod == 'sam_lora' or args.mod == 'sam_adalora':
from models.common import loralib as lora
lora.mark_only_lora_as_trainable(net.image_encoder)
if args.mod == 'sam_adalora':
# Initialize the RankAllocator
rankallocator = lora.RankAllocator(
net.image_encoder, lora_r=4, target_rank=8,
init_warmup=500, final_warmup=1500, mask_interval=10,
total_step=3000, beta1=0.85, beta2=0.85,
)
else:
for n, value in net.image_encoder.named_parameters():
value.requires_grad = True
imge= net.image_encoder(imgs)
with torch.no_grad():
if args.net == 'sam' or args.net == 'mobile_sam':
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
elif args.net == "efficient_sam":
coords_torch,labels_torch = transform_prompt(coords_torch,labels_torch,h,w)
se = net.prompt_encoder(
coords=coords_torch,
labels=labels_torch,
)
if args.net == 'sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=(args.multimask_output > 1),
)
elif args.net == 'mobile_sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
elif args.net == "efficient_sam":
se = se.view(
se.shape[0],
1,
se.shape[1],
se.shape[2],
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
multimask_output=False,
)
# Resize to the ordered output size
pred = F.interpolate(pred,size=(args.out_size,args.out_size))
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
epoch_loss += loss.item()
# nn.utils.clip_grad_value_(net.parameters(), 0.1)
if args.mod == 'sam_adalora':
(loss+lora.compute_orth_regu(net, regu_weight=0.1)).backward()
optimizer.step()
rankallocator.update_and_mask(net, ind)
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()
'''vis images'''
if vis:
if ind % vis == 0:
namecat = 'Train'
for na in name[:2]:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
pbar.update()
return loss
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
# eval mode
net.eval()
mask_type = torch.float32
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0,0,0,0), (0,)*args.multimask_output*2
rater_res = [(0,0,0,0) for _ in range(6)]
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
device = GPUdevice
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masksw = pack['label'].to(dtype = torch.float32, device = GPUdevice)
# for k,v in pack['image_meta_dict'].items():
# print(k)
if 'pt' not in pack or args.thd:
imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw)
else:
ptw = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
if args.thd:
pt = ptw[:,:,buoy: buoy + evl_ch]
else:
pt = ptw
imgs = imgsw[...,buoy:buoy + evl_ch]
masks = masksw[...,buoy:buoy + evl_ch]
buoy += evl_ch
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1,3,1,1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
if point_labels.clone().flatten()[0] != -1:
# point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
if(len(point_labels.shape)==1): # only one point prompt
coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :]
pt = (coords_torch, labels_torch)
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
#true_mask_ave = cons_tensor(true_mask_ave)
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
'''test'''
with torch.no_grad():
imge= net.image_encoder(imgs)
if args.net == 'sam' or args.net == 'mobile_sam':
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
elif args.net == "efficient_sam":
coords_torch,labels_torch = transform_prompt(coords_torch,labels_torch,h,w)
se = net.prompt_encoder(
coords=coords_torch,
labels=labels_torch,
)
if args.net == 'sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=(args.multimask_output > 1),
)
elif args.net == 'mobile_sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
elif args.net == "efficient_sam":
se = se.view(
se.shape[0],
1,
se.shape[1],
se.shape[2],
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
multimask_output=False,
)
# Resize to the ordered output size
pred = F.interpolate(pred,size=(args.out_size,args.out_size))
tot += lossfunc(pred, masks)
'''vis images'''
if ind % args.vis == 0:
namecat = 'Test'
for na in name[:2
]:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
pbar.update()
if args.evl_chunk:
n_val = n_val * (imgsw.size(-1) // evl_ch)
return tot/ n_val , tuple([a/n_val for a in mix_res])
def transform_prompt(coord,label,h,w):
coord = coord.transpose(0,1)
label = label.transpose(0,1)
coord = coord.unsqueeze(1)
label = label.unsqueeze(1)
batch_size, max_num_queries, num_pts, _ = coord.shape
num_pts = coord.shape[2]
rescaled_batched_points = get_rescaled_pts(coord, h, w)
decoder_max_num_input_points = 6
if num_pts > decoder_max_num_input_points:
rescaled_batched_points = rescaled_batched_points[
:, :, : decoder_max_num_input_points, :
]
label = label[
:, :, : decoder_max_num_input_points
]
elif num_pts < decoder_max_num_input_points:
rescaled_batched_points = F.pad(
rescaled_batched_points,
(0, 0, 0, decoder_max_num_input_points - num_pts),
value=-1.0,
)
label = F.pad(
label,
(0, decoder_max_num_input_points - num_pts),
value=-1.0,
)
rescaled_batched_points = rescaled_batched_points.reshape(
batch_size * max_num_queries, decoder_max_num_input_points, 2
)
label = label.reshape(
batch_size * max_num_queries, decoder_max_num_input_points
)
return rescaled_batched_points,label
def get_rescaled_pts(batched_points: torch.Tensor, input_h: int, input_w: int):
return torch.stack(
[
torch.where(
batched_points[..., 0] >= 0,
batched_points[..., 0] * 1024 / input_w,
-1.0,
),
torch.where(
batched_points[..., 1] >= 0,
batched_points[..., 1] * 1024 / input_h,
-1.0,
),
],
dim=-1,
)