Skip to content

Commit

Permalink
Change test/plugins/threaded.jl to a module (#825)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Jan 28, 2025
1 parent a926abe commit 3f31f2d
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions test/plugins/threaded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,27 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

function test_threaded()
module TestThreads

using SDDP
using Test
import HiGHS

function runtests()
if Threads.nthreads() == 1
return # Skip this test if running in serial
return # Skip tests if running in serial
end
for name in names(@__MODULE__; all = true)
if startswith("$(name)", "test_")
@testset "$(name)" begin
getfield(@__MODULE__, name)()
end
end
end
return
end

function test_threaded()
if haskey(ENV, "JULIA_NUM_THREADS")
num_threads = get(ENV, "JULIA_NUM_THREADS", "0")
@test parse(Int, num_threads) == Threads.nthreads()
Expand Down Expand Up @@ -49,8 +66,7 @@ function test_threaded()
SDDP.train(model; iteration_limit = 100, parallel_scheme = SDDP.Threaded())
thread_ids_seen =
Set{Int}(log.pid for log in model.most_recent_training_results.log)
min_threads = Threads.nthreads() == 1 ? 1 : 2
@test min_threads <= length(thread_ids_seen) <= Threads.nthreads()
@test 2 <= length(thread_ids_seen) <= Threads.nthreads()
recorder = Dict{Symbol,Function}(:thread_id => sp -> Threads.threadid())
simulations = SDDP.simulate(
model,
Expand All @@ -60,17 +76,11 @@ function test_threaded()
)
thread_ids_seen =
Set{Int}(data[:thread_id] for sim in simulations for data in sim)
min_threads = Threads.nthreads() > 1 ? 1 : 2
@test min_threads <= length(thread_ids_seen) <= Threads.nthreads()
@test length(thread_ids_seen) == Threads.nthreads()
return
end

test_threaded()

function test_threaded_warning()
if Threads.nthreads() == 1
return # Skip this test if running in serial
end
model = SDDP.PolicyGraph(
SDDP.UnicyclicGraph(0.95);
sense = :Min,
Expand All @@ -85,4 +95,6 @@ function test_threaded_warning()
return
end

test_threaded_warning()
end # module

TestThreads.runtests()

0 comments on commit 3f31f2d

Please sign in to comment.