Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Dirichlet rand overflows #1702 #1886

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
53 changes: 46 additions & 7 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,57 @@ end

function _rand!(rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real})
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = rand(rng, Gamma(αi))
x::AbstractVector{E}) where {E<:Real}

if any(a -> a >= .5, d.alpha)
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = rand(rng, Gamma(αi))
end

return lmul!(inv(sum(x)), x)
else
# Sample in log-space to lower underflow risk
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = _logrand(rng, Gamma(αi))
end

if all(isinf, x)
# Final fallback, parameters likely deeply subnormal
# Distribution behavior approaches categorical as Σα -> 0
p = copy(d.alpha)
p .*= floatmax(eltype(p)) # rescale to non-subnormal
x .= zero(E)
x[rand(rng, Categorical(inv(sum(p)) .* p))] = one(E)
return x
end

return softmax!(x)
end
lmul!(inv(sum(x)), x) # this returns x
end

function _rand!(rng::AbstractRNG,
d::Dirichlet{T,<:FillArrays.AbstractFill{T}},
x::AbstractVector{<:Real}) where {T<:Real}
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
lmul!(inv(sum(x)), x) # this returns x
x::AbstractVector{E}) where {T<:Real, E<:Real}

if FillArrays.getindex_value(d.alpha) >= 0.5
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
return lmul!(inv(sum(x)), x)
else
# Sample in log-space to lower underflow risk
_logrand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)

if all(isinf, x)
# Final fallback, parameters likely deeply subnormal
# Distribution behavior approaches categorical as Σα -> 0
n = length(d.alpha)
p = Fill(inv(n), n)
x .= zero(E)
x[rand(rng, Categorical(p))] = one(E)
return x
end

return softmax!(x)
end
end

#######################################
Expand Down
1 change: 1 addition & 0 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ for fname in ["aliastable.jl",
"poisson.jl",
"exponential.jl",
"gamma.jl",
"expgamma.jl",
"multinomial.jl",
"vonmises.jl",
"vonmisesfisher.jl",
Expand Down
86 changes: 86 additions & 0 deletions src/samplers/expgamma.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# These are used to bypass subnormals when sampling from

# Inverse Power sampler
# uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1
struct ExpGammaIPSampler{S<:Sampleable{Univariate,Continuous},T<:Real} <: Sampleable{Univariate,Continuous}
s::S #sampler for Gamma(1+shape,scale)
nia::T #-1/scale
end

ExpGammaIPSampler(d::Gamma) = ExpGammaIPSampler(d, GammaMTSampler)
function ExpGammaIPSampler(d::Gamma, ::Type{S}) where {S<:Sampleable}
shape_d = shape(d)
sampler = S(Gamma{partype(d)}(1 + shape_d, scale(d)))
return GammaIPSampler(sampler, -inv(shape_d))
end

function rand(rng::AbstractRNG, s::ExpGammaIPSampler)
x = log(rand(rng, s.s))
e = randexp(rng, typeof(x))
return muladd(s.nia, e, x)
end


# Small Shape sampler
# From Liu, C., Martin, R., and Syring, N. (2015) for when shape < 0.3
struct ExpGammaSSSampler{T<:Real} <: Sampleable{Univariate,Continuous}
α::T
θ::T
λ::T
ω::T
ωω::T
end

function ExpGammaSSSampler(d::Gamma)
α = shape(d)
ω = α / MathConstants.e / (1 - α)
return ExpGammaSSSampler(promote(
α,
scale(d),
inv(α) - 1,
ω,
inv(ω + 1)
)...)
end

function rand(rng::AbstractRNG, s::ExpGammaSSSampler{T})::Float64 where T
flT = float(T)
while true
U = rand(rng, flT)
z = (U <= s.ωω) ? -log(U / s.ωω) : log(rand(rng, flT)) / s.λ
h = exp(-z - exp(-z / s.α))
η = z >= zero(T) ? exp(-z) : s.ω * s.λ * exp(s.λ * z)
if h / η > rand(rng, flT)
return s.θ - z / s.α
end
end
end


function _logsampler(d::Gamma)
if shape(d) < 0.3
return ExpGammaSSSampler(d)
else
return ExpGammaIPSampler(d)
end
end

function _logrand(rng::AbstractRNG, d::Gamma)
if shape(d) < 0.3
return rand(rng, ExpGammaSSSampler(d))
else
return rand(rng, ExpGammaIPSampler(d))
end
end

function _logrand!(rng::AbstractRNG, d::Gamma, A::AbstractArray{<:Real})
if shape(d) < 0.3
@inbounds for i in eachindex(A)
A[i] = rand(rng, ExpGammaSSSampler(d))
end
else
@inbounds for i in eachindex(A)
A[i] = rand(rng, ExpGammaIPSampler(d))
end
end
end
26 changes: 26 additions & 0 deletions test/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,29 @@ end
end
end
end

@testset "Dirichlet rand Inf and NaN (#1702)" begin
for d in [
Dirichlet([8e-5, 1e-5, 2e-5]),
Dirichlet([8e-4, 1e-4, 2e-4]),
Dirichlet([4.5e-5, 8e-5]),
Dirichlet([6e-5, 2e-5, 3e-5, 4e-5, 5e-5]),
Dirichlet(FillArrays.Fill(1e-5, 5))
]
x = rand(d, 10^6)
@test mean(x, dims = 2) ≈ mean(d) atol=0.01
@test var(x, dims = 2) ≈ var(d) atol=0.01
end

for (d, μ) in [ # Subnormal params cause mean(d) to error

(Dirichlet([5e-310, 5e-310, 5e-310]), [1/3, 1/3, 1/3]),
(Dirichlet(FillArrays.Fill(5e-310, 3)), [1/3, 1/3, 1/3]),
(Dirichlet([5e-321, 1e-321, 4e-321]), [.5, .1, .4]),
(Dirichlet([1e-321, 2e-321, 3e-321, 4e-321]), [.1, .2, .3, .4]),
(Dirichlet(FillArrays.Fill(1e-321, 4)), [.25, .25, .25, .25])
]
x = rand(d, 10^6)
@test mean(x, dims = 2) ≈ μ atol=0.01
end
end
Loading