Skip to content

Commit

Permalink
Add a test showing how to implement scipy's 'soft_l1' loss in `minimi…
Browse files Browse the repository at this point in the history
…ze.least_squares`

PiperOrigin-RevId: 718649580
Change-Id: I5cec0e2850c805301407c9eb543ab3ba69d42e79
  • Loading branch information
yuvaltassa authored and copybara-github committed Jan 23, 2025
1 parent 21d9790 commit f9fd2e4
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions python/mujoco/minimize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,34 @@ def grad_hess(self, r, proj):
check_derivatives=True,
)

def test_soft_l1_norm(self) -> None:
def residual(x):
return np.stack([1 - x[0, :], 10 * (x[1, :] - x[0, :] ** 2)])

class SoftL1(minimize.Norm):
"""Implementation of the loss called 'soft_l1' in scipy least_squares."""

def value(self, r):
return np.sum(np.sqrt(r**2 + 1) - 1)

def grad_hess(self, r, proj):
s = np.sqrt(r**2 + 1)
y_r = r / s
grad = proj.T @ y_r
y_rr = (1 - y_r ** 2) / s
hess = proj.T @ (y_rr * proj)
return grad, hess

out = io.StringIO()
x0 = np.array((0.0, 0.0))
x, _ = minimize.least_squares(
x0, residual, norm=SoftL1(), output=out, check_derivatives=True
)
expected_x = np.array((1.0, 1.0))
np.testing.assert_array_almost_equal(x, expected_x)
self.assertIn('User-provided norm gradient matches', out.getvalue())
self.assertIn('User-provided norm Hessian matches', out.getvalue())


if __name__ == '__main__':
absltest.main()

0 comments on commit f9fd2e4

Please sign in to comment.