Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix broadcasting into buffers #1488

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open

fix broadcasting into buffers #1488

wants to merge 12 commits into from

Conversation

lxvm
Copy link
Contributor

@lxvm lxvm commented Jan 1, 2024

Hi,

I'm using Zygote as an AD backend in Integrals.jl and while I was writing tests I noticed I couldn't assign a number to length-1 Buffer using broadcasting. I think this is because the method signature for the pullback on materialize! is too restrictive, since copyto! allows for arbitrary iterators on the rhs of buf .= itr. I also added a test for a MWE.

PR Checklist

  • Tests are added
  • Documentation, if applicable

Update: A more complete MWE brings up a second issue:

MWE 2
julia> using Zygote
julia> gradient(1) do p
           b = Zygote.Buffer([1,2,3])
           b .= p
           return sum(copy(b))
         end
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}})(::Vector{Float64})

Closest candidates are:
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:121
  (::ChainRulesCore.ProjectTo{<:Number})(::ChainRulesCore.Tangent{<:Complex})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:192
  (::ChainRulesCore.ProjectTo{<:Number})(::ChainRulesCore.Tangent{<:Number})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zoCjl/src/projection.jl:193
  ...

Stacktrace:
 [1] _project
   @ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:189 [inlined]
 [2] map(f::typeof(Zygote._project), t::Tuple{Int64}, s::Tuple{Vector{Float64}})
   @ Base ./tuple.jl:318
 [3] gradient(::Function, ::Int64, ::Vararg{Int64})
   @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:98
 [4] top-level scope
   @ REPL[11]:1

Update 2: I added an adjoint for copyto! that fixes the MWE, however I'll try to add a generic adjoint for copyto!(buffer, itr) next

Update 3: I started about this the wrong way and the manual has details here on how to bypass broadcasting machinery, so an adjoint for Base.materialize! will have to be discarded and has to be replaced by an adjoint for copyto!(buffer, broadcasted)

Update 4: I finished writing an adjoint and added a test for broadcasted assignment to a buffer from a generator. I'll happily incorporate any feedback and improvements.

@lxvm
Copy link
Contributor Author

lxvm commented Jan 2, 2024

Fixes #254

@lxvm lxvm changed the title fix broadcasting buffers with generic iterators fix broadcasting into buffers Jan 3, 2024
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this PR on call. It looks mostly good, but there are some outstanding points:

  • remove extraneous comments
  • name anonymous functions
  • can you share the error that occurs when @non_differentiable Base.Broadcast.Broadcasted(::Nothing) is removed?

@lxvm
Copy link
Contributor Author

lxvm commented Jan 16, 2024

Thank you for the review! I will address the first two points soon, and for the last one I don't seem to be able to reproduce any error, so I'll remove it. Perhaps it came up due to inadvertently breaking something while debugging.

I also wanted to ask, since I removed @adjoint! copyto!(x::Buffer, y::Any) if we should restore a generic method, and if that should use indexing or iteration to get the elements of y?

@lxvm
Copy link
Contributor Author

lxvm commented Jan 16, 2024

I also blindly removed this line

grad .= eltype(grad) <: Number ? 0 : nothing

Should I restore it?

@ToucheSir
Copy link
Member

I also wanted to ask, since I removed @adjoint! copyto!(x::Buffer, y::Any) if we should restore a generic method, and if that should use indexing or iteration to get the elements of y?

That would be best. What worries me more is that removing that copyto! rule didn't lead to any failures. We probably need to add in tests to fill that gap. As for indexing vs iteration, the rule probably needs to rely on iteration because the copyto! overload for Buffer is so generic.

@lxvm
Copy link
Contributor Author

lxvm commented Jan 18, 2024

As for indexing vs iteration, the rule probably needs to rely on iteration

Zygote doesn't yet have an adjoint for collect(itr), so instead I've just reinstated the previous adjoint for copyto!(buffer, array). Presumably very few people write f!(out, args...) in a situation where a buffer would be substituted for out, although this is what I encountered in Integrals.jl. As for broadcasting, before this pr there has been an adjoint for Base.materialize!(buffer, array) that called the one for copyto!.

I've gone back and added some broken tests to point out where adjoints are missing for collect(itr) and copyto!(buffer, itr)

Having done a fair amount of work on this, I'm not happy with the implementation yet. I think a lot of shortcomings of missing adjoints could be fixed by the following approach:

  • define the adjoint for copyto!(buffer, x) in terms of copyto!(buffer, collect(x))
  • The adjoint for collect(x) should use a trait IsIndexable(x) to determine whether it can use ∇map to compute the adjoint. (Types like AbstractArray, Number, Broadcasted that define getindex would be indexable.)
  • If IsIndexable(x) is false, then assume x is iterable and use a (currently non-existent) collect pullback for generic iterators. (e.g. types from Base.Iterators and Generators would be iterable)

Would this approach be sound? Any ideas?

@ToucheSir
Copy link
Member

I would prefer to even ditch the trait and just check for a set of known good types like the ones you listed to determine if x can be indexed. Otherwise the overall plan sounds reasonable.

@lxvm
Copy link
Contributor Author

lxvm commented Feb 3, 2024

I actually don't have time to work on the improvements I to this pr I suggested, but in order to wrap up the changes I made there are two points to address:

  1. the copyto!(buffer, src) adjoint had this line grad .= eltype(grad) <: Number ? 0 : nothing and I removed it but no tests failed. What does this do exactly?
  2. The copyto! adjoint w.r.t. buffer is always nothing, but I suppose this is the intended behavior of the buffer? I would have tried the following but it seems prohibited:
using Zygote
Zygote.gradient(collect(1:10)) do x
    b = Zygote.Buffer(x)
    tmp1 = sum(copy(b))
    copyto!(b, fill(30))
    tmp2 = sum(copy(b))
    copyto!(b, [2i for i in 1:5])
    tmp3 = sum(copy(b))
    return tmp1 + tmp2 + tmp3
end # ERROR: Buffer is frozen

@ToucheSir
Copy link
Member

ToucheSir commented Feb 4, 2024

  • the copyto!(buffer, src) adjoint had this line grad .= eltype(grad) <: Number ? 0 : nothing and I removed it but no tests failed. What does this do exactly?

I think it's for the case where you have a buffer of non-numbers. Examples would be a buffer of differentiable structs, or a buffer of arrays. That neither of these cases were tested is bad, but also not uncommon for Zygote (which historically has poor test coverage in general).

  • The copyto! adjoint w.r.t. buffer is always nothing, but I suppose this is the intended behavior of the buffer? I would have tried the following but it seems prohibited: ...

My understanding is that differentiable arguments which are mutated should not have gradients returned for correctness reasons. Instead, a copy is kept in the mutable gradient cache managed by grad_mut and returned at either the top level or the point at which the mutable value was constructed.

@lxvm
Copy link
Contributor Author

lxvm commented Feb 12, 2024

Thanks! I added back the zeroing out of grad and kept the nothing gradient of the buffer. I hope this is enough to complete the pr.

src/lib/array.jl Outdated Show resolved Hide resolved
@lxvm
Copy link
Contributor Author

lxvm commented Feb 13, 2024

I added tests for Iterators.take adjoints and rebased on the main branch

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for implementing a pretty tricky feature, @lxvm !

@lxvm
Copy link
Contributor Author

lxvm commented Feb 14, 2024

I said that this pr would fix #254 and I just wanted to say that its MWE is slightly broken

using Zygote
f = (du, u, p, t) -> du .= 0
(y, p, λ, t) = ([1.02634, 0.909691], [1.5, 1.0, 3.0, 1.0], [0.973655, 1.09031], 10.0)
_dy, back = Zygote.pullback(y) do u
  out_ = Zygote.Buffer(u)
  f(out_, u, p, t)
  copy(out_)
end
dλ[:] = vec(back(λ)[1]) # ERROR: MethodError: no method matching vec(::Nothing)

The good news is that the gradient through the overwritten buffer is nothing, but the MWE wasn't written to handle that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants