forked from jfsantos/seq2seq
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Transpose2.lua
53 lines (48 loc) · 1.45 KB
/
Transpose2.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
local Transpose2, parent = torch.class('nn.Transpose2', 'nn.Module')
-- transpose dimensions:
-- n = nn.Transpose2({1,4},{1,3})
-- will transpose dims 1 and 4, then 1 and 3...
-- n = nn.Transpose2({1,4},{1,3},nDim)
-- use nDim for determining if input is mini-batch
function Transpose2:__init(...)
parent.__init(self)
self.permutations = {}
for _,elem in ipairs({...}) do
if type(elem) == 'table' then
table.insert(self.permutations,elem)
else
self.nDim = elem
end
end
end
function Transpose2:updateOutput(input)
local batchadj
if input:nDimension() == self.nDim or self.nDim == nil then
batchadj = 0
elseif input:nDimension() == self.nDim + 1 then
batchadj = 1
else
error('inconsistent tensor size')
end
for _,perm in ipairs(self.permutations) do
input = input:transpose(perm[1]+batchadj,perm[2]+batchadj)
end
self.output:resizeAs(input):copy(input)
return self.output
end
function Transpose2:updateGradInput(input, gradOutput)
local batchadj
if input:nDimension() == self.nDim or self.nDim == nil then
batchadj = 0
elseif input:nDimension() == self.nDim + 1 then
batchadj = 1
else
error('inconsistent tensor size')
end
for i = #self.permutations,1,-1 do
local perm = self.permutations[i]
gradOutput = gradOutput:transpose(perm[1]+batchadj,perm[2]+batchadj)
end
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
return self.gradInput
end