Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient fails for Dict constructed with a generator or a vector of pairs #1293

Open
Kolaru opened this issue Aug 24, 2022 · 3 comments · May be fixed by #1335
Open

Gradient fails for Dict constructed with a generator or a vector of pairs #1293

Kolaru opened this issue Aug 24, 2022 · 3 comments · May be fixed by #1335
Labels
dictionary enhancement New feature or request needs adjoint missing rule up for grabs anyone is welcome to contribute with a PR to fix the issue

Comments

@Kolaru
Copy link

Kolaru commented Aug 24, 2022

When taking the gradient of a function that uses a dict constructor using a generator or a list, it fails because there is a try/catch block somewhere.

e.g. for a generator

julia> using Zygote

julia> function f(x)
              a = Dict(c => x for c in 1:3)
              return a[1]
          end
f (generic function with 1 method)

julia> Zygote.gradient(f, 2.0)
ERROR: Compiling Tuple{Type{Dict}, Base.Generator{UnitRange{Int64}, var"#3#4"{Float64}}}: try/catch is not supported.

It works however if I slurp the vector

julia> function f(x)
              a = Dict([c => x for c in 1:3]...)
              return a[1]
          end
f (generic function with 1 method)

julia> Zygote.gradient(f, 2.0)
(1.0,)
@ToucheSir ToucheSir added enhancement New feature or request needs adjoint missing rule up for grabs anyone is welcome to contribute with a PR to fix the issue dictionary labels Aug 24, 2022
@ToucheSir
Copy link
Member

Can you run this on the latest version of Zygote and post the full stacktrace? It seems like you're on an older version, and just the error message is not enough for us to work with. Thanks!

@Kolaru
Copy link
Author

Kolaru commented Sep 11, 2022

On Zygote master and with julia 1.7.1 I get

julia> Zygote.gradient(f, 2.0)
ERROR: Compiling Tuple{Type{Dict}, Base.Generator{UnitRange{Int64}, var"#3#4"{Float64}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\reverse.jl:121
  [3] #Primal#23
    @ C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\reverse.jl:205 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\reverse.jl:330
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\emit.jl:101
  [6] #s2770#1068
    @ C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface2.jl:28 [inlined]
  [7] var"#s2770#1068"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote .\none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:580
  [9] _pullback
    @ .\REPL[6]:2 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::typeof(f), args::Float64)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface2.jl:0
 [11] pullback(f::Function, cx::Zygote.Context{false}, args::Float64)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface.jl:44
 [12] pullback
    @ C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface.jl:42 [inlined]
 [13] gradient(f::Function, args::Float64)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface.jl:96
 [14] top-level scope
    @ REPL[7]:1

@ToucheSir
Copy link
Member

Thanks. It looks like Dict(::Generator) hits a constructor which accepts any iterable. This constructor uses try/catch, so we'd have to add a rule for it. The codepath taken is somewhat tricky though, so if anyone wants to try this I'd recommend only dispatching for Generator to start.

In the meantime, another workaround if you know your key and value types up-front is to use the typed constructor instead:

a = Dict{Int,Int}(c => x for c in 1:3)

This bypasses the function with the try/catch and should be slightly faster to boot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dictionary enhancement New feature or request needs adjoint missing rule up for grabs anyone is welcome to contribute with a PR to fix the issue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants