Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/OptimizationManopt/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
LinearAlgebra = "1.10"
ManifoldDiff = "0.3.10"
Manifolds = "0.9.18"
ManifoldsBase = "0.15.10"
Manopt = "0.4.63"
ManifoldDiff = "0.4"
Manifolds = "0.10"
ManifoldsBase = "1"
Manopt = "0.5"
Optimization = "4.4"
Reexport = "1.2"
julia = "1.10"
Expand Down
163 changes: 34 additions & 129 deletions lib/OptimizationManopt/src/OptimizationManopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,14 @@ function call_manopt_optimizer(
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M),
kwargs...)
opts = gradient_descent(M,
opts = Manopt.gradient_descent(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
stepsize,
stopping_criterion,
kwargs...)
# we unwrap DebugOptions here
return_state = true, # return the (full, decorated) solver state
kwargs...
)
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
end
Expand All @@ -90,13 +84,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::NelderMea
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
kwargs...)
opts = NelderMead(M,
loss;
return_state = true,
stopping_criterion,
kwargs...)
opts = NelderMead(M, loss; return_state = true, kwargs...)
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
end
Expand All @@ -109,19 +98,14 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M),
kwargs...)
opts = conjugate_gradient_descent(M,
opts = Manopt.conjugate_gradient_descent(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
stepsize,
stopping_criterion,
kwargs...)
kwargs...
)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
Expand All @@ -135,25 +119,10 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
population_size::Int = 100,
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
kwargs...)
initial_population = vcat([x0], [rand(M) for _ in 1:(population_size - 1)])
opts = particle_swarm(M,
loss;
x0 = initial_population,
n = population_size,
return_state = true,
retraction_method,
inverse_retraction_method,
vector_transport_method,
stopping_criterion,
kwargs...)
# we unwrap DebugOptions here
swarm = [x0, [rand(M) for _ in 1:(population_size - 1)]...]
opts = particle_swarm(M, loss, swarm; return_state = true, kwargs...)
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
end
Expand All @@ -167,27 +136,9 @@ function call_manopt_optimizer(M::Manopt.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
stepsize = WolfePowellLinesearch(M;
retraction_method = retraction_method,
vector_transport_method = vector_transport_method,
linesearch_stopsize = 1e-12),
kwargs...
)
opts = quasi_Newton(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
retraction_method,
vector_transport_method,
stepsize,
stopping_criterion,
kwargs...)
opts = quasi_Newton(M, loss, gradF, x0; return_state = true, kwargs...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
Expand All @@ -200,18 +151,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
basis = Manopt.DefaultOrthonormalBasis(),
kwargs...)
opt = cma_es(M,
loss,
x0;
return_state = true,
stopping_criterion,
kwargs...)
opt = cma_es(M, loss, x0; return_state = true, kwargs...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
Expand All @@ -224,21 +165,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
kwargs...)
opt = convex_bundle_method!(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
retraction_method,
vector_transport_method,
stopping_criterion,
kwargs...)
opt = convex_bundle_method(M, loss, gradF, x0; return_state = true, kwargs...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
Expand All @@ -252,21 +180,13 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
gradF,
x0;
hessF = nothing,
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
kwargs...)
opt = adaptive_regularization_with_cubics(M,
loss,
gradF,
hessF,
x0;
return_state = true,
evaluation,
retraction_method,
stopping_criterion,
kwargs...)
# we unwrap DebugOptions here

opt = if isnothing(hessF)
adaptive_regularization_with_cubics(M, loss, gradF, x0; return_state = true, kwargs...)
else
adaptive_regularization_with_cubics(M, loss, gradF, hessF, x0; return_state = true, kwargs...)
end
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
end
Expand All @@ -279,20 +199,12 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
gradF,
x0;
hessF = nothing,
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
kwargs...)
opt = trust_regions(M,
loss,
gradF,
hessF,
x0;
return_state = true,
evaluation,
retraction = retraction_method,
stopping_criterion,
kwargs...)
opt = if isnothing(hessF)
trust_regions(M, loss, gradF, x0; return_state = true, kwargs...)
else
trust_regions(M, loss, gradF, hessF, x0; return_state = true, kwargs...)
end
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
Expand All @@ -305,21 +217,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
stepsize::Stepsize = DecreasingStepsize(; length = 2.0, shift = 2),
kwargs...)
opt = Frank_Wolfe_method(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
retraction_method,
stopping_criterion,
stepsize,
kwargs...)
opt = Frank_Wolfe_method(M, loss, gradF, x0; return_state = true, kwargs...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
Expand All @@ -332,20 +231,22 @@ function SciMLBase.requiresgradient(opt::Union{
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
true
end
# TODO: WHY? they both still accept not passing it
function SciMLBase.requireshessian(opt::Union{
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
true
end

Comment on lines +234 to 239
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this function defined and what is it for?

The current definition here is not correct, both ARC and TR can perform their own (actually quite good) approximation of the hessian – similar to what QN does.
So they do not need a Hessian, but the exact one of course performs a bit better than the approximate one.

function build_loss(f::OptimizationFunction, prob, cb)
function (::AbstractManifold, θ)
return function (::AbstractManifold, θ)
x = f.f(θ, prob.p)
cb(x, θ)
__x = first(x)
return prob.sense === Optimization.MaxSense ? -__x : __x
end
end

#TODO: What does the “true” mean here?
function build_gradF(f::OptimizationFunction{true})
function g(M::AbstractManifold, G, θ)
f.grad(G, θ)
Expand All @@ -356,6 +257,7 @@ function build_gradF(f::OptimizationFunction{true})
f.grad(G, θ)
return riemannian_gradient(M, θ, G)
end
return g
end
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can I find more information about this? Especially what is the parameter {true}?


function build_hessF(f::OptimizationFunction{true})
Expand All @@ -373,6 +275,7 @@ function build_hessF(f::OptimizationFunction{true})
f.grad(G, θ)
return riemannian_Hessian(M, θ, G, H, X)
end
return h
end

function SciMLBase.__solve(cache::OptimizationCache{
Expand All @@ -395,8 +298,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
LC,
UC,
S,
O <:
AbstractManoptOptimizer,
O <: AbstractManoptOptimizer,
D,
P,
C
Expand All @@ -418,6 +320,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
u = θ,
p = cache.p,
objective = x[1])
#TODO: What is this callback for?
cb_call = cache.callback(opt_state, x...)
Comment on lines +323 to 324
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this callback, what is it used for and why is this here?

if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down Expand Up @@ -448,10 +351,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
stopping_criterion = Manopt.StopAfterIteration(500)
end

# TODO: With the new keyword warnings we can not just always pass down hessF!
opt_res = call_manopt_optimizer(manifold, cache.opt, _loss, gradF, cache.u0;
solver_kwarg..., stopping_criterion = stopping_criterion, hessF)
Comment on lines +354 to 356
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Manopt 0.5.22 this will in most cases warn if not error to pass a Hessian as a keyword argument to a solver that does not accept hessians. So this has to be reworked. How does Optimization.jl usually handle that some things are only used for some solvers?


asc = get_stopping_criterion(opt_res.options)
# TODO: Switch to `has_converged` once that was released.
opt_ret = Manopt.indicates_convergence(asc) ? ReturnCode.Success : ReturnCode.Failure

return SciMLBase.build_solution(cache,
Expand Down
Loading
Loading