Skip to content

Commit

Permalink
Last fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Jun 11, 2024
1 parent 9612596 commit ba817fa
Showing 1 changed file with 41 additions and 86 deletions.
127 changes: 41 additions & 86 deletions Burgers_tutorial/burgers_tutorial.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### A Pluto.jl notebook ###
# v0.19.40
# v0.19.42

#> [frontmatter]
#> title = "02 - Burgers Equation"
Expand Down Expand Up @@ -385,7 +385,7 @@ end
# ╔═╡ e14d9a22-4d55-4e66-8a63-8fdccd0c6d27
md"""
# Running the Burgers Model
We are ready to play with our Burgers model. Initially we use $100 \times 100$ grid points and only $10$ time steps.
We are ready to play with our Burgers model. We use $100 \times 100$ grid points and only $10$ time steps.
"""

# ╔═╡ 48b5e47b-e090-401d-ad7a-4898874b5117
Expand Down Expand Up @@ -444,16 +444,16 @@ To get a nicely colored plot, we visualize the velocity magnitude $|v|^2 = u^2 +
"""

# ╔═╡ bc294229-2b02-4b19-8f1b-71348793b323
function velocity_magnitude(burgers)
function velocity_magnitude_sq(burgers)
burgers.nextu[2:end-1, 2:end-1] .^ 2 + burgers.nextv[2:end-1, 2:end-1] .^ 2
end

# ╔═╡ f9f6b93f-74b4-49bb-91ed-71b14c87fb3a
surface(
range(-3, 3, length=burgers.nx-2),
range(-3, 3, length=burgers.ny-2),
velocity_magnitude(burgers);
axis=(type=Axis3,),
velocity_magnitude_sq(burgers);
axis=(type=Axis3, azimuth = -pi/4,)
)

# ╔═╡ b8201b3f-5187-4fdd-86a2-feb4e1d4f05b
Expand Down Expand Up @@ -482,20 +482,12 @@ md"""
Loading the tool 🔨
"""

# ╔═╡ 8112d5d6-151c-4b40-ac85-450786dff438
surface(
range(-3, 3, length=burgers.nx-2),
range(-3, 3, length=burgers.ny-2),
velocity_magnitude(burgers);
axis=(type=Axis3,),
)

# ╔═╡ c81d0aff-fdaa-40c2-b78b-a89143bf401d
md""" We define the _adjoint velocity magnitude_ as $(dJ/du)^2 + (dJ/dv)^2$. Not that in the implementation we have to access the `last` field (input) as opposed to the `next` field (output) when computing the velocity magnitude."""
md""" We define the adjoint velocity magnitude squared as $(dJ/du)^2 + (dJ/dv)^2$. Not that in the implementation we have to access the `last` field (input) as opposed to the `next` field (output) when computing the velocity magnitude."""

# ╔═╡ f061c0c1-5955-4e1b-888c-922dbae316b8
function adjoint_velocity_magnitude(burgers)
burgers.lastu[2:end-1, 2:end-1] .^ 2 + burgers.lastv[2:end-1, 2:end-1] .^ 2
function adjoint_velocity_magnitude_sq(dburgers)
dburgers.lastu[2:end-1, 2:end-1] .^ 2 + dburgers.lastv[2:end-1, 2:end-1] .^ 2
end

# ╔═╡ 3a4e997d-b002-4129-a69a-90bfe285f824
Expand Down Expand Up @@ -544,8 +536,8 @@ begin
surface(
range(-3, 3, length=burgers_hd.nx-2),
range(-3, 3, length=burgers_hd.ny-2),
velocity_magnitude(burgers_hd);
axis=(type=Axis3,),
velocity_magnitude_sq(burgers_hd);
axis=(type=Axis3, azimuth = -pi/4,)
)
end

Expand All @@ -559,66 +551,47 @@ end
surface(
range(-3, 3, length=dburgers.nx-2),
range(-3, 3, length=dburgers.ny-2),
adjoint_velocity_magnitude(dburgers);
axis=(type=Axis3,),
adjoint_velocity_magnitude_sq(dburgers);
axis=(type=Axis3, azimuth = -pi/4,),
)

# ╔═╡ d17ba6af-0ccc-4de5-9de5-9cefd4afa87c
md"""
We use the [Revolve](https://doi.org/10.1145/347837.347846) algorithm and set it up for `tsteps=10` timesteps and with a limit of 2 checkpoints.
We use the [Revolve](https://doi.org/10.1145/347837.347846) algorithm and set it up for `tsteps=1000` timesteps and with a limit of 10 checkpoints.
"""

# ╔═╡ 6989b054-e102-4349-bbb4-f45fabfa4d3e
revolve = Revolve{Burgers}(tsteps, 2; verbose = 1)
revolve = Revolve{Burgers}(1000, 10; verbose = 1)

# ╔═╡ 9092bf93-a2a5-46b5-9798-4adbadacf3f0
md"""
Instead of $10$ forward steps, we now do $10$ extra forward steps, leading to an overhead of $20/10=2$.
Instead of $1,000$ forward steps, we now do $3,636$ forward steps, leading to an overhead of $3.6$.
"""

# ╔═╡ c2bbb137-50f0-4a5e-b68e-e35b665e06e1
md"""
Start the differentiation using Enzyme. The checkpointing scheme is passed as a Const or "non-active" variable.
"""

# ╔═╡ 42a682f0-6fc4-44ae-9c3f-1434a69f5ff6
begin
reset!(revolve)
autodiff(ReverseWithPrimal, final_energy!, Active, Duplicated(burgers, dburgers), Const(revolve))
end

# ╔═╡ 3db86c1b-6a9a-4855-8385-a0c2a8f073b6
md"""
We then plot the adjoint velocity magnitude.
"""

# ╔═╡ 5beb4b8b-2c08-440a-a275-12fba8bbd852
surface(
range(-3, 3, length=burgers.nx-2),
range(-3, 3, length=burgers.ny-2),
adjoint_velocity_magnitude(dburgers);
axis=(type=Axis3,),
)

# ╔═╡ caaef553-35cc-4984-9dcd-bc057bb93cf2
md"""
Let's do the high-resolution example
"""

# ╔═╡ b06b03fb-9be6-4077-8640-3a5dae4d7980
begin
set_bc!(burgers_hd)
set_ic!(burgers_hd)
reset!(revolve)
dburgers_hd = Enzyme.make_zero(deepcopy(burgers_hd))
revolve_hd = Revolve{Burgers}(1000, 10; verbose = 1)
autodiff(ReverseWithPrimal, final_energy!, Active, Duplicated(burgers_hd, dburgers_hd), Const(revolve_hd))
autodiff(ReverseWithPrimal, final_energy!, Active, Duplicated(burgers_hd, dburgers_hd), Const(revolve))
end

# ╔═╡ 9c5f3dbe-598c-4160-875f-51de499aba05
surface(
range(-3, 3, length=burgers_hd.nx-2),
range(-3, 3, length=burgers_hd.ny-2),
adjoint_velocity_magnitude(dburgers_hd);
adjoint_velocity_magnitude_sq(dburgers_hd);
axis=(type=Axis3,),
)

Expand Down Expand Up @@ -672,7 +645,7 @@ PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
[compat]
Adapt = "~4.0.4"
CairoMakie = "~0.12.2"
Checkpointing = "~0.9.3"
Checkpointing = "~0.9.4"
Enzyme = "~0.12.12"
HypertextLiteral = "~0.9.5"
KernelAbstractions = "~0.9.20"
Expand All @@ -685,7 +658,7 @@ PLUTO_MANIFEST_TOML_CONTENTS = """
julia_version = "1.10.4"
manifest_format = "2.0"
project_hash = "14708f6ea11412bb8ccaba2800ba809319433f13"
project_hash = "971aba4f718439fd9b0f2a4346450ffe9d29e256"
[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -779,12 +752,6 @@ version = "0.5.0"
[[deps.CRC32c]]
uuid = "8bf52ea8-c179-5cab-976a-9e18b702a9bc"
[[deps.CRlibm]]
deps = ["CRlibm_jll"]
git-tree-sha1 = "32abd86e3c2025db5172aa182b982debed519834"
uuid = "96374032-68de-5a5b-8d9e-752f78720389"
version = "1.0.1"
[[deps.CRlibm_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "e329286945d0cfc04456972ea732551869af1cfc"
Expand Down Expand Up @@ -827,9 +794,9 @@ weakdeps = ["SparseArrays"]
[[deps.Checkpointing]]
deps = ["ChainRulesCore", "DataStructures", "Enzyme", "HDF5", "LinearAlgebra", "Serialization"]
git-tree-sha1 = "fbdfab0de3acb9095942c1f381d43a8ea6fced78"
git-tree-sha1 = "fc899d226991468fad8ec922168b3536b1ee5026"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
version = "0.9.3"
version = "0.9.4"
[[deps.ColorBrewer]]
deps = ["Colors", "JSON", "Test"]
Expand Down Expand Up @@ -997,11 +964,6 @@ git-tree-sha1 = "e3ece7b5fb991252abd138a2978e970063fc1412"
uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef"
version = "0.0.121+0"
[[deps.ErrorfreeArithmetic]]
git-tree-sha1 = "d6863c556f1142a061532e79f611aa46be201686"
uuid = "90fa49ef-747e-5e6f-a989-263ba693cf1a"
version = "0.5.2"
[[deps.ExactPredicates]]
deps = ["IntervalArithmetic", "Random", "StaticArrays"]
git-tree-sha1 = "b3f2ff58735b5f024c392fde763f29b057e4b025"
Expand Down Expand Up @@ -1042,12 +1004,6 @@ git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea"
uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a"
version = "3.3.10+0"
[[deps.FastRounding]]
deps = ["ErrorfreeArithmetic", "LinearAlgebra"]
git-tree-sha1 = "6344aa18f654196be82e62816935225b3b9abe44"
uuid = "fa42c844-2597-5d31-933b-ebd51ab2693f"
version = "0.3.1"
[[deps.FileIO]]
deps = ["Pkg", "Requires", "UUIDs"]
git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322"
Expand Down Expand Up @@ -1296,22 +1252,36 @@ weakdeps = ["Unitful"]
InterpolationsUnitfulExt = "Unitful"
[[deps.IntervalArithmetic]]
deps = ["CRlibm", "EnumX", "FastRounding", "LinearAlgebra", "Markdown", "Random", "RecipesBase", "RoundingEmulator", "SetRounding", "StaticArrays"]
git-tree-sha1 = "f59e639916283c1d2e106d2b00910b50f4dab76c"
deps = ["CRlibm_jll", "MacroTools", "RoundingEmulator"]
git-tree-sha1 = "433b0bb201cd76cb087b017e49244f10394ebe9c"
uuid = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253"
version = "0.21.2"
version = "0.22.14"
[deps.IntervalArithmetic.extensions]
IntervalArithmeticDiffRulesExt = "DiffRules"
IntervalArithmeticForwardDiffExt = "ForwardDiff"
IntervalArithmeticRecipesBaseExt = "RecipesBase"
[deps.IntervalArithmetic.weakdeps]
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
[[deps.IntervalSets]]
git-tree-sha1 = "dba9ddf07f77f60450fe5d2e2beb9854d9a49bd0"
uuid = "8197267c-284f-5f27-9208-e0e47529a953"
version = "0.7.10"
weakdeps = ["Random", "RecipesBase", "Statistics"]
[deps.IntervalSets.extensions]
IntervalSetsRandomExt = "Random"
IntervalSetsRecipesBaseExt = "RecipesBase"
IntervalSetsStatisticsExt = "Statistics"
[deps.IntervalSets.weakdeps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[deps.IrrationalConstants]]
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
Expand Down Expand Up @@ -1831,12 +1801,6 @@ weakdeps = ["FixedPointNumbers"]
[deps.Ratios.extensions]
RatiosFixedPointNumbersExt = "FixedPointNumbers"
[[deps.RecipesBase]]
deps = ["PrecompileTools"]
git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff"
uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
version = "1.3.4"
[[deps.Reexport]]
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -1890,11 +1854,6 @@ version = "1.2.1"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[deps.SetRounding]]
git-tree-sha1 = "d7a25e439d07a17b7cdf97eecee504c50fedf5f6"
uuid = "3cc68bcd-71a2-5612-b932-767ffbe40ab0"
version = "0.2.1"
[[deps.ShaderAbstractions]]
deps = ["ColorTypes", "FixedPointNumbers", "GeometryBasics", "LinearAlgebra", "Observables", "StaticArrays", "StructArrays", "Tables"]
git-tree-sha1 = "79123bc60c5507f035e6d1d9e563bb2971954ec8"
Expand Down Expand Up @@ -2339,7 +2298,7 @@ version = "3.5.0+0"
# ╟─12dfa763-b307-4f76-9a0d-5b24e6130da9
# ╠═64b6eca3-8b8a-470c-b496-8207d88fb99c
# ╠═96a54adb-015c-4243-a99d-9cda695f2a4d
# ╟─e14d9a22-4d55-4e66-8a63-8fdccd0c6d27
# ╠═e14d9a22-4d55-4e66-8a63-8fdccd0c6d27
# ╠═48b5e47b-e090-401d-ad7a-4898874b5117
# ╟─3843e17d-7eb0-4b15-a1d8-ca7e0eaefd8d
# ╠═de82b316-f8b9-479d-acac-dd18768e1e43
Expand All @@ -2363,7 +2322,6 @@ version = "3.5.0+0"
# ╠═bd0d9801-9d0d-4bee-beea-131439c262cc
# ╟─af620e89-c59b-4d69-89ca-020d5d958810
# ╠═06f23f52-0fcc-4707-b06c-80f4529b506d
# ╠═8112d5d6-151c-4b40-ac85-450786dff438
# ╠═f152117a-578c-451f-86ce-4acd1a669bfd
# ╟─c81d0aff-fdaa-40c2-b78b-a89143bf401d
# ╠═f061c0c1-5955-4e1b-888c-922dbae316b8
Expand All @@ -2378,10 +2336,7 @@ version = "3.5.0+0"
# ╠═6989b054-e102-4349-bbb4-f45fabfa4d3e
# ╟─9092bf93-a2a5-46b5-9798-4adbadacf3f0
# ╟─c2bbb137-50f0-4a5e-b68e-e35b665e06e1
# ╠═42a682f0-6fc4-44ae-9c3f-1434a69f5ff6
# ╟─3db86c1b-6a9a-4855-8385-a0c2a8f073b6
# ╠═5beb4b8b-2c08-440a-a275-12fba8bbd852
# ╟─caaef553-35cc-4984-9dcd-bc057bb93cf2
# ╠═b06b03fb-9be6-4077-8640-3a5dae4d7980
# ╠═9c5f3dbe-598c-4160-875f-51de499aba05
# ╟─871a7d07-2417-4e73-a9b9-5d7916edb167
Expand Down

0 comments on commit ba817fa

Please sign in to comment.