Skip to content

Commit

Permalink
Fix problems related to unthunk_tangent for IdDict
Browse files Browse the repository at this point in the history
Co-authored-by: Brian Chen <[email protected]>
  • Loading branch information
oschulz and ToucheSir committed Sep 23, 2022
1 parent 89b0ad9 commit 5f028a4
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline unthunk_tangent(x::AbstractArray{<:AbstractThunk}) = map(unthunk_tangent, x)
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
function ChainRulesCore.rrule(::typeof(unthunk_tangent), d::IdDict)
unthunk_iddict_pullback(_) = (NoTangent(), ChainRulesCore.@not_implemented "unthunking IdDict")
return d, unthunk_iddict_pullback
end
@non_differentiable unthunk_tangent(::IdDict)


Expand Down

0 comments on commit 5f028a4

Please sign in to comment.