Skip to content

Commit

Permalink
Simplify init_cache dispatching
Browse files Browse the repository at this point in the history
  • Loading branch information
PerezHz committed Dec 23, 2024
1 parent 1c6c474 commit 5a77506
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 22 deletions.
24 changes: 12 additions & 12 deletions src/integrator/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,25 +125,25 @@ end

# init_cache

function init_cache(cachetype::Type{ScalarCache}, dense::Val{D}, t0::T, x0::U, maxsteps::Int, order::Int) where {D, U, T}
function init_cache(dense::Val{D}, t0::T, x0::U, maxsteps::Int, order::Int) where {D, U, T}
# Initialize the Taylor1 expansions
t, x = init_expansions(t0, x0, order)
# Initialize cache
return cachetype(
return ScalarCache(
Array{T}(undef, maxsteps + 1),
Array{U}(undef, maxsteps + 1),
init_psol(dense, maxsteps, 1, x),
t,
x)
end

function init_cache(cachetype::Type{ScalarCache}, ::Val{false}, trange::AbstractVector{T}, x0::U, maxsteps::Int, order::Int) where {U, T}
function init_cache(::Val{false}, trange::AbstractVector{T}, x0::U, maxsteps::Int, order::Int) where {U, T}
# Initialize the Taylor1 expansions
t0 = trange[1]
t, x = init_expansions(t0, x0, order)
# Initialize cache
nn = length(trange)
cache = cachetype(
cache = ScalarCache(
trange,
Array{U}(undef, nn),
init_psol(Val(false), maxsteps, 1, x),
Expand All @@ -153,12 +153,12 @@ function init_cache(cachetype::Type{ScalarCache}, ::Val{false}, trange::Abstract
return cache
end

function init_cache(cachetype::Type{VectorCache}, dense::Val{D}, t0::T, q0::Vector{U}, maxsteps::Int, order::Int) where {D, U, T}
function init_cache(dense::Val{D}, t0::T, q0::Vector{U}, maxsteps::Int, order::Int) where {D, U, T}
# Initialize the vector of Taylor1 expansions
t, x, dx = init_expansions(t0, q0, order)
# Initialize cache
dof = length(q0)
return cachetype(
return VectorCache(
Array{T}(undef, maxsteps + 1),
Array{U}(undef, dof, maxsteps + 1),
init_psol(dense, maxsteps, dof, x),
Expand All @@ -168,14 +168,14 @@ function init_cache(cachetype::Type{VectorCache}, dense::Val{D}, t0::T, q0::Vect
dx)
end

function init_cache(cachetype::Type{VectorTRangeCache}, ::Val{false}, trange::AbstractVector{T}, q0::Vector{U}, maxsteps::Int, order::Int) where {U, T}
function init_cache(::Val{false}, trange::AbstractVector{T}, q0::Vector{U}, maxsteps::Int, order::Int) where {U, T}
# Initialize the vector of Taylor1 expansions
t0 = trange[1]
t, x, dx = init_expansions(t0, q0, order)
# Initialize cache
nn = length(trange)
dof = length(q0)
cache = cachetype(
cache = VectorTRangeCache(
trange,
Array{U}(undef, dof, nn),
init_psol(Val(false), maxsteps, dof, x),
Expand All @@ -192,7 +192,7 @@ function init_cache(cachetype::Type{VectorTRangeCache}, ::Val{false}, trange::Ab
return cache
end

function init_cache(cachetype::Type{LyapunovSpectrumCache}, dense, t0::T, q0::Vector{U}, maxsteps::Int, order::Int) where {U, T}
function init_cache_lyap(t0::T, q0::Vector{U}, maxsteps::Int, order::Int) where {U, T}
# Initialize the vector of Taylor1 expansions
dof = length(q0)
jt = Matrix{U}(I, dof, dof)
Expand All @@ -201,7 +201,7 @@ function init_cache(cachetype::Type{LyapunovSpectrumCache}, dense, t0::T, q0::Ve
# Initialize cache
nx0 = length(x0)
dvars = Array{TaylorN{Taylor1{U}}}(undef, dof)
cache = cachetype(
cache = LyapunovSpectrumCache(
Array{T}(undef, maxsteps+1),
Array{U}(undef, dof, maxsteps+1),
nothing,
Expand All @@ -228,7 +228,7 @@ function init_cache(cachetype::Type{LyapunovSpectrumCache}, dense, t0::T, q0::Ve
return cache
end

function init_cache(cachetype::Type{LyapunovSpectrumTRangeCache}, dense, trange::AbstractVector{T}, q0::Vector{U}, maxsteps::Int, order::Int) where {U, T}
function init_cache_lyap(trange::AbstractVector{T}, q0::Vector{U}, maxsteps::Int, order::Int) where {U, T}
# Initialize the vector of Taylor1 expansions
t0 = trange[1]
dof = length(q0)
Expand All @@ -239,7 +239,7 @@ function init_cache(cachetype::Type{LyapunovSpectrumTRangeCache}, dense, trange:
nn = length(trange)
nx0 = length(x0)
dvars = Array{TaylorN{Taylor1{U}}}(undef, dof)
cache = cachetype(
cache = LyapunovSpectrumTRangeCache(
trange,
Array{U}(undef, dof, nn),
nothing,
Expand Down
8 changes: 4 additions & 4 deletions src/integrator/taylorinteg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function taylorinteg(f, x0::U, t0::T, tmax::T, order::Int, abstol::T, params = n
maxsteps::Int=500, parse_eqs::Bool=true, dense::Bool=true) where {T<:Real, U<:Number}

# Allocation
cache = init_cache(ScalarCache, Val(dense), t0, x0, maxsteps, order)
cache = init_cache(Val(dense), t0, x0, maxsteps, order)

# Determine if specialized jetcoeffs! method exists
parse_eqs, rv = _determine_parsing!(parse_eqs, f, cache.t, cache.x, params)
Expand Down Expand Up @@ -101,7 +101,7 @@ function taylorinteg(f!, q0::Vector{U}, t0::T, tmax::T, order::Int, abstol::T, p
maxsteps::Int=500, parse_eqs::Bool=true, dense::Bool=true) where {T<:Real, U<:Number}

# Allocation
cache = init_cache(VectorCache, Val(dense), t0, q0, maxsteps, order)
cache = init_cache(Val(dense), t0, q0, maxsteps, order)

# Determine if specialized jetcoeffs! method exists
parse_eqs, rv = _determine_parsing!(parse_eqs, f!, cache.t, cache.x, cache.dx, params)
Expand Down Expand Up @@ -247,7 +247,7 @@ function taylorinteg(f, x0::U, trange::AbstractVector{T},
issorted(trange, rev=true)) "`trange` or `reverse(trange)` must be sorted"

# Allocation
cache = init_cache(ScalarCache, Val(false), trange, x0, maxsteps, order)
cache = init_cache(Val(false), trange, x0, maxsteps, order)

# Determine if specialized jetcoeffs! method exists
parse_eqs, rv = _determine_parsing!(parse_eqs, f, cache.t, cache.x, params)
Expand Down Expand Up @@ -308,7 +308,7 @@ function taylorinteg(f!, q0::Vector{U}, trange::AbstractVector{T},
issorted(trange, rev=true)) "`trange` or `reverse(trange)` must be sorted"

# Allocation
cache = init_cache(VectorTRangeCache, Val(false), trange, q0, maxsteps, order)
cache = init_cache(Val(false), trange, q0, maxsteps, order)

# Determine if specialized jetcoeffs! method exists
parse_eqs, rv = _determine_parsing!(parse_eqs, f!, cache.t, cache.x, cache.dx, params)
Expand Down
4 changes: 2 additions & 2 deletions src/lyapunovspectrum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ function lyap_taylorinteg(f!, q0::Array{U,1}, t0::T, tmax::T,
dof = length(q0)

# Allocation
cache = init_cache(LyapunovSpectrumCache, Val(false), t0, q0, maxsteps, order)
cache = init_cache_lyap(t0, q0, maxsteps, order)

# If user does not provide Jacobian, check number of TaylorN variables and initialize _dv
if isa(jacobianfunc!, Nothing)
Expand Down Expand Up @@ -344,7 +344,7 @@ function lyap_taylorinteg(f!, q0::Array{U,1}, trange::AbstractVector{T},
issorted(trange, rev=true)) "`trange` or `reverse(trange)` must be sorted"

# Allocation
cache = init_cache(LyapunovSpectrumTRangeCache, Val(false), trange, q0, maxsteps, order)
cache = init_cache_lyap(trange, q0, maxsteps, order)

# If user does not provide Jacobian, check number of TaylorN variables and initialize _dv
if isnothing(jacobianfunc!)
Expand Down
4 changes: 2 additions & 2 deletions src/rootfinding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ function taylorinteg(f!, g, q0::Array{U,1}, t0::T, tmax::T,
@assert order eventorder "`eventorder` must be less than or equal to `order`"

# Allocation
cache = init_cache(VectorCache, Val(dense), t0, q0, maxsteps, order)
cache = init_cache(Val(dense), t0, q0, maxsteps, order)

# Determine if specialized jetcoeffs! method exists
parse_eqs, rv = _determine_parsing!(parse_eqs, f!, cache.t, cache.x, cache.dx, params)
Expand Down Expand Up @@ -296,7 +296,7 @@ function taylorinteg(f!, g, q0::Array{U,1}, trange::AbstractVector{T},
issorted(trange, rev=true)) "`trange` or `reverse(trange)` must be sorted"

# Allocation
cache = init_cache(VectorTRangeCache, Val(false), trange, q0, maxsteps, order)
cache = init_cache(Val(false), trange, q0, maxsteps, order)

# Determine if specialized jetcoeffs! method exists
parse_eqs, rv = _determine_parsing!(parse_eqs, f!, cache.t, cache.x, cache.dx, params)
Expand Down
4 changes: 2 additions & 2 deletions test/taylorize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1331,8 +1331,8 @@ import Logging: Warn
_T = eltype(_t0)
_U = eltype(_q0)
# Allocation
_cache_true = TaylorIntegration.init_cache(TaylorIntegration.VectorCache, Val(true), _t0, _q0, _maxsteps, _order)
_cache_false = TaylorIntegration.init_cache(TaylorIntegration.VectorCache, Val(false), _t0, _q0, _maxsteps, _order)
_cache_true = TaylorIntegration.init_cache(Val(true), _t0, _q0, _maxsteps, _order)
_cache_false = TaylorIntegration.init_cache(Val(false), _t0, _q0, _maxsteps, _order)
# Determine if specialized jetcoeffs! method exists
__parse_eqs, __rv = TaylorIntegration._determine_parsing!(true, kepler1!, _cache_true.t, _cache_true.x, _cache_true.dx, _params);
solTN = @inferred TaylorIntegration.taylorinteg!(Val(true), kepler1!, _q0, _t0, _tmax, _abstol, __rv, _cache_true, _params; parse_eqs=__parse_eqs, maxsteps=_maxsteps)
Expand Down

0 comments on commit 5a77506

Please sign in to comment.