diff --git a/test/estimation.jl b/test/estimation.jl index fad934b..4f18f93 100644 --- a/test/estimation.jl +++ b/test/estimation.jl @@ -3,6 +3,7 @@ using Distributions: Uniform, Normal using Statistics: mean using LinearAlgebra: norm using PERK: perk, GaussianKernel, EuclideanKernel, GaussianRFF +using StableRNGs: StableRNG function test_perk_1() @@ -21,7 +22,7 @@ function test_perk_1() xhat = perk(rng, y, T, xDists, noiseDist, signalModels, kernel, ρ) error_rel = abs(xhat - xtrue) / xtrue - return isapprox(error_rel, 0.042677934398487306, atol = 1e-7) + @test isapprox(error_rel, 0.042677934398487306, atol = 1e-7) end @@ -47,7 +48,8 @@ function test_perk_2() end error_rel_avg = sum(error_rel) / length(error_rel) - return isapprox(error_rel_avg, 0.043934535569840415, atol = 1e-7) + ref = VERSION < v"1.11" ? 0.043934535569840415 : 0.04400107950561938 + @test isapprox(error_rel_avg, ref, atol = 1e-7) end @@ -70,7 +72,7 @@ function test_perk_3() xhat = perk(rng, y, T, xDists, noiseDist, signalModels, kernel, ρ) error_rel = abs(xhat[] - xtrue) / xtrue - return isapprox(error_rel, 0.07054474887124002, atol = 1e-7) + @test isapprox(error_rel, 0.07054474887124002, atol = 1e-7) end @@ -96,7 +98,9 @@ function test_perk_4() end error_rel_avg = sum(error_rel) / length(error_rel) - return isapprox(error_rel_avg, 0.05827088471817421, atol = 1e-7) + ref = VERSION < v"1.11" ? 0.05827088471817421 : 0.05822451811335079 + + @test isapprox(error_rel_avg, ref, atol = 1e-7) end @@ -120,7 +124,7 @@ function test_perk_5() xhat = perk(rng, y, ν, T, xDists, νDists, noiseDist, signalModels, kernel, ρ) error_rel = abs(xhat[] - xtrue) / xtrue - return isapprox(error_rel, 0.0057175117436099755, atol = 1e-7) + @test isapprox(error_rel, 0.0057175117436099755, atol = 1e-7) end @@ -146,7 +150,7 @@ function test_perk_6() signalModels, kernel, ρ) error_rel = norm(xhat .- xtrue) / (sqrt(N) * xtrue) - return isapprox(error_rel, 0.0009325470666789215, atol = 1e-7) + @test isapprox(error_rel, 0.0009325470666789215, atol = 1e-7) end @@ -168,7 +172,7 @@ function test_perk_7() xhat = perk(rng, y, ν, T, xDists, νDists, noiseDist, signalModels, kernel, ρ) error_rel = abs(xhat[] - xtrue) / xtrue - return isapprox(error_rel, 0.1083022379521612, atol = 1e-7) + @test isapprox(error_rel, 0.1083022379521612, atol = 1e-7) end @@ -190,7 +194,7 @@ function test_perk_8() xhat = perk(rng, y, ν, T, xDists, νDists, noiseDist, signalModels, kernel, ρ) error_rel = abs(xhat[] - xtrue) / xtrue - return isapprox(error_rel, 0.1083022379521612, atol = 1e-7) + @test isapprox(error_rel, 0.1083022379521612, atol = 1e-7) end @@ -210,7 +214,7 @@ function test_perk_9() xhat = perk(rng, y, T, xDists, noiseDist, signalModels, kernel, ρ) error_rel = abs(xhat[] - xtrue) / xtrue - return isapprox(error_rel, 0.09566274665119057, atol = 1e-6) + @test isapprox(error_rel, 0.09566274665119057, atol = 1e-6) end @@ -233,7 +237,7 @@ function test_perk_10() xhat = perk(rng, y, ν, T, xDists, νDists, noiseDist, signalModels, kernel, ρ) error_rel = norm(xhat .- xtrue) / (sqrt(N) * xtrue) - return isapprox(error_rel, 0.060384227201893494, atol = 1e-7) + @test isapprox(error_rel, 0.060384227201893494, atol = 1e-7) end @@ -256,7 +260,7 @@ function test_perk_11() xhat = perk(rng, y, ν, T, xDists, νDists, noiseDist, signalModels, kernel, ρ) error_rel = norm(xhat .- xtrue) / (sqrt(N) * xtrue) - return isapprox(error_rel, 0.060384227201893494, atol = 1e-7) + @test isapprox(error_rel, 0.060384227201893494, atol = 1e-7) end @@ -279,24 +283,24 @@ function test_perk_12() xhat = perk(rng, y, ν, T, xDists, νDists, noiseDist, signalModels, kernel, ρ) error_rel = norm(xhat .- xtrue) / (sqrt(N) * xtrue) - return isapprox(error_rel, 0.06038422720189351, atol = 1e-7) + @test isapprox(error_rel, 0.06038422720189351, atol = 1e-7) end @testset "PERK" begin - @test test_perk_1() - @test test_perk_2() - @test test_perk_3() - @test test_perk_4() - @test test_perk_5() - @test test_perk_6() - @test test_perk_7() - @test test_perk_8() - @test test_perk_9() - @test test_perk_10() - @test test_perk_11() - @test test_perk_12() + test_perk_1() + test_perk_2() + test_perk_3() + test_perk_4() + test_perk_5() + test_perk_6() + test_perk_7() + test_perk_8() + test_perk_9() + test_perk_10() + test_perk_11() + test_perk_12() end