Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Elide stack generation outside of looping control flow #1195

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

ToucheSir
Copy link
Member

This PR ports @Keno's work on #78 to 2022 Zygote.

Because IRTools and base Julia have slightly different IR representations, some tweaks were necessary for the core algorithm:

  1. Instead of inserting phi nodes, we need to add block arguments. This is a bit more tedious because it requires updating multiple blocks.
  2. On the bright side, we don't need to calculate an iterated dominance frontier for each block. Whether any savings from that are wiped away from calling IRTools.dominators I'm not sure.
  3. Blocks are iterated over in reverse order. This allows us to iteratively narrow down the number of unaccounted alpha vars. Although forward_stacks! now theoretically runs in O(blocks * alphas) instead of O(max(blocks, alphas)) now, in practice the vast majority of alphas will be eliminated very quickly (if not in the first loop iteration).

Performance Comparison

using Zygote, BenchmarkTools

function qux(a, b, x) # Simple control flow
   aa = a ? sin(x) : cos(x)
   bb = b ? sech(aa) : tanh(aa)
   return bb
end

foldminus(xs) = Base.afoldl(-, xs...) # afoldl is very branch-heavy

xs = ntuple(identity, 16)
julia> @time gradient(qux, true, false, 1.0);
  0.146199 seconds (60.84 k allocations: 3.519 MiB, 99.73% compilation time) # 0.6.37
  0.135723 seconds (52.94 k allocations: 3.086 MiB, 99.86% compilation time) # This PR

julia> @btime gradient(qux, true, false, 1.0);
  3.378 μs (46 allocations: 1.31 KiB)
  3.044 μs (35 allocations: 720 bytes)

julia> @time gradient(foldminus, xs);
  4.785566 seconds (11.53 M allocations: 616.818 MiB, 2.59% gc time, 99.97% compilation time)
  4.428252 seconds (11.97 M allocations: 660.290 MiB, 3.03% gc time, 99.97% compilation time)

julia> @btime gradient(foldminus, $xs);
  111.256 μs (506 allocations: 20.30 KiB)
  151.316 ns (8 allocations: 848 bytes)

The afoldl example is particularly interesting because of how that function is defined. Despite the presence of a loop at the end, not requiring stacks for the block of conditionals is significantly faster. This could have immediate downstream impact for code like FluxML/Flux.jl#1809 (comment).

Next Steps

The Zygote test suite passes locally for me, so if CI + downstream is green then I think this should be a drop-in replacement for the current compiler code path. Per the comments, more optimizations may be possible for aspects such as calculating self-reachability. After looking through a bunch of IRTools code, there's probably a lot of low hanging fruit to optimize there as well.

@ToucheSir ToucheSir changed the title Elide stack generation outside of non-looping control flow Elide stack generation outside of looping control flow Apr 5, 2022
@CarloLucibello
Copy link
Member

Wow

@DhairyaLGandhi DhairyaLGandhi requested a review from Keno April 5, 2022 07:36
@MikeInnes
Copy link
Member

Awesome, really nice work @ToucheSir. If this is based on @Keno's original code it probably makes sense to add a co-author to the commit? (Alternatively you could treat this as an update to his branch, but that might be a hassle.)

I may be able to help with review if I get some time (but please don't wait up if someone else gets there first).

@ToucheSir
Copy link
Member Author

Thanks @MikeInnes! Treating this as a branch update is a little beyond my ability since the original PR was filed before the IRTools transition, but I've now tagged the commit with co-authorship info.

@jlperla
Copy link

jlperla commented May 3, 2022

Trying to track references in issues, the guess is that this is the solution to TuringLang/Turing.jl#1754 or am I missing something?

If so, is this PR sufficiently solid that it can be checked (on julia 1.7) or should I wait until it is merged?

@DhairyaLGandhi
Copy link
Member

Please do check this. It may not make too much difference in the compilation but it should help with control flow heavy code. Besides it's a good idea to test against Turing in general. We should add that to our downstream tests if we can get a subsection of the testset that sufficiently checks for Zygote correctness.

@ToucheSir
Copy link
Member Author

Friendly bump on this :)

@torfjelde
Copy link
Contributor

I just came across this, and I'll that this is huge for anything that uses DIstributions.jl (which we do in Turing.jl) due to the amount of if-statements in StatsFuns.jl/Distributions.jl. I've literally shaved off days of runtime for certain large models with Zygote by spending a grueling amount of effort tracking down if-statements in StatsFuns.jl and removing them.

I'm currently trying to do some benchmarks to see exactly what sort of effect it has on both runtime and compile time for our use-cases.

@CarloLucibello
Copy link
Member

@ToucheSir would you rebase?

@ToucheSir ToucheSir force-pushed the bc/stack-elision branch 3 times, most recently from 36453ca to 143f929 Compare November 11, 2022 07:01
@torfjelde
Copy link
Contributor

So it unfortuantely seems to significantly increase compilation time (and memory usage) in the example in TuringLang/Turing.jl#1754. For 15 tilde-statements, it blows out my 32GB mem laptop using this PR while the memory overhead for the current release (I haven't tested against master) has a minimal memory usage (it still takes ages to compile).

@torfjelde
Copy link
Contributor

Regarding the increase in compile-time, you can also observe this for the currently running tets, e.g. DiffEqFlux.jl/Layers. Atm it has been running for ~6hrs, while in the previously merged PR it seems to have only taken ~20mins: https://github.com/FluxML/Zygote.jl/actions/runs/3260268471/jobs/5353708714

@ToucheSir
Copy link
Member Author

ToucheSir commented Nov 11, 2022

2/4 failures on nightly and all failures on stable+LTS should be squashed now. The remaining 2 nightly ones are because of a missing rule and have been reported at JuliaDiff/ChainRules.jl#684.

e.g. DiffEqFlux.jl/Layers. Atm it has been running for ~6hrs, while in the previously merged PR it seems to have only taken ~20mins:

This one has been mysteriously timing out before this PR as well. I'll have another look at TuringLang/Turing.jl#1754 though. Last I checked (around the time of #1195 (comment)) the changes here didn't make a difference to latency, so perhaps the compiler has become smarter since...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants