diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 1f0fe786a..befad1e61 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -61,7 +61,7 @@ DEFAULT_PLOT_FUNC(x,y,z) = (x,y,z) # For v0.5.2 bug denseplot = (sol.dense || typeof(sol.prob) <: AbstractDiscreteProblem) && !(typeof(sol) <: AbstractRODESolution) && - !(hasfield(typeof(sol),:interp) && + !(hasfield(typeof(sol),:interp) && typeof(sol.interp) <: SensitivityInterpolation), plotdensity = min(Int(1e5),sol.tslocation==0 ? (typeof(sol.prob) <: AbstractDiscreteProblem ? @@ -95,10 +95,10 @@ DEFAULT_PLOT_FUNC(x,y,z) = (x,y,z) # For v0.5.2 bug if getindex.(int_vars,1) == zeros(length(int_vars)) || getindex.(int_vars,2) == zeros(length(int_vars)) xguide --> "t" end - if length(int_vars[1]) >= 3 && getindex.(int_vars,3) == zeros(length(int_vars)) + if all(x->length(x) >= 3 && x[4] == 0,int_vars) yguide --> "t" end - if length(int_vars[1]) >= 4 && getindex.(int_vars,4) == zeros(length(int_vars)) + if all(x->length(x) >= 4 && x[4] == 0,int_vars) zguide --> "t" end @@ -236,12 +236,7 @@ function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_saf end end - dims = length(int_vars[1]) - for var in int_vars - @assert length(var) == dims - end - # Should check that all have the same dims! - plot_vecs,labels = solplot_vecs_and_labels(dims,int_vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,strs) + plot_vecs,labels = solplot_vecs_and_labels(int_vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,strs) end function interpret_vars(vars,sol,syms) @@ -420,12 +415,12 @@ function u_n(timeseries::AbstractArray, n::Int,sol,plott,plot_timeseries) end end -function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,strs) +function solplot_vecs_and_labels(vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,strs) plot_vecs = [] labels = String[] for x in vars tmp = [] - for j in 2:dims + for j in 2:length(x) push!(tmp, u_n(plot_timeseries, x[j],sol,plott,plot_timeseries)) end @@ -440,7 +435,7 @@ function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analyt end push!(plot_vecs[i],tmp[i]) end - add_labels!(labels,x,dims,sol,strs) + add_labels!(labels,x,length(x),sol,strs) end @@ -449,7 +444,7 @@ function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analyt analytic_plot_vecs = [] for x in vars tmp = [] - for j in 2:dims + for j in 2:length(x) push!(tmp, u_n(plot_analytic_timeseries, x[j],sol,plott,plot_analytic_timeseries)) end f = x[1] @@ -458,7 +453,7 @@ function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analyt for i in eachindex(tmp) push!(plot_vecs[i],tmp[i]) end - add_analytic_labels!(labels,x,dims,sol,strs) + add_analytic_labels!(labels,x,length(x),sol,strs) end end plot_vecs = [hcat(x...) for x in plot_vecs] diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index bd345ebff..2dec0346e 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -6,6 +6,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqProblemLibrary = "a077e3f3-b75c-5d7f-a0c6-6bc4c8ec64a9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" diff --git a/test/downstream/plot_tests.jl b/test/downstream/plot_tests.jl new file mode 100644 index 000000000..22aea2e32 --- /dev/null +++ b/test/downstream/plot_tests.jl @@ -0,0 +1,14 @@ +using Plots, OrdinaryDiffEq +unicodeplots() + +f(x,p,t) = p*x +prob = ODEProblem(f,[1.0,2.0,3.0],(0.0,1.0),-.2) +sol = solve(prob, Tsit5()) + +f1(t,x,y) = (t,x+y) +f2(t,x) = (t,x) +f3(t,x,y) = (t,x) +plot(sol, vars = [(f1,0,1,2),]) +plot(sol, vars = [(f1,0,1,2), (0,3)]) +plot(sol, vars = [(f1,0,1,2), (f2,0,3)]) +plot(sol, vars = [(f1,0,1,2), (f3,0,3,1)]) diff --git a/test/runtests.jl b/test/runtests.jl index 5c4ef2d02..33177c42e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,7 @@ end if !is_APPVEYOR && GROUP == "Downstream" activate_downstream_env() @time @safetestset "Unitful" begin include("downstream/unitful.jl") end + @time @safetestset "Test Plotting" begin include("downstream/plot_tests.jl") end @time @safetestset "Null Parameters" begin include("downstream/null_params_test.jl") end @time @safetestset "Ensemble Simulations" begin include("downstream/ensemble.jl") end @time @safetestset "Ensemble Analysis" begin include("downstream/ensemble_analysis.jl") end