forked from KaiyangZhou/Dassl.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dann.py
78 lines (62 loc) · 2.6 KB
/
dann.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling import build_head
from dassl.modeling.ops import ReverseGrad
@TRAINER_REGISTRY.register()
class DANN(TrainerXU):
"""Domain-Adversarial Neural Networks.
https://arxiv.org/abs/1505.07818.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.build_critic()
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
self.revgrad = ReverseGrad()
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
global_step = self.batch_idx + self.epoch * self.num_batches
progress = global_step / (self.max_epoch * self.num_batches)
lmda = 2 / (1 + np.exp(-10 * progress)) - 1
logit_x, feat_x = self.model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
loss_x = self.ce(logit_x, label_x)
feat_x = self.revgrad(feat_x, grad_scaling=lmda)
feat_u = self.revgrad(feat_u, grad_scaling=lmda)
output_xd = self.critic(feat_x)
output_ud = self.critic(feat_u)
loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)
loss = loss_x + loss_d
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_d": loss_d.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary