Skip to content

Commit

Permalink
Merge pull request #1108 from mattsignorelli/gtpsa-2
Browse files Browse the repository at this point in the history
Add `ODE_DEFAULT_NORM` overload for GTPSA
  • Loading branch information
ChrisRackauckas authored Jan 6, 2025
2 parents a9f73bb + 0f8481c commit 419a2b5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
24 changes: 21 additions & 3 deletions ext/DiffEqBaseGTPSAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,34 @@ module DiffEqBaseGTPSAExt

if isdefined(Base, :get_extension)
using DiffEqBase
import DiffEqBase: value
import DiffEqBase: value, ODE_DEFAULT_NORM
using GTPSA
else
using ..DiffEqBase
import ..DiffEqBase: value
import ..DiffEqBase: value, ODE_DEFAULT_NORM
using ..GTPSA
end

value(x::TPS) = scalar(x);
value(x::TPS) = scalar(x)
value(::Type{TPS{T}}) where {T} = T

ODE_DEFAULT_NORM(u::TPS, t) = normTPS(u)
ODE_DEFAULT_NORM(f::F, u::TPS, t) where {F} = normTPS(f(u))

function ODE_DEFAULT_NORM(u::AbstractArray{TPS{T}}, t) where {T}
x = zero(real(T))
@inbounds @fastmath for ui in u
x += normTPS(ui)^2
end
Base.FastMath.sqrt_fast(x / max(length(u), 1))
end

function ODE_DEFAULT_NORM(f::F, u::AbstractArray{TPS{T}}, t) where {F, T}
x = zero(real(T))
@inbounds @fastmath for ui in u
x += normTPS(f(ui))^2
end
Base.FastMath.sqrt_fast(x / max(length(u), 1))
end

end
37 changes: 37 additions & 0 deletions test/downstream/gtpsa.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using OrdinaryDiffEq, ForwardDiff, GTPSA, Test

# ODEProblem 1 =======================

f!(du, u, p, t) = du .= p .* u

# Initial variables and parameters
Expand Down Expand Up @@ -37,3 +39,38 @@ for i in 1:3
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
end


# ODEProblem 2 =======================
pdot!(dq, p, q, params, t) = dq .= [0.0, 0.0, 0.0]
qdot!(dp, p, q, params, t) = dp .= [p[1] / sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2),
p[2] / sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2),
p[3] / sqrt(1 + p[3]^2) - (p[3] + 1)/sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2)]

prob = DynamicalODEProblem(pdot!, qdot!, [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], (0.0, 25.0))
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)

desc = Descriptor(6, 2) # 6 variables to 2nd order
dx = vars(desc) # identity map
prob_GTPSA = DynamicalODEProblem(pdot!, qdot!, dx[1:3], dx[4:6], (0.0, 25.0))
sol_GTPSA = solve(prob_GTPSA, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)

@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part

# Compare Jacobian against ForwardDiff
J_FD = ForwardDiff.jacobian(zeros(6)) do t
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
sol.u[end]
end

@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params=true)

# Compare Hessians against ForwardDiff
for i in 1:6
Hi_FD = ForwardDiff.hessian(zeros(6)) do t
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
sol.u[end][i]
end
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
end

0 comments on commit 419a2b5

Please sign in to comment.