Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix broadcasting into buffers #1488

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
18 changes: 17 additions & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)

# With mismatched lengths, map stops early. With mismatched shapes, it makes a vector.
# So we keep axes(x) to restore gradient dx to its full length & correct shape.
_tryaxes(x) = axes(x)
_tryaxes(x) = (s = Base.IteratorSize(x); s isa Base.HasShape ? axes(x) : s isa Base.HasLength ? (Base.OneTo(length(x)),) : throw(ArgumentError("iterator size must be finite")))
_tryaxes(x::AbstractArray) = axes(x)
_tryaxes(x::Tuple) = Val(length(x))
_tryaxes(x::Number) = x
_restore(dx::AbstractArray{Nothing}, ax::Tuple) = similar(dx, ax)
Expand Down Expand Up @@ -319,6 +320,21 @@ end
collect(z), collect_zip_pullback
end

takefunc(itr, dy) = _restore(dy, _tryaxes(itr))

@adjoint function Iterators.take(itr, n)
take_pullback(::AbstractArray{Nothing}) = nothing
take_pullback(dy::NamedTuple{(:xs,:n)}) = (dy.xs, dy.n)
take_pullback(dy::NamedTuple{(:n,:xs)}) = (dy.xs, dy.n)
take_pullback(dy::AbstractArray) = (takefunc(itr, dy), nothing)
Iterators.take(itr, n), take_pullback
end

@adjoint function Base.collect(t::Iterators.Take)
collect_take_pullback(dy) = ((xs=takefunc(t.xs, dy), n=nothing),)
collect(t), collect_take_pullback
end

# Reductions
@adjoint function sum(xs::AbstractArray; dims = :)
if dims === (:)
Expand Down
44 changes: 34 additions & 10 deletions src/lib/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ grad_mut(cx::Context, b::Buffer{T}, ::Type{S}=Union{}) where {T<:Number, S<:Numb
@non_differentiable Buffer(::Any...)

@adjoint function getindex(b::Buffer, i...)
b[i...], function (Δ)
function getindex_buffer_pullback(Δ)
grad = grad_mut(__context__, b, eltype(Δ))
grad[i...] = accum(grad[i...], Δ)
return
end
b[i...], getindex_buffer_pullback
end

@adjoint! function setindex!(b::Buffer, v, i...)
setindex!(b, v, i...), function (_)
function setindex!_buffer_pullback(_)
grad = grad_mut(__context__, b)
v̄ = grad[i...]
zero = eltype(grad) <: Number ? 0 : nothing
Expand All @@ -26,26 +27,49 @@ end
end
(nothing, v̄, map(_->nothing, i)...)
end
setindex!(b, v, i...), setindex!_buffer_pullback
end

@adjoint! function copyto!(b::Buffer, xs)
copyto!(b, xs), function (_)
@adjoint! function copyto!(b::Buffer, src::AbstractArray)
function copyto!_buffer_array_pullback(_)
grad = grad_mut(__context__, b)
x̄s = copy(grad)
grad .= eltype(grad) <: Number ? 0 : nothing
return (nothing, x̄s)
xs = copy(grad)
grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing
return (nothing, xs)
end
copyto!(b, src), copyto!_buffer_array_pullback
end

@adjoint! function copyto!(b::Buffer, bc::Base.Broadcast.Broadcasted)
xs, map_pullback = ∇map(__context__, i -> bc[i], eachindex(bc))
function copyto!_buffer_broadcast_pullback(_)
grad = grad_mut(__context__, b)
d, = map_pullback(reshape(first(grad, length(xs)), size(xs)))
grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing
return (nothing, d.bc)
end
copyto!(b, xs), copyto!_buffer_broadcast_pullback
end

function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, g::Base.Generator)
xs, collect_pullback = _pullback(cx, collect, g)
function copyto!_buffer_generator_pullback(_)
grad = grad_mut(cx, b)
_, dg = collect_pullback(reshape(first(grad, length(xs)), size(xs)))
grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing
return (nothing, nothing, dg)
end
copyto!(b, xs), copyto!_buffer_generator_pullback
end

@adjoint! function push!(b::Buffer, x)
push!(b, x), function (y)
function push!_buffer_pullback(_)
grad = grad_mut(__context__, b)
return (nothing, pop!(grad))
end
push!(b, x), push!_buffer_pullback
end

_pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::AbstractArray) =
_pullback(cx, copyto!, b, x)

@adjoint function copy(b::Buffer)
res = copy(b)
Expand Down
6 changes: 4 additions & 2 deletions src/tools/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ function Base.deleteat!(b::Buffer, i)
return b
end

@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes,
Base.eachindex, Base.stride, Base.strides, Base.findfirst,
@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes,
Base.eachindex, Base.stride, Base.strides, Base.findfirst,
Base.keys

Base.IteratorSize(::Type{<:Buffer{<:Any, A}}) where {A} = Base.IteratorSize(A)
Expand All @@ -84,3 +84,5 @@ function Base.iterate(b::Buffer, state=(eachindex(b),))
y === nothing && return nothing
b[y[1]], (state[1], tail(y)...)
end

Base.BroadcastStyle(::Type{Buffer{T,A}}) where {T,A} = Base.BroadcastStyle(A)
51 changes: 49 additions & 2 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,23 @@ end
@test gradient(x -> sum(inv, collect(view(x', 1,:))), ones(2,2)) == ([-1 0; -1 0],)

@test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],)

# adjoint of generators is available and should support generic arrays and iterators
# generator of array
@test gradient(p -> sum(collect(p*i for i in [1.0, 2.0, 3.0])), 2.0) == (6.0,)
# generator of iterator with HasShape
@test gradient(p -> sum(collect(p*i for (i,) in zip([1.0, 2.0, 3.0]))), 2.0) == (6.0,)
# generator of iterator with HasLength
@test gradient(p -> sum(collect(p*i for i in Iterators.take([1.0, 2.0, 3.0], 3))), 2.0) == (6.0,)
@test gradient(p -> sum(collect(p*i for i in Iterators.take(p*[1.0, 2.0, 3.0], 2))), 2.0) == (12.0,)
# generator 0-d behavior handled incorrectly
@test_broken gradient(p -> sum(collect(p*i for i in 1.0)), 2.0)
@test_broken gradient(p -> sum(collect(p*i for i in fill(1.0))), 2.0)

# adjoints for iterators
@test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 4))), 1.0) == (10.0,)
@test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0) == (15.0,)
@test_broken gradient(sum∘collect, 1.0) == (1.0,) # broken since no generic adjoint
end

@test gradtest(x -> reverse(x), rand(17))
Expand Down Expand Up @@ -523,7 +540,7 @@ end
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))

@test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9]

# issue 1224, second order
f1244(w, x) = sum(maximum((w * x).^2, dims=1))
g1244(w, x) = sum(gradient(f1244, w, x)[2].^2)
Expand Down Expand Up @@ -1538,6 +1555,36 @@ using Zygote: Buffer
return sum(copy(b))
end == ([2,2,2],)

@test gradient([1, 2, 3]) do xs
b = Zygote.Buffer(xs)
b .= 2
return sum(copy(b))
end == (nothing,)

@test gradient(1.1) do p
b = Zygote.Buffer(zeros(3))
b .= (p*i for i in eachindex(b))
return sum(copy(b) .* (2:4))
end[1] ≈ 1*2 + 2*3 + 3*4

@test gradient(1.1) do p
b = Zygote.Buffer(zeros(3))
copyto!(b, [p*i for i in eachindex(b)])
return sum(copy(b) .* (2:4))
end[1] ≈ 1*2 + 2*3 + 3*4

@test gradient(1.1) do p
b = Zygote.Buffer(zeros(3))
copyto!(b, (p*i for i in eachindex(b)))
return sum(copy(b) .* (2:4))
end[1] ≈ 1*2 + 2*3 + 3*4

@test_broken gradient(1.1) do p
b = Zygote.Buffer(zeros(3))
copyto!(b, p)
return sum(copy(b) .* (2:4))
end[1] ≈ 1*2

@test gradient(2) do x
b = Zygote.Buffer([])
push!(b, x)
Expand Down Expand Up @@ -1701,7 +1748,7 @@ end
end

@testset "FillArrays" begin

@test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1])
@test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing
@test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing
Expand Down
27 changes: 27 additions & 0 deletions test/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ end
@test @inferred back(collect(y)) == (nothing, [1.0, 0.0, 0.0, 0.0, 0.0], fill(1.0))
end

@testset "adjoints of Iterators.take" begin
y, back = _pullback(Iterators.take, 1:5, 3)
@test back(collect(y)) == (nothing, [1.0, 2.0, 3.0, 0.0, 0.0], nothing)
@test back([nothing for i in 1:3]) === nothing

@test gradient(x -> sum([2y for y in Iterators.take(x, 4)]), [1,2,3,4])[1] ≈ [2, 2, 2, 2]
@test gradient(x -> sum(2y for y in Iterators.take(x, 4)), [1,2,3,4])[1] ≈ [2, 2, 2, 2]

for p in (1.0, fill(1.0), [1.0])
@test gradient(p_ -> sum(map(prod, Iterators.take(p_, 1))), p) == (p,)
@test gradient(p_ -> sum(x for x in Iterators.take(p_, 1)), p) == (p,)
end

y, back = _pullback(Iterators.take, ones(2, 2), 3)
@test @inferred back(collect(y)) == (nothing, [1.0 1.0; 1.0 0.0], nothing)
end

@testset "collect" begin
@testset "Dict" begin
d = Dict(1 => 5, 2 => 6)
Expand Down Expand Up @@ -97,6 +114,16 @@ end
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x.^2))), ones(4)) == (3ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x.^2))), ones(4)) == (4ones(4),)
end


@testset "Iterators.Take" begin
z = Iterators.take(1:3, 2)
g = gradient(z -> sum(collect(z)), z)[1]
@test g == (xs=[1.0, 1.0, 0.0], n=nothing)

@test gradient(x -> sum(broadcast(prod, Iterators.take(x,2))), ones(4)) == ([1.0,1.0,0.0,0.0],)
@test gradient(x -> sum(broadcast(prod, Iterators.take(x.^2,2))), ones(4)) == (2*[1.0,1.0,0.0,0.0],)
end
end

@testset "dictionary comprehension" begin
Expand Down