-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprototypical_loss.py
76 lines (63 loc) · 3 KB
/
prototypical_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# coding=utf-8
import torch
from torch.nn import functional as F
from torch.nn.modules import Module
import sys
def euclidean_dist(x, y):
'''
Compute euclidean distance between two tensors
'''
# x: N x D
# y: M x D
n = x.size(0)
m = y.size(0)
d = x.size(1)
if d != y.size(1):
raise Exception
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
return torch.pow(x - y, 2).sum(2)
def prototypical_loss(input, target, n_support):
'''
Inspired by https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
Compute the barycentres by averaging the features of n_support
samples for each class in target, computes then the distances from each
samples' features to each one of the barycentres, computes the
log_probability for each n_query samples for each one of the current
classes, of appartaining to a class c, loss and accuracy are then computed
and returned
Args:
- input: the model output for a batch of samples
- target: ground truth for the above batch of samples
- n_support: number of samples to keep in account when computing
barycentres, for each one of the current classes
'''
target_cpu = target.to('cpu')
input_cpu = input.to('cpu')
def supp_idxs(c):
# FIXME when torch will support where as np
return target_cpu.eq(c).nonzero()[:n_support].squeeze(1)
# FIXME when torch.unique will be available on cuda too
#print('input.shape:', input.shape) ##(600,64), 600=(opt.class_per_it_tr * (opt.num_support_tr+num_query_tr))
#print('target_cpu:', len(target_cpu)) #(600)
classes = torch.unique(target_cpu)
n_classes = len(classes)
# FIXME when torch will support where as np
# assuming n_query, n_target constants
n_query = target_cpu.eq(classes[0].item()).sum().item() - n_support
support_idxs = list(map(supp_idxs, classes))
prototypes = torch.stack([input_cpu[idx_list].mean(0) for idx_list in support_idxs]) ##(60, 64)
# FIXME when torch will support where as np
query_idxs = torch.stack(list(map(lambda c: target_cpu.eq(c).nonzero()[n_support:], classes))).view(-1)
query_samples = input.to('cpu')[query_idxs] ##(300, 64), 300=60*5
dists = euclidean_dist(query_samples, prototypes) ##(300, 60)
log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1) ##(60, 5, 60), adjacent rows in dists from query_samples belonging to same class
target_inds = torch.arange(0, n_classes)
target_inds = target_inds.view(n_classes, 1, 1)
target_inds = target_inds.expand(n_classes, n_query, 1).long()
loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean() ##gather: choose each loss in log_p_y of each true class dimension indicated by target_inds
_, y_hat = log_p_y.max(2) ##y_hat.shape: (60, 5)
#print('y_hat:', y_hat)
acc_val = y_hat.eq(target_inds.squeeze()).float().mean()
print('loss&acc:', (loss_val, acc_val))
return loss_val, acc_val