import math
import numpy as np
import torch
import copy

# eps for numerical stability
DEBUG = True
eps = 1e-6

  import logging

class YFOptimizer(object):
  def __init__(self, var_list, lr=0.1, mu=0.0, clip_thresh=None, weight_decay=0.0,
    beta=0.999, curv_win_width=20, zero_debias=True, sparsity_debias=True, delta_mu=0.0, 
    auto_clip_fac=None, force_non_inc_step=False, lr_grad_norm_thresh=1.0,
    h_max_log_smooth=False, h_min_log_smooth=False, checkpoint_interval=500):
    clip thresh is the threshold value on ||lr * gradient||
    delta_mu can be place holder/variable/python scalar. They are used for additional
    momentum in situations such as asynchronous-parallel training. The default is 0.0
    for basic usage of the optimizer.
      lr: python scalar. The initial value of learning rate, we use 1.0 in our paper.
      mu: python scalar. The initial value of momentum, we use 0.0 in our paper.
      clip_thresh: python scalar. The manaully-set clipping threshold for tf.clip_by_global_norm.
        if None, the automatic clipping can be carried out. The automatic clipping 
        feature is parameterized by argument auto_clip_fac. The auto clip feature
        can be switched off with auto_clip_fac = None
      beta: python scalar. The smoothing parameter for estimations.
      sparsity_debias: gradient norm and curvature are biased to larger values when 
      calculated with sparse gradient. This is useful when the model is very sparse,
      e.g. LSTM with word embedding. For non-sparse CNN, turning it off could slightly
      accelerate the speed.
      delta_mu: for extensions. Not necessary in the basic use. 
      force_non_inc_step: in some very rare cases, it is necessary to force ||lr * gradient||
      to be not increasing dramatically for stableness after some iterations. 
      In practice, if turned on, we enforce lr * sqrt(smoothed ||grad||^2) 
      to be less than 2x of the minimal value of historical value on smoothed || lr * grad ||. 
      This feature is turned off by default.
    Other features:
      If you want to manually control the learning rates, self.lr_factor is
      an interface to the outside, it is an multiplier for the internal learning rate
      in YellowFin. It is helpful when you want to do additional hand tuning
      or some decaying scheme to the tuned learning rate in YellowFin. 
      Example on using lr_factor can be found here:
    self._lr = lr
    self._mu = mu
    # we convert var_list from generator to list so that
    # it can be used for multiple times
    self._var_list = list(var_list)
    self._clip_thresh = clip_thresh
    self._auto_clip_fac = auto_clip_fac
    self._beta = beta
    self._curv_win_width = curv_win_width
    self._zero_debias = zero_debias
    self._sparsity_debias = sparsity_debias
    self._force_non_inc_step = force_non_inc_step
    self._optimizer = torch.optim.SGD(self._var_list, lr=self._lr, 
      momentum=self._mu, weight_decay=weight_decay)
    self._iter = 0
    # global states are the statistics
    self._global_state = {}

    # for decaying learning rate and etc.
    self._lr_factor = 1.0

    # lr threshold
    self._lr_grad_norm_thresh = lr_grad_norm_thresh

    # smoothing options
    self._h_max_log_smooth = h_max_log_smooth
    self._h_min_log_smooth = h_min_log_smooth

    # checkpoint interval
    self._checkpoint_interval = checkpoint_interval

    if DEBUG:
      logging.debug('This message should go to the log file')

  def state_dict(self):
    # for checkpoint saving
    sgd_state_dict = self._optimizer.state_dict()
    # for recover model internally in case of numerical issue
    model_state_list = [ \
      for group in self._optimizer.param_groups for p in group['params'] ]
    global_state = self._global_state
    lr_factor = self._lr_factor
    iter = self._iter
    lr = self._lr
    mu = self._mu
    clip_thresh = self._clip_thresh
    beta = self._beta
    curv_win_width = self._curv_win_width
    zero_debias = self._zero_debias
    h_min = self._h_min
    h_max = self._h_max

    return {
      "sgd_state_dict": sgd_state_dict,
      "model_state_list": model_state_list,
      "global_state": global_state,
      "lr_factor": lr_factor,
      "iter": iter,
      "lr": lr,
      "mu": mu,
      "clip_thresh": clip_thresh,
      "beta": beta,
      "curv_win_width": curv_win_width,
      "zero_debias": zero_debias,
      "h_min": h_min,
      "h_max": h_max

  def load_state_dict(self, state_dict):
    # for checkpoint saving
    # for recover model internally if any numerical issue happens
    param_id = 0
    for group in self._optimizer.param_groups:
      for p in group["params"]: = state_dict["model_state_list"][param_id]
        param_id += 1
    self._global_state = state_dict['global_state']
    self._lr_factor = state_dict['lr_factor']
    self._iter = state_dict['iter']
    self._lr = state_dict['lr']
    self._mu = state_dict['mu']
    self._clip_thresh = state_dict['clip_thresh']
    self._beta = state_dict['beta']
    self._curv_win_width = state_dict['curv_win_width']
    self._zero_debias = state_dict['zero_debias']
    self._h_min = state_dict["h_min"]
    self._h_max = state_dict["h_max"]

  def set_lr_factor(self, factor):
    self._lr_factor = factor

  def get_lr_factor(self):
    return self._lr_factor

  def zero_grad(self):

  def zero_debias_factor(self):
    return 1.0 - self._beta ** (self._iter + 1)

  def zero_debias_factor_delay(self, delay):
    # for exponentially averaged stat which starts at non-zero iter
    return 1.0 - self._beta ** (self._iter - delay + 1)

  def curvature_range(self):
    global_state = self._global_state
    if self._iter == 0:
      global_state["curv_win"] = torch.FloatTensor(self._curv_win_width, 1).zero_()
    curv_win = global_state["curv_win"]
    grad_norm_squared = self._global_state["grad_norm_squared"]
    # curv_win[self._iter % self._curv_win_width] = np.log(grad_norm_squared + eps)
    curv_win[self._iter % self._curv_win_width] = grad_norm_squared
    valid_end = min(self._curv_win_width, self._iter + 1)
    # we use running average over log scale, accelerating 
    # h_max / min in the begining to follow the varying trend of curvature.
    beta = self._beta
    if self._iter == 0:
      global_state["h_min_avg"] = 0.0
      global_state["h_max_avg"] = 0.0
      self._h_min = 0.0
      self._h_max = 0.0
    if self._h_min_log_smooth:
      global_state["h_min_avg"] = \
        global_state["h_min_avg"] * beta + (1 - beta) * torch.min(np.log(curv_win[:valid_end] + eps) )
      global_state["h_min_avg"] = \
        global_state["h_min_avg"] * beta + (1 - beta) * torch.min(curv_win[:valid_end] )
    if self._h_max_log_smooth:
      global_state["h_max_avg"] = \
        global_state["h_max_avg"] * beta + (1 - beta) * torch.max(np.log(curv_win[:valid_end] + eps) )
      global_state["h_max_avg"] = \
        global_state["h_max_avg"] * beta + (1 - beta) * torch.max(curv_win[:valid_end] )
    if self._zero_debias:
      debias_factor = self.zero_debias_factor()
      if self._h_min_log_smooth:
        self._h_min = np.exp(global_state["h_min_avg"] / debias_factor)
        self._h_min = global_state["h_min_avg"] / debias_factor
      if self._h_max_log_smooth:
        self._h_max = np.exp(global_state["h_max_avg"] / debias_factor)
        self._h_max = global_state["h_max_avg"] / debias_factor
      if self._h_min_log_smooth:
        self._h_min = np.exp(global_state["h_min_avg"] )
        self._h_min = global_state["h_min_avg"]
      if self._h_max_log_smooth:
        self._h_max = np.exp(global_state["h_max_avg"] )
        self._h_max = global_state["h_max_avg"]
    if self._sparsity_debias:
      self._h_min *= self._sparsity_avg
      self._h_max *= self._sparsity_avg

  def grad_variance(self):
    global_state = self._global_state
    beta = self._beta
    self._grad_var = np.array(0.0, dtype=np.float32)
    for group_id, group in enumerate(self._optimizer.param_groups):
      for p_id, p in enumerate(group['params'] ):
        if p.grad is None:
        grad =
        state = self._optimizer.state[p]
        if self._iter == 0:
          state["grad_avg"] =
          state["grad_avg_squared"] = 0.0
        state["grad_avg"].mul_(beta).add_(1 - beta, grad)
        self._grad_var += torch.sum(state["grad_avg"] * state["grad_avg"] )
    if self._zero_debias:
      debias_factor = self.zero_debias_factor()
      debias_factor = 1.0

    self._grad_var /= -(debias_factor**2)
    self._grad_var += global_state['grad_norm_squared_avg'] / debias_factor
    # in case of negative variance: the two term are using different debias factors
    self._grad_var = max(self._grad_var, eps)
    if self._sparsity_debias:
      self._grad_var *= self._sparsity_avg

  def dist_to_opt(self):
    global_state = self._global_state
    beta = self._beta
    if self._iter == 0:
      global_state["grad_norm_avg"] = 0.0
      global_state["dist_to_opt_avg"] = 0.0
    global_state["grad_norm_avg"] = \
      global_state["grad_norm_avg"] * beta + (1 - beta) * math.sqrt(global_state["grad_norm_squared"] )
    global_state["dist_to_opt_avg"] = \
      global_state["dist_to_opt_avg"] * beta \
      + (1 - beta) * global_state["grad_norm_avg"] / (global_state['grad_norm_squared_avg'] + eps)
    if self._zero_debias:
      debias_factor = self.zero_debias_factor()
      self._dist_to_opt = global_state["dist_to_opt_avg"] / debias_factor
      self._dist_to_opt = global_state["dist_to_opt_avg"]
    if self._sparsity_debias:
      self._dist_to_opt /= (np.sqrt(self._sparsity_avg) + eps)

  def grad_sparsity(self):
    global_state = self._global_state
    if self._iter == 0:
      global_state["sparsity_avg"] = 0.0
    non_zero_cnt = 0.0
    all_entry_cnt = 0.0
    for group in self._optimizer.param_groups:
      for p in group['params']:
        if p.grad is None:
        grad =
        grad_non_zero = grad.nonzero()
        if grad_non_zero.dim() > 0:
          non_zero_cnt += grad_non_zero.size()[0]
        all_entry_cnt += torch.numel(grad)
    beta = self._beta
    global_state["sparsity_avg"] = beta * global_state["sparsity_avg"] \
      + (1 - beta) * non_zero_cnt / float(all_entry_cnt)
    self._sparsity_avg = \
      global_state["sparsity_avg"] / self.zero_debias_factor()
    if DEBUG:
      logging.debug("sparsity %f, sparsity avg %f", non_zero_cnt / float(all_entry_cnt), self._sparsity_avg)


  def lr_grad_norm_avg(self):
    # this is for enforcing lr * grad_norm not 
    # increasing dramatically in case of instability.
    #  Not necessary for basic use.
    global_state = self._global_state
    beta = self._beta
    if "lr_grad_norm_avg" not in global_state:
      global_state['grad_norm_squared_avg_log'] = 0.0
    global_state['grad_norm_squared_avg_log'] = \
      global_state['grad_norm_squared_avg_log'] * beta \
      + (1 - beta) * np.log(global_state['grad_norm_squared'] + eps)
    if "lr_grad_norm_avg" not in global_state:
      global_state["lr_grad_norm_avg"] = \
        0.0 * beta + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
      # we monitor the minimal smoothed ||lr * grad||
      global_state["lr_grad_norm_avg_min"] = \
        np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() )
      global_state["lr_grad_norm_avg"] = global_state["lr_grad_norm_avg"] * beta \
        + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
      global_state["lr_grad_norm_avg_min"] = \
            np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() ) )

  def before_apply(self):
    # compute running average of gradient and norm of gradient
    beta = self._beta
    global_state = self._global_state
    if self._iter == 0:
      global_state["grad_norm_squared_avg"] = 0.0

    global_state["grad_norm_squared"] = 0.0
    for group_id, group in enumerate(self._optimizer.param_groups):
      for p_id, p in enumerate(group['params'] ):
        if p.grad is None:
        grad =
        param_grad_norm_squared = torch.sum(grad * grad)
        global_state['grad_norm_squared'] += param_grad_norm_squared

        if DEBUG:
          logging.debug("Iteration  %f", self._iter) 
          logging.debug("param grad squared gid %d, pid %d, %f, %f", group_id, p_id, param_grad_norm_squared,
            np.log(param_grad_norm_squared) / np.log(10) )

    global_state['grad_norm_squared_avg'] = \
      global_state['grad_norm_squared_avg'] * beta + (1 - beta) * global_state['grad_norm_squared']
    if DEBUG:
      logging.debug("overall grad norm squared %f, %f", 
        global_state['grad_norm_squared'], np.log(global_state['grad_norm_squared'] ) / np.log(10))

    if self._sparsity_debias:


    if DEBUG:
      logging.debug("h_min %f, %f", self._h_min, np.log(self._h_min) )
      logging.debug("h_max %f, %f", self._h_max, np.log(self._h_max) )
      logging.debug("dist %f, %f", self._dist_to_opt, np.log(self._dist_to_opt) )
      logging.debug("var %f, %f", self._grad_var, np.log(self._grad_var) )

    if self._iter > 0:

      self._lr = beta * self._lr + (1 - beta) * self._lr_t
      self._mu = beta * self._mu + (1 - beta) * self._mu_t

      if DEBUG:
        logging.debug("lr_t %f", self._lr_t) 
        logging.debug("mu_t %f", self._mu_t)
        logging.debug("lr %f", self._lr)
        logging.debug("mu %f", self._mu)


  def get_lr(self):
    self._lr_t = (1.0 - math.sqrt(self._mu_t) )**2 / (self._h_min + eps)

  def get_cubic_root(self):
    # We have the equation x^2 D^2 + (1-x)^4 * C / h_min^2
    # where x = sqrt(mu).
    # We substitute x, which is sqrt(mu), with x = y + 1.
    # It gives y^3 + py = q
    # where p = (D^2 h_min^2)/(2*C) and q = -p.
    # We use the Vieta's substution to compute the root.
    # There is only one real solution y (which is in [0, 1] ).
    # eps in the numerator is to prevent momentum = 1 in case of zero gradient
    p = (self._dist_to_opt + eps)**2 * (self._h_min + eps)**2 / 2 / (self._grad_var + eps)
    w3 = (-math.sqrt(p**2 + 4.0 / 27.0 * p**3) - p) / 2.0
    w = math.copysign(1.0, w3) * math.pow(math.fabs(w3), 1.0/3.0)
    y = w - p / 3.0 / (w + eps)
    x = y + 1

    if DEBUG:
      logging.debug("p %f, den %f", p, self._grad_var + eps)
      logging.debug("w3 %f ", w3)
      logging.debug("y %f, den %f", y, w + eps)

    return x

  def get_mu(self):
    root = self.get_cubic_root()
    dr = (self._h_max + eps) / (self._h_min + eps)
    self._mu_t = max(root**2, ( (np.sqrt(dr) - 1) / (np.sqrt(dr) + 1) )**2 )

  def update_hyper_param(self):
    for group in self._optimizer.param_groups:
      group['momentum'] = self._mu
      if self._force_non_inc_step == False:
        group['lr'] = min(self._lr * self._lr_factor, 
          self._lr_grad_norm_thresh / (math.sqrt(self._global_state["grad_norm_squared"] ) + eps) )
      elif self._iter > self._curv_win_width:
        # force to guarantee lr * grad_norm not increasing dramatically. 
        # Not necessary for basic use. Please refer to the comments
        # in YFOptimizer.__init__ for more details
        debias_factor = self.zero_debias_factor()
        group['lr'] = min(self._lr * self._lr_factor,
          2.0 * self._global_state["lr_grad_norm_avg_min"] \
          / (np.sqrt(np.exp(self._global_state['grad_norm_squared_avg_log'] / debias_factor) ) + eps) )

  def auto_clip_thresh(self):
    # Heuristic to automatically prevent sudden exploding gradient
    # Not necessary for basic use.
    return math.sqrt(self._h_max) * self._auto_clip_fac

  def step(self):
    # add weight decay
    for group in self._optimizer.param_groups:
      for p in group['params']:
        if p.grad is None:
        grad =

        if group['weight_decay'] != 0:
            grad = grad.add(group['weight_decay'],
    if self._clip_thresh != None:
      torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh)
    elif (self._iter != 0 and self._auto_clip_fac != None):
      # do not clip the first iteration
      torch.nn.utils.clip_grad_norm(self._var_list, self.auto_clip_thresh() )

    if True:#try:
      # before appply

      # update learning rate and momentum

      # apply update

      # periodically save model and states
      if self._iter % self._checkpoint_interval == 1:
        self._state_checkpoint = copy.deepcopy(self.state_dict() )

      self._iter += 1
    #  # load the last checkpoint
    #  logging.warning("Numerical issue triggered restore with backup states. Resumed from last internal checkpoint.")
    #  self.load_state_dict(self._state_checkpoint)
