-
Notifications
You must be signed in to change notification settings - Fork 25
/
data.py
122 lines (102 loc) · 3.7 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import copy
import math
from torchvision import datasets, transforms
from torchvision.transforms import ImageOps
from torch.utils.data import ConcatDataset
def _permutate_image_pixels(image, permutation):
if permutation is None:
return image
c, h, w = image.size()
image = image.view(-1, c)
image = image[permutation, :]
return image.view(c, h, w)
def _colorize_grayscale_image(image):
return ImageOps.colorize(image, (0, 0, 0), (255, 255, 255))
def get_dataset(name, train=True, permutation=None, capacity=None):
dataset = (TRAIN_DATASETS[name] if train else TEST_DATASETS[name])()
dataset.transform = transforms.Compose([
dataset.transform,
transforms.Lambda(lambda x: _permutate_image_pixels(x, permutation)),
])
if capacity is not None and len(dataset) < capacity:
return ConcatDataset([
copy.deepcopy(dataset) for _ in
range(math.ceil(capacity / len(dataset)))
])
else:
return dataset
_MNIST_TRAIN_TRANSFORMS = _MNIST_TEST_TRANSFORMS = [
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.Pad(2),
transforms.ToTensor(),
]
_MNIST_COLORIZED_TRAIN_TRANSFORMS = _MNIST_COLORIZED_TEST_TRANSFORMS = [
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.Lambda(lambda x: _colorize_grayscale_image(x)),
transforms.Pad(2),
transforms.ToTensor(),
]
_CIFAR_TRAIN_TRANSFORMS = _CIFAR_TEST_TRANSFORMS = [
transforms.ToTensor(),
]
_SVHN_TRAIN_TRANSFORMS = _SVHN_TEST_TRANSFORMS = [
transforms.ToTensor(),
]
_SVHN_TARGET_TRANSFORMS = [
transforms.Lambda(lambda y: y % 10)
]
TRAIN_DATASETS = {
'mnist': lambda: datasets.MNIST(
'./datasets/mnist', train=True, download=True,
transform=transforms.Compose(_MNIST_TRAIN_TRANSFORMS)
),
'mnist-color': lambda: datasets.MNIST(
'./datasets/mnist', train=True, download=True,
transform=transforms.Compose(_MNIST_COLORIZED_TRAIN_TRANSFORMS)
),
'cifar10': lambda: datasets.CIFAR10(
'./datasets/cifar10', train=True, download=True,
transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS)
),
'cifar100': lambda: datasets.CIFAR100(
'./datasets/cifar100', train=True, download=True,
transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS)
),
'svhn': lambda: datasets.SVHN(
'./datasets/svhn', split='train', download=True,
transform=transforms.Compose(_SVHN_TRAIN_TRANSFORMS),
target_transform=transforms.Compose(_SVHN_TARGET_TRANSFORMS),
),
}
TEST_DATASETS = {
'mnist': lambda: datasets.MNIST(
'./datasets/mnist', train=False,
transform=transforms.Compose(_MNIST_TEST_TRANSFORMS)
),
'mnist-color': lambda: datasets.MNIST(
'./datasets/mnist', train=False, download=True,
transform=transforms.Compose(_MNIST_COLORIZED_TEST_TRANSFORMS)
),
'cifar10': lambda: datasets.CIFAR10(
'./datasets/cifar10', train=False,
transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS)
),
'cifar100': lambda: datasets.CIFAR100(
'./datasets/cifar100', train=False,
transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS)
),
'svhn': lambda: datasets.SVHN(
'./datasets/svhn', split='test', download=True,
transform=transforms.Compose(_SVHN_TEST_TRANSFORMS),
target_transform=transforms.Compose(_SVHN_TARGET_TRANSFORMS),
),
}
DATASET_CONFIGS = {
'mnist': {'size': 32, 'channels': 1, 'classes': 10},
'mnist-color': {'size': 32, 'channels': 3, 'classes': 10},
'cifar10': {'size': 32, 'channels': 3, 'classes': 10},
'cifar100': {'size': 32, 'channels': 3, 'classes': 100},
'svhn': {'size': 32, 'channels': 3, 'classes': 10},
}