-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGradLM.py
66 lines (49 loc) · 1.58 KB
/
GradLM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from abc import ABC
import torch
class Function(ABC):
def __init__(self, x):
self.x = x
self.params = torch.rand(4)
def value(self):
pass
def jacobian(self):
pass
def calc(self):
pass
def update_params(self, d):
self.params = self.params + d
class GradLM:
def __init__(self, y, func, lamda_min=0.1, lambda_max=1, D=1, sigma=1e-5, tol=1e-8, max_iter=100):
self.y = y
self.func = func
# damping function
self.lamda_min = lamda_min
self.lambda_max = lambda_max
self.D = D
self.sigma = sigma
self.tol = tol
self.max_iter = max_iter
def _qLambda(self, r0, r1):
return self.lamda_min + (self.lambda_max - self.lamda_min) / (1 + self.D * torch.exp(-self.sigma * (r1 - r0)))
def _qX(self, dx, r0, r1):
return dx / (1 + torch.exp(r0 - r1))
def step(self, lmda):
J = self.func.jacobian()
r = self.y - self.func.value()
return torch.inverse(J.T@J + lmda * torch.eye(J.shape[-1]).type_as(J))@(J.T@r)
def optimize(self):
lmda = self.lambda_max
r = self.y - self.func.value()
r0 = r.T@r
r1 = r0.clone()
for _ in range(self.max_iter):
dx = self.step(lmda)
d = self._qX(dx, r0, r1)
self.func.update_params(d)
if dx.norm() < self.tol:
return self.func
r0 = r1.clone()
r = self.y - self.func.value()
r1 = r.T@r
lmda = self._qLambda(r0, r1)
return self.func