-
Notifications
You must be signed in to change notification settings - Fork 0
/
SampleWeighter.lua
57 lines (44 loc) · 1.47 KB
/
SampleWeighter.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
require 'nn'
-----------
-- Weights backpropagation of different samples in batch by specified factors.
local SampleWeighter, SampleWeighter_parent = torch.class('myrock.SampleWeighter', 'nn.Module')
function SampleWeighter:__init()
SampleWeighter_parent.__init(self)
self.factors = torch.Tensor() --set externally
end
function SampleWeighter:updateOutput(input)
self.output = input
return self.output
end
function SampleWeighter:updateGradInput(input, gradOutput)
if self.train then
local factors = gradOutput:dim()==2 and self.factors:view(-1,1) or self.factors
local L1orig = gradOutput:norm(1) + 1e-6
gradOutput:cmul(factors:expandAs(gradOutput))
local L1new = gradOutput:norm(1) + 1e-6
gradOutput:mul(L1orig/L1new) --keep the L1-norm of gradient ("its energy") for consistency
self.gradInput = gradOutput
end
self.gradInput = gradOutput
return self.gradInput
end
function SampleWeighter:setFactors(factors)
self.factors:resize(factors:size()):copy(factors)
end
-----TEST
--[[
local mytest = {}
local OFFmytest = {}
local tester = torch.Tester()
function mytest.test1()
local input = torch.Tensor{1,2,3,0,4,0}
local output = torch.Tensor{10,20,30,0,40,0}
local module = myrock.SampleWeighter()
module:training()
module:setFactors(torch.Tensor{0,0.1,0.1,0,1,0})
module:forward(input)
module:backward(input, output)
print(module.gradInput)
end
tester:add(mytest)
tester:run() --]]