Skip to content

Commit

Permalink
first working gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Jul 9, 2024
1 parent 5254983 commit 45edb17
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 79 deletions.
95 changes: 42 additions & 53 deletions src/general.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
calculate_separables_nokw([::Type{AT},] fct, sz::NTuple{N, Int}, offset = sz.÷2 .+1, scale = one(real(eltype(AT))),
factor = one(real(eltype(AT))), args...; dims = 1:N,
args...; dims = 1:N,
all_axes = (similar_arr_type(AT, dims=Val(1)))(undef, sum(sz[[dims...]])), pos=zero(real(eltype(AT))),
kwargs...) where {AT, N}
Expand All @@ -14,7 +14,6 @@ This function is used in `separable_view` and `separable_create`.
+ `sz`: the size of the result array (when appying the one-D axes)
+ `offset`: specifying the center (zero-position) of the result array in one-based coordinates. The default corresponds to the Fourier-center.
+ `scale`: multiplies the index before passing it to `fct`
+ `factor`: multiplies the result of `fct` before storing it in the result array.
+ `args`: further arguments which are passed over to the function `fct`.
+ `dims`: a vector `[]` of valid dimensions. Only these dimension will be calculated but they are oriented in ND.
+ `all_axes`: if provided, this memory is used instead of allocating a new one. This can be useful if you want to use the same memory for multiple calculations.
Expand All @@ -40,45 +39,43 @@ julia> gauss_sep = calculate_separables(fct, (6,5), (0.5,1.0), pos = (0.1,0.2))
function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
offset = sz2 .+1,
scale = one(real(eltype(AT))),
factor = one(real(eltype(AT))),
args...; dims = 1:N,
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(AT))),
kwargs...) where {AT, N}

RT = real(eltype(AT))
RT = real(float(eltype(AT)))
offset = isnothing(offset) ? sz2 .+1 : RT.(offset)
scale = isnothing(scale) ? one(real(eltype(RT))) : RT.(scale)
factor = isnothing(factor) ? one(real(eltype(RT))) : RT.(factor)
start = 1 .- offset
start = ntuple((d)->1, N) # 1 .- offset

# @show typeof(idc)
dims = [dims...]
valid_sz = sz[dims]
# return ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], offset))
# return out_of_place_assing(nothing, d, fct, pick_n(dims[d], scale) .* ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], pos)), sz[dims[d]], arg_n(dims[d], args, RT))
# allocate a contigous memory to be as cash-efficient as possible and dice it up below
res = ntuple((d) -> reorient((@view all_axes[1+sum(valid_sz[1:d])-sz[dims[d]]:sum(valid_sz[1:d])]), dims[d], Val(N)), lastindex(dims)) # Vector{AT}()

# @show kwarg_n(dims[1], kwargs)
# @show arg_n(dims[1], args)
# @show idc

# @show extra_args
# if isa(factor, Number)
# factor = ntuple((d) -> factor, N)
# end
# @show offset
# @show extra_args
# @show args
# @show kwargs
# @show collect(arg_n(dims[1], args))
# @show (idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...)
# res[1][:] .= (factor[1]) .* fct.(idc, sz[dims[1]], arg_n(dims[1], args)...)
# res[1][:] .= fct.(idc, sz[dims[1]], arg_n(dims[1], args)...)
# idc = pick_n(dims[1], scale) .* ((start[dims[1]]:start[dims[1]]+sz[dims[1]]-1) .- pick_n(dims[1], pos))
# res = in_place_assing!(res, 1, factor[1], fct, idc, sz[dims[1]], arg_n(dims[1], args))
# res = in_place_assing!(res, 1, fct, idc, sz[dims[1]], arg_n(dims[1], args))
#push!(res, collect(reorient(fct.(idc, sz[1], arg_n(1, args)...; kwarg_n(1, kwargs)...), 1, Val(N))))
# for d = eachindex(dims)

# toreturn = (in_place_assing!(res, 1, factor[1], fct, pick_n(dims[1], scale) .* ((start[dims[1]]:start[dims[1]]+sz[dims[1]]-1) .- pick_n(dims[1], pos)), sz[dims[1]], arg_n(dims[1], args, RT)),
# in_place_assing!(res, 2, factor[2], fct, pick_n(dims[2], scale) .* ((start[dims[2]]:start[dims[2]]+sz[dims[2]]-1) .- pick_n(dims[2], pos)), sz[dims[2]], arg_n(dims[2], args, RT)))
# toreturn = (in_place_assing!(res, 1, fct, pick_n(dims[1], scale) .* ((start[dims[1]]:start[dims[1]]+sz[dims[1]]-1) .- pick_n(dims[1], pos)), sz[dims[1]], arg_n(dims[1], args, RT)),
# in_place_assing!(res, 2, fct, pick_n(dims[2], scale) .* ((start[dims[2]]:start[dims[2]]+sz[dims[2]]-1) .- pick_n(dims[2], pos)), sz[dims[2]], arg_n(dims[2], args, RT)))

# idc = start[dims[1]]:start[dims[1]]+sz[dims[1]]-1
# @show collect(arg_n(dims[1], args, RT))
Expand All @@ -87,21 +84,14 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
# idc = pick_n(dims[d], scale) .* ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], pos))
# myaxis = collect(fct.(idc,arg_n(d, args)...)) # no need to reorient
# extra_args = kwargs_to_args(defaults, kwarg_n(dims[d], kwargs))
# tmp = let
# if isa(factor[d], Number)
# factor[d]
# else
# @view factor[d][:]
# end
# end
# res[d][:] .= tmp .* fct.(idc, sz[dims[d]], arg_n(dims[d], args)...)
in_place_assing!(res, d, factor, fct, pick_n(dims[d], scale) .* ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], pos)), sz[dims[d]], arg_n(dims[d], args, RT))
# out_of_place_assing(res, d, factor[d], fct, pick_n(dims[d], scale) .* ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], pos)), sz[dims[d]], arg_n(dims[d], args, RT))
# res[d][:] .= fct.(idc, sz[dims[d]], arg_n(dims[d], args)...)
in_place_assing!(res, d, fct, pick_n(dims[d], scale) .* ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], offset)), sz[dims[d]], arg_n(dims[d], args, RT))
# AT(out_of_place_assing(res, d, fct, pick_n(dims[d], scale) .* ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], offset)), sz[dims[d]], arg_n(dims[d], args, RT)))

# LazyArray representation of expression
# push!(res, myaxis)
, N) # Vector{AT}()
# @show toreturn[1]
# @show typeof(toreturn[1])
# end
return toreturn
# return res
Expand All @@ -110,21 +100,21 @@ end

# a special in-place assignment, which gets its own differentiation rule for the reverse mode
# to avoid problems with memory-assignment and AD.
function in_place_assing!(res, d, tmp, fct, idc, sz1d, args_d)
res[d][:] .= tmp .* fct.(idc, sz1d, args_d...)
function in_place_assing!(res, d, fct, idc, sz1d, args_d)
res[d][:] .= fct.(idc, sz1d, args_d...)
return res[d]
end

function out_of_place_assing(res, d, tmp, fct, idc, sz1d, args_d)
println("oop assign!")
return reorient(tmp .* fct.(idc, sz1d, args_d...), Val(d))
function out_of_place_assing(res, d, fct, idc, sz1d, args_d)
# println("oop assign!")
return reorient(fct.(idc, sz1d, args_d...), Val(d))
end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(in_place_assing!), res, d, tmp, fct, idc, sz1d, args_d)
println("in rrule in_place_assing!")
y = in_place_assing!(res, d, tmp, fct, idc, sz1d, args_d)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(in_place_assing!), res, d, fct, idc, sz1d, args_d)
# println("in rrule in_place_assing!")
y = in_place_assing!(res, d, fct, idc, sz1d, args_d)
# @show collect(y)
_, in_place_assing_pullback = rrule_via_ad(config, out_of_place_assing, res, d, tmp, fct, idc, sz1d, args_d)
_, in_place_assing_pullback = rrule_via_ad(config, out_of_place_assing, res, d, fct, idc, sz1d, args_d)

# function in_place_assing_pullback(dy) # dy is a tuple of arrays.
# println("in in_place_assing_pullback")
Expand All @@ -148,22 +138,20 @@ function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int},
defaults=NamedTuple(), pos=zero(real(eltype(AT))),
offset = sz2 .+1,
scale = one(real(eltype(AT))),
factor = one(real(eltype(AT))),
kwargs...) where {AT, N}

extra_args = kwargs_to_args(defaults, kwargs)
return calculate_separables_nokw(AT, fct, sz, offset, scale, factor, extra_args..., args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
return calculate_separables_nokw(AT, fct, sz, offset, scale, extra_args..., args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
end

function calculate_separables(fct, sz::NTuple{N, Int}, args...; dims=1:N,
all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])),
defaults=NamedTuple(), pos=zero(real(eltype(DefaultArrType))),
offset = sz2 .+1,
scale = one(real(eltype(DefaultArrType))),
factor = one(real(eltype(DefaultArrType))),
kwargs...) where {N}
extra_args = kwargs_to_args(defaults, kwargs)
calculate_separables(DefaultArrType, fct, sz, extra_args..., args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, factor=factor, kwargs...)
calculate_separables(DefaultArrType, fct, sz, extra_args..., args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, kwargs...)
end

# define custom adjoint for calculate_separables
Expand All @@ -172,16 +160,16 @@ end

#

# calculate_separables_nokw(AT, fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
# calculate_separables_nokw(AT, fct, sz, offset, scale, args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
# function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables_nokw), ::Type{AT}, fct, sz::NTuple{N, Int},
# function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables_nokw), ::Type{AT}, fct, sz::NTuple{N, Int},
# offset = sz.÷2 .+1, scale=one(real(eltype(AT))), factor = one(real(eltype(AT)),), args...; dims = 1:N,
# offset = sz.÷2 .+1, scale=one(real(eltype(AT))), args...; dims = 1:N,
# all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
# defaults = NamedTuple(), pos=zero(real(eltype(AT))), kwargs...) where {AT, N}

# println("inside calculate_separables_nokw rrule! $(sz), $(dims), $(offset)")
# # foward pass
# y = collect(calculate_broadcasted_nokw(AT, fct, sz, offset, scale, factor, args...;
# y = collect(calculate_broadcasted_nokw(AT, fct, sz, offset, scale, args...;
# dims=dims, all_axes=all_axes, defaults = defaults, pos=pos, kwargs...))

# @show size(y)
Expand Down Expand Up @@ -216,18 +204,16 @@ end
# @show d_offset_fct.(13.3:0.2:15.3, (20,), args...)
# @show offset
# @show scale
# @show factor
# doffset = (factor, factor) .* sum(reshape([dy...], sz) .* collect(calculate_broadcasted_nokw(AT, d_offset_fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_grad_axes, pos=pos, defaults = defaults, kwargs...)))
# doffset = sum(reshape([dy...], sz) .* collect(calculate_broadcasted_nokw(AT, d_offset_fct, sz, offset, scale, args...; dims=dims, all_axes=all_grad_axes, pos=pos, defaults = defaults, kwargs...)))
# @show doffset
# dscale = dy # .* collect(calculate_broadcasted_nokw(AT, d_scale_fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_axes, pos=pos, defaults = defaults, kwargs...))[:]
# dfactor = dy
# # dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# dscale = dy # .* collect(calculate_broadcasted_nokw(AT, d_scale_fct, sz, offset, scale, args...; dims=dims, all_axes=all_axes, pos=pos, defaults = defaults, kwargs...))[:]
# # dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, kwargs...)
# dargs = args;
# # It should return the gradient of the inputs
# # println(doffset)

# # calculate_separables_nokw(AT, fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
# return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), doffset, dscale, dfactor, dargs..., NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
# # calculate_separables_nokw(AT, fct, sz, offset, scale, args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
# return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), doffset, dscale, dargs..., NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
# end
# @show "returning from pullback"
# return y, calculate_separables_pullback
Expand Down Expand Up @@ -268,13 +254,13 @@ end
# # multiply by dy?
# # @show d_off_fct(13.3, 20, args...)

# doffset = haskey(kwargs, :offset) ? calculate_separables(AT, d_off_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = dy, kwargs...) : Zygote.NoTangent()
# doffset = haskey(kwargs, :offset) ? calculate_separables(AT, d_off_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, kwargs...) : Zygote.NoTangent()
# # doffset = collect(doffset)
# # @show size(.*(doffset...))
# # @show eltype(.*(doffset...))
# println("calculated doffset $(typeof(doffset))")
# dscale = 1 #calculate_separables(AT, d_scale_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# dscale = 1 #calculate_separables(AT, d_scale_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, kwargs...)
# dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, kwargs...)
# dargs = args;
# dkwargs = (;offset = doffset, scale = dscale, pos = dpos);
# # It should return the gradient of the inputs
Expand Down Expand Up @@ -345,7 +331,6 @@ function calculate_broadcasted(fct, sz::NTuple{N, Int}, args...; dims=1:N,
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
end


# function calculate_sep_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims=1:N,
# all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
# pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {AT, N}
Expand All @@ -364,14 +349,18 @@ function calculate_broadcasted_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, args...
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {AT, N}
# defaults should be evaluated here and filled into args...
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
res = Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
# @show eltype(collect(res))
return res
end

function calculate_broadcasted_nokw(fct, sz::NTuple{N, Int}, args...; dims=1:N,
all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {N}
# defaults should be evaluated here and filled into args...
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
res = Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
# @show eltype(collect(res))
return res
end

# towards a Gaussian that can also be rotated:
Expand Down
16 changes: 8 additions & 8 deletions src/specific.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,20 @@ for F in generate_functions_expr()
) where {TA, N}
fct = $(F[3]) # to assign the function to a symbol

# return calculate_broadcasted_nokw(TA, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes)
pos=zero(real(eltype(TA)))
operation=$(F[5])
return calculate_separables_nokw(TA, fct, sz, args...; pos=pos, all_axes=all_axes), operation
return calculate_broadcasted_nokw(TA, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes)
# pos=zero(real(eltype(TA)))
# operation=$(F[5])
# return calculate_separables_nokw(TA, fct, sz, args...; pos=pos, all_axes=all_axes), operation
end

@eval function $(Symbol(F[1], :_nokw_sep))(sz::NTuple{N, Int}, args...;
all_axes = (similar_arr_type(Array{$(F[4])}, eltype(Array{$(F[4])}), Val(1)))(undef, sum(sz[[(1:N)...]]))
) where {N}
fct = $(F[3]) # to assign the function to a symbol
# return calculate_broadcasted_nokw(Array{$(F[4])}, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes)
pos=zero(real(eltype(DefaultArrType)))
operation=$(F[5])
return calculate_separables_nokw(Array{$(F[4])}, fct, sz, args...; pos=pos, all_axes=all_axes), operation
return calculate_broadcasted_nokw(Array{$(F[4])}, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes)
# pos=zero(real(eltype(DefaultArrType)))
# operation=$(F[5])
# return calculate_separables_nokw(Array{$(F[4])}, fct, sz, args...; pos=pos, all_axes=all_axes), operation
end

@eval function $(Symbol(F[1], :_lz))(::Type{TA}, sz::NTuple{N, Int}, args...; kwargs...) where {TA, N}
Expand Down
Loading

0 comments on commit 45edb17

Please sign in to comment.