Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast (reverse-mode AD) hybrid ODE sensitivities - a collection of requirements (with MWE) #863

Open
ThummeTo opened this issue Aug 11, 2023 · 4 comments

Comments

@ThummeTo
Copy link
Contributor

ThummeTo commented Aug 11, 2023

Dear @frankschae,

as promised, I tried to conclude the requirements (in form of MWEs) that are needed to train over arbitrary FMUs using FMI.jl (or probably soon FMISensitivity.jl).
Both examples are very simple NeuralODEs, we don't need to train over FMUs for the MWEs.

The requirements / MWEs are (in order of priority, most important first):

  • MWE 1 of 2: train hybrid ODEs with Discretize-then-Optimize-Methods and Reverse-Mode-AD with in-place-ODE-function-evaluations, like e.g. ReverseDiff.gradient(...) and sensealg=ReverseDiffAdjoint() (this will allow to train fast on any FMU solution)
  • MWE 2 of 2: train hybrid ODEs with Optimize-then-Discretize-Methods and Reverse-Mode-AD with in-place-ODE-function-evaluations, like e.g. ReverseDiff.gradient(...) and sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()) (this will allow to train even faster on FMU solutions of FMUs that support "checkpointing", namely fmiXGetFMUState and fmiXSetFMUState)
  • not a real issue, because I can fix that by myself: multiple state-events per time instant (e.g. if identical event-conditions are defined, which is unfortunately more usual than it should be for FMUs)

Both MWEs run into the problem that the determined gradient contains NaNs, which would lead to NaNs in the parameters and later NaNs during ANN inference.

Some additional info:

  • ForwardDiff currently works, but of course is quite slow for medium and large parameter counts (> 100)
  • Zygote still needs the out-of-place-interface for ODE-function-evaluation, which is quite slow for large systems with many states (that's why we prefer ReverseDiff here)

Please don't hesitate to involve me if there is anything I can do to support.
For example, we could open a PR with tests on basis of the MWEs and/or examples for the documentation.
If there is something unclear I can post more information/code or similar.

If we get this working, we have a significant improvement for training ML-models including FMUs (and in general: hybrid ODEs).

Thank you very much & best regards,
ThummeTo

PS: "Unfortunately" I am on vacation for the next three weeks :-)

--------------- MWE ---------------

using SciMLSensitivity
using Flux
using DifferentialEquations
using DiffEqCallbacks
import SciMLSensitivity.SciMLBase: RightRootFind
import SciMLSensitivity: ReverseDiff, ForwardDiff, FakeIntegrator
import Random

Random.seed!(1234)

net = Chain(Dense(2, 16, tanh),
            Dense(16, 2, tanh))

x0 = [1.0f0, 1.0f0]
tspan = (0.0f0, 3.0f0)
saveat = tspan[1]:0.1:tspan[end]
data = sin.(saveat)

params, re = Flux.destructure(net)
initial_params = copy(params)

function fx(dx, x, p, t)
    dx[:] = re(p)(x)
end

ff = ODEFunction{true}(fx, tgrad=nothing)
prob = ODEProblem{true}(ff, x0, tspan, params)

function condition(out, x, t, integrator)
    out[1] = cos(x[1])
    out[2] = sin(x[1])
end

function affect!(integrator, idx)
    u_new = x0
    integrator.u .= u_new
end

eventCb = VectorContinuousCallback(condition,
                                   affect!,
                                   2;
                                   rootfind=RightRootFind, save_positions=(false, false))

function loss(p; sensealg=nothing)
    sol = solve(prob; p=p, callback=CallbackSet(eventCb), sensealg=sensealg, saveat=saveat)
    
    # ReverseDiff over solution returns a Array-solution instead of an ODESolution object!
    vals = sol[1,:] 

    solsize = size(sol)
    if solsize != (length(x0), length(saveat))
        @error "Step failed with solsize = $(solsize)!"
        return Inf
    end

    return Flux.Losses.mse(data, vals)
end

# loss function for Discretize-then-Optimize (DtO) and Optimize-then-Discretize (OtD)
loss_DtO = (p) -> loss(p; sensealg=ReverseDiffAdjoint())
loss_OtD = (p) -> loss(p; sensealg=InterpolatingAdjoint(;autojacvec=ReverseDiffVJP()))

# check simple gradinets for both loss functions
for loss in (loss_DtO, loss_OtD)
    grad_fd = ForwardDiff.gradient(loss, params, 
        ForwardDiff.GradientConfig(loss, params, ForwardDiff.Chunk{32}()))
    grad_rd = ReverseDiff.gradient(loss, params)          
    
    # small deviations are ok, so this is good for now!
    @info "$(loss) max deviation between ForwardDiff and ReverseDiff: $(max(abs.(grad_fd.-grad_rd)...))"
end

#### 
optim = Adam(1e-5)

# do some training steps
for loss in (loss_DtO, loss_OtD)

    # reset params (so every sensealg has the same "chance")
    params[:] = initial_params[:]

    # a very simple custom train loop, that checks the gradient before applying it
    for i in 1:500

        # get the gradient
        g = ReverseDiff.gradient(loss_DtO, params)

        # check if NaNs are in there
        if any(isnan.(g)) 
            @error "\tGradient NaN at step $(i) for loss $(loss), exiting!"
            break
        end
        
        # apply optimization step, update parameters
        step = Flux.Optimise.apply!(optim, params, g)
        params .-= step
    end
end

----------- MWE OUTPUT -------------

[ Info: #13 max deviation between ForwardDiff and ReverseDiff: 5.048493233239526e-6
[ Info: #15 max deviation between ForwardDiff and ReverseDiff: 3.90012033929521e-6

┌ Error:        Gradient NaN at step 16 for loss #13, exiting!
└ @ Main c:\Users\...:90
┌ Error:        Gradient NaN at step 13 for loss #15, exiting!
└ @ Main c:\Users\....jl:90
@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Aug 12, 2023

I think this is the kind of thing we just want to be working on getting Enzyme ready for.

sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())

Why not sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)) here?

Note that without true,

function fx(dx, x, p, t)
dx[:] = re(p)(x)
end

This out of place form will be slower than fx(x, p, t) = re(p)(x) of course because of the scalarizing.

Also one major improvement is to use Lux instead, or for small neural networks use SimpleChains (with static arrays)

@ThummeTo
Copy link
Contributor Author

Thanks for the reply! Yep ReverseDiffVJP(true) is a good point, to be honest I wasn't sure if this is allowed to use, because of the "no-branching" requirement for pre-compilation of tapes.

Migration to Lux is also on the to-do-list :-)

And I am super-curious what progress Enzyme is making (after the big steps in the last months/weeks). I will keep checking for that.

@ThummeTo
Copy link
Contributor Author

ThummeTo commented Oct 7, 2023

Very good news: DtO works in the current release(s) if you specify a solver by hand. Sensitivities are determined correctly and without numerical instabilities/NaNs. Thank you very much @ChrisRackauckas and @frankschae. However the provided MWE as it is (without a solver specified) still fails because of the linked DiffEqBase-issue.

Current progress:

Single event at the same time instant:

  • Discretize-then-Optimize-Methods and Reverse-Mode-AD with in-place-ODE-function-evaluations
  • Optimize-then-Discretize-Methods and Reverse-Mode-AD with in-place-ODE-function-evaluations

Multiple events (multiple zero-crossing event conditions) at the same time instant:

  • Discretize-then-Optimize-Methods and Reverse-Mode-AD with in-place-ODE-function-evaluations
  • Optimize-then-Discretize-Methods and Reverse-Mode-AD with in-place-ODE-function-evaluations

So the only thing remaing is the adjoint sensitivity problem for multiple zero-crossing event conditions.
Especially in my application, this is not that important, because solving FMUs backwards in time is not supported by design and causes additional overhead ...

So again, thank you very much!

PS: Are there plans for the last feature for the near future? If not, we could close this issue from my side, but I can offer to open another issue to keep track of that last feature (in case someone searches for it or similar).

@ChrisRackauckas
Copy link
Member

We plan to just keep going until everything is supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants