Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Jul 21, 2024
1 parent a50eb9f commit 83a3412
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 12 deletions.
10 changes: 5 additions & 5 deletions src/specific.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,17 @@ function generate_functions_expr()
:((f, x, sz, sigma) -> .-x./sigma^2 .* f),
:((f, x, sz, sigma) -> x.^2 ./sigma^3 .* f)
),
(:(normal), (sigma=1.0,), :((x,sz, sigma) -> exp.(.- x.^2 ./(2*sigma^2)) ./ (sqrt(eltype(x)(2pi))*sigma)), Float32, *,
(:(normal), (sigma=1.0,), :((x,sz, sigma) -> exp.(.- x.^2 ./(2*sigma^2)) ./ (sqrt(eltype(x)(2pi))*abs(sigma))), Float32, *,
:((f, x, sz, sigma) -> .-x./sigma^2 .* f),
:((f, x, sz, sigma) -> (x.^2 ./sigma^3 .+ 1/sigma) .* f)
:((f, x, sz, sigma) -> (x.^2 ./sigma^3 .- inv(sigma)) .* f)
),
(:(sinc), NamedTuple(), :((x,sz) -> sinc.(x)), Float32, *,
:((f, x, sz) -> ifelse.(x .== zero(eltype(x)), zeros(eltype(x), size(x)), (cospi.(x) .- f)./x))
),
# the value "nothing" means that this default argument will not be handed over. But this works only for the last argument!
(:(exp_ikx), (shift_by=nothing,), :((x,sz, shift_by=sz÷2) -> cis.(x.*(-eltype(x)(2pi)*shift_by/sz))), ComplexF32, *,
:((f, x, sz, shift_by) -> (-eltype(x)(2pi)*shift_by/sz) .* f),
:((f, x, sz, shift_by) -> (-eltype(x)(2pi)/sz) .*x .* f)
:((f, x, sz, shift_by) -> (-1im*eltype(x)(2pi)*shift_by/sz) .* f),
:((f, x, sz, shift_by) -> (-1im*eltype(x)(2pi)/sz) .*x .* f)
),
(:(ramp), (slope=0,), :((x,sz, slope) -> slope.*x), Float32, +,
:((f, x, sz, slope) -> slope),
Expand Down Expand Up @@ -101,7 +101,7 @@ for F in generate_functions_expr()
# println("pb")
# @show dy
# @show $(F[6])(y, x, sz, args...; kwargs...)
mydx = dy .* $(F[6])(y, x, sz, args...; kwargs...)
mydx = conj.(dy) .* $(F[6])(y, x, sz, args...; kwargs...)
# targ = ntuple(d -> begin
# mydarg = F[6+d]
# dy .* $(mydarg)(y, x, sz, args...; kwargs...)
Expand Down
56 changes: 49 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using Test
using IndexFunArrays
using SeparableFunctions
using FiniteDifferences
using Zygote
using Random

function test_fct(T, fcts, sz, args...; kwargs...)
ifa, fct = fcts
Expand Down Expand Up @@ -165,9 +168,26 @@ end
@test maximum(abs.(res4 .- res5)) < 1e-6
end

function test_gradient(T, fct, sz, args...; kwargs...)
RT = real(T)
Random.seed!(1234)
dat = rand(T, sz...)
off0 = rand(RT, length(sz))
sca0 = rand(RT, length(sz))
args = ntuple((d)->RT.(args[d]), length(args))
loss = (off, sca, args...) -> sum(abs2.(fct(sz, off, sca, args..., kwargs...) .- dat))
# @show loss(off0, sca0, args...)
g = gradient(loss, off0, sca0, args...)
gn = grad(central_fdm(5, 1), loss, off0, sca0, args...) # 5th order method, 1st derivative
for r in 1:length(gn)
@test eltype(g[r]) == RT
# @show g[r]
# @show gn[r]
@test all(isapprox.(g[r], gn[r], rtol=1e-1))
end
end

@testset "gradient tests" begin
using FiniteDifferences
using Zygote
rng = collect(1:0.1:2)
sz = length(rng)

Expand All @@ -179,16 +199,38 @@ end
@test g[1] gn[1]
@test g[2] gn[2]

sz = (11, 22)
test_gradient(Float32, gaussian_nokw_sep, (11,22), (2.2, -0.8))
test_gradient(Float64, gaussian_nokw_sep, (6, 22, 7), 2.0)
test_gradient(Float32, gaussian_nokw_sep, (6,), 4.2f0)

test_gradient(Float64, normal_nokw_sep, (6, 22, 7), (2.0, -3.1, 1.2))

test_gradient(Float32, sinc_nokw_sep, (22, 11))
test_gradient(Float32, ramp_nokw_sep, (22, 11), (1.0, 2.0))
test_gradient(Float32, rr2_nokw_sep, (22, 11))
test_gradient(ComplexF32, exp_ikx_nokw_sep, (22, 11), (1.0, 2.0))

sz = (3,)
loss2 = (off, sca, shift_by) -> sum(imag.(exp_ikx_nokw_sep(sz, off, sca, shift_by)))
shift_by0 = 0.7
sca0 = 1.4
off0 = 0.3
loss2(off0, sca0, shift_by0)
g = gradient(loss2, off0, sca0, shift_by0)
gn = grad(central_fdm(5, 1), loss2, off0, sca0, shift_by0) # 5th order method, 1st derivative
@test all(isapprox.(g[1], gn[1], atol=5e-3))
@test all(isapprox.(g[2], gn[2], atol=1e-2))

sz = (11, 22, 7)
loss2 = (off, sca, sigma) -> sum(gaussian_nokw_sep(sz, off, sca, sigma))
sigma0 = 2.0
sca0 = (0.9, 1.2)
off0 = (1.1, 2.2)
sca0 = (0.9, 1.2, 0.4)
off0 = (0.9, 1.2, 0.4)
loss2(off0, sca0, sigma0)
g = gradient(loss2, off0, sca0, sigma0)
gn = grad(central_fdm(5, 1), loss2, off0, sca0, sigma0) # 5th order method, 1st derivative
@test all(isapprox.(g[1], gn[1], atol=2e-3))
@test all(isapprox.(g[2], gn[2], atol=2e-3))
@test all(isapprox.(g[1], gn[1], atol=5e-3))
@test all(isapprox.(g[2], gn[2], atol=1e-2))

end

Expand Down

0 comments on commit 83a3412

Please sign in to comment.