# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Library with common functions for training and eval."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six

import tensorflow as tf

from tensorflow.contrib.slim.nets import resnet_v2


def default_hparams():
  """Returns default hyperparameters."""
  return tf.contrib.training.HParams(
      # Batch size for training and evaluation.
      batch_size=32,
      eval_batch_size=50,

      # General training parameters.
      weight_decay=0.0001,
      label_smoothing=0.1,

      # Parameters of the adversarial training.
      train_adv_method='clean',  # adversarial training method
      train_lp_weight=0.0,  # Weight of adversarial logit pairing loss

      # Parameters of the optimizer.
      optimizer='rms',  # possible values are: 'rms', 'momentum', 'adam'
      momentum=0.9,  # momentum
      rmsprop_decay=0.9,  # Decay term for RMSProp
      rmsprop_epsilon=1.0,  # Epsilon term for RMSProp

      # Parameters of learning rate schedule.
      lr_schedule='exp_decay',  # Possible values: 'exp_decay', 'step', 'fixed'
      learning_rate=0.045,
      lr_decay_factor=0.94,  # Learning exponential decay
      lr_num_epochs_per_decay=2.0,  # Number of epochs per lr decay
      lr_list=[1.0 / 6, 2.0 / 6, 3.0 / 6,
               4.0 / 6, 5.0 / 6, 1.0, 0.1, 0.01,
               0.001, 0.0001],
      lr_decay_epochs=[1, 2, 3, 4, 5, 30, 60, 80,
                       90])


def get_lr_schedule(hparams, examples_per_epoch, replicas_to_aggregate=1):
  """Returns TensorFlow op which compute learning rate.

  Args:
    hparams: hyper parameters.
    examples_per_epoch: number of training examples per epoch.
    replicas_to_aggregate: number of training replicas running in parallel.

  Raises:
    ValueError: if learning rate schedule specified in hparams is incorrect.

  Returns:
    learning_rate: tensor with learning rate.
    steps_per_epoch: number of training steps per epoch.
  """
  global_step = tf.train.get_or_create_global_step()
  steps_per_epoch = float(examples_per_epoch) / float(hparams.batch_size)
  if replicas_to_aggregate > 0:
    steps_per_epoch /= replicas_to_aggregate

  if hparams.lr_schedule == 'exp_decay':
    decay_steps = long(steps_per_epoch * hparams.lr_num_epochs_per_decay)
    learning_rate = tf.train.exponential_decay(
        hparams.learning_rate,
        global_step,
        decay_steps,
        hparams.lr_decay_factor,
        staircase=True)
  elif hparams.lr_schedule == 'step':
    lr_decay_steps = [long(epoch * steps_per_epoch)
                      for epoch in hparams.lr_decay_epochs]
    learning_rate = tf.train.piecewise_constant(
        global_step, lr_decay_steps, hparams.lr_list)
  elif hparams.lr_schedule == 'fixed':
    learning_rate = hparams.learning_rate
  else:
    raise ValueError('Invalid value of lr_schedule: %s' % hparams.lr_schedule)

  if replicas_to_aggregate > 0:
    learning_rate *= replicas_to_aggregate

  return learning_rate, steps_per_epoch


def get_optimizer(hparams, learning_rate):
  """Returns optimizer.

  Args:
    hparams: hyper parameters.
    learning_rate: learning rate tensor.

  Raises:
    ValueError: if type of optimizer specified in hparams is incorrect.

  Returns:
    Instance of optimizer class.
  """
  if hparams.optimizer == 'rms':
    optimizer = tf.train.RMSPropOptimizer(learning_rate,
                                          hparams.rmsprop_decay,
                                          hparams.momentum,
                                          hparams.rmsprop_epsilon)
  elif hparams.optimizer == 'momentum':
    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                           hparams.momentum)
  elif hparams.optimizer == 'adam':
    optimizer = tf.train.AdamOptimizer(learning_rate)
  else:
    raise ValueError('Invalid value of optimizer: %s' % hparams.optimizer)
  return optimizer


RESNET_MODELS = {'resnet_v2_50': resnet_v2.resnet_v2_50}


def get_model(model_name, num_classes):
  """Returns function which creates model.

  Args:
    model_name: Name of the model.
    num_classes: Number of classes.

  Raises:
    ValueError: If model_name is invalid.

  Returns:
    Function, which creates model when called.
  """
  if model_name.startswith('resnet'):
    def resnet_model(images, is_training, reuse=tf.AUTO_REUSE):
      with tf.contrib.framework.arg_scope(resnet_v2.resnet_arg_scope()):
        resnet_fn = RESNET_MODELS[model_name]
        logits, _ = resnet_fn(images, num_classes, is_training=is_training,
                              reuse=reuse)
        logits = tf.reshape(logits, [-1, num_classes])
      return logits
    return resnet_model
  else:
    raise ValueError('Invalid model: %s' % model_name)


def filter_trainable_variables(trainable_scopes):
  """Keep only trainable variables which are prefixed with given scopes.

  Args:
    trainable_scopes: either list of trainable scopes or string with comma
      separated list of trainable scopes.

  This function removes all variables which are not prefixed with given
  trainable_scopes from collection of trainable variables.
  Useful during network fine tuning, when you only need to train subset of
  variables.
  """
  if not trainable_scopes:
    return
  if isinstance(trainable_scopes, six.string_types):
    trainable_scopes = [scope.strip() for scope in trainable_scopes.split(',')]
  trainable_scopes = {scope for scope in trainable_scopes if scope}
  if not trainable_scopes:
    return
  trainable_collection = tf.get_collection_ref(
      tf.GraphKeys.TRAINABLE_VARIABLES)
  non_trainable_vars = [
      v for v in trainable_collection
      if not any([v.op.name.startswith(s) for s in trainable_scopes])
  ]
  for v in non_trainable_vars:
    trainable_collection.remove(v)