Skip to content

Commit

Permalink
Update convstack model definition and extend optimizer config flags.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600455981
  • Loading branch information
mjanusz authored and copybara-github committed Jan 22, 2024
1 parent 76f00a6 commit 5f65949
Show file tree
Hide file tree
Showing 11 changed files with 1,044 additions and 517 deletions.
42 changes: 26 additions & 16 deletions ffn/inference/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def __init__(self, model, session, counters, batch_size):
self.active_clients = 0

# Cache input/output sizes.
self._input_seed_size = np.array(model.input_seed_size[::-1]).tolist()
self._input_image_size = np.array(model.input_image_size[::-1]).tolist()
self._pred_size = np.array(model.pred_mask_size[::-1]).tolist()
self._input_seed_size = np.array(model.info.input_seed_size[::-1]).tolist()
self._input_image_size = np.array(
model.info.input_image_size[::-1]
).tolist()
self._pred_size = np.array(model.info.pred_mask_size[::-1]).tolist()

self._initialize_model()

Expand Down Expand Up @@ -111,8 +113,9 @@ class ThreadingBatchExecutor(BatchExecutor):
"""

def __init__(self, model, session, counters, batch_size, expected_clients=1):
super(ThreadingBatchExecutor, self).__init__(model, session, counters,
batch_size)
super(ThreadingBatchExecutor, self).__init__(
model, session, counters, batch_size
)
self._lock = threading.Lock()
self.outputs = {} # Will be populated by Queues as clients register.
# Used by clients to communiate with the executor. The protocol is
Expand All @@ -131,10 +134,12 @@ def __init__(self, model, session, counters, batch_size, expected_clients=1):
self.expected_clients = expected_clients

# Arrays fed to TF.
self.input_seed = np.zeros([batch_size] + self._input_seed_size + [1],
dtype=np.float32)
self.input_image = np.zeros([batch_size] + self._input_image_size + [1],
dtype=np.float32)
self.input_seed = np.zeros(
[batch_size] + self._input_seed_size + [1], dtype=np.float32
)
self.input_image = np.zeros(
[batch_size] + self._input_image_size + [1], dtype=np.float32
)
self.th_executor = None

def start_server(self):
Expand All @@ -146,7 +151,8 @@ def start_server(self):
"""
if self.th_executor is None:
self.th_executor = threading.Thread(
target=self._run_executor_log_exceptions)
target=self._run_executor_log_exceptions
)
self.th_executor.start()

def stop_server(self):
Expand All @@ -166,8 +172,10 @@ def _run_executor(self):

with timer_counter(self.counters, 'executor-input'):
ready = []
while (len(ready) < min(self.active_clients, self.batch_size) or
not self.active_clients):
while (
len(ready) < min(self.active_clients, self.batch_size)
or not self.active_clients
):
try:
data = self.input_queue.get(timeout=5)
except queue.Empty:
Expand Down Expand Up @@ -201,9 +209,12 @@ def _schedule_batch(self, client_ids, fetches):
with timer_counter(self.counters, 'executor-inference'):
try:
ret = self.session.run(
fetches, {
fetches,
{
self.model.input_seed: self.input_seed,
self.model.input_patches: self.input_image})
self.model.input_patches: self.input_image,
},
)
except Exception as e: # pylint:disable=broad-except
logging.exception(e)
# If calling TF didn't work (faulty hardware, misconfiguration, etc),
Expand All @@ -215,8 +226,7 @@ def _schedule_batch(self, client_ids, fetches):
with self._lock:
for i, client_id in enumerate(client_ids):
try:
self.outputs[client_id].put(
{k: v[i, ...] for k, v in ret.items()})
self.outputs[client_id].put({k: v[i, ...] for k, v in ret.items()})
except KeyError:
# This could happen if a client unregistered itself
# while inference was running.
Expand Down
28 changes: 16 additions & 12 deletions ffn/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from scipy.special import logit
import tensorflow.compat.v1 as tf
from tensorflow.io import gfile
from ..training import model as ffn_model
from ..training.import_util import import_symbol
from ..utils import bounding_box
from ..utils import ortho_plane_visualization
Expand All @@ -48,7 +49,7 @@

# Visualization.
# ---------------------------------------------------------------------------
class DynamicImage(object):
class DynamicImage:
def UpdateFromPIL(self, new_img):
from io import BytesIO
from IPython import display
Expand Down Expand Up @@ -172,11 +173,11 @@ def _halt_signaler(fetches, pos, orig_pos, counters, **unused_kwargs):


# TODO(mjanusz): Add support for sparse inference.
class Canvas(object):
class Canvas:
"""Tracks state of the inference progress and results within a subvolume."""

def __init__(self,
model,
model: ffn_model.FFNModel,
tf_executor,
image,
options,
Expand Down Expand Up @@ -242,9 +243,9 @@ def __init__(self,

# Cast to array to ensure we can do elementwise expressions later.
# All of these are in zyx order.
self._pred_size = np.array(model.pred_mask_size[::-1])
self._input_seed_size = np.array(model.input_seed_size[::-1])
self._input_image_size = np.array(model.input_image_size[::-1])
self._pred_size = np.array(model.info.pred_mask_size[::-1])
self._input_seed_size = np.array(model.info.input_seed_size[::-1])
self._input_image_size = np.array(model.info.input_image_size[::-1])
self.margin = self._input_image_size // 2

self._pred_delta = (self._input_seed_size - self._pred_size) // 2
Expand Down Expand Up @@ -277,7 +278,7 @@ def __init__(self,
if movement_policy_fn is None:
# The model.deltas are (for now) in xyz order and must be swapped to zyx.
self.movement_policy = movement.FaceMaxMovementPolicy(
self, deltas=model.deltas[::-1],
self, deltas=model.info.deltas[::-1],
score_threshold=self.options.move_threshold)
else:
self.movement_policy = movement_policy_fn(self)
Expand Down Expand Up @@ -789,7 +790,7 @@ def _maybe_save_checkpoint(self):
self.checkpoint_last = time.time()


class Runner(object):
class Runner:
"""Helper for managing FFN inference runs.
Takes care of initializing the FFN model and any related functionality
Expand All @@ -799,6 +800,9 @@ class Runner(object):

ALL_MASKED = 1

request: inference_pb2.InferenceRequest
executor: executor.BatchExecutor

def __init__(self):
self.counters = inference_utils.Counters()
self.executor = None
Expand All @@ -823,7 +827,8 @@ def _load_model_checkpoint(self, checkpoint_path):
"""
with timer_counter(self.counters, 'restore-tf-checkpoint'):
logging.info('Loading checkpoint.')
self.model.saver.restore(self.session, checkpoint_path)
saver = tf.train.Saver()
saver.restore(self.session, checkpoint_path)
logging.info('Checkpoint loaded.')

def start(self, request, batch_size=1, exec_cls=None, session=None):
Expand Down Expand Up @@ -908,9 +913,8 @@ def _open_or_none(settings):

self.executor = exec_cls(
self.model, self.session, self.counters, batch_size)
self.movement_policy_fn = movement.get_policy_fn(request, self.model)
self.movement_policy_fn = movement.get_policy_fn(request, self.model.info)

self.saver = tf.train.Saver()
self._load_model_checkpoint(request.model_checkpoint_path)

self.executor.start_server()
Expand Down Expand Up @@ -975,7 +979,7 @@ def make_restrictor(self, corner, subvol_size, image, alignment):
start=self.request.shift_mask_fov.start,
size=self.request.shift_mask_fov.size)
else:
shift_mask_diameter = np.array(self.model.input_image_size)
shift_mask_diameter = np.array(self.model.info.input_image_size)
shift_mask_fov = bounding_box.BoundingBox(
start=-(shift_mask_diameter // 2), size=shift_mask_diameter)

Expand Down
Loading

0 comments on commit 5f65949

Please sign in to comment.