-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprototypical_batch_sampler.py
74 lines (67 loc) · 3.36 KB
/
prototypical_batch_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# coding=utf-8
import numpy as np
import torch
#Here "iteration" means equally to "episode"
class PrototypicalBatchSampler(object):
'''
PrototypicalBatchSampler: yield a batch of indexes at each iteration.
Indexes are calculated by keeping in account 'classes_per_it' and 'num_samples',
In fact at every iteration the batch indexes will refer to 'num_support' + 'num_query' samples
for 'classes_per_it' random classes.
__len__ returns the number of episodes per epoch (same as 'self.iterations').
'''
def __init__(self, labels, classes_per_it, num_samples, iterations):
'''
Initialize the PrototypicalBatchSampler object
Args:
- labels: an iterable containing all the labels for the current dataset
samples indexes will be infered from this iterable.
- classes_per_it: number of random classes for each iteration
- num_samples: number of samples for each iteration for each class (support + query)
- iterations: number of iterations (episodes) per epoch
'''
super(PrototypicalBatchSampler, self).__init__()
self.labels = labels
self.classes_per_it = classes_per_it
self.sample_per_class = num_samples
self.iterations = iterations
self.classes, self.counts = np.unique(self.labels, return_counts=True) #np.unique()
self.classes = torch.LongTensor(self.classes)
# create a matrix, indexes, of dim: classes X max(elements per class)
# fill it with nans
# for every class c, fill the relative row with the indices samples belonging to c
# in numel_per_class we store the number of samples for each class/row
self.idxs = range(len(self.labels))
self.indexes = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan
self.indexes = torch.Tensor(self.indexes)
self.numel_per_class = torch.zeros_like(self.classes)
for idx, label in enumerate(self.labels):
label_idx = np.argwhere(self.classes == label).item()
self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx
self.numel_per_class[label_idx] += 1
def __iter__(self):
'''
yield a batch of indexes
'''
spc = self.sample_per_class
cpi = self.classes_per_it
for it in range(self.iterations):
batch_size = spc * cpi
batch = torch.LongTensor(batch_size)
c_idxs = torch.randperm(len(self.classes))[:cpi] #torch.randperm(): random permutation of a iter-object
for i, c in enumerate(self.classes[c_idxs]):
s = slice(i * spc, (i + 1) * spc) #slice function create a slice object
# FIXME when torch.argwhere will exists
label_idx = torch.arange(len(self.classes)).long()[self.classes == c].item()
#print('(label_idx, c):', (label_idx, c)) ## label_idx is equal to c in value, except c is tensor
sample_idxs = torch.randperm(self.numel_per_class[label_idx])[:spc]
batch[s] = self.indexes[label_idx][sample_idxs]
batch = batch[torch.randperm(len(batch))]
print('batch:', batch)
#sys.exit()
yield batch
def __len__(self):
'''
returns the number of iterations (episodes) per epoch
'''
return self.iterations