Skip to content

Commit

Permalink
Don't export set operations on Domain (for now)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kolaru committed Feb 10, 2025
1 parent aa65113 commit 8982c71
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 45 deletions.
31 changes: 16 additions & 15 deletions ext/IntervalArithmeticForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module IntervalArithmeticForwardDiffExt

using IntervalArithmetic, ForwardDiff
using IntervalArithmetic: isempty_domain, overlap_domain, intersect_domain, in_domain, leftof
using ForwardDiff: Dual, Partials, , value, partials

ForwardDiff.can_dual(::Type{ExactReal}) = true
Expand Down Expand Up @@ -99,34 +100,34 @@ end

function (piecewise::Piecewise)(dual::Dual{T, <:Interval}) where {T}
X = value(dual)
set = Domain(X)
if !IntervalArithmetic.overlapdomain(set, piecewise)
input_domain = Domain(X)
if !overlap_domain(input_domain, piecewise)
return Dual{T}(emptyinterval(X), emptyinterval(X) .* partials(dual))
end

if !IntervalArithmetic.indomain(set, piecewise)
if !in_domain(input_domain, piecewise)
dec = trv
elseif any(in(set), discontinuities(piecewise, 1))
elseif any(x -> in_domain(x, input_domain), discontinuities(piecewise, 1))
dec = def
else
dec = com
end

dual_results = []
for (subdomain, f) in pieces(piecewise)
subset = intersect(set, subdomain)
isempty(subset) && continue
sub_X = interval(inf(subset), sup(subset), decoration(X))
dual_piece_outputs = []
for (piece_domain, f) in pieces(piecewise)
piece_input = intersect_domain(input_domain, piece_domain)
isempty_domain(piece_input) && continue
sub_X = interval(inf(piece_input), sup(piece_input), decoration(X))
sub_dual = Dual{T}(sub_X, partials(dual))
push!(dual_results, f(sub_dual))
push!(dual_piece_outputs, f(sub_dual))
end

results = value.(dual_results)
dec = min(dec, minimum(decoration.(results)))
primal = IntervalArithmetic.setdecoration(reduce(hull, results), dec)
piece_outputs = value.(dual_piece_outputs)
dec = min(dec, minimum(decoration.(piece_outputs)))
primal = IntervalArithmetic.setdecoration(reduce(hull, piece_outputs), dec)

dresults = partials.(dual_results)
partial = map(zip(dresults...)) do pp
doutputs = partials.(dual_piece_outputs)
partial = map(zip(doutputs...)) do pp
pdec = min(dec, minimum(decoration.(pp)))
return IntervalArithmetic.setdecoration(reduce(hull, pp), pdec)
end
Expand Down
34 changes: 17 additions & 17 deletions src/piecewise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@ function leftof(d1::Domain, d2::Domain)
return val1 < val2
end

Base.in(x::Real, domain::Domain) = rightof(x, lowerbound(domain)) && leftof(x, upperbound(domain))
in_domain(x::Real, domain::Domain) = rightof(x, lowerbound(domain)) && leftof(x, upperbound(domain))

function Base.intersect(d1::Domain, d2::Domain)
function intersect_domain(d1::Domain, d2::Domain)
left = max(lowerbound(d1), lowerbound(d2))
right = min(upperbound(d1), upperbound(d2))

left > right && return Domain()
return Domain(left, right)
end

function Base.isempty(domain::Domain)
function isempty_domain(domain::Domain)
lo, lobound = lowerbound(domain)
hi, hibound = upperbound(domain)

Expand Down Expand Up @@ -258,7 +258,7 @@ function Base.show(io::IO, ::MIME"text/plain", piecewise::Piecewise)
end
end

function indomain(domain, piecewise)
function in_domain(domain, piecewise)
rightof(upperbound(domain), upperbound(domains(piecewise)[end])) && return false

# This relies on the fact that domains are ordered
Expand All @@ -283,34 +283,34 @@ function indomain(domain, piecewise)
return true
end

overlapdomain(domain, piecewise) = any(!isempty, intersect.(Ref(domain), domains(piecewise)))
overlap_domain(domain, piecewise) = any(!isempty_domain, intersect_domain.(Ref(domain), domains(piecewise)))

function (piecewise::Piecewise)(X::Interval{T}) where {T}
set = Domain(X)
!overlapdomain(set, piecewise) && return emptyinterval(T)
input_domain = Domain(X)
!overlap_domain(input_domain, piecewise) && return emptyinterval(T)

if !indomain(set, piecewise)
if !in_domain(input_domain, piecewise)
dec = trv
elseif any(in(set), discontinuities(piecewise))
elseif any(x -> in_domain(x, input_domain), discontinuities(piecewise))
dec = def
else
dec = com
end

results = Interval{T}[]
for (domain, f) in pieces(piecewise)
subset = intersect(set, domain)
isempty(subset) && continue
push!(results, f(interval(inf(subset), sup(subset), decoration(X))))
piece_outputs = Interval{T}[]
for (piece_domain, f) in pieces(piecewise)
piece_input = intersect_domain(input_domain, piece_domain)
isempty_domain(piece_input) && continue
push!(piece_outputs, f(interval(inf(piece_input), sup(piece_input), decoration(X))))
end

dec = min(dec, minimum(decoration.(results)))
return IntervalArithmetic.setdecoration(reduce(hull, results), dec)
dec = min(dec, minimum(decoration.(piece_outputs)))
return IntervalArithmetic.setdecoration(reduce(hull, piece_outputs), dec)
end

function (piecewise::Piecewise)(x::Real)
for (domain, f) in pieces(piecewise)
(x in domain) && return f(x)
(in_domain(x, domain)) && return f(x)
end
throw(DomainError(x, "piecewise function was called outside of its domain $(domain_string(piecewise))"))
end
26 changes: 13 additions & 13 deletions test/interval_tests/piecewise.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using IntervalArithmetic: leftof
using IntervalArithmetic: leftof, intersect_domain, isempty_domain

@testset "Domain" begin
@testset "Construction" begin
Expand Down Expand Up @@ -29,25 +29,25 @@ using IntervalArithmetic: leftof
@test leftof(d2, d4)
end

@testset "intersect" begin
@testset "intersect_domain" begin
d1 = Domain{:closed, :open}(0, 10)
d2 = Domain{:closed, :open}(2, 15)
d3 = Domain{:open, :closed}(4, 7)
d4 = Domain{:open, :closed}(-20, 3)

@test intersect(d1, d2) == Domain{:closed, :open}(2, 10)
@test intersect(d1, d3) == Domain{:open, :closed}(4, 7)
@test intersect(d1, d4) == Domain{:closed, :closed}(0, 3)
@test intersect(d2, d3) == Domain{:open, :closed}(4, 7)
@test intersect(d2, d4) == Domain{:closed, :closed}(2, 3)
@test intersect(d3, d4) == Domain()
@test intersect_domain(d1, d2) == Domain{:closed, :open}(2, 10)
@test intersect_domain(d1, d3) == Domain{:open, :closed}(4, 7)
@test intersect_domain(d1, d4) == Domain{:closed, :closed}(0, 3)
@test intersect_domain(d2, d3) == Domain{:open, :closed}(4, 7)
@test intersect_domain(d2, d4) == Domain{:closed, :closed}(2, 3)
@test intersect_domain(d3, d4) == Domain()
end

@testset "isempty" begin
@test isempty(Domain{:open, :open}(1, 1))
@test !isempty(Domain{:closed, :closed}(1, 1))
@test !isempty(Domain{:open, :open}(1, 2))
@test isempty(Domain())
@testset "isempty_domain" begin
@test isempty_domain(Domain{:open, :open}(1, 1))
@test !isempty_domain(Domain{:closed, :closed}(1, 1))
@test !isempty_domain(Domain{:open, :open}(1, 2))
@test isempty_domain(Domain())
end
end

Expand Down

0 comments on commit 8982c71

Please sign in to comment.