# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Base estimator defining TCN training, test, and inference."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from abc import ABCMeta
from abc import abstractmethod
import os
import numpy as np
import numpy as np
import data_providers
import preprocessing
from utils import util
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.python.training import session_run_hook

tf.app.flags.DEFINE_integer(
    'tf_random_seed', 0, 'Random seed.')
FLAGS = tf.app.flags.FLAGS


class InitFromPretrainedCheckpointHook(session_run_hook.SessionRunHook):
  """Hook that can init graph from a pretrained checkpoint."""

  def __init__(self, pretrained_checkpoint_dir):
    """Initializes a `InitFromPretrainedCheckpointHook`.

    Args:
      pretrained_checkpoint_dir: The dir of pretrained checkpoint.

    Raises:
      ValueError: If pretrained_checkpoint_dir is invalid.
    """
    if pretrained_checkpoint_dir is None:
      raise ValueError('pretrained_checkpoint_dir must be specified.')
    self._pretrained_checkpoint_dir = pretrained_checkpoint_dir

  def begin(self):
    checkpoint_reader = tf.contrib.framework.load_checkpoint(
        self._pretrained_checkpoint_dir)
    variable_shape_map = checkpoint_reader.get_variable_to_shape_map()

    exclude_scopes = 'logits/,final_layer/,aux_'
    # Skip restoring global_step as to run fine tuning from step=0.
    exclusions = ['global_step']
    if exclude_scopes:
      exclusions.extend([scope.strip() for scope in exclude_scopes.split(',')])

    variable_to_restore = tf.contrib.framework.get_model_variables()

    # Variable filtering by given exclude_scopes.
    filtered_variables_to_restore = {}
    for v in variable_to_restore:
      excluded = False
      for exclusion in exclusions:
        if v.name.startswith(exclusion):
          excluded = True
          break
      if not excluded:
        var_name = v.name.split(':')[0]
        filtered_variables_to_restore[var_name] = v

    # Final filter by checking shape matching and skipping variables that
    # are not in the checkpoint.
    final_variables_to_restore = {}
    for var_name, var_tensor in filtered_variables_to_restore.iteritems():
      if var_name not in variable_shape_map:
        # Try moving average version of variable.
        var_name = os.path.join(var_name, 'ExponentialMovingAverage')
        if var_name not in variable_shape_map:
          tf.logging.info(
              'Skip init [%s] because it is not in ckpt.', var_name)
          # Skip variables not in the checkpoint.
          continue

      if not var_tensor.get_shape().is_compatible_with(
          variable_shape_map[var_name]):
        # Skip init variable from ckpt if shape dismatch.
        tf.logging.info(
            'Skip init [%s] from [%s] in ckpt because shape dismatch: %s vs %s',
            var_tensor.name, var_name,
            var_tensor.get_shape(), variable_shape_map[var_name])
        continue

      tf.logging.info('Init %s from %s in ckpt' % (var_tensor, var_name))
      final_variables_to_restore[var_name] = var_tensor

    self._init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
        self._pretrained_checkpoint_dir,
        final_variables_to_restore)

  def after_create_session(self, session, coord):
    tf.logging.info('Restoring InceptionV3 weights.')
    self._init_fn(session)
    tf.logging.info('Done restoring InceptionV3 weights.')


class BaseEstimator(object):
  """Abstract TCN base estimator class."""
  __metaclass__ = ABCMeta

  def __init__(self, config, logdir):
    """Constructor.

    Args:
      config: A Luatable-like T object holding training config.
      logdir: String, a directory where checkpoints and summaries are written.
    """
    self._config = config
    self._logdir = logdir

  @abstractmethod
  def construct_input_fn(self, records, is_training):
    """Builds an estimator input_fn.

    The input_fn is used to pass feature and target data to the train,
    evaluate, and predict methods of the Estimator.

    Method to be overridden by implementations.

    Args:
      records: A list of Strings, paths to TFRecords with image data.
      is_training: Boolean, whether or not we're training.

    Returns:
      Function, that has signature of ()->(dict of features, target).
        features is a dict mapping feature names to `Tensors`
        containing the corresponding feature data (typically, just a single
        key/value pair 'raw_data' -> image `Tensor` for TCN.
        labels is a 1-D int32 `Tensor` holding labels.
    """
    pass

  def preprocess_data(self, images, is_training):
    """Preprocesses raw images for either training or inference.

    Args:
      images: A 4-D float32 `Tensor` holding images to preprocess.
      is_training: Boolean, whether or not we're in training.

    Returns:
      data_preprocessed: data after the preprocessor.
    """
    config = self._config
    height = config.data.height
    width = config.data.width
    min_scale = config.data.augmentation.minscale
    max_scale = config.data.augmentation.maxscale
    p_scale_up = config.data.augmentation.proportion_scaled_up
    aug_color = config.data.augmentation.color
    fast_mode = config.data.augmentation.fast_mode
    crop_strategy = config.data.preprocessing.eval_cropping
    preprocessed_images = preprocessing.preprocess_images(
        images, is_training, height, width,
        min_scale, max_scale, p_scale_up,
        aug_color=aug_color, fast_mode=fast_mode,
        crop_strategy=crop_strategy)
    return preprocessed_images

  @abstractmethod
  def forward(self, images, is_training, reuse=False):
    """Defines the forward pass that converts batch images to embeddings.

    Method to be overridden by implementations.

    Args:
      images: A 4-D float32 `Tensor` holding images to be embedded.
      is_training: Boolean, whether or not we're in training mode.
      reuse: Boolean, whether or not to reuse embedder.
    Returns:
      embeddings: A 2-D float32 `Tensor` holding embedded images.
    """
    pass

  @abstractmethod
  def define_loss(self, embeddings, labels, is_training):
    """Defines the loss function on the embedding vectors.

    Method to be overridden by implementations.

    Args:
      embeddings: A 2-D float32 `Tensor` holding embedded images.
      labels: A 1-D int32 `Tensor` holding problem labels.
      is_training: Boolean, whether or not we're in training mode.

    Returns:
      loss: tf.float32 scalar.
    """
    pass

  @abstractmethod
  def define_eval_metric_ops(self):
    """Defines the dictionary of eval metric tensors.

    Method to be overridden by implementations.

    Returns:
      eval_metric_ops:  A dict of name/value pairs specifying the
        metrics that will be calculated when the model runs in EVAL mode.
    """
    pass

  def get_train_op(self, loss):
    """Creates a training op.

    Args:
      loss: A float32 `Tensor` representing the total training loss.
    Returns:
      train_op: A slim.learning.create_train_op train_op.
    Raises:
      ValueError: If specified optimizer isn't supported.
    """
    # Get variables to train (defined in subclass).
    assert self.variables_to_train

    # Define a learning rate schedule.
    decay_steps = self._config.learning.decay_steps
    decay_factor = self._config.learning.decay_factor
    learning_rate = float(self._config.learning.learning_rate)

    # Define a learning rate schedule.
    global_step = slim.get_or_create_global_step()
    learning_rate = tf.train.exponential_decay(
        learning_rate,
        global_step,
        decay_steps,
        decay_factor,
        staircase=True)

    # Create an optimizer.
    opt_type = self._config.learning.optimizer
    if opt_type == 'adam':
      opt = tf.train.AdamOptimizer(learning_rate)
    elif opt_type == 'momentum':
      opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
    elif opt_type == 'rmsprop':
      opt = tf.train.RMSPropOptimizer(learning_rate, momentum=0.9,
                                      epsilon=1.0, decay=0.9)
    else:
      raise ValueError('Unsupported optimizer %s' % opt_type)

    if self._config.use_tpu:
      opt = tpu_optimizer.CrossShardOptimizer(opt)

    # Create a training op.
    # train_op = opt.minimize(loss, var_list=self.variables_to_train)
    # Create a training op.
    train_op = slim.learning.create_train_op(
        loss,
        optimizer=opt,
        variables_to_train=self.variables_to_train,
        update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))

    return train_op

  def _get_model_fn(self):
    """Defines behavior for training, evaluation, and inference (prediction).

    Returns:
      `model_fn` for `Estimator`.
    """
    # pylint: disable=unused-argument
    def model_fn(features, labels, mode, params):
      """Build the model based on features, labels, and mode.

      Args:
        features: Dict, strings to `Tensor` input data, returned by the
          input_fn.
        labels: The labels Tensor returned by the input_fn.
        mode: A string indicating the mode. This will be either
          tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.PREDICT,
          or tf.estimator.ModeKeys.EVAL.
        params: A dict holding training parameters, passed in during TPU
          training.

      Returns:
        A tf.estimator.EstimatorSpec specifying train/test/inference behavior.
      """
      is_training = mode == tf.estimator.ModeKeys.TRAIN

      # Get preprocessed images from the features dict.
      batch_preprocessed = features['batch_preprocessed']

      # Do a forward pass to embed data.
      batch_encoded = self.forward(batch_preprocessed, is_training)

      # Optionally set the pretrained initialization function.
      initializer_fn = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        initializer_fn = self.pretrained_init_fn

      # If we're training or evaluating, define total loss.
      total_loss = None
      if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
        loss = self.define_loss(batch_encoded, labels, is_training)
        tf.losses.add_loss(loss)
        total_loss = tf.losses.get_total_loss()

      # If we're training, define a train op.
      train_op = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = self.get_train_op(total_loss)

      # If we're doing inference, set the output to be the embedded images.
      predictions_dict = None
      if mode == tf.estimator.ModeKeys.PREDICT:
        predictions_dict = {'embeddings': batch_encoded}
        # Pass through additional metadata stored in features.
        for k, v in features.iteritems():
          predictions_dict[k] = v

      # If we're evaluating, define some eval metrics.
      eval_metric_ops = None
      if mode == tf.estimator.ModeKeys.EVAL:
        eval_metric_ops = self.define_eval_metric_ops()

      # Define training scaffold to load pretrained weights.
      num_checkpoint_to_keep = self._config.logging.checkpoint.num_to_keep
      saver = tf.train.Saver(
          max_to_keep=num_checkpoint_to_keep)

      if is_training and self._config.use_tpu:
        # TPU doesn't have a scaffold option at the moment, so initialize
        # pretrained weights using a custom train_hook instead.
        return tpu_estimator.TPUEstimatorSpec(
            mode,
            loss=total_loss,
            eval_metrics=None,
            train_op=train_op,
            predictions=predictions_dict)
      else:
        # Build a scaffold to initialize pretrained weights.
        scaffold = tf.train.Scaffold(
            init_fn=initializer_fn,
            saver=saver,
            summary_op=None)
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions_dict,
            loss=total_loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            scaffold=scaffold)
    return model_fn

  def train(self):
    """Runs training."""
    # Get a list of training tfrecords.
    config = self._config
    training_dir = config.data.training
    training_records = util.GetFilesRecursively(training_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    train_input_fn = self.construct_input_fn(
        training_records, is_training=True)

    # Create the estimator.
    estimator = self._build_estimator(is_training=True)

    train_hooks = None
    if config.use_tpu:
      # TPU training initializes pretrained weights using a custom train hook.
      train_hooks = []
      if tf.train.latest_checkpoint(self._logdir) is None:
        train_hooks.append(
            InitFromPretrainedCheckpointHook(
                config[config.embedder_strategy].pretrained_checkpoint))

    # Run training.
    estimator.train(input_fn=train_input_fn, hooks=train_hooks,
                    steps=config.learning.max_step)

  def _build_estimator(self, is_training):
    """Returns an Estimator object.

    Args:
      is_training: Boolean, whether or not we're in training mode.

    Returns:
      A tf.estimator.Estimator.
    """
    config = self._config
    save_checkpoints_steps = config.logging.checkpoint.save_checkpoints_steps
    keep_checkpoint_max = self._config.logging.checkpoint.num_to_keep
    if is_training and config.use_tpu:
      iterations = config.tpu.iterations
      num_shards = config.tpu.num_shards
      run_config = tpu_config.RunConfig(
          save_checkpoints_secs=None,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          master=FLAGS.master,
          evaluation_master=FLAGS.master,
          model_dir=self._logdir,
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=iterations,
              num_shards=num_shards,
              per_host_input_for_training=num_shards <= 8),
          tf_random_seed=FLAGS.tf_random_seed)

      batch_size = config.data.batch_size
      return tpu_estimator.TPUEstimator(
          model_fn=self._get_model_fn(),
          config=run_config,
          use_tpu=True,
          train_batch_size=batch_size,
          eval_batch_size=batch_size)
    else:
      run_config = tf.estimator.RunConfig().replace(
          model_dir=self._logdir,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          tf_random_seed=FLAGS.tf_random_seed)
      return tf.estimator.Estimator(
          model_fn=self._get_model_fn(),
          config=run_config)

  def evaluate(self):
    """Runs `Estimator` validation.
    """
    config = self._config

    # Get a list of validation tfrecords.
    validation_dir = config.data.validation
    validation_records = util.GetFilesRecursively(validation_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    validation_input_fn = self.construct_input_fn(
        validation_records, False)

    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Run validation.
    eval_batch_size = config.data.batch_size
    num_eval_samples = config.val.num_eval_samples
    num_eval_batches = int(num_eval_samples / eval_batch_size)
    estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches)

  def inference(
      self, inference_input, checkpoint_path, batch_size=None, **kwargs):
    """Defines 3 of modes of inference.

    Inputs:
    * Mode 1: Input is an input_fn.
    * Mode 2: Input is a TFRecord (or list of TFRecords).
    * Mode 3: Input is a numpy array holding an image (or array of images).

    Outputs:
    * Mode 1: this returns an iterator over embeddings and additional
      metadata. See
      https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#predict
      for details.
    * Mode 2: Returns an iterator over tuples of
      (embeddings, raw_image_strings, sequence_name), where embeddings is a
      2-D float32 numpy array holding [sequence_size, embedding_size] image
      embeddings, raw_image_strings is a 1-D string numpy array holding
      [sequence_size] jpeg-encoded image strings, and sequence_name is a
      string holding the name of the embedded sequence.
    * Mode 3: Returns a tuple of (embeddings, raw_image_strings), where
      embeddings is a 2-D float32 numpy array holding
      [batch_size, embedding_size] image embeddings, raw_image_strings is a
      1-D string numpy array holding [batch_size] jpeg-encoded image strings.

    Args:
      inference_input: This can be a tf.Estimator input_fn, a TFRecord path,
        a list of TFRecord paths, a numpy image, or an array of numpy images.
      checkpoint_path: String, path to the checkpoint to restore for inference.
      batch_size: Int, the size of the batch to use for inference.
      **kwargs: Additional keyword arguments, depending on the mode.
        See _input_fn_inference, _tfrecord_inference, and _np_inference.
    Returns:
      inference_output: Inference output depending on mode, see above for
        details.
    Raises:
      ValueError: If inference_input isn't a tf.Estimator input_fn,
        a TFRecord path, a list of TFRecord paths, or a numpy array,
    """
    # Mode 1: input is a callable tf.Estimator input_fn.
    if callable(inference_input):
      return self._input_fn_inference(
          input_fn=inference_input, checkpoint_path=checkpoint_path, **kwargs)
    # Mode 2: Input is a TFRecord path (or list of TFRecord paths).
    elif util.is_tfrecord_input(inference_input):
      return self._tfrecord_inference(
          records=inference_input, checkpoint_path=checkpoint_path,
          batch_size=batch_size, **kwargs)
    # Mode 3: Input is a numpy array of raw images.
    elif util.is_np_array(inference_input):
      return self._np_inference(
          np_images=inference_input, checkpoint_path=checkpoint_path, **kwargs)
    else:
      raise ValueError(
          'inference input must be a tf.Estimator input_fn, a TFRecord path,'
          'a list of TFRecord paths, or a numpy array. Got: %s' % str(type(
              inference_input)))

  def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None):
    """Mode 1: tf.Estimator inference.

    Args:
      input_fn: Function, that has signature of ()->(dict of features, None).
        This is a function called by the estimator to get input tensors (stored
        in the features dict) to do inference over.
      checkpoint_path: String, path to a specific checkpoint to restore.
      predict_keys: List of strings, the keys of the `Tensors` in the features
        dict (returned by the input_fn) to evaluate during inference.
    Returns:
      predictions: An Iterator, yielding evaluated values of `Tensors`
        specified in `predict_keys`.
    """
    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Create an iterator of predicted embeddings.
    predictions = estimator.predict(input_fn=input_fn,
                                    checkpoint_path=checkpoint_path,
                                    predict_keys=predict_keys)
    return predictions

  def _tfrecord_inference(self, records, checkpoint_path, batch_size,
                          num_sequences=-1, reuse=False):
    """Mode 2: TFRecord inference.

    Args:
      records: List of strings, paths to TFRecords.
      checkpoint_path: String, path to a specific checkpoint to restore.
      batch_size: Int, size of inference batch.
      num_sequences: Int, number of sequences to embed. If -1,
        embed everything.
      reuse: Boolean, whether or not to reuse embedder weights.
    Yields:
      (embeddings, raw_image_strings, sequence_name):
        embeddings is a 2-D float32 numpy array holding
        [sequence_size, embedding_size] image embeddings.
        raw_image_strings is a 1-D string numpy array holding
        [sequence_size] jpeg-encoded image strings.
        sequence_name is a string holding the name of the embedded sequence.
    """
    tf.reset_default_graph()
    if not isinstance(records, list):
      records = list(records)

    # Map the list of tfrecords to a dataset of preprocessed images.
    num_views = self._config.data.num_views
    (views, task, seq_len) = data_providers.full_sequence_provider(
        records, num_views)
    tensor_dict = {
        'raw_image_strings': views,
        'task': task,
        'seq_len': seq_len
    }

    # Create a preprocess function over raw image string placeholders.
    image_str_placeholder = tf.placeholder(tf.string, shape=[None])
    decoded = preprocessing.decode_images(image_str_placeholder)
    decoded.set_shape([batch_size, None, None, 3])
    preprocessed = self.preprocess_data(decoded, is_training=False)

    # Create an inference graph over preprocessed images.
    embeddings = self.forward(preprocessed, is_training=False, reuse=reuse)

    # Create a saver to restore model variables.
    tf.train.get_or_create_global_step()
    saver = tf.train.Saver(tf.all_variables())

    # Create a session and restore model variables.
    with tf.train.MonitoredSession() as sess:
      saver.restore(sess, checkpoint_path)
      cnt = 0
      # If num_sequences is specified, embed that many sequences, else embed
      # everything.
      try:
        while cnt < num_sequences if num_sequences != -1 else True:
          # Get a preprocessed image sequence.
          np_data = sess.run(tensor_dict)
          np_raw_images = np_data['raw_image_strings']
          np_seq_len = np_data['seq_len']
          np_task = np_data['task']

          # Embed each view.
          embedding_size = self._config.embedding_size
          view_embeddings = [
              np.zeros((0, embedding_size)) for _ in range(num_views)]
          for view_index in range(num_views):
            view_raw = np_raw_images[view_index]
            # Embed the full sequence.
            t = 0
            while t < np_seq_len:
              # Decode and preprocess the batch of image strings.
              embeddings_np = sess.run(
                  embeddings, feed_dict={
                      image_str_placeholder: view_raw[t:t+batch_size]})
              view_embeddings[view_index] = np.append(
                  view_embeddings[view_index], embeddings_np, axis=0)
              tf.logging.info('Embedded %d images for task %s' % (t, np_task))
              t += batch_size

          # Done embedding for all views.
          view_raw_images = np_data['raw_image_strings']
          yield (view_embeddings, view_raw_images, np_task)
          cnt += 1
      except tf.errors.OutOfRangeError:
        tf.logging.info('Done embedding entire dataset.')

  def _np_inference(self, np_images, checkpoint_path):
    """Mode 3: Call this repeatedly to do inference over numpy images.

    This mode is for when we we want to do real-time inference over
    some stream of images (represented as numpy arrays).

    Args:
      np_images: A float32 numpy array holding images to embed.
      checkpoint_path: String, path to a specific checkpoint to restore.
    Returns:
      (embeddings, raw_image_strings):
        embeddings is a 2-D float32 numpy array holding
        [inferred batch_size, embedding_size] image embeddings.
        raw_image_strings is a 1-D string numpy array holding
        [inferred batch_size] jpeg-encoded image strings.
    """
    if isinstance(np_images, list):
      np_images = np.asarray(np_images)
    # Add a batch dimension if only 3-dimensional.
    if len(np_images.shape) == 3:
      np_images = np.expand_dims(np_images, axis=0)

    # If np_images are in the range [0,255], convert to [0,1].
    assert np.min(np_images) >= 0.
    if (np.min(np_images), np.max(np_images)) == (0, 255):
      np_images = np_images.astype(np.float32) / 255.
      assert (np.min(np_images), np.max(np_images)) == (0., 1.)

    # If this is the first pass, set up inference graph.
    if not hasattr(self, '_np_inf_tensor_dict'):
      self._setup_np_inference(np_images, checkpoint_path)

    # Convert np_images to embeddings.
    np_tensor_dict = self._sess.run(self._np_inf_tensor_dict, feed_dict={
        self._image_placeholder: np_images
    })
    return np_tensor_dict['embeddings'], np_tensor_dict['raw_image_strings']

  def _setup_np_inference(self, np_images, checkpoint_path):
    """Sets up and restores inference graph, creates and caches a Session."""
    tf.logging.info('Restoring model weights.')

    # Define inference over an image placeholder.
    _, height, width, _ = np.shape(np_images)
    image_placeholder = tf.placeholder(
        tf.float32, shape=(None, height, width, 3))

    # Preprocess batch.
    preprocessed = self.preprocess_data(image_placeholder, is_training=False)

    # Unscale and jpeg encode preprocessed images for display purposes.
    im_strings = preprocessing.unscale_jpeg_encode(preprocessed)

    # Do forward pass to get embeddings.
    embeddings = self.forward(preprocessed, is_training=False)

    # Create a saver to restore model variables.
    tf.train.get_or_create_global_step()
    saver = tf.train.Saver(tf.all_variables())

    self._image_placeholder = image_placeholder
    self._batch_encoded = embeddings

    self._np_inf_tensor_dict = {
        'embeddings': embeddings,
        'raw_image_strings': im_strings,
    }

    # Create a session and restore model variables.
    self._sess = tf.Session()
    saver.restore(self._sess, checkpoint_path)