Skip to content

Commit

Permalink
Add support for irrationals in MathConstants
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierHnt committed Jan 19, 2024
1 parent 317c279 commit 99cfa94
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 64 deletions.
87 changes: 33 additions & 54 deletions src/intervals/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ struct BareInterval{T<:NumTypes}

# need explicit signatures to avoid method ambiguities

global __unsafe_bareinterval(::Type{T}, a::T, b::T) where {S<:Integer,T<:Rational{S}} =
global _unsafe_bareinterval(::Type{T}, a::T, b::T) where {S<:Integer,T<:Rational{S}} =
new{T}(_normalisezero(a), _normalisezero(b))
__unsafe_bareinterval(::Type{T}, a::T, b::T) where {S<:Union{Int8,UInt8},T<:Rational{S}} =
_unsafe_bareinterval(::Type{T}, a::T, b::T) where {S<:Union{Int8,UInt8},T<:Rational{S}} =
new{T}(_normalisezero(a), _normalisezero(b))
__unsafe_bareinterval(::Type{T}, a::T, b::T) where {S<:Union{Int16,UInt16},T<:Rational{S}} =
_unsafe_bareinterval(::Type{T}, a::T, b::T) where {S<:Union{Int16,UInt16},T<:Rational{S}} =
new{T}(_normalisezero(a), _normalisezero(b))

__unsafe_bareinterval(::Type{T}, a::T, b::T) where {T<:AbstractFloat} =
_unsafe_bareinterval(::Type{T}, a::T, b::T) where {T<:AbstractFloat} =
new{T}(_normalisezero(a), _normalisezero(b))
end

Expand All @@ -104,31 +104,8 @@ Internal constructor which assumes that `is_valid_interval(lo, hi) == true`.
Since misuse of this function can deeply corrupt code, its usage is
**strongly discouraged** in favour of [`bareinterval`](@ref).
"""
_unsafe_bareinterval(::Type{T}, a, b) where {T<:NumTypes} = __unsafe_bareinterval(T, a, b)

_unsafe_bareinterval(::Type{T}, ::AbstractIrrational, ::AbstractIrrational) where {T<:NumTypes} = throw(ArgumentError("only irrationals from MathConstants are supported"))
_unsafe_bareinterval(::Type{T}, ::AbstractIrrational, _) where {T<:NumTypes} = throw(ArgumentError("only irrationals from MathConstants are supported"))
_unsafe_bareinterval(::Type{T}, _, ::AbstractIrrational) where {T<:NumTypes} = throw(ArgumentError("only irrationals from MathConstants are supported"))
for irr1 (:(), :(), :(:catalan))
for irr2 (:(), :(), :(:catalan))
@eval begin
function _unsafe_bareinterval(::Type{T}, a::Irrational{$irr1}, b::Irrational{$irr2}) where {T<:NumTypes}
big_x = __unsafe_bareinterval(big(float(T)), a, b)
return __unsafe_bareinterval(T, _inf(big_x), _sup(big_x))
end
end
end
@eval begin
function _unsafe_bareinterval(::Type{T}, a::Irrational{$irr1}, b) where {T<:NumTypes}
big_x = __unsafe_bareinterval(big(float(T)), a, b)
return __unsafe_bareinterval(T, _inf(big_x), _sup(big_x))
end
function _unsafe_bareinterval(::Type{T}, a, b::Irrational{$irr1}) where {T<:NumTypes}
big_x = __unsafe_bareinterval(big(float(T)), a, b)
return __unsafe_bareinterval(T, _inf(big_x), _sup(big_x))
end
end
end
_unsafe_bareinterval(::Type{T}, a, b) where {T<:NumTypes} =
_unsafe_bareinterval(T, _round(T, a, RoundDown), _round(T, b, RoundUp))

_normalisezero(a) = ifelse(iszero(a), zero(a), a)
# used only to construct intervals; needed to avoid `inf` and `sup` normalization
Expand All @@ -138,27 +115,35 @@ _inf(x::Real) = x
_sup(x::Real) = x
#

__unsafe_bareinterval(::Type{T}, a::Rational, b::Rational) where {S<:Integer,T<:Rational{S}} =
__unsafe_bareinterval(T, T(a), T(b))
__unsafe_bareinterval(::Type{T}, a::Rational, b::Rational) where {S<:Union{Int8,UInt8},T<:Rational{S}} =
__unsafe_bareinterval(T, T(a), T(b))
__unsafe_bareinterval(::Type{T}, a::Rational, b::Rational) where {S<:Union{Int16,UInt16},T<:Rational{S}} =
__unsafe_bareinterval(T, T(a), T(b))
__unsafe_bareinterval(::Type{T}, a::Rational, b) where {S<:Integer,T<:Rational{S}} =
__unsafe_bareinterval(T, T(a), rationalize(S, nextfloat(float(S)(b, RoundUp))))
__unsafe_bareinterval(::Type{T}, a, b::Rational) where {S<:Integer,T<:Rational{S}} =
__unsafe_bareinterval(T, rationalize(S, nextfloat(float(S)(a, RoundDown))), T(b))
function __unsafe_bareinterval(::Type{T}, a, b) where {S<:Integer,T<:Rational{S}}
R = float(S)
return __unsafe_bareinterval(T, rationalize(S, prevfloat(R(a, RoundDown))), rationalize(S, nextfloat(R(b, RoundUp))))
_round(::Type{T}, a, r::RoundingMode) where {T<:NumTypes} = __round(T, a, r)
# irrationals
_round(::Type{<:NumTypes}, ::AbstractIrrational, ::RoundingMode) = throw(ArgumentError("only irrationals from MathConstants are supported"))
for irr (:(), :(), :(:catalan)) # irrationals supported by MPFR
@eval _round(::Type{T}, a::Irrational{$irr}, r::RoundingMode) where {T<:NumTypes} =
__round(T, BigFloat(a, r), r)
end
# irrationals not supported by MPFR, use their exact formula ℯ = exp(1), φ = (1+sqrt(5))/2
_round(::Type{T}, ::Irrational{:ℯ}, r::RoundingMode{:Down}) where {T<:NumTypes} =
__round(T, inf(exp(bareinterval(BigFloat, 1))), r)
_round(::Type{T}, ::Irrational{:ℯ}, r::RoundingMode{:Up}) where {T<:NumTypes} =
__round(T, sup(exp(bareinterval(BigFloat, 1))), r)
_round(::Type{T}, ::Irrational{:φ}, r::RoundingMode{:Down}) where {T<:NumTypes} =
__round(T, inf((bareinterval(BigFloat, 1) + sqrt(bareinterval(BigFloat, 5))) / bareinterval(BigFloat, 2)), r)
_round(::Type{T}, ::Irrational{:φ}, r::RoundingMode{:Up}) where {T<:NumTypes} =
__round(T, sup((bareinterval(BigFloat, 1) + sqrt(bareinterval(BigFloat, 5))) / bareinterval(BigFloat, 2)), r)
# floats
__round(::Type{T}, a, r::RoundingMode) where {T<:AbstractFloat} = T(a, r)
# rationals
__round(::Type{T}, a::Rational, ::RoundingMode{:Down}) where {S<:Integer,T<:Rational{S}} = T(a)
__round(::Type{T}, a::Rational, ::RoundingMode{:Up}) where {S<:Integer,T<:Rational{S}} = T(a)
__round(::Type{T}, a, r::RoundingMode{:Down}) where {S<:Integer,T<:Rational{S}} =
rationalize(S, prevfloat(__float(S)(a, r)))
__round(::Type{T}, a, r::RoundingMode{:Up}) where {S<:Integer,T<:Rational{S}} =
rationalize(S, nextfloat(__float(S)(a, r)))
# need the following since `float(Int8) == float(Int16) == Float64`
__unsafe_bareinterval(::Type{T}, a, b) where {S<:Union{Int8,UInt8},T<:Rational{S}} =
__unsafe_bareinterval(T, rationalize(S, prevfloat(Float16(a, RoundDown))), rationalize(S, nextfloat(Float16(b, RoundUp))))
__unsafe_bareinterval(::Type{T}, a, b) where {S<:Union{Int16,UInt16},T<:Rational{S}} =
__unsafe_bareinterval(T, rationalize(S, prevfloat(Float32(a, RoundDown))), rationalize(S, nextfloat(Float32(b, RoundUp))))

__unsafe_bareinterval(::Type{T}, a, b) where {T<:AbstractFloat} = __unsafe_bareinterval(T, T(a, RoundDown), T(b, RoundUp))
__float(::Type{T}) where {T<:Integer} = float(T)
__float(::Type{T}) where {T<:Union{Int8,UInt8}} = Float16
__float(::Type{T}) where {T<:Union{Int16,UInt16}} = Float32

BareInterval{T}(x::BareInterval) where {T<:NumTypes} = convert(BareInterval{T}, x)

Expand Down Expand Up @@ -214,12 +199,6 @@ bareinterval(::Type{T}, a::BareInterval) where {T<:NumTypes} =
bareinterval(::Type{T}, a::Tuple) where {T<:NumTypes} = bareinterval(T, a...)
bareinterval(a::Tuple) = bareinterval(T, a...)

# note: generated functions must be defined after all the methods they use
@generated function bareinterval(::Type{T}, a::AbstractIrrational) where {T<:NumTypes}
x = _unsafe_bareinterval(T, a(), a()) # precompute the interval
return :($x) # set body of the function to return the precomputed result
end

# promotion

Base.promote_rule(::Type{BareInterval{T}}, ::Type{BareInterval{S}}) where {T<:NumTypes,S<:NumTypes} =
Expand Down
8 changes: 8 additions & 0 deletions src/intervals/intervals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ include("interval_operations/set_operations.jl")
export intersect_interval, hull, interiordiff
include("interval_operations/bisect.jl")
export bisect, mince



# Note: generated functions must be defined after all the methods they use
@generated function bareinterval(::Type{T}, a::AbstractIrrational) where {T<:NumTypes}
x = _unsafe_bareinterval(T, a(), a()) # precompute the interval
return :($x) # set body of the function to return the precomputed result
end
11 changes: 1 addition & 10 deletions test/interval_tests/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ end
end

@testset "Irrationals" begin
for irr (MathConstants., MathConstants., MathConstants.:catalan)
for irr (MathConstants., MathConstants., MathConstants.:catalan, MathConstants., MathConstants.:ℯ)
for T (Float16, Float32, Float64, BigFloat)
@test in_interval(irr, interval(T, irr))
if T !== BigFloat
Expand All @@ -95,15 +95,6 @@ end
@test in_interval(irr, interval(Rational{T}, irr))
end
end

for irr (MathConstants., MathConstants.:ℯ)
for T (Float16, Float32, Float64, BigFloat)
@test_throws ArgumentError interval(T, irr)
end
for T [InteractiveUtils.subtypes(Signed) ; InteractiveUtils.subtypes(Unsigned)]
@test_throws ArgumentError interval(Rational{T}, irr)
end
end
end

@testset "Midpoint" begin
Expand Down

0 comments on commit 99cfa94

Please sign in to comment.