forked from KaiyangZhou/Dassl.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcd.py
105 lines (85 loc) · 3.51 KB
/
mcd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class MCD(TrainerXU):
"""Maximum Classifier Discrepancy.
https://arxiv.org/abs/1712.02560.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.n_step_F = cfg.TRAINER.MCD.N_STEP_F
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building C1")
self.C1 = nn.Linear(fdim, self.num_classes)
self.C1.to(self.device)
print("# params: {:,}".format(count_num_param(self.C1)))
self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
self.register_model("C1", self.C1, self.optim_C1, self.sched_C1)
print("Building C2")
self.C2 = nn.Linear(fdim, self.num_classes)
self.C2.to(self.device)
print("# params: {:,}".format(count_num_param(self.C2)))
self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
self.register_model("C2", self.C2, self.optim_C2, self.sched_C2)
def forward_backward(self, batch_x, batch_u):
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, input_u = parsed
# Step A
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_step_A = loss_x1 + loss_x2
self.model_backward_and_update(loss_step_A)
# Step B
with torch.no_grad():
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_x = loss_x1 + loss_x2
with torch.no_grad():
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_dis = self.discrepancy(pred_u1, pred_u2)
loss_step_B = loss_x - loss_dis
self.model_backward_and_update(loss_step_B, ["C1", "C2"])
# Step C
for _ in range(self.n_step_F):
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_step_C = self.discrepancy(pred_u1, pred_u2)
self.model_backward_and_update(loss_step_C, "F")
loss_summary = {
"loss_step_A": loss_step_A.item(),
"loss_step_B": loss_step_B.item(),
"loss_step_C": loss_step_C.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def discrepancy(self, y1, y2):
return (y1 - y2).abs().mean()
def model_inference(self, input):
feat = self.F(input)
return self.C1(feat)