Skip to content

Commit

Permalink
Merge pull request #799 from SciML/nloptcons
Browse files Browse the repository at this point in the history
Add constraints support for NLopt
  • Loading branch information
Vaibhavdixit02 authored Sep 19, 2024
2 parents 6bb4e7b + 3bf492c commit c692526
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 43 deletions.
3 changes: 2 additions & 1 deletion lib/OptimizationMultistartOptimization/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ Reexport = "1.2"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
OptimizationNLopt= "4e6fcdb7-1186-4e1f-a706-475e75c168bb"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ForwardDiff", "ReverseDiff", "Pkg", "Test"]
test = ["ForwardDiff", "OptimizationNLopt", "ReverseDiff", "Pkg", "Test"]
2 changes: 0 additions & 2 deletions lib/OptimizationMultistartOptimization/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Pkg;
Pkg.develop(path = joinpath(@__DIR__, "../../", "OptimizationNLopt"));
using OptimizationMultistartOptimization, Optimization, ForwardDiff, OptimizationNLopt
using Test, ReverseDiff

Expand Down
6 changes: 4 additions & 2 deletions lib/OptimizationNLopt/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@ version = "0.2.2"
[deps]
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
NLopt = "0.6, 1"
NLopt = "1.1"
Optimization = "3.21"
Reexport = "1.2"
julia = "1"

[extras]
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Zygote"]
test = ["ReverseDiff", "Test", "Zygote"]
97 changes: 70 additions & 27 deletions lib/OptimizationNLopt/src/OptimizationNLopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module OptimizationNLopt
using Reexport
@reexport using NLopt, Optimization
using Optimization.SciMLBase
using Optimization: deduce_retcode

(f::NLopt.Algorithm)() = f

Expand Down Expand Up @@ -63,6 +64,38 @@ function SciMLBase.requiresconsjac(opt::Union{NLopt.Algorithm, NLopt.Opt}) #http
end
end

function SciMLBase.allowsconstraints(opt::NLopt.Algorithm)
str_opt = string(opt)
if occursin("AUGLAG", str_opt) || occursin("CCSA", str_opt) ||
occursin("MMA", str_opt) || occursin("COBYLA", str_opt) ||
occursin("ISRES", str_opt) || occursin("AGS", str_opt) ||
occursin("ORIG_DIRECT", str_opt) || occursin("SLSQP", str_opt)
return true
else
return false
end
end

function SciMLBase.requiresconsjac(opt::NLopt.Algorithm)
str_opt = string(opt)
if occursin("AUGLAG", str_opt) || occursin("CCSA", str_opt) ||
occursin("MMA", str_opt) || occursin("COBYLA", str_opt) ||
occursin("ISRES", str_opt) || occursin("AGS", str_opt) ||
occursin("ORIG_DIRECT", str_opt) || occursin("SLSQP", str_opt)
return true
else
return false
end
end

function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::NLopt.Algorithm,
; cons_tol = 1e-6,
callback = (args...) -> (false),
progress = false, kwargs...)
return OptimizationCache(prob, opt; cons_tol, callback, progress,
kwargs...)
end

function __map_optimizer_args!(cache::OptimizationCache, opt::NLopt.Opt;
callback = nothing,
maxiters::Union{Number, Nothing} = nothing,
Expand Down Expand Up @@ -103,7 +136,9 @@ function __map_optimizer_args!(cache::OptimizationCache, opt::NLopt.Opt;

# add optimiser options from kwargs
for j in kwargs
eval(Meta.parse("NLopt." * string(j.first) * "!"))(opt, j.second)
if j.first != :cons_tol
eval(Meta.parse("NLopt." * string(j.first) * "!"))(opt, j.second)
end
end

if cache.ub !== nothing
Expand Down Expand Up @@ -132,31 +167,6 @@ function __map_optimizer_args!(cache::OptimizationCache, opt::NLopt.Opt;
return nothing
end

function __nlopt_status_to_ReturnCode(status::Symbol)
if status in Symbol.([
NLopt.SUCCESS,
NLopt.STOPVAL_REACHED,
NLopt.FTOL_REACHED,
NLopt.XTOL_REACHED,
NLopt.ROUNDOFF_LIMITED
])
return ReturnCode.Success
elseif status == Symbol(NLopt.MAXEVAL_REACHED)
return ReturnCode.MaxIters
elseif status == Symbol(NLopt.MAXTIME_REACHED)
return ReturnCode.MaxTime
elseif status in Symbol.([
NLopt.OUT_OF_MEMORY,
NLopt.INVALID_ARGS,
NLopt.FAILURE,
NLopt.FORCED_STOP
])
return ReturnCode.Failure
else
return ReturnCode.Default
end
end

function SciMLBase.__solve(cache::OptimizationCache{
F,
RC,
Expand Down Expand Up @@ -219,6 +229,39 @@ function SciMLBase.__solve(cache::OptimizationCache{
NLopt.min_objective!(opt_setup, fg!)
end

if cache.f.cons !== nothing
eqinds = map((y) -> y[1] == y[2], zip(cache.lcons, cache.ucons))
ineqinds = map((y) -> y[1] != y[2], zip(cache.lcons, cache.ucons))
if sum(ineqinds) > 0
ineqcons = function (res, θ, J)
cons_cache = zeros(eltype(res), sum(eqinds) + sum(ineqinds))
cache.f.cons(cons_cache, θ)
res .= @view(cons_cache[ineqinds])
if length(J) > 0
Jcache = zeros(eltype(J), sum(ineqinds) + sum(eqinds), length(θ))
cache.f.cons_j(Jcache, θ)
J .= @view(Jcache[ineqinds, :])'
end
end
NLopt.inequality_constraint!(
opt_setup, ineqcons, [cache.solver_args.cons_tol for i in 1:sum(ineqinds)])
end
if sum(eqinds) > 0
eqcons = function (res, θ, J)
cons_cache = zeros(eltype(res), sum(eqinds) + sum(ineqinds))
cache.f.cons(cons_cache, θ)
res .= @view(cons_cache[eqinds])
if length(J) > 0
Jcache = zeros(eltype(res), sum(eqinds) + sum(ineqinds), length(θ))
cache.f.cons_j(Jcache, θ)
J .= @view(Jcache[eqinds, :])'
end
end
NLopt.equality_constraint!(
opt_setup, eqcons, [cache.solver_args.cons_tol for i in 1:sum(eqinds)])
end
end

maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)

Expand All @@ -229,7 +272,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
t0 = time()
(minf, minx, ret) = NLopt.optimize(opt_setup, cache.u0)
t1 = time()
retcode = __nlopt_status_to_ReturnCode(ret)
retcode = deduce_retcode(ret)

if retcode == ReturnCode.Failure
@warn "NLopt failed to converge: $(ret)"
Expand Down
56 changes: 49 additions & 7 deletions lib/OptimizationNLopt/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using OptimizationNLopt, Optimization, Zygote
using Test
using OptimizationNLopt, Optimization, Zygote, ReverseDiff
using Test, Random

@testset "OptimizationNLopt.jl" begin
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
Expand All @@ -16,7 +16,7 @@ using Test
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
prob = OptimizationProblem(optprob, x0, _p)

sol = solve(prob, NLopt.Opt(:LN_BOBYQA, 2))
sol = solve(prob, NLopt.Opt(:LD_LBFGS, 2))
@test sol.retcode == ReturnCode.Success
@test 10 * sol.objective < l1

Expand All @@ -26,10 +26,6 @@ using Test
@test sol.retcode == ReturnCode.Success
@test 10 * sol.objective < l1

sol = solve(prob, NLopt.Opt(:LD_LBFGS, 2))
@test sol.retcode == ReturnCode.Success
@test 10 * sol.objective < l1

sol = solve(prob, NLopt.Opt(:G_MLSL_LDS, 2), local_method = NLopt.Opt(:LD_LBFGS, 2),
maxiters = 10000)
@test sol.retcode == ReturnCode.MaxIters
Expand Down Expand Up @@ -82,4 +78,50 @@ using Test
#nlopt gives the last best not the one where callback stops
@test sol.objective < 0.8
end

@testset "constrained" begin
cons = (res, x, p) -> res .= [x[1]^2 + x[2]^2 - 1.0]
x0 = zeros(2)
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote();
cons = cons)
prob = OptimizationProblem(optprob, x0, _p, lcons = [0.0], ucons = [0.0])
sol = solve(prob, NLopt.LN_COBYLA())
@test sol.retcode == ReturnCode.Success
@test 10 * sol.objective < l1

Random.seed!(1)
prob = OptimizationProblem(optprob, rand(2), _p,
lcons = [0.0], ucons = [0.0])

sol = solve(prob, NLopt.LD_SLSQP())
@test sol.retcode == ReturnCode.Success
@test 10 * sol.objective < l1

Random.seed!(1)
prob = OptimizationProblem(optprob, rand(2), _p,
lcons = [0.0], ucons = [0.0])
sol = solve(prob, NLopt.AUGLAG(), local_method = NLopt.LD_LBFGS())
# @test sol.retcode == ReturnCode.Success
@test 10 * sol.objective < l1

function con2_c(res, x, p)
res .= [x[1]^2 + x[2]^2 - 1.0, x[2] * sin(x[1]) - x[1] - 2.0]
end

optprob = OptimizationFunction(
rosenbrock, Optimization.AutoForwardDiff(); cons = con2_c)
Random.seed!(1)
prob = OptimizationProblem(
optprob, rand(2), _p, lcons = [0.0, -Inf], ucons = [0.0, 0.0])
sol = solve(prob, NLopt.LD_AUGLAG(), local_method = NLopt.LD_LBFGS())
# @test sol.retcode == ReturnCode.Success
@test 10 * sol.objective < l1

Random.seed!(1)
prob = OptimizationProblem(optprob, rand(2), _p, lcons = [-Inf, -Inf],
ucons = [0.0, 0.0], lb = [-1.0, -1.0], ub = [1.0, 1.0])
sol = solve(prob, NLopt.GN_ISRES(), maxiters = 1000)
@test sol.retcode == ReturnCode.MaxIters
@test 10 * sol.objective < l1
end
end
8 changes: 5 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ const STOP_REASON_MAP = Dict(
r"STOP: XTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
r"STOP: TERMINATION" => ReturnCode.Terminated,
r"Optimization completed" => ReturnCode.Success,
r"Convergence achieved" => ReturnCode.Success
r"Convergence achieved" => ReturnCode.Success,
r"ROUNDOFF_LIMITED" => ReturnCode.Success
)

# Function to deduce ReturnCode from a stop_reason string using the dictionary
Expand All @@ -99,11 +100,12 @@ function deduce_retcode(retcode::Symbol)
return ReturnCode.Default
elseif retcode == :Success || retcode == :EXACT_SOLUTION_LEFT ||
retcode == :FLOATING_POINT_LIMIT || retcode == :true || retcode == :OPTIMAL ||
retcode == :LOCALLY_SOLVED
retcode == :LOCALLY_SOLVED || retcode == :ROUNDOFF_LIMITED || retcode == :SUCCESS
return ReturnCode.Success
elseif retcode == :Terminated
return ReturnCode.Terminated
elseif retcode == :MaxIters || retcode == :MAXITERS_EXCEED
elseif retcode == :MaxIters || retcode == :MAXITERS_EXCEED ||
retcode == :MAXEVAL_REACHED
return ReturnCode.MaxIters
elseif retcode == :MaxTime || retcode == :TIME_LIMIT
return ReturnCode.MaxTime
Expand Down
1 change: 0 additions & 1 deletion test/stdout.txt

This file was deleted.

0 comments on commit c692526

Please sign in to comment.