# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# pylint: disable=line-too-long
r"""Binary for training and evaluating a model."""
# pylint: enable=line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import math
import os
import time

import numpy as np
import tensorflow as tf
from typing import Dict, List, Tuple

# pylint: disable=g-bad-import-order
from isl import augment
from isl import controller
from isl import data_provider
from isl import infer
from isl import util
from isl.models import concordance
from isl.models import model_util

slim = tf.contrib.slim
metrics = tf.contrib.metrics
app = tf.app
logging = tf.logging
flags = tf.flags
gfile = tf.gfile
lt = tf.contrib.labeled_tensor

MODE_TRAIN = 'TRAIN'
MODE_EVAL_TRAIN = 'EVAL_TRAIN'
MODE_EVAL_EVAL = 'EVAL_EVAL'
MODE_EXPORT = 'EXPORT'
flags.DEFINE_string('mode', MODE_TRAIN, 'What this binary will do.')

METRIC_LOSS = 'LOSS'
METRIC_JITTER_STITCH = 'JITTER_STITCH'
METRIC_STITCH = 'STITCH'
METRIC_INFER_FULL = 'INFER_FULL'
flags.DEFINE_string('metric', METRIC_LOSS, 'What this binary will display.')

flags.DEFINE_string('master', 'local',
                    'BNS name of the TensorFlow master to use.')
flags.DEFINE_string('base_directory', '/tmp/minception/',
                    'Directory where model checkpoints are written.')
flags.DEFINE_string('export_directory', '/tmp/minception_export/',
                    'Directory where exported model is written.')
flags.DEFINE_integer(
    'save_summaries_secs', 180,
    'The frequency with which summaries are saved, in seconds.')
flags.DEFINE_integer('save_interval_secs', 180,
                     'The frequency with which the model is saved, in seconds.')
flags.DEFINE_integer('eval_interval_secs', 15,
                     'The frequency, in seconds, with which evaluation is run.')
flags.DEFINE_integer('eval_delay_secs', 0,
                     'The time to wait before starting evaluations.')
flags.DEFINE_integer(
    'metric_num_examples', 1 << 10,
    'The number of examples to use when computing tf.slim metrics.')
flags.DEFINE_integer(
    'ps_tasks', 0,
    'The number of parameter servers. If the value is 0, then the parameters '
    'are handled locally by the worker.')
flags.DEFINE_integer(
    'task', 0,
    'The Task ID. This value is used when training with multiple workers to '
    'identify each worker.')
flags.DEFINE_string(
    'restore_directory', '',
    'If provided, the directory from which to restore a model checkpoint for '
    'training or exporting.')

OPTIMIZER_MOMENTUM = 'MOMENTUM'
OPTIMIZER_ADAGRAD = 'ADAGRAD'
OPTIMIZER_ADAM = 'ADAM'
flags.DEFINE_string('optimizer', 'ADAM', 'The train optimizer.')

flags.DEFINE_float('learning_rate', 1e-4, 'The learning rate.')

flags.DEFINE_integer(
    'learning_decay_steps', 1 << 12,
    'The learning decay steps, used by the MOMENTUM optimizer.')

flags.DEFINE_bool(
    'read_pngs', False,
    'Whether to read the input images from a provided folder rather than '
    'from a RecordIO or SSTable.')

# Parameters for when read_pngs is True.
flags.DEFINE_string('dataset_train_directory', None,
                    'If read_pngs, the directory containing the train dataset.')
flags.DEFINE_string(
    'dataset_eval_directory', None,
    'If read_pngs, the directory containing the evaluation dataset.')

# Parameters for when read_pngs is False.
flags.DEFINE_string(
    'dataset_pattern', None,
    'If not read_pngs, format string giving the dataset location. '
    'It will be subdivided into train and eval sets.')
flags.DEFINE_integer('dataset_num_shards', 1024,
                     'If not read_pngs, the number of shards in the dataset.')
flags.DEFINE_bool(
    'is_recordio', True,
    'If not read_pngs, whether the dataset is stored as a RecordIO, '
    'else an SSTable.')
flags.DEFINE_integer(
    'data_batch_size', 4,
    'If not read_pngs, batch size for first part of preprocessing.')
flags.DEFINE_integer(
    'data_batch_num_threads', 4,
    'If not read_pngs, number of threads loading data from disk.')
flags.DEFINE_integer(
    'data_batch_capacity', 8,
    'If not read_pngs, batch capacity for threads loading data from disk.')

flags.DEFINE_integer('loss_crop_size', 520, 'Image crop size for training.')
flags.DEFINE_integer('loss_patch_stride', 256, '')
flags.DEFINE_integer('stitch_crop_size', 500, 'Image crop size for stitching.')
flags.DEFINE_integer(
    'infer_size', 16,
    'The number of inferences to do in parallel in each row x column dimension.'
    ' For example, a size of 16 will do 16 x 16 = 256 inferences in parallel.')
flags.DEFINE_bool('infer_continuously', False,
                  'Whether to run inference in a while loop.')
flags.DEFINE_string('infer_channel_whitelist', None,
                    'If provided, the channels to whitelist for inference.')
flags.DEFINE_bool('infer_simplify_error_panels', True,
                  'Whether to simplify the error panels.')

flags.DEFINE_float('augment_offset_std', 0.0,
                   'Augmentation noise corruption parameter.')
flags.DEFINE_float('augment_multiplier_std', 0.0,
                   'Augmentation noise corruption parameter.')
flags.DEFINE_float('augment_noise_std', 0.0,
                   'Augmentation noise corruption parameter.')

flags.DEFINE_integer('preprocess_batch_size', 16, 'Batch size for the model.')
flags.DEFINE_integer(
    'preprocess_shuffle_batch_num_threads', 16,
    'Number of threads doing the second half of preprocessing during training.')
flags.DEFINE_integer('preprocess_batch_capacity', 64,
                     'Batch capacity for second half of preprocessing.')

flags.DEFINE_bool(
    'train_on_full_dataset', False,
    'If true, train on the full dataset, not the subset used for training in '
    'the training / evaluation split. Useful for getting the last bit of '
    'performance out of a model we trust.')

LOSS_CROSS_ENTROPY = 'CROSS_ENTROPY'
LOSS_RANKED_PROBABILITY_SCORE = 'RANKED_PROBABILITY_SCORE'
flags.DEFINE_string('loss', LOSS_CROSS_ENTROPY, 'The loss to use.')

MODEL_CONCORDANCE = 'CONCORDANCE'
flags.DEFINE_string('model', MODEL_CONCORDANCE, 'The network model to use.')

flags.DEFINE_integer('base_depth', 400, 'Model parameter.')

flags.DEFINE_bool(
    'restore_logits', True,
    'Whether to restore the heads when resuming training. Set to False if you '
    'want to add or remove heads but restore the rest of the network.')
flags.DEFINE_bool(
    'restore_inputs', True,
    'Whether to restore the input layers when resuming training. Set to False '
    'if you are restoring from another model checkpoint where the input '
    'dimension does not match the input dimension of this network.  For '
    'instance, the DeepLab pretrained models only have three channels for RGB.')

flags.DEFINE_integer('num_z_values', 26,
                     'Number of z depths to use from input.')

FLAGS = flags.FLAGS


# TODO(ericmc): Consider simplifying this using np.linspace.
def get_z_values() -> List[float]:
  """Gets the z-values the model will take as input."""
  values = np.linspace(0.0, 1.0, FLAGS.num_z_values)
  values = [round(v, 4) for v in values]
  logging.info('z_values: %r', values)
  return values


INPUT_CHANNEL_VALUES = [
    'BRIGHTFIELD',
    'PHASE_CONTRAST',
    'DIC',
]
TARGET_Z_VALUES = ['MAXPROJECT']
TARGET_CHANNEL_VALUES = [
    'DAPI_CONFOCAL',
    'DAPI_WIDEFIELD',
    'CELLMASK_CONFOCAL',
    'TUJ1_WIDEFIELD',
    'NFH_CONFOCAL',
    'MAP2_CONFOCAL',
    'ISLET_WIDEFIELD',
    'DEAD_CONFOCAL',
]

# The size of the extracted input patches.
CONCORDANCE_EXTRACT_PATCH_SIZE = 250
# The size of the model output.
CONCORDANCE_STITCH_PATCH_SIZE = 8
# The stride to use when stitching.
# It is almost always equal to STITCH_PATCH_SIZE except when debugging.
CONCORDANCE_STITCH_STRIDE = 8

# The number of classes into which to bin pixel values.
NUM_CLASSES = 256


def data_parameters() -> data_provider.DataParameters:
  """Creates the DataParameters."""
  if FLAGS.read_pngs:
    if FLAGS.mode == MODE_TRAIN or FLAGS.mode == MODE_EVAL_TRAIN:
      directory = FLAGS.dataset_train_directory
    else:
      directory = FLAGS.dataset_eval_directory

    if FLAGS.metric == METRIC_LOSS:
      crop_size = FLAGS.loss_crop_size
    else:
      crop_size = FLAGS.stitch_crop_size

    io_parameters = data_provider.ReadPNGsParameters(directory, None, None,
                                                     crop_size)
  else:
    # Use an eighth of the dataset for validation.
    if FLAGS.mode == MODE_TRAIN or FLAGS.mode == MODE_EVAL_TRAIN:
      dataset = [
          FLAGS.dataset_pattern % i
          for i in range(FLAGS.dataset_num_shards)
          if (i % 8 != 0) or FLAGS.train_on_full_dataset
      ]
    else:
      dataset = [
          FLAGS.dataset_pattern % i
          for i in range(FLAGS.dataset_num_shards)
          if i % 8 == 0
      ]
    if FLAGS.metric == METRIC_LOSS:
      crop_size = FLAGS.loss_crop_size
    else:
      crop_size = FLAGS.stitch_crop_size

    if FLAGS.model == MODEL_CONCORDANCE:
      extract_patch_size = CONCORDANCE_EXTRACT_PATCH_SIZE
      stitch_patch_size = CONCORDANCE_STITCH_PATCH_SIZE
    else:
      raise NotImplementedError('Unsupported model: %s' % FLAGS.model)

    if FLAGS.mode == MODE_EXPORT:
      # Any padding will be done by the C++ caller.
      pad_width = 0
    else:
      pad_width = (extract_patch_size - stitch_patch_size) // 2

    io_parameters = data_provider.ReadTableParameters(
        dataset,
        FLAGS.is_recordio,
        util.BatchParameters(FLAGS.data_batch_size,
                             FLAGS.data_batch_num_threads,
                             FLAGS.data_batch_capacity),
        # Do non-deterministic data fetching, to increase the variety of what we
        # see in the visualizer.
        False,
        pad_width,
        crop_size)

  z_values = get_z_values()
  return data_provider.DataParameters(io_parameters, z_values,
                                      INPUT_CHANNEL_VALUES, TARGET_Z_VALUES,
                                      TARGET_CHANNEL_VALUES)


def parameters() -> controller.GetInputTargetAndPredictedParameters:
  """Creates the network parameters for the given inputs and flags.

  Returns:
    A GetInputTargetAndPredictedParameters containing network parameters for the
    given mode, metric, and other flags.
  """
  if FLAGS.metric == METRIC_LOSS:
    stride = FLAGS.loss_patch_stride
    shuffle = True
  else:
    if FLAGS.model == MODEL_CONCORDANCE:
      stride = CONCORDANCE_STITCH_STRIDE
    else:
      raise NotImplementedError('Unsupported model: %s' % FLAGS.model)
    # Shuffling breaks stitching.
    shuffle = False

  if FLAGS.mode == MODE_TRAIN:
    is_train = True
  else:
    is_train = False

  if FLAGS.model == MODEL_CONCORDANCE:
    core_model = functools.partial(concordance.core, FLAGS.base_depth)
    add_head = functools.partial(model_util.add_head, is_residual_conv=True)
    extract_patch_size = CONCORDANCE_EXTRACT_PATCH_SIZE
    stitch_patch_size = CONCORDANCE_STITCH_PATCH_SIZE
  else:
    raise NotImplementedError('Unsupported model: %s' % FLAGS.model)

  dp = data_parameters()

  if shuffle:
    preprocess_num_threads = FLAGS.preprocess_shuffle_batch_num_threads
  else:
    # Thread racing is an additional source of shuffling, so we can only
    # use 1 thread per queue.
    preprocess_num_threads = 1
  if is_train or FLAGS.metric == METRIC_JITTER_STITCH:
    ap = augment.AugmentParameters(FLAGS.augment_offset_std,
                                   FLAGS.augment_multiplier_std,
                                   FLAGS.augment_noise_std)
  else:
    ap = None

  if FLAGS.metric == METRIC_INFER_FULL:
    bp = None
  else:
    bp = util.BatchParameters(FLAGS.preprocess_batch_size,
                              preprocess_num_threads,
                              FLAGS.preprocess_batch_capacity)

  if FLAGS.loss == LOSS_CROSS_ENTROPY:
    loss = util.softmax_cross_entropy
  elif FLAGS.loss == LOSS_RANKED_PROBABILITY_SCORE:
    loss = util.ranked_probability_score
  else:
    logging.fatal('Invalid loss: %s', FLAGS.loss)

  return controller.GetInputTargetAndPredictedParameters(
      dp, ap, extract_patch_size, stride, stitch_patch_size, bp, core_model,
      add_head, shuffle, NUM_CLASSES, loss, is_train)


def train_directory() -> str:
  """The directory where the training data is written."""
  return os.path.join(FLAGS.base_directory, 'train')


def output_directory() -> str:
  """The output directory for the current invocation of this binary."""
  if FLAGS.mode == MODE_TRAIN:
    return train_directory()
  else:
    if FLAGS.mode == MODE_EVAL_TRAIN:
      prefix = 'eval_train_'
    else:
      prefix = 'eval_eval_'

    if FLAGS.metric == METRIC_INFER_FULL:
      suffix = 'infer'
    elif FLAGS.metric == METRIC_LOSS:
      suffix = 'loss_' + FLAGS.loss
    elif FLAGS.metric == METRIC_JITTER_STITCH:
      suffix = 'jitter_stitch'
    else:
      suffix = 'stitch'

    return os.path.join(FLAGS.base_directory, prefix + suffix)


def total_loss(
    gitapp: controller.GetInputTargetAndPredictedParameters,
) -> Tuple[tf.Tensor, Dict[str, lt.LabeledTensor], Dict[str, lt.LabeledTensor]]:
  """Get the total weighted training loss."""
  input_loss_lts, target_loss_lts = controller.setup_losses(gitapp)

  def mean(lts: Dict[str, lt.LabeledTensor]) -> tf.Tensor:
    sum_op = tf.add_n([t.tensor for t in lts.values()])
    return sum_op / float(len(lts))

  # Give the input loss the same weight as the target loss.
  input_weight = 0.5
  total_loss_op = input_weight * mean(input_loss_lts) + (
      1 - input_weight) * mean(target_loss_lts)
  tf.summary.scalar('total_loss', total_loss_op)

  return total_loss_op, input_loss_lts, target_loss_lts


def log_entry_points(g: tf.Graph):
  logging.info('Entry points: %s',
               [o.name for o in g.get_operations() if 'entry_point' in o.name])


def train(gitapp: controller.GetInputTargetAndPredictedParameters):
  """Train a model."""
  g = tf.Graph()
  with g.as_default():
    total_loss_op, _, _ = total_loss(gitapp)

    if FLAGS.optimizer == OPTIMIZER_MOMENTUM:
      # TODO(ericmc): We may want to do weight decay with the other
      # optimizers, too.
      learning_rate = tf.train.exponential_decay(
          FLAGS.learning_rate,
          slim.variables.get_global_step(),
          FLAGS.learning_decay_steps,
          0.999,
          staircase=False)
      tf.summary.scalar('learning_rate', learning_rate)

      optimizer = tf.train.MomentumOptimizer(learning_rate, 0.875)
    elif FLAGS.optimizer == OPTIMIZER_ADAGRAD:
      optimizer = tf.train.AdagradOptimizer(FLAGS.learning_rate)
    elif FLAGS.optimizer == OPTIMIZER_ADAM:
      optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
    else:
      raise NotImplementedError('Unsupported optimizer: %s' % FLAGS.optimizer)

    # Set up training.
    train_op = slim.learning.create_train_op(
        total_loss_op, optimizer, summarize_gradients=True)

    if FLAGS.restore_directory:
      init_fn = util.restore_model(FLAGS.restore_directory,
                                   FLAGS.restore_logits)

    else:
      logging.info('Training a new model.')
      init_fn = None

    total_variable_size, _ = slim.model_analyzer.analyze_vars(
        slim.get_variables(), print_info=True)
    logging.info('Total number of variables: %d', total_variable_size)

    log_entry_points(g)

    slim.learning.train(
        train_op=train_op,
        logdir=output_directory(),
        master=FLAGS.master,
        is_chief=FLAGS.task == 0,
        number_of_steps=None,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        init_fn=init_fn,
        saver=tf.train.Saver(keep_checkpoint_every_n_hours=2.0))


def eval_loss(gitapp: controller.GetInputTargetAndPredictedParameters):
  g = tf.Graph()
  with g.as_default():
    total_loss_op, input_loss_lts, target_loss_lts = total_loss(gitapp)

    metric_names = ['total_loss']
    metric_values = [total_loss_op]
    for name, loss_lt in dict(input_loss_lts, **target_loss_lts).items():
      metric_names.append(name)
      metric_values.append(loss_lt.tensor)
    metric_names = ['metric/' + n for n in metric_names]
    metric_values = [metrics.streaming_mean(v) for v in metric_values]

    names_to_values, names_to_updates = metrics.aggregate_metric_map(
        dict(zip(metric_names, metric_values)))

    for name, value in names_to_values.iteritems():
      slim.summaries.add_scalar_summary(value, name, print_summary=True)

    log_entry_points(g)

    num_batches = FLAGS.metric_num_examples // gitapp.bp.size

    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=train_directory(),
        logdir=output_directory(),
        num_evals=num_batches,
        eval_op=names_to_updates.values(),
        eval_interval_secs=FLAGS.eval_interval_secs)


def eval_stitch(gitapp: controller.GetInputTargetAndPredictedParameters):
  g = tf.Graph()
  with g.as_default():
    controller.setup_stitch(gitapp)

    summary_ops = tf.get_collection(tf.GraphKeys.SUMMARIES)
    input_summary_op = next(
        x for x in summary_ops if 'input_error_panel' in x.name)
    target_summary_op = next(
        x for x in summary_ops if 'target_error_panel' in x.name)

    log_entry_points(g)

    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        num_evals=0,
        checkpoint_dir=train_directory(),
        logdir=output_directory(),
        # Merge the summaries to keep the graph state in sync.
        summary_op=tf.summary.merge([input_summary_op, target_summary_op]),
        eval_interval_secs=FLAGS.eval_interval_secs)


def export(gitapp: controller.GetInputTargetAndPredictedParameters):
  g = tf.Graph()
  with g.as_default():
    assert FLAGS.metric == METRIC_STITCH

    controller.setup_stitch(gitapp)

    log_entry_points(g)

    signature_map = dict(
        [(o.name, o) for o in g.get_operations() if 'entry_point' in o.name])

    logging.info('Exporting checkpoint at %s to %s', FLAGS.restore_directory,
                 FLAGS.export_directory)
    slim.export_for_serving(
        g,
        checkpoint_dir=FLAGS.restore_directory,
        export_dir=FLAGS.export_directory,
        generic_signature_tensor_map=signature_map)


def infer_single_image(gitapp: controller.GetInputTargetAndPredictedParameters):
  """Predicts the labels for a single image."""
  if not gfile.Exists(output_directory()):
    gfile.MakeDirs(output_directory())

  if FLAGS.infer_channel_whitelist is not None:
    infer_channel_whitelist = FLAGS.infer_channel_whitelist.split(',')
  else:
    infer_channel_whitelist = None

  while True:
    infer.infer(
        gitapp=gitapp,
        restore_directory=FLAGS.restore_directory or train_directory(),
        output_directory=output_directory(),
        extract_patch_size=CONCORDANCE_EXTRACT_PATCH_SIZE,
        stitch_stride=CONCORDANCE_STITCH_STRIDE,
        infer_size=FLAGS.infer_size,
        channel_whitelist=infer_channel_whitelist,
        simplify_error_panels=FLAGS.infer_simplify_error_panels,
    )
    if not FLAGS.infer_continuously:
      break


def main(_):
  logging.set_verbosity("INFO")
  if FLAGS.mode == MODE_TRAIN:
    assert FLAGS.metric == METRIC_LOSS

  if FLAGS.task == 0 and not gfile.Exists(FLAGS.base_directory):
    gfile.MakeDirs(FLAGS.base_directory)

  gitapp = parameters()
  if FLAGS.metric == METRIC_INFER_FULL:
    infer_single_image(gitapp)
  elif FLAGS.mode == MODE_TRAIN:
    train(gitapp)
  elif FLAGS.mode == MODE_EXPORT:
    export(gitapp)
  elif FLAGS.metric == METRIC_LOSS:
    logging.info('Sleeping %d seconds before beginning evaluation',
                 FLAGS.eval_delay_secs)
    time.sleep(FLAGS.eval_delay_secs)
    eval_loss(gitapp)
  elif FLAGS.metric == METRIC_JITTER_STITCH or FLAGS.metric == METRIC_STITCH:
    logging.info('Sleeping %d seconds before beginning evaluation',
                 FLAGS.eval_delay_secs)
    time.sleep(FLAGS.eval_delay_secs)
    eval_stitch(gitapp)
  else:
    raise NotImplementedError


if __name__ == '__main__':
  app.run()