Skip to content

Commit

Permalink
Merge pull request #114 from Simple-Robotics/topic/linesearch
Browse files Browse the repository at this point in the history
Minor changes to linesearch
  • Loading branch information
ManifoldFR authored Nov 1, 2024
2 parents f264576 + 84a3c81 commit 9827759
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Remove redundant `#include <cassert>` in `math.hpp`
- linesearch-armijo.hpp : some changes
- linesearch-base.hpp : add reset()
- linesearch : remove getter and setter for options, make struct public
- python : linesearch : expose linesearch classes to Python

## [0.9.0] - 2024-10-14

Expand Down
9 changes: 8 additions & 1 deletion bindings/python/expose-solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ void exposeSolver() {
.value("LDLT_PROXSUITE", LDLTChoice::PROXSUITE)
.export_values();

using LinesearchOptions = Linesearch<Scalar>::Options;
using Linesearch = Linesearch<Scalar>;
using LinesearchOptions = Linesearch::Options;
bp::class_<Linesearch>("Linesearch", bp::no_init)
.def(bp::init<const LinesearchOptions &>(("self"_a, "options")))
.def_readwrite("options", &Linesearch::options_);
bp::class_<ArmijoLinesearch<Scalar>, bp::bases<Linesearch>>(
"ArmijoLinesearch", bp::no_init)
.def(bp::init<const LinesearchOptions &>(("self"_a, "options")));
bp::class_<LinesearchOptions>("LinesearchOptions", "Linesearch options.",
bp::init<>(("self"_a), "Default constructor."))
.def_readwrite("armijo_c1", &LinesearchOptions::armijo_c1)
Expand Down
47 changes: 23 additions & 24 deletions include/proxsuite-nlp/linesearch-armijo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ template <typename Scalar>
class ArmijoLinesearch final : public Linesearch<Scalar> {
public:
using Base = Linesearch<Scalar>;
using Base::options_;
using FunctionSample = typename Base::FunctionSample;
using Polynomial = PolynomialTpl<Scalar>;
using VectorXs = typename math_types<Scalar>::VectorXs;
Expand All @@ -53,8 +54,10 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {

ArmijoLinesearch(const typename Base::Options &options) : Base(options) {}

template <typename Fn>
Scalar run(Fn phi, const Scalar phi0, const Scalar dphi0, Scalar &alpha_try) {
using fun_t = std::function<Scalar(Scalar)>;

Scalar run(fun_t phi, const Scalar phi0, const Scalar dphi0,
Scalar &alpha_try) {
const FunctionSample lower_bound(0., phi0, dphi0);

alpha_try = 1.;
Expand All @@ -69,26 +72,26 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {
break;
} catch (const std::runtime_error &e) {
alpha_try *= 0.5;
if (alpha_try <= options().alpha_min) {
alpha_try = options().alpha_min;
if (alpha_try <= options_.alpha_min) {
alpha_try = options_.alpha_min;
break;
}
}
}

if (std::abs(dphi0) < options().dphi_thresh) {
if (std::abs(dphi0) < options_.dphi_thresh) {
return latest.phi;
}

for (std::size_t i = 0; i < options().max_num_steps; i++) {
for (std::size_t i = 0; i < options_.max_num_steps; i++) {

const Scalar dM = latest.phi - phi0;
if (dM <= options().armijo_c1 * alpha_try * dphi0) {
if (dM <= options_.armijo_c1 * alpha_try * dphi0) {
break;
}

// compute next alpha try
LSInterpolation strat = options().interp_type;
LSInterpolation strat = options_.interp_type;
if (strat == LSInterpolation::BISECTION) {
alpha_try *= 0.5;
} else {
Expand Down Expand Up @@ -117,15 +120,15 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {
}

alpha_try = this->minimize_interpolant(
strat, options().contraction_min * alpha_try,
options().contraction_max * alpha_try);
strat, options_.contraction_min * alpha_try,
options_.contraction_max * alpha_try);
}

if (std::isnan(alpha_try)) {
// handle NaN case
alpha_try = options().contraction_min * previous.alpha;
alpha_try = options_.contraction_min * previous.alpha;
} else {
alpha_try = std::max(alpha_try, options().alpha_min);
alpha_try = std::max(alpha_try, options_.alpha_min);
}

try {
Expand All @@ -135,11 +138,11 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {
continue;
}

if (alpha_try <= options().alpha_min) {
if (alpha_try <= options_.alpha_min) {
break;
}
}
alpha_try = std::max(alpha_try, options().alpha_min);
alpha_try = std::max(alpha_try, options_.alpha_min);
return latest.phi;
}

Expand Down Expand Up @@ -176,15 +179,18 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {
const FunctionSample &cand1 = samples[2];
const Scalar &a0 = cand0.alpha;
const Scalar &a1 = cand1.alpha;
Matrix2s alph_mat;
Vector2s coeffs_cubic_interpolant;
/// Solver for the 2x2 linear system
alph_mat(0, 0) = a0 * a0 * a0;
alph_mat(0, 1) = a0 * a0;
alph_mat(1, 0) = a1 * a1 * a1;
alph_mat(1, 1) = a1 * a1;

alph_rhs(0) = cand1.phi - phi0 - dphi0 * a1;
alph_rhs(1) = cand0.phi - phi0 - dphi0 * a0;
Vector2s alph_rhs{cand1.phi - phi0 - dphi0 * a1,
cand0.phi - phi0 - dphi0 * a0};

decomp.compute(alph_mat);
Eigen::HouseholderQR<Matrix2s> decomp(alph_mat);
coeffs_cubic_interpolant = decomp.solve(alph_rhs);

const Scalar c3 = coeffs_cubic_interpolant(0);
Expand Down Expand Up @@ -216,15 +222,8 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {
}

protected:
using Base::options;
Polynomial interpolant;
std::vector<FunctionSample> samples; // interpolation samples

Matrix2s alph_mat;
Vector2s alph_rhs;
Vector2s coeffs_cubic_interpolant;
/// Solver for the 2x2 linear system
Eigen::HouseholderQR<Matrix2s> decomp;
};

#ifdef PROXSUITE_NLP_ENABLE_TEMPLATE_INSTANTIATION
Expand Down
4 changes: 2 additions & 2 deletions include/proxsuite-nlp/linesearch-base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ template <typename T> class Linesearch {
FunctionSample(T a, T v, T g) : alpha(a), phi(v), dphi(g), valid(true) {}
};

const Linesearch::Options &options() const { return options_; }
void setOptions(const Linesearch::Options &options) { options_ = options; }

private:
void reset() {}

Linesearch::Options options_;
};

Expand Down

0 comments on commit 9827759

Please sign in to comment.