# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for running FFN inference."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import namedtuple
import functools
import json
import logging
import os
import threading
import time
import numpy as np
from numpy.lib.stride_tricks import as_strided

from scipy.special import expit
from scipy.special import logit
from skimage import transform

import tensorflow as tf

from tensorflow import gfile
from . import align
from . import executor
from . import inference_pb2
from . import inference_utils
from . import movement
from . import seed
from . import storage
from .inference_utils import Counters
from .inference_utils import TimedIter
from .inference_utils import timer_counter
from . import segmentation
from ..training.import_util import import_symbol
from ..utils import ortho_plane_visualization
from ..utils import bounding_box

MSEC_IN_SEC = 1000
MAX_SELF_CONSISTENT_ITERS = 32


# Visualization.
# ---------------------------------------------------------------------------
class DynamicImage(object):
  def UpdateFromPIL(self, new_img):
    from io import BytesIO
    from IPython import display
    display.clear_output(wait=True)
    image = BytesIO()
    new_img.save(image, format='png')
    display.display(display.Image(image.getvalue()))


def _cmap_rgb1(drw):
  """Default color palette from gnuplot."""
  r = np.sqrt(drw)
  g = np.power(drw, 3)
  b = np.sin(drw * np.pi)

  return (np.dstack([r, g, b]) * 250.0).astype(np.uint8)


def visualize_state(seed_logits, pos, movement_policy, dynimage):
  """Visualizes the inference state.

  Args:
    seed_logits: ndarray (z, y, x) with the current predicted mask
    pos: current FoV position within 'seed' as z, y, x
    movement_policy: movement policy object
    dynimage: DynamicImage object which is to be updated with the
        state visualization
  """
  from PIL import Image

  planes = ortho_plane_visualization.cut_ortho_planes(
      seed_logits, center=pos, cross_hair=True)
  to_vis = ortho_plane_visualization.concat_ortho_planes(planes)

  if isinstance(movement_policy.scored_coords, np.ndarray):
    scores = movement_policy.scored_coords
    # Upsample the grid.
    zf, yf, xf = movement_policy.deltas
    zz, yy, xx = scores.shape
    zs, ys, xs = scores.strides
    new_sh = (zz, zf, yy, yf, xx, xf)
    new_st = (zs, 0, ys, 0, xs, 0)
    scores_up = as_strided(scores, new_sh, new_st)
    scores_up = scores_up.reshape((zz * zf, yy * yf, xx * xf))
    # TODO(mkillinger) might need padding in some cases, if crashes: fix.
    # The grid might be too large, cut it to be symmetrical
    cut = (np.array(scores_up.shape) - np.array(seed_logits.shape)) // 2
    sh = seed_logits.shape
    scores_up = scores_up[cut[0]:cut[0] + sh[0],
                          cut[1]:cut[1] + sh[1],
                          cut[2]:cut[2] + sh[2]]
    grid_planes = ortho_plane_visualization.cut_ortho_planes(
        scores_up, center=pos, cross_hair=True)
    grid_view = ortho_plane_visualization.concat_ortho_planes(grid_planes)
    grid_view *= 4  # Looks better this way
    to_vis = np.concatenate((to_vis, grid_view), axis=1)

  val = _cmap_rgb1(expit(to_vis))
  y, x = pos[1:]

  # Mark seed in the xy plane.
  val[(y - 1):(y + 2), (x - 1):(x + 2), 0] = 255
  val[(y - 1):(y + 2), (x - 1):(x + 2), 1:] = 0

  vis = Image.fromarray(val)
  dynimage.UpdateFromPIL(vis)


# Self-prediction halting
# ---------------------------------------------------------------------------
HALT_SILENT = 0
PRINT_HALTS = 1
HALT_VERBOSE = 2

HaltInfo = namedtuple('HaltInfo', ['is_halt', 'extra_fetches'])


def no_halt(verbosity=HALT_SILENT, log_function=logging.info):
  """Dummy HaltInfo."""
  def _halt_signaler(*unused_args, **unused_kwargs):
    return False

  def _halt_signaler_verbose(fetches, pos, **unused_kwargs):
    log_function('%s, %s' % (pos, fetches))
    return False

  if verbosity == HALT_VERBOSE:
    return HaltInfo(_halt_signaler_verbose, [])
  else:
    return HaltInfo(_halt_signaler, [])


def self_prediction_halt(
    threshold, orig_threshold=None, verbosity=HALT_SILENT,
    log_function=logging.info):
  """HaltInfo based on FFN self-predictions."""

  def _halt_signaler(fetches, pos, orig_pos, counters, **unused_kwargs):
    """Returns true if FFN prediction should be halted."""
    if pos == orig_pos and orig_threshold is not None:
      t = orig_threshold
    else:
      t = threshold

    # [0] is by convention the total incorrect proportion prediction.
    halt = fetches['self_prediction'][0] > t

    if halt:
      counters['halts'].Increment()

    if verbosity == HALT_VERBOSE or (
        halt and verbosity == PRINT_HALTS):
      log_function('%s, %s' % (pos, fetches))

    return halt

  # Add self_prediction to the extra_fetches.
  return HaltInfo(_halt_signaler, ['self_prediction'])

# ---------------------------------------------------------------------------


# TODO(mjanusz): Add support for sparse inference.
class Canvas(object):
  """Tracks state of the inference progress and results within a subvolume."""

  def __init__(self,
               model,
               tf_executor,
               image,
               options,
               counters=None,
               restrictor=None,
               movement_policy_fn=None,
               halt_signaler=no_halt(),
               keep_history=False,
               checkpoint_path=None,
               checkpoint_interval_sec=0,
               corner_zyx=None):
    """Initializes the canvas.

    Args:
      model: FFNModel object
      tf_executor: Executor object to use for inference
      image: 3d ndarray-like of shape (z, y, x)
      options: InferenceOptions proto
      counters: (optional) counter container, where __getitem__ returns a
          counter compatible with the MR Counter API
      restrictor: (optional) a MovementRestrictor object which can exclude
          some areas of the data from the segmentation process
      movement_policy_fn: callable taking the Canvas object as its
          only argument and returning a movement policy object
          (see movement.BaseMovementPolicy)
      halt_signaler: HaltInfo object determining early stopping policy
      keep_history: whether to maintain a record of locations visited by the
          FFN, together with any associated metadata; note that this data is
          kept only for the object currently being segmented
      checkpoint_path: (optional) path at which to save a checkpoint file
      checkpoint_interval_sec: how often to save a checkpoint file (in
          seconds); if <= 0, no checkpoint are going to be saved
      corner_zyx: 3 element array-like indicating the spatial corner of the
          image in (z, y, x)
    """
    self.model = model
    self.image = image
    self.executor = tf_executor
    self._exec_client_id = None

    self.options = inference_pb2.InferenceOptions()
    self.options.CopyFrom(options)
    # Convert thresholds, etc. to logit values for efficient inference.
    for attr in ('init_activation', 'pad_value', 'move_threshold',
                 'segment_threshold'):
      setattr(self.options, attr, logit(getattr(self.options, attr)))

    self.halt_signaler = halt_signaler

    self.counters = counters if counters is not None else Counters()
    self.checkpoint_interval_sec = checkpoint_interval_sec
    self.checkpoint_path = checkpoint_path
    self.checkpoint_last = time.time()

    self._keep_history = keep_history
    self.corner_zyx = corner_zyx
    self.shape = image.shape

    if restrictor is None:
      self.restrictor = movement.MovementRestrictor()
    else:
      self.restrictor = restrictor

    # Cast to array to ensure we can do elementwise expressions later.
    # All of these are in zyx order.
    self._pred_size = np.array(model.pred_mask_size[::-1])
    self._input_seed_size = np.array(model.input_seed_size[::-1])
    self._input_image_size = np.array(model.input_image_size[::-1])
    self.margin = self._input_image_size // 2

    self._pred_delta = (self._input_seed_size - self._pred_size) // 2
    assert np.all(self._pred_delta >= 0)

    # Current working area. This represents an object probability map
    # in logit form, and is fed directly as the mask input to the FFN
    # model.
    self.seed = np.zeros(self.shape, dtype=np.float32)
    self.segmentation = np.zeros(self.shape, dtype=np.int32)
    self.seg_prob = np.zeros(self.shape, dtype=np.uint8)

    # When an initial segmentation is provided, maps the global ID space
    # to locally used IDs.
    self.global_to_local_ids = {}

    self.seed_policy = None
    self._seed_policy_state = None

    # Maximum segment ID already assigned.
    self._max_id = 0

    # Maps of segment id -> ..
    self.origins = {}  # seed location
    self.overlaps = {}  # (ids, number overlapping voxels)

    # Whether to always create a new seed in segment_at.
    self.reset_seed_per_segment = True

    if movement_policy_fn is None:
      # The model.deltas are (for now) in xyz order and must be swapped to zyx.
      self.movement_policy = movement.FaceMaxMovementPolicy(
          self, deltas=model.deltas[::-1],
          score_threshold=self.options.move_threshold)
    else:
      self.movement_policy = movement_policy_fn(self)

    self.reset_state((0, 0, 0))
    self.t_last_predict = None

  def _register_client(self):
    if self._exec_client_id is None:
      self._exec_client_id = self.executor.start_client()
      logging.info('Registered as client %d.', self._exec_client_id)

  def _deregister_client(self):
    if self._exec_client_id is not None:
      logging.info('Deregistering client %d', self._exec_client_id)
      self.executor.finish_client(self._exec_client_id)
      self._exec_client_id = None

  def __del__(self):
    # Note that the presence of this method will cause a memory leak in
    # case the Canvas object is part of a reference cycle. Use weakref.proxy
    # where such cycles are really needed.
    self._deregister_client()

  def local_id(self, segment_id):
    return self.global_to_local_ids.get(segment_id, segment_id)

  def reset_state(self, start_pos):
    # Resetting the movement_policy is currently necessary to update the
    # policy's bitmask for whether a position is already segmented (the
    # canvas updates the segmented mask only between calls to segment_at
    # and therefore the policy does not update this mask for every call.).
    self.movement_policy.reset_state(start_pos)
    self.history = []
    self.history_deleted = []

    self._min_pos = np.array(start_pos)
    self._max_pos = np.array(start_pos)
    self._register_client()

  def is_valid_pos(self, pos, ignore_move_threshold=False):
    """Returns True if segmentation should be attempted at the given position.

    Args:
      pos: position to check as (z, y, x)
      ignore_move_threshold: (boolean) when starting a new segment at pos the
          move threshold can and must be ignored.

    Returns:
      Boolean indicating whether to run FFN inference at the given position.
    """

    if not ignore_move_threshold:
      if self.seed[pos] < self.options.move_threshold:
        self.counters['skip_threshold'].Increment()
        logging.debug('.. seed value below threshold.')
        return False

    # Not enough image context?
    np_pos = np.array(pos)
    low = np_pos - self.margin
    high = np_pos + self.margin

    if np.any(low < 0) or np.any(high >= self.shape):
      self.counters['skip_invalid_pos'].Increment()
      logging.debug('.. too close to border: %r', pos)
      return False

    # Location already segmented?
    if self.segmentation[pos] > 0:
      self.counters['skip_invalid_pos'].Increment()
      logging.debug('.. segmentation already active: %r', pos)
      return False

    return True

  def predict(self, pos, logit_seed, extra_fetches):
    """Runs a single step of FFN prediction.

    Args:
      pos: (z, y, x) position of the center of the FoV
      logit_seed: current seed to feed to the model as input, z, y, x ndarray
      extra_fetches: dict of additional fetches to retrieve, can be empty

    Returns:
      tuple of:
        (logistic prediction, logits)
        dict of additional fetches corresponding to extra_fetches
    """
    with timer_counter(self.counters, 'predict'):
      # Top-left corner of the FoV.
      start = np.array(pos) - self.margin
      end = start + self._input_image_size
      img = self.image[[slice(s, e) for s, e in zip(start, end)]]

      # Record the amount of time spent on non-prediction tasks.
      if self.t_last_predict is not None:
        delta_t = time.time() - self.t_last_predict
        self.counters['inference-not-predict-ms'].IncrementBy(
            delta_t * MSEC_IN_SEC)

      extra_fetches['logits'] = self.model.logits
      with timer_counter(self.counters, 'inference'):
        fetches = self.executor.predict(self._exec_client_id,
                                        logit_seed, img, extra_fetches)

      self.t_last_predict = time.time()

    logits = fetches.pop('logits')
    prob = expit(logits)
    return (prob[..., 0], logits[..., 0]), fetches

  def update_at(self, pos, start_pos):
    """Updates object mask prediction at a specific position.

    Note that depending on the settings of the canvas, the update might involve
    more than 1 inference run of the FFN.

    Args:
      pos: (z, y, x) position of the center of the FoV
      start_pos: (z, y, x) position from which the segmentation of the current
          object has started

    Returns:
      ndarray of the predicted mask in logit space
    """
    with timer_counter(self.counters, 'update_at'):
      off = self._input_seed_size // 2  # zyx

      start = np.array(pos) - off
      end = start + self._input_seed_size
      logit_seed = np.array(
          self.seed[[slice(s, e) for s, e in zip(start, end)]])
      init_prediction = np.isnan(logit_seed)
      logit_seed[init_prediction] = np.float32(self.options.pad_value)

      extra_fetches = {f: getattr(self.model, f) for f
                       in self.halt_signaler.extra_fetches}

      prob_seed = expit(logit_seed)
      for _ in range(MAX_SELF_CONSISTENT_ITERS):
        (prob, logits), fetches = self.predict(pos, logit_seed,
                                               extra_fetches=extra_fetches)
        if self.options.consistency_threshold <= 0:
          break

        diff = np.average(np.abs(prob_seed - prob))
        if diff < self.options.consistency_threshold:
          break

        prob_seed, logit_seed = prob, logits

      if self.halt_signaler.is_halt(fetches=fetches, pos=pos,
                                    orig_pos=start_pos,
                                    counters=self.counters):
        logits[:] = np.float32(self.options.pad_value)

      start += self._pred_delta
      end = start + self._pred_size
      sel = [slice(s, e) for s, e in zip(start, end)]

      # Bias towards oversegmentation by making it impossible to reverse
      # disconnectedness predictions in the course of inference.
      if self.options.disco_seed_threshold >= 0:
        th_max = logit(0.5)
        old_seed = self.seed[sel]

        if self._keep_history:
          self.history_deleted.append(
              np.sum((old_seed >= logit(0.8)) & (logits < th_max)))

        if (np.mean(logits >= self.options.move_threshold) >
            self.options.disco_seed_threshold):
          # Because (x > NaN) is always False, this mask excludes positions that
          # were previously uninitialized (i.e. set to NaN in old_seed).
          try:
            old_err = np.seterr(invalid='ignore')
            mask = ((old_seed < th_max) & (logits > old_seed))
          finally:
            np.seterr(**old_err)
          logits[mask] = old_seed[mask]

      # Update working space.
      self.seed[sel] = logits

    return logits

  def init_seed(self, pos):
    """Reinitiailizes the object mask with a seed.

    Args:
      pos: position at which to place the seed (z, y, x)
    """
    self.seed[...] = np.nan
    self.seed[pos] = self.options.init_activation

  def segment_at(self, start_pos, dynamic_image=None,
                 vis_update_every=10,
                 vis_fixed_z=False):
    """Runs FFN segmentation starting from a specific point.

    Args:
      start_pos: location at which to run segmentation as (z, y, x)
      dynamic_image: optional DynamicImage object which to update with
          a visualization of the segmentation state
      vis_update_every: number of FFN iterations between subsequent
          updates of the dynamic image
      vis_fixed_z: if True, the z position used for visualization is
          fixed at the starting value specified in `pos`. Otherwise,
          the current FoV of the FFN is used to determine what to
          visualize.

    Returns:
      number of iterations performed
    """
    if self.reset_seed_per_segment:
      self.init_seed(start_pos)
    # This includes a reset of the movement policy, see comment in method body.
    self.reset_state(start_pos)

    num_iters = 0

    if not self.movement_policy:
      # Add first element with arbitrary priority 1 (it will be consumed
      # right away anyway).
      item = (self.movement_policy.score_threshold * 2, start_pos)
      self.movement_policy.append(item)

    with timer_counter(self.counters, 'segment_at-loop'):
      for pos in self.movement_policy:
        # Terminate early if the seed got too weak.
        if self.seed[start_pos] < self.options.move_threshold:
          self.counters['seed_got_too_weak'].Increment()
          break

        if not self.restrictor.is_valid_pos(pos):
          self.counters['skip_restriced_pos'].Increment()
          continue

        pred = self.update_at(pos, start_pos)
        self._min_pos = np.minimum(self._min_pos, pos)
        self._max_pos = np.maximum(self._max_pos, pos)
        num_iters += 1

        with timer_counter(self.counters, 'movement_policy'):
          self.movement_policy.update(pred, pos)

        with timer_counter(self.counters, 'segment_at-overhead'):
          if self._keep_history:
            self.history.append(pos)

          if dynamic_image is not None and num_iters % vis_update_every == 0:
            vis_pos = pos if not vis_fixed_z else (start_pos[0], pos[1],
                                                   pos[2])
            visualize_state(self.seed, vis_pos, self.movement_policy,
                            dynamic_image)

          assert np.all(pred.shape == self._pred_size)

          self._maybe_save_checkpoint()

    return num_iters

  def log_info(self, string, *args, **kwargs):
    logging.info('[cl %d] ' + string, self._exec_client_id,
                 *args, **kwargs)

  def segment_all(self, seed_policy=seed.PolicyPeaks):
    """Segments the input image.

    Segmentation is attempted from all valid starting points provided by
    the seed policy.

    Args:
      seed_policy: callable taking the image and the canvas object as arguments
          and returning an iterator over proposed seed point.
    """
    self.seed_policy = seed_policy(self)
    if self._seed_policy_state is not None:
      self.seed_policy.set_state(self._seed_policy_state)
      self._seed_policy_state = None

    with timer_counter(self.counters, 'segment_all'):
      mbd = self.options.min_boundary_dist
      mbd = np.array([mbd.z, mbd.y, mbd.x])

      for pos in TimedIter(self.seed_policy, self.counters, 'seed-policy'):
        # When starting a new segment the move_threshold on the probability
        # should be ignored when determining if the position is valid.
        if not (self.is_valid_pos(pos, ignore_move_threshold=True)
                and self.restrictor.is_valid_pos(pos)
                and self.restrictor.is_valid_seed(pos)):
          continue

        self._maybe_save_checkpoint()

        # Too close to an existing segment?
        low = np.array(pos) - mbd
        high = np.array(pos) + mbd + 1
        sel = [slice(s, e) for s, e in zip(low, high)]
        if np.any(self.segmentation[sel] > 0):
          logging.debug('Too close to existing segment.')
          self.segmentation[pos] = -1
          continue

        self.log_info('Starting segmentation at %r (zyx)', pos)

        # Try segmentation.
        seg_start = time.time()
        num_iters = self.segment_at(pos)
        t_seg = time.time() - seg_start

        # Check if segmentation was successful.
        if num_iters <= 0:
          self.counters['invalid-other-time-ms'].IncrementBy(t_seg *
                                                             MSEC_IN_SEC)
          self.log_info('Failed: num iters was %d', num_iters)
          continue

        # Original seed too weak?
        if self.seed[pos] < self.options.move_threshold:
          # Mark this location as excluded.
          if self.segmentation[pos] == 0:
            self.segmentation[pos] = -1
          self.log_info('Failed: weak seed')
          self.counters['invalid-weak-time-ms'].IncrementBy(t_seg * MSEC_IN_SEC)
          continue

        # Restrict probability map -> segment processing to a bounding box
        # covering the area that was actually changed by the FFN. In case the
        # segment is going to be rejected due to small size, this can
        # significantly reduce processing time.
        sel = [slice(max(s, 0), e + 1) for s, e in zip(
            self._min_pos - self._pred_size // 2,
            self._max_pos + self._pred_size // 2)]

        # We only allow creation of new segments in areas that are currently
        # empty.
        mask = self.seed[sel] >= self.options.segment_threshold
        raw_segmented_voxels = np.sum(mask)

        # Record existing segment IDs overlapped by the newly added object.
        overlapped_ids, counts = np.unique(self.segmentation[sel][mask],
                                           return_counts=True)
        valid = overlapped_ids > 0
        overlapped_ids = overlapped_ids[valid]
        counts = counts[valid]

        mask &= self.segmentation[sel] <= 0
        actual_segmented_voxels = np.sum(mask)

        # Segment too small?
        if actual_segmented_voxels < self.options.min_segment_size:
          if self.segmentation[pos] == 0:
            self.segmentation[pos] = -1
          self.log_info('Failed: too small: %d', actual_segmented_voxels)
          self.counters['invalid-small-time-ms'].IncrementBy(t_seg *
                                                             MSEC_IN_SEC)
          continue

        self.counters['voxels-segmented'].IncrementBy(actual_segmented_voxels)
        self.counters['voxels-overlapping'].IncrementBy(
            raw_segmented_voxels - actual_segmented_voxels)

        # Find the next free ID to assign.
        self._max_id += 1
        while self._max_id in self.origins:
          self._max_id += 1

        self.segmentation[sel][mask] = self._max_id
        self.seg_prob[sel][mask] = storage.quantize_probability(
            expit(self.seed[sel][mask]))

        self.log_info('Created supervoxel:%d  seed(zyx):%s  size:%d  iters:%d',
                      self._max_id, pos,
                      actual_segmented_voxels, num_iters)

        self.overlaps[self._max_id] = np.array([overlapped_ids, counts])

        # Record information about how a given supervoxel was created.
        self.origins[self._max_id] = storage.OriginInfo(pos, num_iters, t_seg)
        self.counters['valid-time-ms'].IncrementBy(t_seg * MSEC_IN_SEC)

    self.log_info('Segmentation done.')

    # It is important to deregister ourselves when the segmentation is complete.
    # This matters particularly if less than a full batch of subvolumes remains
    # to be segmented. Without the deregistration, the executor will wait to
    # fill the full batch (no longer possible) instead of proceeding with
    # inference.
    self._deregister_client()

  def init_segmentation_from_volume(self, volume, corner, end,
                                    align_and_crop=None):
    """Initializes segmentation from an existing VolumeStore.

    This is useful to start inference with an existing segmentation.
    The segmentation does not need to be generated with an FFN.

    Args:
      volume: volume object, as returned by storage.decorated_volume.
      corner: location at which to read data as (z, y, x)
      end: location at which to stop reading data as (z, y, x)
      align_and_crop: callable to align & crop a 3d segmentation subvolume
    """
    self.log_info('Loading initial segmentation from (zyx) %r:%r',
                  corner, end)

    init_seg = volume[:,  #
                      corner[0]:end[0],  #
                      corner[1]:end[1],  #
                      corner[2]:end[2]]

    init_seg, global_to_local = segmentation.make_labels_contiguous(init_seg)
    init_seg = init_seg[0, ...]

    self.global_to_local_ids = dict(global_to_local)

    self.log_info('Segmentation loaded, shape: %r. Canvas segmentation is %r',
                  init_seg.shape, self.segmentation.shape)
    if align_and_crop is not None:
      init_seg = align_and_crop(init_seg)
      self.log_info('Segmentation cropped to: %r', init_seg.shape)

    self.segmentation[:] = init_seg
    self.seg_prob[self.segmentation > 0] = storage.quantize_probability(
        np.array([1.0]))
    self._max_id = np.max(self.segmentation)
    self.log_info('Max restored ID is: %d.', self._max_id)

  def restore_checkpoint(self, path):
    """Restores state from the checkpoint at `path`."""
    self.log_info('Restoring inference checkpoint: %s', path)
    with gfile.Open(path, 'rb') as f:
      data = np.load(f)

      self.segmentation[:] = data['segmentation']
      self.seed[:] = data['seed']
      self.seg_prob[:] = data['seg_qprob']
      self.history_deleted = list(data['history_deleted'])
      self.history = list(data['history'])
      self.origins = data['origins'].item()
      if 'overlaps' in data:
        self.overlaps = data['overlaps'].item()

      segmented_voxels = np.sum(self.segmentation != 0)
      self.counters['voxels-segmented'].Set(segmented_voxels)
      self._max_id = np.max(self.segmentation)

      self.movement_policy.restore_state(data['movement_policy'])

      seed_policy_state = data['seed_policy_state']
      # When restoring the state of a previously unused Canvas, the seed
      # policy will not be defined. We just save the seed policy state here
      # for future use in .segment_all().
      self._seed_policy_state = seed_policy_state

      self.counters.loads(data['counters'].item())

    self.log_info('Inference checkpoint restored.')

  def save_checkpoint(self, path):
    """Saves a inference checkpoint to `path`."""
    self.log_info('Saving inference checkpoint to %s.', path)
    with timer_counter(self.counters, 'save_checkpoint'):
      gfile.MakeDirs(os.path.dirname(path))
      with storage.atomic_file(path) as fd:
        seed_policy_state = None
        if self.seed_policy is not None:
          seed_policy_state = self.seed_policy.get_state()

        np.savez_compressed(fd,
                            movement_policy=self.movement_policy.get_state(),
                            segmentation=self.segmentation,
                            seg_qprob=self.seg_prob,
                            seed=self.seed,
                            origins=self.origins,
                            overlaps=self.overlaps,
                            history=np.array(self.history),
                            history_deleted=np.array(self.history_deleted),
                            seed_policy_state=seed_policy_state,
                            counters=self.counters.dumps())
    self.log_info('Inference checkpoint saved.')

  def _maybe_save_checkpoint(self):
    """Attempts to save a checkpoint.

    A checkpoint is only saved if the canvas is configured to keep checkpoints
    and if sufficient time has passed since the last one was saved.
    """
    if self.checkpoint_path is None or self.checkpoint_interval_sec <= 0:
      return

    if time.time() - self.checkpoint_last < self.checkpoint_interval_sec:
      return

    self.save_checkpoint(self.checkpoint_path)
    self.checkpoint_last = time.time()


class Runner(object):
  """Helper for managing FFN inference runs.

  Takes care of initializing the FFN model and any related functionality
  (e.g. movement policies), as well as input/output of the FFN inference
  data (loading inputs, saving segmentations).
  """

  ALL_MASKED = 1

  def __init__(self):
    self.counters = inference_utils.Counters()
    self.executor = None

  def __del__(self):
    self.stop_executor()

  def stop_executor(self):
    """Shuts down the executor.

    No-op when no executor is active.
    """
    if self.executor is not None:
      self.executor.stop_server()
      self.executor = None

  def _load_model_checkpoint(self, checkpoint_path):
    """Restores the inference model from a training checkpoint.

    Args:
      checkpoint_path: the string path to the checkpoint file to load
    """
    with timer_counter(self.counters, 'restore-tf-checkpoint'):
      logging.info('Loading checkpoint.')
      self.model.saver.restore(self.session, checkpoint_path)
      logging.info('Checkpoint loaded.')

  def start(self, request, batch_size=1, exec_cls=None, session=None):
    """Opens input volumes and initializes the FFN."""
    self.request = request
    assert self.request.segmentation_output_dir

    logging.debug('Received request:\n%s', request)

    if not gfile.Exists(request.segmentation_output_dir):
      gfile.MakeDirs(request.segmentation_output_dir)

    with timer_counter(self.counters, 'volstore-open'):
      # Disabling cache compression can improve access times by 20-30%
      # as of Aug 2016.
      self._image_volume = storage.decorated_volume(
          request.image, cache_max_bytes=int(1e8),
          cache_compression=False)
      assert self._image_volume is not None

      if request.HasField('init_segmentation'):
        self.init_seg_volume = storage.decorated_volume(
            request.init_segmentation, cache_max_bytes=int(1e8))
      else:
        self.init_seg_volume = None

      def _open_or_none(settings):
        if settings.WhichOneof('volume_path') is None:
          return None
        return storage.decorated_volume(
            settings, cache_max_bytes=int(1e7), cache_compression=False)
      self._mask_volumes = {}
      self._shift_mask_volume = _open_or_none(request.shift_mask)

      alignment_options = request.alignment_options
      null_alignment = inference_pb2.AlignmentOptions.NO_ALIGNMENT
      if not alignment_options or alignment_options.type == null_alignment:
        self._aligner = align.Aligner()
      else:
        type_name = inference_pb2.AlignmentOptions.AlignType.Name(
            alignment_options.type)
        error_string = 'Alignment for type %s is not implemented' % type_name
        logging.error(error_string)
        raise NotImplementedError(error_string)

      def _open_or_none(settings):
        if settings.WhichOneof('volume_path') is None:
          return None
        return storage.decorated_volume(
            settings, cache_max_bytes=int(1e7), cache_compression=False)
      self._mask_volumes = {}
      self._shift_mask_volume = _open_or_none(request.shift_mask)

    if request.reference_histogram:
      with gfile.Open(request.reference_histogram, 'rb') as f:
        data = np.load(f)
        self._reference_lut = data['lut']
    else:
      self._reference_lut = None

    self.stop_executor()

    if session is None:
      config = tf.ConfigProto()
      tf.reset_default_graph()
      session = tf.Session(config=config)
    self.session = session
    logging.info('Available TF devices: %r', self.session.list_devices())

    # Initialize the FFN model.
    model_class = import_symbol(request.model_name)
    if request.model_args:
      args = json.loads(request.model_args)
    else:
      args = {}

    args['batch_size'] = batch_size
    self.model = model_class(**args)

    if exec_cls is None:
      exec_cls = executor.ThreadingBatchExecutor

    self.executor = exec_cls(
        self.model, self.session, self.counters, batch_size)
    self.movement_policy_fn = movement.get_policy_fn(request, self.model)

    self.saver = tf.train.Saver()
    self._load_model_checkpoint(request.model_checkpoint_path)

    self.executor.start_server()

  def make_restrictor(self, corner, subvol_size, image, alignment):
    """Builds a MovementRestrictor object."""
    kwargs = {}

    if self.request.masks:
      with timer_counter(self.counters, 'load-mask'):
        final_mask = storage.build_mask(self.request.masks,
                                        corner, subvol_size,
                                        self._mask_volumes,
                                        image, alignment)

        if np.all(final_mask):
          logging.info('Everything masked.')
          return self.ALL_MASKED

        kwargs['mask'] = final_mask

    if self.request.seed_masks:
      with timer_counter(self.counters, 'load-seed-mask'):
        seed_mask = storage.build_mask(self.request.seed_masks,
                                       corner, subvol_size,
                                       self._mask_volumes,
                                       image, alignment)

        if np.all(seed_mask):
          logging.info('All seeds masked.')
          return self.ALL_MASKED

        kwargs['seed_mask'] = seed_mask

    if self._shift_mask_volume:
      with timer_counter(self.counters, 'load-shift-mask'):
        s = self.request.shift_mask_scale
        shift_corner = np.array(corner) // (1, s, s)
        shift_size = -(-np.array(subvol_size) // (1, s, s))

        shift_alignment = alignment.rescaled(
            np.array((1.0, 1.0, 1.0)) / (1, s, s))
        src_corner, src_size = shift_alignment.expand_bounds(
            shift_corner, shift_size, forward=False)
        src_corner, src_size = storage.clip_subvolume_to_bounds(
            src_corner, src_size, self._shift_mask_volume)
        src_end = src_corner + src_size

        expanded_shift_mask = self._shift_mask_volume[
            0:2,  #
            src_corner[0]:src_end[0],  #
            src_corner[1]:src_end[1],  #
            src_corner[2]:src_end[2]]
        shift_mask = np.array([
            shift_alignment.align_and_crop(
                src_corner, expanded_shift_mask[i], shift_corner, shift_size)
            for i in range(2)])
        shift_mask = alignment.transform_shift_mask(corner, s, shift_mask)

        if self.request.HasField('shift_mask_fov'):
          shift_mask_fov = bounding_box.BoundingBox(
              start=self.request.shift_mask_fov.start,
              size=self.request.shift_mask_fov.size)
        else:
          shift_mask_diameter = np.array(self.model.input_image_size)
          shift_mask_fov = bounding_box.BoundingBox(
              start=-(shift_mask_diameter // 2), size=shift_mask_diameter)

        kwargs.update({
            'shift_mask': shift_mask,
            'shift_mask_fov': shift_mask_fov,
            'shift_mask_scale': self.request.shift_mask_scale,
            'shift_mask_threshold': self.request.shift_mask_threshold})

    if kwargs:
      return movement.MovementRestrictor(**kwargs)
    else:
      return None

  def make_canvas(self, corner, subvol_size, **canvas_kwargs):
    """Builds the Canvas object for inference on a subvolume.

    Args:
      corner: start of the subvolume (z, y, x)
      subvol_size: size of the subvolume (z, y, x)
      **canvas_kwargs: passed to Canvas

    Returns:
      A tuple of:
        Canvas object
        Alignment object
    """
    subvol_counters = self.counters.get_sub_counters()
    with timer_counter(subvol_counters, 'load-image'):
      logging.info('Process subvolume: %r', corner)

      # A Subvolume with bounds defined by (src_size, src_corner) is guaranteed
      # to result in no missing data when aligned to (dst_size, dst_corner).
      # Likewise, one defined by (dst_size, dst_corner) is guaranteed to result
      # in no missing data when reverse-aligned to (corner, subvol_size).
      alignment = self._aligner.generate_alignment(corner, subvol_size)

      # Bounding box for the aligned destination subvolume.
      dst_corner, dst_size = alignment.expand_bounds(
          corner, subvol_size, forward=True)
      # Bounding box for the pre-aligned imageset to be fetched from the volume.
      src_corner, src_size = alignment.expand_bounds(
          dst_corner, dst_size, forward=False)
      # Ensure that the request bounds don't extend beyond volume bounds.
      src_corner, src_size = storage.clip_subvolume_to_bounds(
          src_corner, src_size, self._image_volume)

      logging.info('Requested bounds are %r + %r', corner, subvol_size)
      logging.info('Destination bounds are %r + %r', dst_corner, dst_size)
      logging.info('Fetch bounds are %r + %r', src_corner, src_size)

      # Fetch the image from the volume using the src bounding box.
      def get_data_3d(volume, bbox):
        slc = bbox.to_slice()
        if volume.ndim == 4:
          slc = np.index_exp[0:1] + slc
        data = volume[slc]
        if data.ndim == 4:
          data = data.squeeze(axis=0)
        return data
      src_bbox = bounding_box.BoundingBox(
          start=src_corner[::-1], size=src_size[::-1])
      src_image = get_data_3d(self._image_volume, src_bbox)
      logging.info('Fetched image of size %r prior to transform',
                   src_image.shape)

      def align_and_crop(image):
        return alignment.align_and_crop(src_corner, image, dst_corner, dst_size,
                                        forward=True)

      # Align and crop to the dst bounding box.
      image = align_and_crop(src_image)
      # image now has corner dst_corner and size dst_size.

      logging.info('Image data loaded, shape: %r.', image.shape)

    restrictor = self.make_restrictor(dst_corner, dst_size, image, alignment)

    try:
      if self._reference_lut is not None:
        if self.request.histogram_masks:
          histogram_mask = storage.build_mask(self.request.histogram_masks,
                                              dst_corner, dst_size,
                                              self._mask_volumes,
                                              image, alignment)
        else:
          histogram_mask = None

        inference_utils.match_histogram(image, self._reference_lut,
                                        mask=histogram_mask)
    except ValueError as e:
      # This can happen if the subvolume is relatively small because of tiling
      # done by CLAHE. For now we just ignore these subvolumes.
      # TODO(mjanusz): Handle these cases by reducing the number of tiles.
      logging.info('Could not match histogram: %r', e)
      return None, None

    image = (image.astype(np.float32) -
             self.request.image_mean) / self.request.image_stddev
    if restrictor == self.ALL_MASKED:
      return None, None

    if self.request.HasField('self_prediction'):
      halt_signaler = self_prediction_halt(
          self.request.self_prediction.threshold,
          orig_threshold=self.request.self_prediction.orig_threshold,
          verbosity=PRINT_HALTS)
    else:
      halt_signaler = no_halt()

    canvas = Canvas(
        self.model,
        self.executor,
        image,
        self.request.inference_options,
        counters=subvol_counters,
        restrictor=restrictor,
        movement_policy_fn=self.movement_policy_fn,
        halt_signaler=halt_signaler,
        checkpoint_path=storage.checkpoint_path(
            self.request.segmentation_output_dir, corner),
        checkpoint_interval_sec=self.request.checkpoint_interval,
        corner_zyx=dst_corner,
        **canvas_kwargs)

    if self.request.HasField('init_segmentation'):
      canvas.init_segmentation_from_volume(self.init_seg_volume, src_corner,
                                           src_bbox.end[::-1], align_and_crop)
    return canvas, alignment

  def get_seed_policy(self, corner, subvol_size):
    """Get seed policy generating callable.

    Args:
      corner: the original corner of the requested subvolume, before any
          modification e.g. dynamic alignment.
      subvol_size: the original requested size.

    Returns:
      A callable for generating seed policies.
    """
    policy_cls = getattr(seed, self.request.seed_policy)
    kwargs = {'corner': corner, 'subvol_size': subvol_size}
    if self.request.seed_policy_args:
      kwargs.update(json.loads(self.request.seed_policy_args))
    return functools.partial(policy_cls, **kwargs)

  def save_segmentation(self, canvas, alignment, target_path, prob_path):
    """Saves segmentation to a file.

    Args:
      canvas: Canvas object containing the segmentation
      alignment: the local Alignment used with the canvas, or None
      target_path: path to the file where the segmentation should
          be saved
      prob_path: path to the file where the segmentation probability
          map should be saved
    """
    def unalign_image(im3d):
      if alignment is None:
        return im3d
      return alignment.align_and_crop(
          canvas.corner_zyx,
          im3d,
          alignment.corner,
          alignment.size,
          forward=False)

    def unalign_origins(origins, canvas_corner):
      out_origins = dict()
      for key, value in origins.items():
        zyx = np.array(value.start_zyx) + canvas_corner
        zyx = alignment.transform(zyx[:, np.newaxis], forward=False).squeeze()
        zyx -= canvas_corner
        out_origins[key] = value._replace(start_zyx=tuple(zyx))
      return out_origins

    # Remove markers.
    canvas.segmentation[canvas.segmentation < 0] = 0

    # Save segmentation results. Reduce # of bits per item if possible.
    storage.save_subvolume(
        unalign_image(canvas.segmentation),
        unalign_origins(canvas.origins, np.array(canvas.corner_zyx)),
        target_path,
        request=self.request.SerializeToString(),
        counters=canvas.counters.dumps(),
        overlaps=canvas.overlaps)

    # Save probability map separately. This has to happen after the
    # segmentation is saved, as `save_subvolume` will create any necessary
    # directories.
    prob = unalign_image(canvas.seg_prob)
    with storage.atomic_file(prob_path) as fd:
      np.savez_compressed(fd, qprob=prob)

  def run(self, corner, subvol_size, reset_counters=True):
    """Runs FFN inference over a subvolume.

    Args:
      corner: start of the subvolume (z, y, x)
      subvol_size: size of the subvolume (z, y, x)
      reset_counters: whether to reset the counters

    Returns:
      Canvas object with the segmentation or None if the canvas could not
      be created or the segmentation subvolume already exists.
    """
    if reset_counters:
      self.counters.reset()

    seg_path = storage.segmentation_path(
        self.request.segmentation_output_dir, corner)
    prob_path = storage.object_prob_path(
        self.request.segmentation_output_dir, corner)
    cpoint_path = storage.checkpoint_path(
        self.request.segmentation_output_dir, corner)

    if gfile.Exists(seg_path):
      return None

    canvas, alignment = self.make_canvas(corner, subvol_size)
    if canvas is None:
      return None

    if gfile.Exists(cpoint_path):
      canvas.restore_checkpoint(cpoint_path)

    if self.request.alignment_options.save_raw:
      image_path = storage.subvolume_path(self.request.segmentation_output_dir,
                                          corner, 'align')
      with storage.atomic_file(image_path) as fd:
        np.savez_compressed(fd, im=canvas.image)

    canvas.segment_all(seed_policy=self.get_seed_policy(corner, subvol_size))
    self.save_segmentation(canvas, alignment, seg_path, prob_path)

    # Attempt to remove the checkpoint file now that we no longer need it.
    try:
      gfile.Remove(cpoint_path)
    except:  # pylint: disable=bare-except
      pass

    return canvas