-
Notifications
You must be signed in to change notification settings - Fork 1
/
conditioned_modules.py
70 lines (61 loc) · 1.99 KB
/
conditioned_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
class MaxPoolStride1(nn.Module):
def __init__(self):
super(MaxPoolStride1, self).__init__()
def forward(self, x):
x = F.max_pool2d(F.pad(x, (0,1,0,1), mode='replicate'), 2, stride=1)
return x
class Upsample(nn.Module):
def __init__(self, stride=2):
super(Upsample, self).__init__()
self.stride = stride
def forward(self, x):
stride = self.stride
assert(x.data.dim() == 4)
B = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
ws = stride
hs = stride
x = x.view(B, C, H, 1, W, 1).expand(B, C, H, hs, W, ws).contiguous().view(B, C, H*hs, W*ws)
return x
class Reorg(nn.Module):
def __init__(self, stride=2):
super(Reorg, self).__init__()
self.stride = stride
def forward(self, x):
stride = self.stride
assert(x.data.dim() == 4)
B = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
assert(H % stride == 0)
assert(W % stride == 0)
ws = stride
hs = stride
x = x.view(B, C, H//hs, hs, W//ws, ws).transpose(3,4).contiguous()
x = x.view(B, C, (H//hs)*(W//ws), hs*ws).transpose(2,3).contiguous()
x = x.view(B, C, hs*ws, H//hs, W//ws).transpose(1,2).contiguous()
x = x.view(B, hs*ws*C, H//hs, W//ws)
return x
class GlobalAvgPool2d(nn.Module):
def __init__(self):
super(GlobalAvgPool2d, self).__init__()
def forward(self, x):
N = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
x = F.avg_pool2d(x, (H, W))
x = x.view(N, C)
return x
class EmptyModule(nn.Module):
def __init__(self):
super(EmptyModule, self).__init__()
def forward(self, x):
return x