"""
YellowFin optimizer.

YellowFin and the Art of Momentum Tuning
https://arxiv.org/abs/1706.03471
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import tensorflow as tf
from tensorflow.python.framework import ops

# EPS for numerical stability
EPS = 1e-6
LARGE_FLOAT_VAL = 1e15

class YFOptimizer(object):
  """
  Optimizer that implements the YellowFin algorithm.

  Implemented as a wrapper around tf.train.MomentumOptimizer
  """
  # Available gate_gradients values
  GATE_NONE = tf.train.Optimizer.GATE_NONE
  GATE_OP = tf.train.Optimizer.GATE_OP
  GATE_GRAPH = tf.train.Optimizer.GATE_GRAPH

  def __init__(self, learning_rate=0.0001, momentum=0.0, clip_thresh=None,
               beta=0.999, curv_win_width=20, zero_debias=True, delta_mu=0.0,
               sparsity_debias=False, use_locking=False, name="YellowFin",
               use_nesterov=False, use_unsmoothed_lr_mu=True,
               h_max_log_smooth=True, h_min_log_smooth=True,
               use_adapt_grad_clip=True, stat_protect_fac=100.0):
    """
    Construct a new YellowFin optimizer.

    Args:
      learning rate: Python scalar. The initial value of learning rate,
        we use 1.0 in our paper.
      momentum: Python scalar. The initial value of momentum, we use
        0.0 in our paper.
      clip_thresh: Python scalar. The cliping threshold for
        `tf.clip_by_global_norm`. If None, no clipping will be used.
      beta: Python scalar. The smoothing parameter for estimations.
      curv_win_width: TODO
      zero_debias: TODO
      delta_mu: for extensions. Not necessary in the basic use.
      sparsity_debias: Python boolean. Gradient norm and curvature are
        biased to larger values when calculated with sparse gradient.
        This is useful when the model is very sparse, e.g. LSTM with
        word embedding. For non-sparse CNN, turning it off could
        slightly accelerate the speed.
      use_locking: If True, use locks for update operations.
      name: Optional name prefix for the operations created when
        applying gradients. Defaults to "YellowFin".
      use_nesterov: If True, the underlying MomentumOptimizer uses Nesterov
        Momentum. Set to False in the default YellowFin algorithm.

    Notes:
      `clip_thresh` is the threshold value on ||lr * gradient||
      `delta_mu` can be a placeholder/variable/python scalar. Used for
      additional momentum in situations such as asynchronous-parallel
      training. The default is 0.0 for basic usage of the optimizer.

    Other features:
      If you want to manually control the learning rates,
      `self.lr_factor` is an interface to the outside. It is a
      multiplier for the internal learning rate in YellowFin. It is
      helpful when you want to do additional hand tuning or some
      decaying scheme for the internal learning rate. Example on using
      `lr_factor` can be found here:
      https://github.com/JianGoForIt/YellowFin/blob/master/char-rnn-tensorflow/train_YF.py#L140
    """
    self._lr = learning_rate
    self._mu = momentum

    self._lr_var = tf.Variable(
      learning_rate, dtype=tf.float32, name="YF_lr", trainable=False)
    self._mu_var = tf.Variable(
      momentum, dtype=tf.float32, name="YF_mu", trainable=False)
    # for step scheme or decaying scheme for the learning rates
    self.lr_factor = tf.Variable(
      1.0, dtype=tf.float32, name="YF_lr_factor", trainable=False)
    if clip_thresh is not None:
      self._clip_thresh_var = tf.Variable(
        clip_thresh, dtype=tf.float32, name="YF_clip_thresh",
        trainable=False)
    else:
      self._clip_thresh_var = None

    # the underlying momentum optimizer
    self._optimizer = tf.train.MomentumOptimizer(
      self._lr_var * self.lr_factor, self._mu_var + delta_mu,
      use_locking, name, use_nesterov)

    # moving average for statistics
    self._beta = beta
    self._moving_averager = None

    # for global step counting
    self._global_step = tf.Variable(0, trainable=False)

    self._do_tune = tf.greater(self._global_step, tf.constant(0) )

    self._zero_debias = zero_debias
    self._sparsity_debias = sparsity_debias

    self._tvars = None

    # for curvature range
    self._curv_win_width = curv_win_width
    self._curv_win = None

    # option for using smoothed or unsmoothed lr and mu
    self._use_unsmoothed_lr_mu = use_unsmoothed_lr_mu

    # options for curvature envelop smoothing
    self._h_max_log_smooth = h_max_log_smooth
    self._h_min_log_smooth = h_min_log_smooth

    # for adaptive gradient clipping
    self._use_adapt_grad_clip = use_adapt_grad_clip
    self._adapt_grad_clip_thresh = \
      tf.Variable(LARGE_FLOAT_VAL, dtype=tf.float32, trainable=False)
    self._adapt_grad_clip_target_val = \
      tf.Variable(LARGE_FLOAT_VAL, dtype=tf.float32, trainable=False)

    # prevent exploding gradient from ruining the statistics
    self._stat_protect_fac = stat_protect_fac

  def curvature_range(self):
    # set up the curvature window
    self._curv_win = tf.Variable(
      np.zeros([self._curv_win_width, ]), dtype=tf.float32,
      name="curv_win", trainable=False)
    # we can use log smoothing for curvature range to follow trend faster
    # self._curv_win = tf.scatter_update(
    #   self._curv_win, self._global_step % self._curv_win_width,
    #   tf.log(self._grad_norm_squared + EPS))
    self._curv_win = tf.scatter_update(
      self._curv_win, self._global_step % self._curv_win_width,
      self._grad_norm_squared + EPS)
    # note here the iterations start from iteration 0
    valid_window = tf.slice(
      self._curv_win, tf.constant([0, ]), tf.expand_dims(
        tf.minimum(tf.constant(self._curv_win_width),
                   self._global_step + 1), dim=0))

    if self._h_min_log_smooth:
      self._h_min_t = tf.log(tf.reduce_min(valid_window) + EPS)
    else:
      self._h_min_t = tf.reduce_min(valid_window)
    if self._h_max_log_smooth:
      self._h_max_t = tf.log(tf.reduce_max(valid_window) + EPS)
    else:
      self._h_max_t = tf.reduce_max(valid_window)

    curv_range_ops = []
    with tf.control_dependencies([self._h_min_t, self._h_max_t] ):
      avg_op = self._moving_averager.apply(
        [self._h_min_t, self._h_max_t])
      with tf.control_dependencies([avg_op]):
        if self._h_min_log_smooth:
          self._h_min = tf.exp(
            tf.identity(self._moving_averager.average(self._h_min_t)))
        else:
          self._h_min = \
            tf.identity(self._moving_averager.average(self._h_min_t))
        if self._h_max_log_smooth:
          self._h_max = tf.exp(
            tf.identity(self._moving_averager.average(self._h_max_t)))
        else:
          self._h_max = \
            tf.identity(self._moving_averager.average(self._h_max_t))
      if self._sparsity_debias:
        self._h_min = self._h_min * self._sparsity_avg
        self._h_max = self._h_max * self._sparsity_avg
    curv_range_ops.append(avg_op)
    return curv_range_ops

  def grad_variance(self):
    grad_var_ops = []
    tensor_to_avg = []
    for t, g in zip(self._tvars, self._grads):
      if isinstance(g, ops.IndexedSlices):
        tensor_to_avg.append(
          tf.reshape(tf.unsorted_segment_sum(
            g.values, g.indices, g.dense_shape[0]),
            shape=t.get_shape()))
      else:
        tensor_to_avg.append(g)
    avg_op = self._moving_averager.apply(tensor_to_avg)
    grad_var_ops.append(avg_op)
    with tf.control_dependencies([avg_op]):
      self._grad_avg = [
        self._moving_averager.average(val) for val in tensor_to_avg]
      self._grad_avg_squared = [tf.square(val) for val in self._grad_avg]
    self._grad_var = tf.maximum(
      tf.constant(EPS, dtype=self._grad_norm_squared_avg.dtype),
      self._grad_norm_squared_avg
      - tf.add_n([tf.reduce_sum(val) for val in self._grad_avg_squared] ) )
    if self._sparsity_debias:
      self._grad_var *= self._sparsity_avg
    return grad_var_ops

  def dist_to_opt(self):
    dist_to_opt_ops = []
    # running average of the norm of gradeint
    self._grad_norm = tf.sqrt(self._grad_norm_squared)
    avg_op = self._moving_averager.apply([self._grad_norm, ])
    dist_to_opt_ops.append(avg_op)
    with tf.control_dependencies([avg_op]):
      self._grad_norm_avg = self._moving_averager.average(
        self._grad_norm)
      # single iteration distance estimation
      # note that self._grad_norm_avg is per variable
      self._dist_to_opt = (self._grad_norm_avg
                 / (self._grad_norm_squared_avg + EPS) )
    # running average of distance
    avg_op = self._moving_averager.apply([self._dist_to_opt])
    dist_to_opt_ops.append(avg_op)
    with tf.control_dependencies([avg_op]):
      self._dist_to_opt_avg = tf.identity(
        self._moving_averager.average(self._dist_to_opt))
      if self._sparsity_debias:
        self._dist_to_opt_avg /= (tf.sqrt(self._sparsity_avg) + EPS)
    return dist_to_opt_ops

  def grad_sparsity(self):
    # If the sparse minibatch gradient has 10 percent of its entries
    # non-zero, its sparsity is 0.1.
    # The norm of dense gradient averaged from full dataset
    # are roughly estimated norm of minibatch
    # sparse gradient norm * sqrt(sparsity)
    # An extension maybe only correct the sparse blob.
    non_zero_cnt = tf.add_n([tf.count_nonzero(g) for g in self._grads])
    all_entry_cnt = tf.add_n([tf.size(g) for g in self._grads])
    self._sparsity = tf.cast(non_zero_cnt, self._grads[0].dtype) \
      / tf.cast(all_entry_cnt, self._grads[0].dtype)
    avg_op = self._moving_averager.apply([self._sparsity, ])
    with tf.control_dependencies([avg_op]):
      self._sparsity_avg = self._moving_averager.average(self._sparsity)
    return avg_op

  def before_apply(self):
    self._moving_averager = tf.train.ExponentialMovingAverage(
      decay=self._beta, zero_debias=self._zero_debias)
    assert self._grads is not None and len(self._grads) > 0
    before_apply_ops = []

    # get per var g**2 and norm**2
    self._grad_squared = []
    self._grad_norm_squared = []
    for v, g in zip(self._tvars, self._grads):
      if g is None:
        continue
      with ops.colocate_with(v):
        self._grad_squared.append(tf.square(g))
    self._grad_norm_squared = [
      tf.reduce_sum(grad_squared) for grad_squared in self._grad_squared]

    if self._sparsity_debias:
      avg_op_sparsity = self.grad_sparsity()
      before_apply_ops.append(avg_op_sparsity)

    # the following running average on squared norm of gradient is shared
    # by `grad_variance` and `dist_to_opt`
    avg_op = self._moving_averager.apply(self._grad_norm_squared)
    with tf.control_dependencies([avg_op]):
      self._grad_norm_squared_avg = [self._moving_averager.average(val)
                                     for val in self._grad_norm_squared]
      self._grad_norm_squared = tf.add_n(self._grad_norm_squared)
      self._grad_norm_squared_avg = tf.add_n(self._grad_norm_squared_avg)
    before_apply_ops.append(avg_op)

    with tf.control_dependencies([avg_op]):
      curv_range_ops = self.curvature_range()
      before_apply_ops += curv_range_ops
      grad_var_ops = self.grad_variance()
      before_apply_ops += grad_var_ops
      dist_to_opt_ops = self.dist_to_opt()
      before_apply_ops += dist_to_opt_ops
    return tf.group(*before_apply_ops)

  def get_lr_tensor(self):
    lr = (1.0 - tf.sqrt(self._mu))**2 / (self._h_min + EPS)
    lr = tf.minimum(lr, lr * (tf.to_float(self._global_step) + 1.0) / 10.0 / tf.to_float(tf.constant(self._curv_win_width) ) )
    return lr

  def get_cubic_root(self):
    # We have the equation x^2 D^2 + (1-x)^4 * C / h_min^2
    # where x = sqrt(mu).
    # We substitute x, which is sqrt(mu), with x = y + 1.
    # It gives y^3 + py = q
    # where p = (D^2 h_min^2)/(2*C) and q = -p.
    # We use the Vieta's substution to compute the root.
    # There is only one real solution y (which is in [0, 1] ).
    # http://mathworld.wolfram.com/VietasSubstitution.html
    # assert_array = \
    #   [tf.Assert(tf.logical_not(tf.is_nan(self._dist_to_opt_avg) ), [self._dist_to_opt_avg,]), 
    #   tf.Assert(tf.logical_not(tf.is_nan(self._h_min) ), [self._h_min,]), 
    #   tf.Assert(tf.logical_not(tf.is_nan(self._grad_var) ), [self._grad_var,]),
    #   tf.Assert(tf.logical_not(tf.is_inf(self._dist_to_opt_avg) ), [self._dist_to_opt_avg,]), 
    #   tf.Assert(tf.logical_not(tf.is_inf(self._h_min) ), [self._h_min,]), 
    #   tf.Assert(tf.logical_not(tf.is_inf(self._grad_var) ), [self._grad_var,])]
    # with tf.control_dependencies(assert_array):
    # EPS in the numerator to prevent momentum being exactly one in case of 0 gradient
    p = (self._dist_to_opt_avg + EPS)**2 * (self._h_min + EPS)**2 / 2 / (self._grad_var + EPS)
    w3 = (-tf.sqrt(p**2 + 4.0 / 27.0 * p**3) - p) / 2.0
    w = tf.sign(w3) * tf.pow(tf.abs(w3), 1.0/3.0)
    y = w - p / 3.0 / (w + EPS)
    x = y + 1
    return x

  def get_mu_tensor(self):
    root = self.get_cubic_root()
    dr = tf.maximum( (self._h_max + EPS) / (self._h_min + EPS), 1.0 + EPS)
    mu = tf.maximum(
      root**2, ((tf.sqrt(dr) - 1) / (tf.sqrt(dr) + 1))**2)
    return mu

  def update_hyper_param(self):
    assign_hyper_ops = []
    self._mu = tf.identity(tf.cond(
      self._do_tune, lambda: self.get_mu_tensor(),
      lambda: self._mu_var))
    with tf.control_dependencies([self._mu]):
      self._lr = tf.identity(tf.cond(
        self._do_tune, lambda: self.get_lr_tensor(),
        lambda: self._lr_var))

    with tf.control_dependencies([self._mu, self._lr]):
      if self._use_unsmoothed_lr_mu:
        assign_hyper_ops.append(tf.assign(self._mu_var, self._mu) )
        assign_hyper_ops.append(tf.assign(self._lr_var, self._lr) )
      else:
        self._mu = self._beta * self._mu_var + (1 - self._beta) * self._mu
        self._lr = self._beta * self._lr_var + (1 - self._beta) * self._lr
        with tf.control_dependencies([self._mu, self._lr] ):
          assign_hyper_ops.append(tf.assign(self._mu_var, self._mu) )
          assign_hyper_ops.append(tf.assign(self._lr_var, self._lr) )
    assign_hyper_op = tf.group(*assign_hyper_ops)
    return assign_hyper_op

  def get_name(self):
      return self._optimizer.get_name()

  def apply_gradients(self, grads_tvars, global_step=None, name=None):
    self._grads, self._tvars = zip(
      *[(g, t) for g, t in grads_tvars if g is not None])

    # for manual gradient clipping
    if self._clip_thresh_var is not None:
      self._grads, self._grads_norm = tf.clip_by_global_norm(
        self._grads, self._clip_thresh_var)

    # loosely adaptive clipping of gradient in case exploding gradient ruins statistics
    if self._use_adapt_grad_clip:
      thresh = tf.cond(self._do_tune, 
        lambda: tf.sqrt(self._stat_protect_fac * self._adapt_grad_clip_thresh**2),
        lambda: tf.to_float(tf.constant(LARGE_FLOAT_VAL)))
      self._grads, self._grads_norm = tf.clip_by_global_norm(self._grads, thresh)

    with tf.variable_scope("before_apply"):
      before_apply_op = self.before_apply()

    with tf.variable_scope("update_hyper"):
      with tf.control_dependencies([before_apply_op]):
        update_hyper_op = self.update_hyper_param()

    with tf.variable_scope("apply_updates"):
      with tf.control_dependencies([update_hyper_op]):

        # clip exploding gradient according to h_max
        if self._use_adapt_grad_clip:
          thresh = tf.cond(tf.greater(tf.global_norm(self._grads), 
            self._adapt_grad_clip_thresh), 
            lambda: self._adapt_grad_clip_target_val,
            lambda: tf.to_float(tf.constant(LARGE_FLOAT_VAL)))
          self._grads, self._grads_norm = tf.clip_by_global_norm(
            self._grads, thresh)

        apply_grad_op = self._optimizer.apply_gradients(
          zip(self._grads, self._tvars), global_step, name)

    with tf.control_dependencies([apply_grad_op]):
      self._increment_global_step_op = tf.assign(
        self._global_step, self._global_step + 1)
      
      self._adapt_grad_clip_thresh_op = \
        tf.assign(self._adapt_grad_clip_thresh, tf.sqrt(self._h_max) )
      self._adapt_grad_clip_target_val_op = \
        tf.assign(self._adapt_grad_clip_target_val, tf.sqrt(self._h_max) )
      # self._adapt_grad_clip_target_val_op = \
      #   tf.assign(self._adapt_grad_clip_target_val, tf.sqrt(tf.sqrt(self._h_max * self._h_min)))

    return tf.group(before_apply_op, update_hyper_op, apply_grad_op,
                    self._adapt_grad_clip_thresh_op, self._adapt_grad_clip_target_val_op,
                    self._increment_global_step_op)


  def compute_gradients(self, loss, var_list=None,
                        gate_gradients=GATE_OP,
                        aggregation_method=None,
                        colocate_gradients_with_ops=False,
                        grad_loss=None):
    return self._optimizer.compute_gradients(
      loss, var_list=var_list,
      gate_gradients=gate_gradients,
      aggregation_method=aggregation_method,
      colocate_gradients_with_ops=colocate_gradients_with_ops,
      grad_loss=grad_loss)

  def minimize(self, loss, global_step=None, var_list=None,
               gate_gradients=GATE_OP,
               aggregation_method=None,
               colocate_gradients_with_ops=False,
               name=None,
               grad_loss=None):
    """Add operations to minimize `loss` by updating `var_list`.

    This method simply combines calls `compute_gradients()` and
    `apply_gradients()`. If you want to process the gradient before
    applying them, call `tf.gradients()` and `self.apply_gradients()`
    explicitly instead of using this function.

    Adapted from Tensorflow Optimizer base class member function.
    """
    grads_and_vars = self._optimizer.compute_gradients(
      loss, var_list=var_list,
      gate_gradients=gate_gradients,
      aggregation_method=aggregation_method,
      colocate_gradients_with_ops=colocate_gradients_with_ops,
      grad_loss=grad_loss)

    vars_with_grad = [v for g, v in grads_and_vars if g is not None]
    if not vars_with_grad:
      raise ValueError(
        "No gradients provided for any variable, check your graph for "
        "ops that do not support gradients, between variables "
        "%s and loss %s." %
        ([str(v) for _, v in grads_and_vars], loss))

    return self.apply_gradients(grads_and_vars, global_step, name)

  def get_slot(self, var, name):
    """
    Return a slot named `name` created for `var` by
    the underlying MomentumOptimizer.

    Args:
      var: A variable passed to `minimize()` or `apply_gradients()`.
      name: A string.

    Returns:
      The `Variable` for the slot if it was created, `None` otherwise.
    """
    return self._optimizer.get_slot(var, name)

  def get_slot_names(self):
    """
    Return a list of the names of the slots created by the
    underlying MomentumOptimizer.

    Returns:
      A list of strings.
    """
    return self._optimizer.get_slot_names()