Skip to content

Commit

Permalink
fixes for FMIBase
Browse files Browse the repository at this point in the history
  • Loading branch information
ThummeTo committed May 28, 2024
1 parent 6ca59c0 commit 1e4f7c1
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 228 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
[compat]
FMIBase = "1.0.0"
ForwardDiffChainRules = "0.2.0"
SciMLSensitivity = "7.0 - 7.58"
SciMLSensitivity = "7.0 - 7.59"
julia = "1.6"
61 changes: 40 additions & 21 deletions src/sense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import FMIBase: eval!, invalidate!, check_invalidate!
using FMIBase: getDirectionalDerivative!, getAdjointDerivative!
using FMIBase: setContinuousStates, setInputs, setReal, setTime, setReal, getReal!, getEventIndicators!
using FMIBase: setContinuousStates, setInputs, setReal, setTime, setReal, getReal!, getEventIndicators!, getRealType

# in FMI2 and FMI3 we can use fmi2GetDirectionalDerivative for JVP-computations
function jvp!(c::FMU2Component, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)
function jvp!(c::FMUInstance, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)

jac = getfield(c, mtxCache)
if isnothing(jac)
Expand All @@ -23,14 +23,14 @@ function jvp!(c::FMU2Component, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)
jac.x_refs = ∂x_refs

if c.fmu.executionConfig.JVPBuiltInDerivatives && providesDirectionalDerivatives(c.fmu) && !isa(jac.f_refs, Tuple) && !isa(jac.x_refs, Symbol)
getDirectionalDerivative!(c, ∂f_refs, ∂x_refs, jac.jvp, seed)
getDirectionalDerivative!(c, ∂f_refs, ∂x_refs, seed, jac.jvp)
return jac.jvp
else
return jvp!(jac, x, seed)
end
end

function gvp!(c::FMU2Component, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)
function gvp!(c::FMUInstance, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)

grad = getfield(c, mtxCache)
if isnothing(grad)
Expand All @@ -45,15 +45,15 @@ function gvp!(c::FMU2Component, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)
grad.x_refs = ∂x_refs

if c.fmu.executionConfig.JVPBuiltInDerivatives && providesDirectionalDerivatives(c.fmu) && !isa(grad.f_refs, Tuple) && !isa(grad.x_refs, Symbol)
getDirectionalDerivative!(c, ∂f_refs, ∂x_refs, grad.gvp, [seed])
getDirectionalDerivative!(c, ∂f_refs, ∂x_refs, [seed], grad.gvp)
return grad.gvp
else
return gvp!(grad, x, seed)
end
end

# in FMI2 there is no helper for VJP-computations (but in FMI3) ...
function vjp!(c::FMU2Component, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)
function vjp!(c::FMUInstance, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)

jac = getfield(c, mtxCache)
if isnothing(jac)
Expand All @@ -68,14 +68,14 @@ function vjp!(c::FMU2Component, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)
jac.x_refs = ∂x_refs

if c.fmu.executionConfig.VJPBuiltInDerivatives && providesAdjointDerivatives(c.fmu) && !isa(jac.f_refs, Tuple) && !isa(jac.x_refs, Symbol)
getAdjointDerivative!(c, ∂f_refs, ∂x_refs, jac.vjp, seed)
getAdjointDerivative!(c, ∂f_refs, ∂x_refs, seed, jac.vjp)
return jac.vjp
else
return vjp!(jac, x, seed)
end
end

function vgp!(c::FMU2Component, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)
function vgp!(c::FMUInstance, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)

grad = getfield(c, mtxCache)
if isnothing(grad)
Expand All @@ -90,7 +90,7 @@ function vgp!(c::FMU2Component, mtxCache::Symbol, ∂f_refs, ∂x_refs, x, seed)
grad.x_refs = ∂x_refs

if c.fmu.executionConfig.VJPBuiltInDerivatives && providesAdjointDerivatives(c.fmu) && !isa(grad.f_refs, Tuple) && !isa(grad.x_refs, Symbol)
getAdjointDerivative!(c, ∂f_refs, ∂x_refs, grad.vgp, [seed])
getAdjointDerivative!(c, ∂f_refs, ∂x_refs, [seed], grad.vgp)
return grad.vgp
else
return vgp!(grad, x, seed)
Expand Down Expand Up @@ -308,7 +308,7 @@ function ChainRulesCore.rrule(::typeof(FMIBase.eval!),
ec_idcs,
t)

@assert !isa(cRef, FMU2Component) "Wrong dispatched!"
@assert !isa(cRef, FMUInstance) "Wrong dispatched!"

@debug "rrule start: $((cRef, dx, dx_refs, y, y_refs, x, u, u_refs, p, p_refs, ec, ec_idcs, t))"

Expand Down Expand Up @@ -998,9 +998,9 @@ end

# FiniteDiff Jacobians

abstract type FMU2Sensitivities end
abstract type FMUSensitivities end

mutable struct FMUJacobian{C, T, F} <: FMU2Sensitivities
mutable struct FMUJacobian{C, T, F} <: FMUSensitivities
valid::Bool
colored::Bool
component::C
Expand Down Expand Up @@ -1065,7 +1065,7 @@ mutable struct FMUJacobian{C, T, F} <: FMU2Sensitivities

end

mutable struct FMUGradient{C, T, F} <: FMU2Sensitivities
mutable struct FMUGradient{C, T, F} <: FMUSensitivities
valid::Bool
colored::Bool
component::C
Expand Down Expand Up @@ -1152,12 +1152,12 @@ function f_∂v_∂t(jac::FMUGradient, dx, x)
return dx
end

function FMIBase.invalidate!(sens::FMU2Sensitivities)
function FMIBase.invalidate!(sens::FMUSensitivities)
sens.valid = false
return nothing
end

function FMIBase.check_invalidate!(vrs, sens::FMU2Sensitivities)
function FMIBase.check_invalidate!(vrs, sens::FMUSensitivities)
if !sens.valid
return
end
Expand All @@ -1175,22 +1175,41 @@ function FMIBase.check_invalidate!(vrs, sens::FMU2Sensitivities)
return nothing
end

function uncolor!(jac::FMU2Sensitivities)
function uncolor!(jac::FMUSensitivities)
jac.colored = false
return nothing
end

function onehot(c::FMUInstance, len::Integer, i::Integer) # [ToDo] this could be solved without allocations
ret = zeros(getRealType(c), len)
ret[i] = 1.0
return ret
end

function validate!(jac::FMUJacobian, x::AbstractVector)

rows = length(jac.f_refs)
cols = length(jac.x_refs)

if jac.component.fmu.executionConfig.sensitivity_strategy == :FMIDirectionalDerivative && providesDirectionalDerivatives(jac.component.fmu) && !isa(jac.f_refs, Tuple) && !isa(jac.x_refs, Symbol)
# ToDo: use directional derivatives with sparsitiy information!
# ToDo: Optimize allocation (ones)
for i in 1:length(jac.x_refs)
getDirectionalDerivative!(jac.component, jac.f_refs, [jac.x_refs[i]], view(jac.mtx, 1:length(jac.f_refs), i), ones(length(jac.f_refs)))
# ToDo: Optimize allocation (onehot)
# [Note] Jacobian is sampled column by column
for i in 1:cols
getDirectionalDerivative!(jac.component, jac.f_refs, jac.x_refs, onehot(jac.component, cols, i), view(jac.mtx, 1:rows, i))
end
elseif jac.component.fmu.executionConfig.sensitivity_strategy == :FMIAdjointDerivative && providesAdjointDerivatives(jac.component.fmu) && !isa(jac.f_refs, Tuple) && !isa(jac.x_refs, Symbol)
# ToDo: use directional derivatives with sparsitiy information!
# ToDo: Optimize allocation (onehot)
# [Note] Jacobian is sampled row by row
for i in 1:rows
getAdjointDerivative!(jac.component, jac.f_refs, jac.x_refs, onehot(jac.component, rows, i), view(jac.mtx, 1:cols, i))
end
else #if jac.component.fmu.executionConfig.sensitivity_strategy == :FiniteDiff
# cache = FiniteDiff.JacobianCache(x)
FiniteDiff.finite_difference_jacobian!(jac.mtx, (_x, _dx) -> (jac.f(jac, _x, _dx)), x) # , cache)
# else
# @assert false "Unknown sensitivity strategy `$(jac.component.fmu.executionConfig.sensitivity_strategy)`."
end

jac.validations += 1
Expand All @@ -1202,7 +1221,7 @@ function validate!(grad::FMUGradient, x::Real)

if grad.component.fmu.executionConfig.sensitivity_strategy == :FMIDirectionalDerivative && providesDirectionalDerivatives(grad.component.fmu) && !isa(grad.f_refs, Tuple) && !isa(grad.x_refs, Symbol)
# ToDo: use directional derivatives with sparsitiy information!
getDirectionalDerivative!(grad.component, grad.f_refs, grad.x_refs, grad.vec, ones(length(jac.f_refs)))
getDirectionalDerivative!(grad.component, grad.f_refs, grad.x_refs, ones(length(jac.f_refs)), grad.vec)
else #if grad.component.fmu.executionConfig.sensitivity_strategy == :FiniteDiff
# cache = FiniteDiff.GradientCache(x)
FiniteDiff.finite_difference_gradient!(grad.vec, (_x, _dx) -> (grad.f(grad, _x, _dx)), x) # , cache)
Expand All @@ -1213,7 +1232,7 @@ function validate!(grad::FMUGradient, x::Real)
return nothing
end

function color!(sens::FMU2Sensitivities)
function color!(sens::FMUSensitivities)
# ToDo
# colors = SparseDiffTools.matrix_colors(sparsejac)

Expand Down
Loading

0 comments on commit 1e4f7c1

Please sign in to comment.