diff --git a/src/algorithm.jl b/src/algorithm.jl index f10092b2e..e63956983 100644 --- a/src/algorithm.jl +++ b/src/algorithm.jl @@ -1135,9 +1135,8 @@ function train( # FIXME(odow): Threaded is broken for objective states parallel_scheme = Serial() end - if forward_pass isa AlternativeForwardPass || - forward_pass isa RegularizedForwardPass - # FIXME(odow): Threaded is broken for these forward passes + if forward_pass isa AlternativeForwardPass + # FIXME(odow): Threaded is broken for this forward pass parallel_scheme = Serial() end if log_frequency <= 0 diff --git a/src/plugins/forward_passes.jl b/src/plugins/forward_passes.jl index 1dfb6f460..ef06368a0 100644 --- a/src/plugins/forward_passes.jl +++ b/src/plugins/forward_passes.jl @@ -324,6 +324,7 @@ mutable struct RegularizedForwardPass{T<:AbstractForwardPass} <: AbstractForwardPass forward_pass::T trial_centre::Dict{Symbol,Float64} + old_bounds::Dict{Symbol,Tuple{Float64,Float64}} ρ::Float64 function RegularizedForwardPass(; @@ -331,7 +332,8 @@ mutable struct RegularizedForwardPass{T<:AbstractForwardPass} <: forward_pass::AbstractForwardPass = DefaultForwardPass(), ) centre = Dict{Symbol,Float64}() - return new{typeof(forward_pass)}(forward_pass, centre, rho) + old_bounds = Dict{Symbol,Tuple{Float64,Float64}}() + return new{typeof(forward_pass)}(forward_pass, centre, old_bounds, rho) end end @@ -353,20 +355,46 @@ function forward_pass( "not deterministic", ) end - old_bounds = Dict{Symbol,Tuple{Float64,Float64}}() - for (k, v) in node.states - if has_lower_bound(v.out) && has_upper_bound(v.out) - old_bounds[k] = (l, u) = (lower_bound(v.out), upper_bound(v.out)) + # It's safe to lock this node while we modify the forward pass and the node + # because there is only one node in the first stage and it is deterministic. + lock(node.lock) do + for (k, v) in node.states + if !(has_lower_bound(v.out) && has_upper_bound(v.out)) + continue # Not a finitely bounded state. Ignore for now + end + if !haskey(fp.old_bounds, k) + fp.old_bounds[k] = (lower_bound(v.out), upper_bound(v.out)) + end + l, u = fp.old_bounds[k] x = get(fp.trial_centre, k, model.initial_root_state[k]) set_lower_bound(v.out, max(l, x - fp.ρ * (u - l))) set_upper_bound(v.out, min(u, x + fp.ρ * (u - l))) end + return end pass = forward_pass(model, options, fp.forward_pass) - for (k, (l, u)) in old_bounds - fp.trial_centre[k] = pass.sampled_states[1][k] - set_lower_bound(node.states[k].out, l) - set_upper_bound(node.states[k].out, u) + # We're locking the node to reset the variable bounds back to their default. + # There are some potential scheduling issues to be aware of: + # + # * Thread A might have obtained the lock, modified the bounds to the trial + # centre, released the lock, and then entered `forward_pass` + # * But before it can re-obtain a lock on the first node to solve the + # problem, Thread B has come along below and modified the bounds back to + # their original value. + # * This means that Thread A is solving the unregularized problem, but it + # doesn't really matter because it doesn't change the validity; this is a + # performance optimization. + # * It might also be that we "skip" some of the starting trial points, + # because Thread A sets a trial centre, then thread B comes along and + # sets a new one before A can start the forward pass. Again, this doesn't + # really matter; this regularization is just a performance optimization. + lock(node.lock) do + for (k, (l, u)) in fp.old_bounds + fp.trial_centre[k] = pass.sampled_states[1][k] + set_lower_bound(node.states[k].out, l) + set_upper_bound(node.states[k].out, u) + end + return end return pass end