Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix thread safety of RegularizedForwardPass #806

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1135,9 +1135,8 @@ function train(
# FIXME(odow): Threaded is broken for objective states
parallel_scheme = Serial()
end
if forward_pass isa AlternativeForwardPass ||
forward_pass isa RegularizedForwardPass
# FIXME(odow): Threaded is broken for these forward passes
if forward_pass isa AlternativeForwardPass
# FIXME(odow): Threaded is broken for this forward pass
parallel_scheme = Serial()
end
if log_frequency <= 0
Expand Down
46 changes: 37 additions & 9 deletions src/plugins/forward_passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,16 @@ mutable struct RegularizedForwardPass{T<:AbstractForwardPass} <:
AbstractForwardPass
forward_pass::T
trial_centre::Dict{Symbol,Float64}
old_bounds::Dict{Symbol,Tuple{Float64,Float64}}
ρ::Float64

function RegularizedForwardPass(;
rho::Float64 = 0.05,
forward_pass::AbstractForwardPass = DefaultForwardPass(),
)
centre = Dict{Symbol,Float64}()
return new{typeof(forward_pass)}(forward_pass, centre, rho)
old_bounds = Dict{Symbol,Tuple{Float64,Float64}}()
return new{typeof(forward_pass)}(forward_pass, centre, old_bounds, rho)
end
end

Expand All @@ -353,20 +355,46 @@ function forward_pass(
"not deterministic",
)
end
old_bounds = Dict{Symbol,Tuple{Float64,Float64}}()
for (k, v) in node.states
if has_lower_bound(v.out) && has_upper_bound(v.out)
old_bounds[k] = (l, u) = (lower_bound(v.out), upper_bound(v.out))
# It's safe to lock this node while we modify the forward pass and the node
# because there is only one node in the first stage and it is deterministic.
lock(node.lock) do
for (k, v) in node.states
if !(has_lower_bound(v.out) && has_upper_bound(v.out))
continue # Not a finitely bounded state. Ignore for now
end
if !haskey(fp.old_bounds, k)
fp.old_bounds[k] = (lower_bound(v.out), upper_bound(v.out))
end
l, u = fp.old_bounds[k]
x = get(fp.trial_centre, k, model.initial_root_state[k])
set_lower_bound(v.out, max(l, x - fp.ρ * (u - l)))
set_upper_bound(v.out, min(u, x + fp.ρ * (u - l)))
end
return
end
pass = forward_pass(model, options, fp.forward_pass)
for (k, (l, u)) in old_bounds
fp.trial_centre[k] = pass.sampled_states[1][k]
set_lower_bound(node.states[k].out, l)
set_upper_bound(node.states[k].out, u)
# We're locking the node to reset the variable bounds back to their default.
# There are some potential scheduling issues to be aware of:
#
# * Thread A might have obtained the lock, modified the bounds to the trial
# centre, released the lock, and then entered `forward_pass`
# * But before it can re-obtain a lock on the first node to solve the
# problem, Thread B has come along below and modified the bounds back to
# their original value.
# * This means that Thread A is solving the unregularized problem, but it
# doesn't really matter because it doesn't change the validity; this is a
# performance optimization.
# * It might also be that we "skip" some of the starting trial points,
# because Thread A sets a trial centre, then thread B comes along and
# sets a new one before A can start the forward pass. Again, this doesn't
# really matter; this regularization is just a performance optimization.
lock(node.lock) do
for (k, (l, u)) in fp.old_bounds
fp.trial_centre[k] = pass.sampled_states[1][k]
set_lower_bound(node.states[k].out, l)
set_upper_bound(node.states[k].out, u)
end
odow marked this conversation as resolved.
Show resolved Hide resolved
return
end
return pass
end
Loading