Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ForwardDiff rules #434

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft

Add ForwardDiff rules #434

wants to merge 9 commits into from

Conversation

sharanry
Copy link
Member

@sharanry sharanry commented Nov 13, 2023

  • Fix ambiquities
  • Refactor code to make more efficient
  • Cleanup debugging statements

Copy link

codecov bot commented Nov 13, 2023

Codecov Report

Attention: 48 lines in your changes are missing coverage. Please review.

Comparison is base (9aaf9b3) 63.96% compared to head (829a914) 27.07%.

Files Patch % Lines
ext/LinearSolveForwardDiff.jl 0.00% 46 Missing ⚠️
src/common.jl 0.00% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #434       +/-   ##
===========================================
- Coverage   63.96%   27.07%   -36.89%     
===========================================
  Files          27       28        +1     
  Lines        2106     2135       +29     
===========================================
- Hits         1347      578      -769     
- Misses        759     1557      +798     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

cacheval = eltype(cache.cacheval.factors) <: Dual ? begin
LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info)
end : cache.cacheval
cache2 = remake(cache; A, b, u, reltol, abstol, cacheval)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being forced to remake cache in order to solve the non-dual version. Is there some other way we can replace Dual Array with a regular array?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want to hook into init. In theory in init what you can do is un-dual the user inputs that are dual, but tag the cache in such a way that in solve! you end up doing two (or number of chunk size + 1) solves and reconstruct the resulting dual numbers in the output.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or rather, it's just one solve! call but in a batched form.

Comment on lines +20 to +25
res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy
dresus = reduce(hcat, map(dAs, dbs) do dA, db
cache2.b = db - dA * res.u
dres = LinearSolve.solve!(cache2, alg, kwargs...)
deepcopy(dres.u)
end)
Copy link
Member Author

@sharanry sharanry Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needing to deepcopy the results of the solves as they are being overwritten by subsequent solves when reusing the cache.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you hook into init and do a single batched solve then this is handled.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any documentation on how to do batched solves? I am unable to find how to do this anywhere. The possi bly closest thing I could find was https://discourse.julialang.org/t/batched-lu-solves-or-factorizations-with-sparse-matrices/106019/2 -- however, couldn't find the right function call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just A\B matrix instead of A\b vector

Copy link
Member Author

@sharanry sharanry Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure what you mean in the context of LinearSolve.jl.

n = 4
A = rand(n, n)
B = rand(n, n)

A \ B  # works

mapreduce(hcat, eachcol(B)) do b
    A \ b
end # works

mapreduce(hcat, eachcol(B)) do b
    prob = LinearProblem(A, b)
    sol = solve(prob)
    sol.u
end # works

begin
    prob = LinearProblem(A, B)
    sol = solve(prob)  # errors
    sol.u
end

Error:

ERROR: MethodError: no method matching ldiv!(::Vector{Float64}, ::LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, ::Matrix{Float64})

Closest candidates are:
  ldiv!(::Any, ::Sparspak.SpkSparseSolver.SparseSolver{IT, FT}, ::Any) where {IT, FT}
   @ Sparspak ~/.julia/packages/Sparspak/oqBYl/src/SparseCSCInterface/SparseCSCInterface.jl:263
  ldiv!(::Any, ::LinearSolve.InvPreconditioner, ::Any)
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/preconditioners.jl:30
  ldiv!(::Any, ::LinearSolve.ComposePreconditioner, ::Any)
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/preconditioners.jl:17
  ...

Stacktrace:
 [1] _ldiv!(x::Vector{Float64}, A::LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, b::Matrix{Float64})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/factorization.jl:11
 [2] macro expansion
   @ ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:135 [inlined]
 [3] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, alg::LUFactorization{LinearAlgebra.RowMaximum}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:127
 [4] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, alg::LUFactorization{LinearAlgebra.RowMaximum})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:127
 [5] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:218
 [6] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:217
 [7] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Matrix{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::LUFactorization{LinearAlgebra.RowMaximum}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:214
 [8] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Matrix{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::LUFactorization{LinearAlgebra.RowMaximum})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:211
 [9] top-level scope
   @ REPL[24]:3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal I thought you handled something with this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal A ping on this. Is there another way to do this if we do not yet have batch dispatch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this case, but a case where A and b are both batched. Here you will have to see how Base handles it, there are special LAPACK routines for these

Comment on lines +85 to +93
function SciMLBase.remake(cache::LinearCache;
A::TA=cache.A, b::TB=cache.b, u::TU=cache.u, p::TP=cache.p, alg::Talg=cache.alg,
cacheval::Tc=cache.cacheval, isfresh::Bool=cache.isfresh, Pl::Tl=cache.Pl, Pr::Tr=cache.Pr,
abstol::Ttol=cache.abstol, reltol::Ttol=cache.reltol, maxiters::Int=cache.maxiters,
verbose::Bool=cache.verbose, assumptions::OperatorAssumptions{issq}=cache.assumptions) where {TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}
LinearCache{TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}(A,b,u,p,alg,cacheval,isfresh,Pl,Pr,abstol,reltol,
maxiters,verbose,assumptions)
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check if there is a way to avoid redefining this by providing a better constructor for LinearCache.

Comment on lines 37 to 41
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]
end
dbs = [zero(cache.b) for _=1:P]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to find a way to allocate less if possible.

@sharanry
Copy link
Member Author

sharanry commented Nov 19, 2023

Still taking a look at performance improvements.

Figured out the method dispatch ambiguities for all methods Krylov:

Stack Trace
ERROR: LoadError: MethodError: no method matching solve!(::Krylov.GmresSolver{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}}, ::Matrix{Float64}, ::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}; M::LinearAlgebra.UniformScaling{Bool}, N::LinearAlgebra.UniformScaling{Bool}, restart::Bool, atol::ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, rtol::ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, itmax::Int64, verbose::Int64, ldiv::Bool, history::Bool)

Closest candidates are:
  solve!(::Krylov.GpmrSolver{T, FC, S}, ::Any, ::Any, ::AbstractVector{FC}, ::AbstractVector{FC}; C, D, E, F, ldiv, gsp, λ, μ, reorthogonalization, atol, rtol, itmax, timemax, verbose, history, callback, iostream) where {T<:AbstractFloat, FC<:Union{Complex{T}, T}, S<:AbstractVector{FC}} got unsupported keyword arguments "M", "N", "restart"
   @ Krylov ~/.julia/packages/Krylov/jLgPS/src/krylov_solve.jl:46
  solve!(::Krylov.GpmrSolver{T, FC, S}, ::Any, ::Any, ::AbstractVector{FC}, ::AbstractVector{FC}, ::AbstractVector, ::AbstractVector; C, D, E, F, ldiv, gsp, λ, μ, reorthogonalization, atol, rtol, itmax, timemax, verbose, history, callback, iostream) where {T<:AbstractFloat, FC<:Union{Complex{T}, T}, S<:AbstractVector{FC}} got unsupported keyword arguments "M", "N", "restart"
   @ Krylov ~/.julia/packages/Krylov/jLgPS/src/krylov_solve.jl:59
  solve!(::Krylov.CrmrSolver{T, FC, S}, ::Any, ::AbstractVector{FC}; N, ldiv, λ, atol, rtol, itmax, timemax, verbose, history, callback, iostream) where {T<:AbstractFloat, FC<:Union{Complex{T}, T}, S<:AbstractVector{FC}} got unsupported keyword arguments "M", "restart"
   @ Krylov ~/.julia/packages/Krylov/jLgPS/src/krylov_solve.jl:46
  ...

Stacktrace:
  [1] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, SciMLBase.NullParameters, KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Krylov.GmresSolver{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}}, IdentityOperator, IdentityOperator, ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, Bool}, alg::KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ LinearSolve ~/code/enzyme_playground/LS_FD/src/iterative_wrappers.jl:256
  [2] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, SciMLBase.NullParameters, KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Krylov.GmresSolver{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}}, IdentityOperator, IdentityOperator, ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, Bool}, alg::KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}})
    @ LinearSolve ~/code/enzyme_playground/LS_FD/src/iterative_wrappers.jl:225
  [3] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, SciMLBase.NullParameters, KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Krylov.GmresSolver{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}}, IdentityOperator, IdentityOperator, ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, Bool}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:218
  [4] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, SciMLBase.NullParameters, KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Krylov.GmresSolver{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}}, IdentityOperator, IdentityOperator, ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}, Bool})
    @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:217
  [5] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:214
  [6] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}})
    @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:211
  [7] (::var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}})(b::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}})
    @ Main ~/code/enzyme_playground/LS_FD/test/forwarddiff.jl:24
  [8] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined]
  [9] vector_mode_gradient(f::var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:89
 [10] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}}, ::Val{true})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:0
 [11] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#fb#31"{KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, Float64}, Float64, 4}}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17
 [12] gradient(f::Function, x::Vector{Float64})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17
 [13] top-level scope
    @ ~/code/enzyme_playground/LS_FD/test/forwarddiff.jl:41
 [14] include(fname::String)
    @ Base.MainInclude ./client.jl:478
 [15] top-level scope
    @ REPL[20]:1
in expression starting at /Users/sharan/code/enzyme_playground/LS_FD/test/forwarddiff.jl:14

Comment on lines +5 to +7
isdefined(Base, :get_extension) ?
(import ForwardDiff; using ForwardDiff: Dual) :
(import ..ForwardDiff; using ..ForwardDiff: Dual)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only 1.9+ is supported now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand. What do you mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically you dont need to do this anymore, just the first import line works

@ChrisRackauckas
Copy link
Member

See SciML/NonlinearSolve.jl#340. It should be somewhat similar, in that init should build an extended cache.

@ChrisRackauckas
Copy link
Member

Note SciML/SciMLBase.jl#558 as a downstream test case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants