From 41f79daa2b27eb4fae11b5de60c5cd110a71daf9 Mon Sep 17 00:00:00 2001 From: odow Date: Tue, 28 Jan 2025 14:48:05 +1300 Subject: [PATCH 1/2] Throw warning if fewer nodes than threads --- src/plugins/parallel_schemes.jl | 12 +++++++++++- test/plugins/threaded.jl | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/plugins/parallel_schemes.jl b/src/plugins/parallel_schemes.jl index 2e7c8d38b..1cf0c23b5 100644 --- a/src/plugins/parallel_schemes.jl +++ b/src/plugins/parallel_schemes.jl @@ -369,9 +369,19 @@ function master_loop( model::PolicyGraph{T}, options::Options, ) where {T} + max_threads = Threads.nthreads() + num_nodes = length(model.nodes) + if num_nodes < max_threads + @warn( + "There are fewer nodes in the graph ($num_nodes) than there are " * + "threads available ($max_threads). Limiting the number of " * + "threads to $num_nodes." + ) + max_threads = num_nodes + end _initialize_solver(model; throw_error = false) keep_iterating, status = true, nothing - @sync for _ in 1:Threads.nthreads() + @sync for _ in 1:max_threads Threads.@spawn begin try # This access of `keep_iterating` is not thread-safe, but it diff --git a/test/plugins/threaded.jl b/test/plugins/threaded.jl index 04869b8b5..7c03edf73 100644 --- a/test/plugins/threaded.jl +++ b/test/plugins/threaded.jl @@ -66,3 +66,26 @@ function test_threaded() end test_threaded() + +function test_threaded_warning() + if Threads.nthreads() == 1 + return # Skip this test if running in serial + end + model = SDDP.PolicyGraph( + SDDP.UnicyclicGraph(0.95); + sense = :Min, + lower_bound = 0.0, + optimizer = HiGHS.Optimizer, + ) do sp, t + @variable(sp, 0 <= x <= 1, SDDP.State, initial_value = 1) + @stageobjective(sp, x.out) + return + end + @test_logs( + (:warn,), + SDDP.train(model; parallel_scheme = SDDP.Threaded()), + ) + return +end + +test_threaded_warning() From d683510dee52c731f455a48d82cefbf461bd586a Mon Sep 17 00:00:00 2001 From: odow Date: Tue, 28 Jan 2025 15:00:58 +1300 Subject: [PATCH 2/2] Fix format --- test/plugins/threaded.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/plugins/threaded.jl b/test/plugins/threaded.jl index 7c03edf73..17c808769 100644 --- a/test/plugins/threaded.jl +++ b/test/plugins/threaded.jl @@ -81,10 +81,7 @@ function test_threaded_warning() @stageobjective(sp, x.out) return end - @test_logs( - (:warn,), - SDDP.train(model; parallel_scheme = SDDP.Threaded()), - ) + @test_logs((:warn,), SDDP.train(model; parallel_scheme = SDDP.Threaded())) return end