From 9e786f3153ee3bea2cd23a2ef5b675902d64f253 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 27 Nov 2024 15:25:11 +1300 Subject: [PATCH 1/5] Fix thread safety of RegularizedForwardPass --- src/plugins/forward_passes.jl | 45 ++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/src/plugins/forward_passes.jl b/src/plugins/forward_passes.jl index 1dfb6f460..99ca3598e 100644 --- a/src/plugins/forward_passes.jl +++ b/src/plugins/forward_passes.jl @@ -324,6 +324,7 @@ mutable struct RegularizedForwardPass{T<:AbstractForwardPass} <: AbstractForwardPass forward_pass::T trial_centre::Dict{Symbol,Float64} + old_bounds::Dict{Symbol,Tuple{Float64,Float64}} ρ::Float64 function RegularizedForwardPass(; @@ -331,7 +332,8 @@ mutable struct RegularizedForwardPass{T<:AbstractForwardPass} <: forward_pass::AbstractForwardPass = DefaultForwardPass(), ) centre = Dict{Symbol,Float64}() - return new{typeof(forward_pass)}(forward_pass, centre, rho) + old_bounds = Dict{Symbol,Tuple{Float64,Float64}}() + return new{typeof(forward_pass)}(forward_pass, centre, old_bounds, rho) end end @@ -353,20 +355,45 @@ function forward_pass( "not deterministic", ) end - old_bounds = Dict{Symbol,Tuple{Float64,Float64}}() - for (k, v) in node.states - if has_lower_bound(v.out) && has_upper_bound(v.out) - old_bounds[k] = (l, u) = (lower_bound(v.out), upper_bound(v.out)) + # It's safe to lock this node while we modify the forward pass and the node + # because there is only one node in the first stage and it is deterministic. + lock(node.lock) do + for (k, v) in node.states + if !(has_lower_bound(v.out) && has_upper_bound(v.out)) + continue # Not a finitely bounded state. Ignore for now + end + if !haskey(fp.old_bounds, k) + old_bounds[k] = (lower_bound(v.out), upper_bound(v.out)) + end + l, u = old_bounds[k] x = get(fp.trial_centre, k, model.initial_root_state[k]) set_lower_bound(v.out, max(l, x - fp.ρ * (u - l))) set_upper_bound(v.out, min(u, x + fp.ρ * (u - l))) end + return end pass = forward_pass(model, options, fp.forward_pass) - for (k, (l, u)) in old_bounds - fp.trial_centre[k] = pass.sampled_states[1][k] - set_lower_bound(node.states[k].out, l) - set_upper_bound(node.states[k].out, u) + # We're locking the node to reset the variable bounds back to their default. + # There are some potential scheduling issues to be aware of: + # + # * Thread A might have obtained the lock, modified the bounds to the trial + # centre, released the lock, and then entered `forward_pass` + # * But before it can re-obtain a lock on the first node to solve the + # problem, Thread B has come along below and modified the bounds back to + # their original value. + # * This means that Thread A is solving the unregularized problem, but it + # doesn't really matter because it doesn't change the validity; this is a + # performance optimization. + # * It might also be that we "skip" some of the starting tirla points, + # because Thread A sets a trial centre, then thread B comes along and + # sets a new one before A can start the forward pass. Again, this doesn't + # really matter; this regularization is just a performance optimization. + lock(node.lock) do + for (k, (l, u)) in fp.old_bounds + fp.trial_centre[k] = pass.sampled_states[1][k] + set_lower_bound(node.states[k].out, l) + set_upper_bound(node.states[k].out, u) + end end return pass end From f12a61bf9d173c3aa5c2ae2a97a4902b08054a9e Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 27 Nov 2024 15:31:09 +1300 Subject: [PATCH 2/5] Update src/plugins/forward_passes.jl --- src/plugins/forward_passes.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/plugins/forward_passes.jl b/src/plugins/forward_passes.jl index 99ca3598e..66631d38f 100644 --- a/src/plugins/forward_passes.jl +++ b/src/plugins/forward_passes.jl @@ -394,6 +394,7 @@ function forward_pass( set_lower_bound(node.states[k].out, l) set_upper_bound(node.states[k].out, u) end + return end return pass end From f852c1361e13b86c7780b837a47047be9b7fd7ed Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 27 Nov 2024 15:39:17 +1300 Subject: [PATCH 3/5] Fix --- src/plugins/forward_passes.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plugins/forward_passes.jl b/src/plugins/forward_passes.jl index 66631d38f..69af539ad 100644 --- a/src/plugins/forward_passes.jl +++ b/src/plugins/forward_passes.jl @@ -363,9 +363,9 @@ function forward_pass( continue # Not a finitely bounded state. Ignore for now end if !haskey(fp.old_bounds, k) - old_bounds[k] = (lower_bound(v.out), upper_bound(v.out)) + fp.old_bounds[k] = (lower_bound(v.out), upper_bound(v.out)) end - l, u = old_bounds[k] + l, u = fp.old_bounds[k] x = get(fp.trial_centre, k, model.initial_root_state[k]) set_lower_bound(v.out, max(l, x - fp.ρ * (u - l))) set_upper_bound(v.out, min(u, x + fp.ρ * (u - l))) From 265b4e5fe7c5af0eef6ba7fdcaff553274d444f2 Mon Sep 17 00:00:00 2001 From: Felix Schmidt <113511577+FSchmidtDIW@users.noreply.github.com> Date: Wed, 27 Nov 2024 20:14:02 +0100 Subject: [PATCH 4/5] Drop exception for regularized forward pass in algorithm.jl (#807) --- src/algorithm.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/algorithm.jl b/src/algorithm.jl index f10092b2e..e63956983 100644 --- a/src/algorithm.jl +++ b/src/algorithm.jl @@ -1135,9 +1135,8 @@ function train( # FIXME(odow): Threaded is broken for objective states parallel_scheme = Serial() end - if forward_pass isa AlternativeForwardPass || - forward_pass isa RegularizedForwardPass - # FIXME(odow): Threaded is broken for these forward passes + if forward_pass isa AlternativeForwardPass + # FIXME(odow): Threaded is broken for this forward pass parallel_scheme = Serial() end if log_frequency <= 0 From e4764eae420490e4dae0621cfe46275d4437bf98 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Thu, 28 Nov 2024 08:14:58 +1300 Subject: [PATCH 5/5] Update src/plugins/forward_passes.jl --- src/plugins/forward_passes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/forward_passes.jl b/src/plugins/forward_passes.jl index 69af539ad..ef06368a0 100644 --- a/src/plugins/forward_passes.jl +++ b/src/plugins/forward_passes.jl @@ -384,7 +384,7 @@ function forward_pass( # * This means that Thread A is solving the unregularized problem, but it # doesn't really matter because it doesn't change the validity; this is a # performance optimization. - # * It might also be that we "skip" some of the starting tirla points, + # * It might also be that we "skip" some of the starting trial points, # because Thread A sets a trial centre, then thread B comes along and # sets a new one before A can start the forward pass. Again, this doesn't # really matter; this regularization is just a performance optimization.