Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Something wrong with empty (Named-)Tuples and generators #1294

Open
axsk opened this issue Aug 24, 2022 · 17 comments
Open

Something wrong with empty (Named-)Tuples and generators #1294

axsk opened this issue Aug 24, 2022 · 17 comments
Labels
piracy A bug caused by a third-party committing piracy

Comments

@axsk
Copy link

axsk commented Aug 24, 2022

In a project of mine I want to take derivatives of some Neural SDE solution (computed by the custom wrapper msolve) wrt. to the Lux NN parameters:

function logvar(prob; ps=prob.p, n=100)  # calling this method works
    sum( msolve(prob, ps=ps) for i in 1:n)
end

Zygote.gradient(ps->logvar(prob, ps=ps, n=n), prob.p)[1] # this doesnt

fails with a

MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})
Stacktrace: [...]
[3] accum(x::NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, y::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:27

After following the suggestion of @ToucheSir in #1290 and replacing the generator with sum(_ -> msolve(prob, ps=ps), 1:n) the error changes to

MethodError: no method matching +(::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, ::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}})

I hotfixed this with

import Base.+
+(::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, ::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}) = (data=(;), itr=nothing)

and the code runs through.

Searching for occurences of (:data, :itr) I could make out only

dps = (data = Base.setindex(data, Δ, k), itr = nothing)

and the resp. function below.

I have no clue how this all works together but thank @mcabbott and @ToucheSir a lot for helping me find the fix.
Feel free to correct the issue title and let me know if I can be of any further help fixing this (regarding the Zygote internals I am quite out of my water though).

@axsk
Copy link
Author

axsk commented Aug 24, 2022

Some background:

function msolve(prob; ps=prob.p, dt=0.01, salg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(), noisemixing=true))
    prob = remake(prob, p=ps)
    s = solve(prob, EM(), sensealg=salg, dt=dt)
    s[end][end]
end

essentially takes an SDE problem prop which has its RHS parametrized by a Lux.Chain with parameters ps=prob.p and returns a ::Float64 from the solutions end.
I am not sure which role Lux, StochasticDiffEq or SciMLSensitivity play in this problem and will try to reduce it to a MWE when time allows (unless someone spots the problem right away :>)

@ToucheSir
Copy link
Member

While you're working on a MWE, can you provide the full message and stacktrace of the latest error along with the code to run it? The gist in the linked issue appears to be out of date.

@axsk
Copy link
Author

axsk commented Aug 24, 2022

I think I nailed it down to the remake call. When using solve(..., p=ps) instead of remake everything works out:

  • no need for the strange Base.+ hotfix
  • can use generators

I believe above methods (msolve, logvar) should suffice to reproduce the problem with any simple SDE problem with some parameter dependence (or even ODE with a corresponding solver..?). Unfortunately its past 1pm and I'm past my sworn bedtime for today, so I'll report more tomorrow.

@axsk
Copy link
Author

axsk commented Aug 24, 2022

The current fixed code is here. To reproduce the error uncomment the remake line and run test()

@mcabbott
Copy link
Member

Haven't run this, but sometimes Zygote is confused by re-using the name prob. Does it happen with e.g. prob2 = remake(prob, p=ps)?

@ToucheSir
Copy link
Member

remake looks...complicated: https://github.com/SciML/SciMLBase.jl/blob/9b361d6a3ea81a9e24ce14aab5768cea6986cdfe/src/remake.jl#L45. After testing Michael's suggestion, I would also be curious what you get from wrapping that remake call in Zygote.@showgrad.

@axsk
Copy link
Author

axsk commented Aug 25, 2022

The problem persists with binding to prob2. @showgrad returns nothing

julia> test()
(prob2 = remake(prob, p = ps)) = nothing
(prob2 = remake(prob, p = ps)) = nothing
(prob2 = remake(prob, p = ps)) = nothing
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})

@mcabbott
Copy link
Member

Without solving where these come from, they should both probably be nothing. You could try asking _project to standardise for you:

julia> methods(Zygote._project)
# 2 methods for generic function "_project" from Zygote:
 [1] _project(x::AbstractArray, dx::Tuple)
     @ ~/.julia/packages/Zygote/xGkZ5/src/compiler/chainrules.jl:188
 [2] _project(x, dx)
     @ ~/.julia/packages/Zygote/xGkZ5/src/compiler/chainrules.jl:183

julia> Zygote._project(x, dx::NamedTuple{()}) = nothing  # shouldn't introduce ambiguities

Or perhaps adding methods to wrap_chainrules_output or something here https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/chainrules.jl

These might be worth doing anyway. (JuliaDiff/ChainRulesCore.jl#565 is something similar.) If _projectworks, you could perhaps ask it to printx` too, for clues as to what object is creating these.

@ToucheSir
Copy link
Member

The problem persists with binding to prob2. @showgrad returns nothing

julia> test()
(prob2 = remake(prob, p = ps)) = nothing
(prob2 = remake(prob, p = ps)) = nothing
(prob2 = remake(prob, p = ps)) = nothing
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})

Sorry, I meant to wrap around just the remake(...) call and not the whole prob2 = remake(...) assignment. Wrapping the statement will always get you nothing unless it's used as a nested expression.

@axsk
Copy link
Author

axsk commented Aug 29, 2022

I think I distilled it into a MWE

using Zygote
using StochasticDiffEq, SciMLSensitivity
import Lux

function mwe()
    x0 = rand(1)
    p0 = rand(1)

    drift(du,u,p,t) = (du .= 1)
    noise(du,u,p,t) = (du .= 1)

    prob = SDEProblem(drift, noise, x0, 1., p0)
    sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
    Zygote.gradient(p0) do p
        sum(Zygote.@showgrad(solve(remake(prob, p=p), EM(), dt=.1, sensealg=sensealg)[end][1]) for i in 1:3)
    end
end

With @showgrad in the correct position this now returns

julia> mwe()

(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [-0.09612757465640165], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [-0.06807801678762193], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [0.16420559144402358], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})

I was surprised to see that the import Lux is necessary for the problem to occur, even though it's not being used.
Without that import there is no error.

@mcabbott
Copy link
Member

That sounds a bit like piracy, which is bad.

Does there seem to be a fix involving sending () or (;) to nothing, maybe using Zygote._project as above? That would be OK even if we never find the origin.

Also, what Julia version does this MWE work on? (Failed to install everything on nightly.)

Precompiling project...
  ✗ Cassette
  ✗ SciMLSensitivity
  13 dependencies successfully precompiled in 248 seconds. 140 already precompiled.
  2 dependencies errored. To see a full report either run `import Pkg; Pkg.precompile()` or load the packages
[ Info: Precompiling StochasticDiffEq [789caeaf-c7a9-5a7d-9973-96adeb23e2a0]
[ Info: Precompiling SciMLSensitivity [1ed8b502-d754-442c-8d5d-10ac956f44a1]
Internal error: encountered unexpected error in runtime:
AssertionError(msg="argextype only works on argument-position values")
argextype at ./compiler/optimize.jl:320

@axsk
Copy link
Author

axsk commented Aug 29, 2022

I tried

Zygote._project(x, dx::NamedTuple{()}) = nothing
Zygote._project(x, dx::NamedTuple{(), Tuple{}}) = nothing
Zygote._project(x, dx::Tuple{}) = nothing

all without effect
I am running it on 1.8 with newest versions of the packages.

Edit:

Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}}) = nothing

seems to fix it.

@axsk
Copy link
Author

axsk commented Aug 29, 2022

Looking into Lux.jl I found:

# Zygote Fixes
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
    return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
end

https://github.com/avik-pal/Lux.jl/blob/11ac3e476161eedea23194b31e48e8d128950e00/src/autodiff.jl#L92

It's a pirate and touching the problematic accum, but I am not using ComponentArray in the MWE.

@ToucheSir
Copy link
Member

Does adding that definition into your own code without importing Lux also break things?

On the thing which seems to fix things

Edit:

Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}}) = nothing

seems to fix it.

Do you mind tweaking the definition to this and pasting the stacktrace it generates here?

function Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}})
  display(stacktrace())
  println()
end

@axsk
Copy link
Author

axsk commented Aug 29, 2022

The definition without Lux does not brake it, so I guess thats not the problem.

Here are the stacktraces you asked for. One should probably start at the end since the problem only occurs after the 3rd iteration.

59-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] ZBack at chainrules.jl:206 [inlined] Pullback at namedtuple.jl:280 [inlined] (::typeof(∂(merge)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at interface2.jl:0 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:62 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522

57-element Vector{Base.StackTraces.StackFrame}:
wrap_chainrules_output at REPL[5]:2 [inlined]
map at tuple.jl:223 [inlined]
wrap_chainrules_output at chainrules.jl:106 [inlined]
(::Zygote.ZBack{Lux.var"#merge_pullback#157"{(), (:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed)}})(dy::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at chainrules.jl:206
Pullback at remake.jl:32 [inlined]
(::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at remake.jl:28 [inlined]
(::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at none:0 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:95 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:62 [inlined]
(::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:48 [inlined]
(::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:44 [inlined]
(::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:162 [inlined]
Pullback at reduce.jl:162 [inlined]
(::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:294 [inlined]
(::typeof(∂(#mapreduce#262)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:294 [inlined]

(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
(::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45
gradient(f::Function, args::Vector{Float64}) at interface.jl:97
mwe() at mwe1294.jl:17
top-level scope at mwe1294.jl:21
eval at boot.jl:368 [inlined]
include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428
_include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488
include(fname::String) at client.jl:476
top-level scope at REPL[6]:1
top-level scope at initialization.jl:52
eval at boot.jl:368 [inlined]
eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151
repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247
start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232
run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369
run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355
(::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419
#invokelatest#2 at essentials.jl:729 [inlined]
invokelatest at essentials.jl:726 [inlined]
run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404
exec_options(opts::Base.JLOptions) at client.jl:318
_start() at client.jl:522

59-element Vector{Base.StackTraces.StackFrame}:
wrap_chainrules_output at REPL[5]:2 [inlined]
map at tuple.jl:223 [inlined]
wrap_chainrules_output at chainrules.jl:106 [inlined]
ZBack at chainrules.jl:206 [inlined]
Pullback at namedtuple.jl:280 [inlined]
(::typeof(∂(merge)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at interface2.jl:0
Pullback at remake.jl:32 [inlined]
(::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at remake.jl:28 [inlined]
(::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at none:0 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:95 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:62 [inlined]
(::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:48 [inlined]
(::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:44 [inlined]
(::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:162 [inlined]
Pullback at reduce.jl:162 [inlined]
(::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:294 [inlined]

(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
(::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45
gradient(f::Function, args::Vector{Float64}) at interface.jl:97
mwe() at mwe1294.jl:17
top-level scope at mwe1294.jl:21
eval at boot.jl:368 [inlined]
include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428
_include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488
include(fname::String) at client.jl:476
top-level scope at REPL[6]:1
top-level scope at initialization.jl:52
eval at boot.jl:368 [inlined]
eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151
repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247
start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232
run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369
run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355
(::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419
#invokelatest#2 at essentials.jl:729 [inlined]
invokelatest at essentials.jl:726 [inlined]
run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404
exec_options(opts::Base.JLOptions) at client.jl:318
_start() at client.jl:522

57-element Vector{Base.StackTraces.StackFrame}:
wrap_chainrules_output at REPL[5]:2 [inlined]
map at tuple.jl:223 [inlined]
wrap_chainrules_output at chainrules.jl:106 [inlined]
(::Zygote.ZBack{Lux.var"#merge_pullback#157"{(), (:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed)}})(dy::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at chainrules.jl:206
Pullback at remake.jl:32 [inlined]
(::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at remake.jl:28 [inlined]
(::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at none:0 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:95 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:62 [inlined]
(::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:48 [inlined]
(::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:44 [inlined]
(::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:162 [inlined]
Pullback at reduce.jl:162 [inlined]
(::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:294 [inlined]
(::typeof(∂(#mapreduce#262)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:294 [inlined]

(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
(::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45
gradient(f::Function, args::Vector{Float64}) at interface.jl:97
mwe() at mwe1294.jl:17
top-level scope at mwe1294.jl:21
eval at boot.jl:368 [inlined]
include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428
_include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488
include(fname::String) at client.jl:476
top-level scope at REPL[6]:1
top-level scope at initialization.jl:52
eval at boot.jl:368 [inlined]
eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151
repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247
start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232
run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369
run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355
(::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419
#invokelatest#2 at essentials.jl:729 [inlined]
invokelatest at essentials.jl:726 [inlined]
run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404
exec_options(opts::Base.JLOptions) at client.jl:318
_start() at client.jl:522

59-element Vector{Base.StackTraces.StackFrame}:
wrap_chainrules_output at REPL[5]:2 [inlined]
map at tuple.jl:223 [inlined]
wrap_chainrules_output at chainrules.jl:106 [inlined]
ZBack at chainrules.jl:206 [inlined]
Pullback at namedtuple.jl:280 [inlined]
(::typeof(∂(merge)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at interface2.jl:0
Pullback at remake.jl:32 [inlined]
(::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at remake.jl:28 [inlined]
(::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at none:0 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:95 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:58 [inlined]
(::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:48 [inlined]
(::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:44 [inlined]
(::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:162 [inlined]
Pullback at reduce.jl:162 [inlined]
(::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:294 [inlined]

(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
(::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45
gradient(f::Function, args::Vector{Float64}) at interface.jl:97
mwe() at mwe1294.jl:17
top-level scope at mwe1294.jl:21
eval at boot.jl:368 [inlined]
include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428
_include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488
include(fname::String) at client.jl:476
top-level scope at REPL[6]:1
top-level scope at initialization.jl:52
eval at boot.jl:368 [inlined]
eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151
repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247
start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232
run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369
run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355
(::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419
#invokelatest#2 at essentials.jl:729 [inlined]
invokelatest at essentials.jl:726 [inlined]
run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404
exec_options(opts::Base.JLOptions) at client.jl:318
_start() at client.jl:522

57-element Vector{Base.StackTraces.StackFrame}:
wrap_chainrules_output at REPL[5]:2 [inlined]
map at tuple.jl:223 [inlined]
wrap_chainrules_output at chainrules.jl:106 [inlined]
(::Zygote.ZBack{Lux.var"#merge_pullback#157"{(), (:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed)}})(dy::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at chainrules.jl:206
Pullback at remake.jl:32 [inlined]
(::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at remake.jl:28 [inlined]
(::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0
Pullback at none:0 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:95 [inlined]
(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:58 [inlined]
(::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:48 [inlined]
(::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:44 [inlined]
(::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:162 [inlined]
Pullback at reduce.jl:162 [inlined]
(::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:294 [inlined]
(::typeof(∂(#mapreduce#262)))(Δ::Float64) at interface2.jl:0
Pullback at reduce.jl:294 [inlined]

(::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0
(::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45
gradient(f::Function, args::Vector{Float64}) at interface.jl:97
mwe() at mwe1294.jl:17
top-level scope at mwe1294.jl:21
eval at boot.jl:368 [inlined]
include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428
_include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488
include(fname::String) at client.jl:476
top-level scope at REPL[6]:1
top-level scope at initialization.jl:52
eval at boot.jl:368 [inlined]
eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151
repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247
start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232
run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369
run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355
(::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419
#invokelatest#2 at essentials.jl:729 [inlined]
invokelatest at essentials.jl:726 [inlined]
run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404
exec_options(opts::Base.JLOptions) at client.jl:318
_start() at client.jl:522

@ToucheSir
Copy link
Member

Thanks! The last stacktrace includes https://github.com/avik-pal/Lux.jl/blob/11ac3e476161eedea23194b31e48e8d128950e00/src/autodiff.jl#L53-L63, which is very much piracy. Is that the last stacktrace printed before the error? If so, can you see if that rrule overload breaks things?

@axsk
Copy link
Author

axsk commented Aug 31, 2022

After removing the lines in question the test runs through 🎷

@mcabbott mcabbott added the piracy A bug caused by a third-party committing piracy label Aug 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
piracy A bug caused by a third-party committing piracy
Projects
None yet
Development

No branches or pull requests

3 participants