Skip to content

Commit

Permalink
Elide stack generation outside of non-looping control flow
Browse files Browse the repository at this point in the history
Co-authored-by: Keno Fischer <[email protected]>
  • Loading branch information
ToucheSir and Keno committed Nov 11, 2022
1 parent d39ab59 commit 7bdfe94
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 41 deletions.
127 changes: 104 additions & 23 deletions src/compiler/emit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,50 +34,130 @@ xtuple(xs...) = xcall(:tuple, xs...)

concrete(T::DataType) = T
concrete(::Type{Type{T}}) where T = typeof(T)
concrete(T) = Any
concrete(@nospecialize _) = Any

runonce(b) = b.id in (1, length(b.ir.blocks))

# TODO use a more efficient algorithm such as Johnson (1975)
# https://epubs.siam.org/doi/abs/10.1137/0204007
self_reaching(cfg, bid, visited = BitSet()) = reaches(cfg, bid, bid, visited)
function reaches(cfg, from, to, visited)
for succ in cfg[from]
if succ === to
return true
elseif succ visited
push!(visited, succ)
if reaches(cfg, succ, to, visited)
return true
end
end
end
return false
end

function forward_stacks!(adj, F)
stks, recs = [], []
stks, recs = Tuple{Int, Alpha, Bool}[], Variable[]
pr = adj.primal
for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id))
if runonce(b)
push!(recs, Variable(α))
else
stk = pushfirst!(pr, xstack(Any))
push!(recs, stk)
push!(b, xcall(Zygote, :_push!, stk, Variable(α)))
blks = blocks(pr)
last_block = length(blks)
cfg = IRTools.CFG(pr)
cfgᵀ = cfg'
doms = IRTools.dominators(cfg)

reaching_visited = BitSet()
in_loop = map(1:last_block) do b
empty!(reaching_visited)
self_reaching(cfg, b, reaching_visited)
end
alphavars = Dict{Alpha, Variable}()
alpha_blocks ==> b.id for b in blks for α in alphauses(block(adj.adjoint, b.id))]
for b in Iterators.reverse(blks)
filter!(alpha_blocks) do (α, bid)
if b.id in doms[bid]
# If a block dominates this block, α is guaranteed to be present here
αvar = Variable(α)
for br in branches(b)
map!(a -> a === α ? αvar : a, br.args, br.args)
end
push!(recs, b.id === last_block ? αvar : alphavars[α])
push!(stks, (bid, α, false))
elseif in_loop[bid]
# This block is in a loop, so we're forced to insert stacks
# Note: all alphas in loops will have stacks after the first iteration
stk = pushfirst!(pr, xstack(Any))
push!(recs, stk)
push!(block(pr, bid), xcall(Zygote, :_push!, stk, Variable(α)))
push!(stks, (bid, α, true))
else
# Fallback case, propagate alpha back through the CFG
argvar = nothing
if b.id > 1
# Need to make sure all predecessors have a branch to add arguments to
IRTools.explicitbranch!(b)
argvar = argument!(b, insert=false)
end
if b.id === last_block
# This alpha has been threaded all the way through to the exit block
alphavars[α] = argvar
end
for br in branches(b)
map!(a -> a === α ? argvar : a, br.args, br.args)
end
for pred in cfgᵀ[b.id]
pred >= b.id && continue # TODO is this needed?
pred_branches = branches(block(pr, pred))
idx = findfirst(br -> br.block === b.id, pred_branches)
if idx === nothing
throw(error("Predecessor $pred of block $(b.id) has no branch to $(b.id)"))
end
branch_here = pred_branches[idx]
push!(branch_here.args, α)
end
# We're not done with this alpha yet, revisit in predecessors
return true
end
return false
end
# Prune any alphas that don't exist on this path through the CFG
for br in branches(b)
map!(a -> a isa Alpha ? nothing : a, br.args, br.args)
end
push!(stks, (b.id, alpha(α)))
end
args = arguments(pr)[3:end]
@assert isempty(alpha_blocks)

rec = push!(pr, xtuple(recs...))
# Pullback{F,Any} reduces specialisation
P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any}
# P = Pullback{F,Any} # reduce specialisation
rec = push!(pr, Expr(:call, P, rec))
ret = xtuple(pr.blocks[end].branches[end].args[1], rec)
ret = push!(pr, ret)
pr.blocks[end].branches[end].args[1] = ret
return pr, stks
end

# Helps constrain pullback function type in the backwards pass
# If we had the type, we could make this a PiNode
notnothing(::Nothing) = error()
notnothing(x) = x

function reverse_stacks!(adj, stks)
ir = adj.adjoint
entry = blocks(ir)[end]
blcks = blocks(ir)
entry = blcks[end]
self = argument!(entry, at = 1)
t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t)))
repl = Dict()
runonce(b) = b.id in (1, length(ir.blocks))
for b in blocks(ir)
for (i, (b′, α)) in enumerate(stks)
t = pushfirst!(entry, xcall(:getfield, self, QuoteNode(:t)))
repl = Dict{Alpha,Variable}()
for b in blcks
for (i, (b′, α, use_stack)) in enumerate(stks)
b.id == b′ || continue
if runonce(b)
val = insertafter!(ir, t, xcall(:getindex, t, i))
else
stk = push!(entry, xcall(:getindex, t, i))
stk = push!(entry, xcall(Zygote, :Stack, stk))
# i.e. recs[i] from forward_stacks!
val = insertafter!(ir, t, xcall(:getindex, t, i))
if use_stack
stk = push!(entry, xcall(Zygote, :Stack, val))
val = pushfirst!(b, xcall(:pop!, stk))
elseif !runonce(b)
# The first and last blocks always run, so this check is redundant there
val = pushfirst!(b, xcall(Zygote, :notnothing, val))
end
repl[α] = val
end
Expand All @@ -87,6 +167,7 @@ end

function stacks!(adj, T)
forw, stks = forward_stacks!(adj, T)
IRTools.domorder!(forw)
back = reverse_stacks!(adj, stks)
permute!(back, length(back.blocks):-1:1)
IRTools.domorder!(back)
Expand Down
6 changes: 3 additions & 3 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ end
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
forw = varargs!(meta, forw, 3)
# IRTools.verify(forw)
# verify(forw)
forw = slots!(pis!(inlineable!(forw)))
# be ready to swap to using chainrule if one is declared
cr_edge != nothing && edge!(meta, cr_edge)
cr_edge !== nothing && edge!(meta, cr_edge)
return update!(meta.code, forw)
end

Expand All @@ -53,7 +53,7 @@ end
end
meta, _, back = g
argnames!(meta, Symbol("#self#"), )
# IRTools.verify(back)
# verify(back)
back = slots!(inlineable!(back))
return update!(meta.code, back)
end
6 changes: 6 additions & 0 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ function accum_global(cx::Context, ref, x̄)
return
end

# Needed for nested AD
function _pullback(::typeof(accum_global), cx::Context, ref, x̄)
accum_global_pullback(_) = nothing
return accum_global(cx, ref, x̄), accum_global_pullback
end

unwrap(x) = x

@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)
Expand Down
62 changes: 49 additions & 13 deletions test/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Zygote, Test
using Zygote, IRTools, Test
using Zygote: pullback, @adjoint, Context

macro test_inferred(ex)
Expand All @@ -18,24 +18,22 @@ end

bad(x) = x
@adjoint bad(x) = x, Δ -> error("bad")
bad_adjoint_line = @__LINE__() - 1 # source location of above

function badly(x)
x = x + 1
x = bad(x)
return x
end
bad_pullback_line = @__LINE__() - 3 # should match source location of Pullback

y, back = pullback(badly, 2)
@test y == 3
@test_throws Exception back(1)
bt = try back(1) catch e stacktrace(catch_backtrace()) end

@test trace_contains(bt, nothing, "compiler.jl", 20)
if VERSION >= v"1.6-"
@test_broken trace_contains(bt, :badly, "compiler.jl", 24)
else
@test trace_contains(bt, :badly, "compiler.jl", 24)
end
bt = try back(1) catch e stacktrace(catch_backtrace()) end
@test trace_contains(bt, nothing, "compiler.jl", bad_adjoint_line)
@test trace_contains(bt, nothing, "compiler.jl", bad_pullback_line)

# Type inference checks

Expand All @@ -58,10 +56,9 @@ y, back = @test_inferred pullback(f, 5)
y, back = @test_inferred pullback(Core._apply, +, (1, 2, 3))
@test_inferred back(1)

# TODO fix bcast inference
# bcast(x) = x .* 5
# y, back = @test_inferred pullback(bcast, [1,2,3])
# @test_inferred back([1,1,1])
bcast(x) = x .* 5
y, back = @test_inferred pullback(bcast, [1,2,3])
@test_inferred back([1,1,1])

foo = let a = 4
x -> x*a
Expand Down Expand Up @@ -91,6 +88,45 @@ struct Funky
y
end

@testset "stack elision" begin
function isstackfree(T)
_, forw, back = Zygote._generate_pullback_via_decomposition(T)
for (_, stmt) in forw
expr = stmt.expr
expr.head == :call && first(expr.args) == GlobalRef(Zygote, :_push!) && return false
end
for (_, stmt) in back
expr = stmt.expr
expr.head == :call && first(expr.args) == GlobalRef(Zygote, :Stack) && return false
end
return true
end

function knockoff_pow(x, n)
n == 0 && return 1
n == 1 && return x
n == 2 && return x * x
n == 3 && return x * x * x
return x ^ n
end

function roundabout_trig(x, fancy_sin, fancy_cos, fancy_tan)
if fancy_tan
s = fancy_sin ? inv(csc(x)) : sin(x)
c = fancy_cos ? inv(sec(x)) : cos(x)
s += 0
c *= 1
return s / c
else
return tan(x)
end
end

@test !isstackfree(Tuple{typeof(pow), Int, Int})
@test isstackfree(Tuple{typeof(knockoff_pow), Int, Int})
@test isstackfree(Tuple{typeof(roundabout_trig), Float64, Bool, Bool, Bool})
end

@testset "issue #851" begin
f = Funky(1, 1);
function Base.getproperty(f::Funky, i::Symbol)
Expand Down Expand Up @@ -128,7 +164,7 @@ end
d_two = Zygote.pullback(two_svds, X)[2](Δoutput)
d_one = Zygote.pullback(one_svd, X)[2](Δoutput)
@test d_one == d_two
end
end

# this test fails if adjoint for literal_getproperty is added
# https://github.com/FluxML/Zygote.jl/issues/922#issuecomment-804128905
Expand Down
2 changes: 1 addition & 1 deletion test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ end == (2,)
global_param = 3

@testset "Global Params" begin
cx = Zygote.Context()
cx = Zygote.Context{true}(nothing) # only makes sense with implicit params
y, back = Zygote._pullback(cx, x -> x*global_param, 2)
@test y == 6
@test back(1) == (nothing, 3)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Zygote, Test
using Zygote: gradient, ZygoteRuleConfig
using CUDA
using CUDA: has_cuda
using LinearAlgebra

@testset "all" begin # Overall testset ensures it keeps running after failure

Expand Down
1 change: 0 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using LinearAlgebra
using ForwardDiff
using Zygote: hessian_dual, hessian_reverse

Expand Down

0 comments on commit 7bdfe94

Please sign in to comment.