Skip to content

Commit

Permalink
Ensure correct rounding for irrationals (#617)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierHnt authored Jan 19, 2024
1 parent 946493a commit 1d744c3
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 58 deletions.
83 changes: 31 additions & 52 deletions src/intervals/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ Internal constructor which assumes that `is_valid_interval(lo, hi) == true`.
Since misuse of this function can deeply corrupt code, its usage is
**strongly discouraged** in favour of [`bareinterval`](@ref).
"""
_unsafe_bareinterval
_unsafe_bareinterval(::Type{T}, a, b) where {T<:NumTypes} =
_unsafe_bareinterval(T, _round(T, a, RoundDown), _round(T, b, RoundUp))

_normalisezero(a) = ifelse(iszero(a), zero(a), a)
# used only to construct intervals; needed to avoid `inf` and `sup` normalization
Expand All @@ -114,51 +115,35 @@ _inf(x::Real) = x
_sup(x::Real) = x
#

_unsafe_bareinterval(::Type{T}, a::Rational, b::Rational) where {S<:Integer,T<:Rational{S}} =
_unsafe_bareinterval(T, T(a), T(b))
_unsafe_bareinterval(::Type{T}, a::Rational, b::Rational) where {S<:Union{Int8,UInt8},T<:Rational{S}} =
_unsafe_bareinterval(T, T(a), T(b))
_unsafe_bareinterval(::Type{T}, a::Rational, b::Rational) where {S<:Union{Int16,UInt16},T<:Rational{S}} =
_unsafe_bareinterval(T, T(a), T(b))
_unsafe_bareinterval(::Type{T}, a::Rational, b) where {S<:Integer,T<:Rational{S}} =
_unsafe_bareinterval(T, T(a), rationalize(S, nextfloat(float(S)(b, RoundUp))))
_unsafe_bareinterval(::Type{T}, a, b::Rational) where {S<:Integer,T<:Rational{S}} =
_unsafe_bareinterval(T, rationalize(S, nextfloat(float(S)(a, RoundDown))), T(b))
function _unsafe_bareinterval(::Type{T}, a, b) where {S<:Integer,T<:Rational{S}}
R = float(S)
return _unsafe_bareinterval(T, rationalize(S, prevfloat(R(a, RoundDown))), rationalize(S, nextfloat(R(b, RoundUp))))
_round(::Type{T}, a, r::RoundingMode) where {T<:NumTypes} = __round(T, a, r)
# irrationals
_round(::Type{<:NumTypes}, ::AbstractIrrational, ::RoundingMode) = throw(ArgumentError("only irrationals from MathConstants are supported"))
for irr (:(), :(), :(:catalan)) # irrationals supported by MPFR
@eval _round(::Type{T}, a::Irrational{$irr}, r::RoundingMode) where {T<:NumTypes} =
__round(T, BigFloat(a, r), r)
end
# irrationals not supported by MPFR, use their exact formula ℯ = exp(1), φ = (1+sqrt(5))/2
_round(::Type{T}, ::Irrational{:ℯ}, r::RoundingMode{:Down}) where {T<:NumTypes} =
__round(T, inf(exp(bareinterval(BigFloat, 1))), r)
_round(::Type{T}, ::Irrational{:ℯ}, r::RoundingMode{:Up}) where {T<:NumTypes} =
__round(T, sup(exp(bareinterval(BigFloat, 1))), r)
_round(::Type{T}, ::Irrational{:φ}, r::RoundingMode{:Down}) where {T<:NumTypes} =
__round(T, inf((bareinterval(BigFloat, 1) + sqrt(bareinterval(BigFloat, 5))) / bareinterval(BigFloat, 2)), r)
_round(::Type{T}, ::Irrational{:φ}, r::RoundingMode{:Up}) where {T<:NumTypes} =
__round(T, sup((bareinterval(BigFloat, 1) + sqrt(bareinterval(BigFloat, 5))) / bareinterval(BigFloat, 2)), r)
# floats
__round(::Type{T}, a, r::RoundingMode) where {T<:AbstractFloat} = T(a, r)
# rationals
__round(::Type{T}, a::Rational, ::RoundingMode{:Down}) where {S<:Integer,T<:Rational{S}} = T(a)
__round(::Type{T}, a::Rational, ::RoundingMode{:Up}) where {S<:Integer,T<:Rational{S}} = T(a)
__round(::Type{T}, a, r::RoundingMode{:Down}) where {S<:Integer,T<:Rational{S}} =
rationalize(S, prevfloat(__float(S)(a, r)))
__round(::Type{T}, a, r::RoundingMode{:Up}) where {S<:Integer,T<:Rational{S}} =
rationalize(S, nextfloat(__float(S)(a, r)))
# need the following since `float(Int8) == float(Int16) == Float64`
_unsafe_bareinterval(::Type{T}, a, b) where {S<:Union{Int8,UInt8},T<:Rational{S}} =
_unsafe_bareinterval(T, rationalize(S, prevfloat(Float16(a, RoundDown))), rationalize(S, nextfloat(Float16(b, RoundUp))))
_unsafe_bareinterval(::Type{T}, a, b) where {S<:Union{Int16,UInt16},T<:Rational{S}} =
_unsafe_bareinterval(T, rationalize(S, prevfloat(Float32(a, RoundDown))), rationalize(S, nextfloat(Float32(b, RoundUp))))

_unsafe_bareinterval(::Type{T}, a, b) where {T<:AbstractFloat} = _unsafe_bareinterval(T, T(a, RoundDown), T(b, RoundUp))

# by-pass the absence of `BigFloat(..., ROUNDING_MODE)` (cf. base/irrationals.jl)
# for some irrationals defined in MathConstants (cf. base/mathconstants.jl)
for sym (:(:ℯ), :())
@eval begin
_unsafe_bareinterval(::Type{BigFloat}, a::Irrational{:ℯ}, b::Irrational{$sym}) =
_unsafe_bareinterval(BigFloat, BigFloat(Float64(a, RoundDown), RoundDown), BigFloat(Float64(b, RoundUp), RoundUp))
_unsafe_bareinterval(::Type{BigFloat}, a::Irrational{:φ}, b::Irrational{$sym}) =
_unsafe_bareinterval(BigFloat, BigFloat(Float64(a, RoundDown), RoundDown), BigFloat(Float64(b, RoundUp), RoundUp))
_unsafe_bareinterval(::Type{BigFloat}, a::Irrational{$sym}, b) =
_unsafe_bareinterval(BigFloat, BigFloat(Float64(a, RoundDown), RoundDown), BigFloat(b, RoundUp))
_unsafe_bareinterval(::Type{BigFloat}, a, b::Irrational{$sym}) =
_unsafe_bareinterval(BigFloat, BigFloat(a, RoundDown), BigFloat(Float64(b, RoundUp), RoundUp))

_unsafe_bareinterval(::Type{Rational{BigInt}}, a::Irrational{:ℯ}, b::Irrational{$sym}) =
_unsafe_bareinterval(Rational{BigInt}, BigFloat(Float64(a, RoundDown), RoundDown), BigFloat(Float64(b, RoundUp), RoundUp))
_unsafe_bareinterval(::Type{Rational{BigInt}}, a::Irrational{:φ}, b::Irrational{$sym}) =
_unsafe_bareinterval(Rational{BigInt}, BigFloat(Float64(a, RoundDown), RoundDown), BigFloat(Float64(b, RoundUp), RoundUp))
_unsafe_bareinterval(::Type{Rational{BigInt}}, a::Irrational{$sym}, b) =
_unsafe_bareinterval(Rational{BigInt}, BigFloat(Float64(a, RoundDown), RoundDown), BigFloat(b, RoundUp))
_unsafe_bareinterval(::Type{Rational{BigInt}}, a, b::Irrational{$sym}) =
_unsafe_bareinterval(Rational{BigInt}, BigFloat(a, RoundDown), BigFloat(Float64(b, RoundUp), RoundUp))
end
end
__float(::Type{T}) where {T<:Integer} = float(T)
__float(::Type{T}) where {T<:Union{Int8,UInt8}} = Float16
__float(::Type{T}) where {T<:Union{Int16,UInt16}} = Float32

BareInterval{T}(x::BareInterval) where {T<:NumTypes} = convert(BareInterval{T}, x)

Expand Down Expand Up @@ -214,12 +199,6 @@ bareinterval(::Type{T}, a::BareInterval) where {T<:NumTypes} =
bareinterval(::Type{T}, a::Tuple) where {T<:NumTypes} = bareinterval(T, a...)
bareinterval(a::Tuple) = bareinterval(T, a...)

# note: generated functions must be defined after all the methods they use
@generated function bareinterval(::Type{T}, a::AbstractIrrational) where {T<:NumTypes}
res = _unsafe_bareinterval(T, a(), a()) # precompute the interval
return :($res) # set body of the function to return the precomputed result
end

# promotion

Base.promote_rule(::Type{BareInterval{T}}, ::Type{BareInterval{S}}) where {T<:NumTypes,S<:NumTypes} =
Expand Down Expand Up @@ -312,8 +291,8 @@ Internal constructor which assumes that `bareinterval` and its decoration
_unsafe_interval

# used only to construct intervals
_inf(a::Interval) = a.bareinterval.lo
_sup(a::Interval) = a.bareinterval.hi
_inf(x::Interval) = x.bareinterval.lo
_sup(x::Interval) = x.bareinterval.hi
#

function bareinterval(x::Interval)
Expand Down
6 changes: 2 additions & 4 deletions src/intervals/flavor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ Some flavors `F` include:
julia> IntervalArithmetic.default_flavor()
IntervalArithmetic.Flavor{:set_based}()
julia> isempty_interval(bareinterval(Inf, Inf))
┌ Warning: invalid interval, empty interval is returned
└ @ IntervalArithmetic ~/work/IntervalArithmetic.jl/IntervalArithmetic.jl/src/intervals/construction.jl:202
true
julia> IntervalArithmetic.is_valid_interval(Inf, Inf)
false
julia> isempty_interval(bareinterval(0)/bareinterval(0))
true
Expand Down
8 changes: 8 additions & 0 deletions src/intervals/intervals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ include("interval_operations/set_operations.jl")
export intersect_interval, hull, interiordiff
include("interval_operations/bisect.jl")
export bisect, mince



# Note: generated functions must be defined after all the methods they use
@generated function bareinterval(::Type{T}, a::AbstractIrrational) where {T<:NumTypes}
x = _unsafe_bareinterval(T, a(), a()) # precompute the interval
return :($x) # set body of the function to return the precomputed result
end
4 changes: 2 additions & 2 deletions test/interval_tests/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ end
end

@testset "Irrationals" begin
for irr (MathConstants.:e, MathConstants., MathConstants.:π, MathConstants., MathConstants.:ℯ)
for irr (MathConstants.:π, MathConstants., MathConstants.:catalan, MathConstants., MathConstants.:ℯ)
for T (Float16, Float32, Float64, BigFloat)
@test in_interval(irr, interval(T, irr))
if T !== BigFloat
Expand Down Expand Up @@ -157,7 +157,7 @@ end
@test_throws DomainError convert(Interval{Float64}, interval(1+im))
end

@testset "Propagation of `isguaranteed" begin
@testset "Propagation of `isguaranteed`" begin
@test !isguaranteed(interval(convert(Interval{Float64}, 0), interval(convert(Interval{Float64}, 1))))
@test !isguaranteed(interval(0, convert(Interval{Float64}, 1)))
@test !isguaranteed(interval(convert(Interval{Float64}, 0), 1))
Expand Down

0 comments on commit 1d744c3

Please sign in to comment.