forked from jfsantos/seq2seq
-
Notifications
You must be signed in to change notification settings - Fork 1
/
AddDim.lua
44 lines (39 loc) · 919 Bytes
/
AddDim.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
require 'nn';
local AddDim, parent = torch.class('nn.AddDim', 'nn.Module')
function AddDim:__init(dim,ndim)
parent.__init(self)
assert(dim ~= nil, 'dim cannot be nil')
if ndim ~= nil then
assert(dim <= ndim, 'dim must be <= ndim')
end
self.dim = dim
self.ndim = ndim
end
function AddDim:updateOutput(input)
local dim
if input:nDimension() == self.ndim then
-- nonbatch mode
dim = self.dim
elseif input:nDimension() == self.ndim + 1 then
-- batch mode
dim = self.dim+1
else
error('inconsistent tensor size')
end
local size = {}
if dim == 0 then
table.insert(size,1)
end
for i=1,#input:size() do
table.insert(size,input:size(i))
if i == dim then
table.insert(size,1)
end
end
self.output = torch.view(input,unpack(size))
return self.output
end
function AddDim:updateGradInput(input, gradOutput)
self.gradInput = torch.view(gradOutput,input:size())
return self.gradInput
end