From 93bdb2603650277e45a4b48db50c9a3f2ef4e567 Mon Sep 17 00:00:00 2001 From: odow Date: Tue, 28 Jan 2025 15:43:21 +1300 Subject: [PATCH] Change test/plugins/threaded.jl to a module --- test/plugins/threaded.jl | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/test/plugins/threaded.jl b/test/plugins/threaded.jl index 17c808769..162ab4fb1 100644 --- a/test/plugins/threaded.jl +++ b/test/plugins/threaded.jl @@ -3,10 +3,27 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -function test_threaded() +module TestThreads + +using SDDP +using Test +import HiGHS + +function runtests() if Threads.nthreads() == 1 - return # Skip this test if running in serial + return # Skip tests if running in serial end + for name in names(@__MODULE__; all = true) + if startswith("$(name)", "test_") + @testset "$(name)" begin + getfield(@__MODULE__, name)() + end + end + end + return +end + +function test_threaded() if haskey(ENV, "JULIA_NUM_THREADS") num_threads = get(ENV, "JULIA_NUM_THREADS", "0") @test parse(Int, num_threads) == Threads.nthreads() @@ -49,8 +66,7 @@ function test_threaded() SDDP.train(model; iteration_limit = 100, parallel_scheme = SDDP.Threaded()) thread_ids_seen = Set{Int}(log.pid for log in model.most_recent_training_results.log) - min_threads = Threads.nthreads() == 1 ? 1 : 2 - @test min_threads <= length(thread_ids_seen) <= Threads.nthreads() + @test 2 <= length(thread_ids_seen) <= Threads.nthreads() recorder = Dict{Symbol,Function}(:thread_id => sp -> Threads.threadid()) simulations = SDDP.simulate( model, @@ -60,17 +76,11 @@ function test_threaded() ) thread_ids_seen = Set{Int}(data[:thread_id] for sim in simulations for data in sim) - min_threads = Threads.nthreads() > 1 ? 1 : 2 - @test min_threads <= length(thread_ids_seen) <= Threads.nthreads() + @test length(thread_ids_seen) == Threads.nthreads() return end -test_threaded() - function test_threaded_warning() - if Threads.nthreads() == 1 - return # Skip this test if running in serial - end model = SDDP.PolicyGraph( SDDP.UnicyclicGraph(0.95); sense = :Min, @@ -85,4 +95,6 @@ function test_threaded_warning() return end -test_threaded_warning() +end # module + +TestThreads.runtests()