Skip to content

Commit

Permalink
gradient working
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Jul 17, 2024
1 parent bee6d1c commit 8765137
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 7 deletions.
1 change: 1 addition & 0 deletions performance_tests/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566"
SeparableFunctions = "c8c7ead4-852c-491e-a42d-3d43bc74259e"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
13 changes: 13 additions & 0 deletions performance_tests/test_performance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,16 @@ a = rand(sz...);
@time c = collect(exp.(1f0im .* sqrt.(max.(0f0, 100.0f0 .- (i*i+j*j))) .* 1f0) for i=1:sz[1], j=1:sz[2]); # 0.11 sec
@time c = map((x) -> exp.(1f0im .* sqrt.(max.(0f0, 100.0f0 .- x)) .* 1f0), .+(myrr2...)); # 0.06

################
#### some gradient tests
using Zygote

sz = (10,10)
dat = rand(sz...)
loss = (off) -> sum(abs2.(gaussian_nokw_sep(sz, off, 1.0, (1.0,1.0)) .- dat))
mystart = (1.1,2.2)
loss(mystart)


g = gradient(loss, mystart);

103 changes: 96 additions & 7 deletions src/general.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,20 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
RT = real(float(eltype(AT)))
offset = isnothing(offset) ? sz2 .+1 : RT.(offset)
scale = isnothing(scale) ? one(real(eltype(RT))) : RT.(scale)
start = ntuple((d)->1, N) # 1 .- offset
# start = ntuple((d)->1, N) # 1 .- offset

# allocate a contigous memory to be as cash-efficient as possible and dice it up below
res = ntuple((d) -> reorient((@view all_axes[1+sum(sz[1:d])-sz[d]:sum(sz[1:d])]), d, Val(N)), N) # Vector{AT}()

toreturn = ntuple((d) ->
in_place_assing!(res, d, fct, pick_n(d, scale) .* ((start[d]:start[d]+sz[d]-1) .- pick_n(d, offset)), sz[d], arg_n(d, args, RT))
, N) # Vector{AT}()
# @show typeof(toreturn[1])
# end
in_place_assing!(res, d, fct, get_1d_ids(d, sz, offset, scale), sz[d], arg_n(d, args, RT))
, N)
return toreturn
# return res
end

get_1d_ids(d, sz, offset, scale) = pick_n(d, scale) .* ((1:sz[d]) .- pick_n(d, offset))


# a special in-place assignment, which gets its own differentiation rule for the reverse mode
# to avoid problems with memory-assignment and AD.
Expand All @@ -73,9 +73,26 @@ end
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(in_place_assing!), res, d, fct, idc, sz1d, args_d)
# println("in rrule in_place_assing!")
y = in_place_assing!(res, d, fct, idc, sz1d, args_d)
# @show d
# @show size(y)
# @show collect(y)
_, in_place_assing_pullback = rrule_via_ad(config, out_of_place_assing, res, d, fct, idc, sz1d, args_d)

function debug_dummy(dy)
println("in debug_dummy") # sz is (10, 20)
# @show dy # NoTangent()
# @show size(dy) # 1st calls: (1, 20) 2nd call: (10, 1)
myres = in_place_assing_pullback(dy)
# @show myres[1] # NoTangent()
# @show myres[2] # NoTangent()
# @show myres[3] # NoTangent()
# @show myres[4] # NoTangent()
# @show myres[5] #
# @show size(myres[5]) # 1st calls: (20,) 2nd call: (10,)
# @show myres[6] # 0.0
return myres
end

# function in_place_assing_pullback(dy) # dy is a tuple of arrays.
# println("in in_place_assing_pullback")

Expand All @@ -88,7 +105,7 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(in_
# return NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), each_deriv, NoTangent(), NoTangent()
# # return NoTangent(), each_deriv, NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
# end
return y, in_place_assing_pullback # in_place_assing_pullback
return y, in_place_assing_pullback # in_place_assing_pullback # in_place_assing_pullback
end


Expand Down Expand Up @@ -216,11 +233,83 @@ function calculate_broadcasted_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, args...
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz)),
operation = *, defaults = nothing, kwargs...) where {AT, N}
# defaults should be evaluated here and filled into args...
res = Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...))
res = Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw_hook(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...))
# @show eltype(collect(res))
return res
end

function calculate_separables_nokw_hook(::Type{AT}, fct, sz::NTuple{N, Int}, args...; kwargs...) where {AT, N}
return calculate_separables_nokw(AT, fct, sz, args...; kwargs...)
end

# function calculate_separables_nokw_hook2()
# end

# to make the code run with Tuples and Vectors alike
function optional_convert(ref_arg::T, val) where {T}
return T <: AbstractArray ? [val...] : val
end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables_nokw_hook), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; kwargs...) where {AT, N}
# println("in rrule calculate_separables_nokw_hook")
y = calculate_separables_nokw(AT, fct, sz, args...; kwargs...)
# @show typeof(y)
# @show collect(y)
_, calculate_separables_nokw_pullback = rrule_via_ad(config, calculate_separables_nokw, AT, fct, sz, args...; kwargs...)

# fct signature is like this: (x, sz, sigma)-> exp(-x^2/(2*sigma^2) . args has offset, scale, and further args
ids = get_1d_ids.(1:N, Ref(sz), Ref(args[1]), Ref(args[2]))
ids_offset_only = get_1d_ids.(1:N, Ref(sz), Ref(args[1]), one(eltype(AT)))
# wrap each function accordingly, yieldin a tuple of fuctions fcts
fcts = ntuple((d) -> ((ids, sz, args...) -> fct.(ids, sz, args...)), N)
# calcuate a pullback for each of the dimensions via broadcast. singlton arguments are automatically broadcasted:
# Note that ids as also a tuple of dimensions of ranges
calculate_fct_value_pullbacks = rrule_via_ad.(config, fcts, ids, sz, args[3:end]...)

# sep_diff = calculate_separables_nokw(AT, calculate_fct_pullbacks)

function calculate_separables_nokw_hook_pullback(dy)
# println("in calculate_separables_nokw_hook_pullback") # sz is (10, 20)
# @show dy
# @show typeof(dy) # Tangent{Any, Tuple{Matrix{Float32}, Matrix{Float32}}}
# @show size(dy[1]) # (10, 1)
# @show size(dy[2]) # (1, 20)
# myres = calculate_separables_nokw_pullback(dy)
# @show length(myres) # 7
# @show myres[1] # NoTangent() # mytype
# @show myres[2] # NoTangent() # AT
# @show myres[3] # NoTangent() # fct
# @show myres[4] # Tangent{Tuple{Int64, Int64}}(0.0f0, 0.0f0) # sz
# @show myres[5] # Tangent{Tuple{Float32, Float32}}(-2.893550157546997, -0.7438024878501892) # offset
# both of the derivatives below are based on the ids- (also called x-) derivative. The first two refers to the second return argument being the pullback
# x = scale * (ids_raw - offset), so the derivative dx/doffst = - scale, dx/dscale = ids_raw - offset
# dy(f(x(offset)))/doffset = dy(x)/df * dx/doffset = df(x)/dx * -scale
dx_dids = ntuple((d)->calculate_fct_value_pullbacks[d][2](one(eltype(AT)))[2], N)
doffset = optional_convert(args[1], ntuple((d) -> - (pick_n(d, args[2]) * (dy[d][:]' * dx_dids[d]))[1], N)) # ids @ offset the -1 is since the argument of fct is idx-offset
doffset = length(args[1]) == 1 ? sum(doffset) : doffset
# @show doffset
dscale = optional_convert(args[2], ntuple((d) -> ((@view dy[d][:])' * (ids_offset_only[d] .* dx_dids[d]))[1], N)) # ids @ offset the -1 is since the argument of fct is idx-offset
dscale = (length(args[2]) == 1) ? sum(dscale) : dscale
# @show myres[6] # 9.438228607177734 # scale
# @show dscale
# dy(f(x, arg)/darg = dy(x)/df * df / darg
# dx_dargs = ntuple((d) -> calculate_fct_value_pullbacks[d][2](@view dy[d][:])[3+1], N)
# @show dx_dargs
# dargs = ntuple((argno) -> ntuple((d) -> (ntuple((d) -> calculate_fct_value_pullbacks[d][2](@view dy[d][:])[3+argno], N))[d],
# N), # ids @ offset the -1 is since the argument of fct is idx-offset
# length(args)-2)
# dargs = ntuple((argno) -> ntuple((d) -> calculate_fct_value_pullbacks[d][2](@view dy[d][:])[3+argno], N), length(args)-2)
dargs = ntuple((argno) -> optional_convert(args[2+argno], ntuple((d) -> calculate_fct_value_pullbacks[d][2](@view dy[d][:])[3+argno], N)), length(args)-2)
# @show myres[7] # Tangent{Tuple{Float64, Float64}}((-4.5020714f0, -4.9361587f0) # sigma
# @show dargs
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), doffset, dscale, dargs...)
# return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), myres[5], myres[6], myres[7])
end

return y, calculate_separables_nokw_hook_pullback # in_place_assing_pullback # in_place_assing_pullback
end


function calculate_broadcasted_nokw(fct, sz::NTuple{N, Int}, args...;
all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz)),
operation = *, defaults = nothing, kwargs...) where {N}
Expand Down

0 comments on commit 8765137

Please sign in to comment.