# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Neural Network Gaussian Process (nngp) kernel computation.

Implementaion based on
"Deep Neural Networks as Gaussian Processes" by
Jaehoon Lee, Yasaman Bahri, Roman Novak, Samuel S. Schoenholz,
Jeffrey Pennington, Jascha Sohl-Dickstein
arXiv:1711.00165 (https://arxiv.org/abs/1711.00165).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import multiprocessing
import os

import numpy as np
import tensorflow as tf

import interp

flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_boolean("use_precomputed_grid", True,
                     "Option to save/load pre-computed grid")
flags.DEFINE_integer(
    "fraction_of_int32", 32,
    "allow batches at most of size int32.max / fraction_of_int32")


class NNGPKernel(object):
  """The iterative covariance Kernel for Neural Network Gaussian Process.

  Args:
    depth: int, number of hidden layers in corresponding NN.
    nonlin_fn: tf ops corresponding to point-wise non-linearity in corresponding
      NN. e.g.) tf.nn.relu, tf.nn.sigmoid, lambda x: x * tf.nn.sigmoid(x), ...
    weight_var: initial value for the weight_variances parameter.
    bias_var: initial value for the bias_variance parameter.
    n_gauss: Number of gaussian integration grid. Choose odd integer, so that
      there is a gridpoint at 0.
    n_var: Number of variance grid points.
    n_corr: Number of correlation grid points.
    use_fixed_point_norm: bool, normalize input to variance fixed point.
      Defaults to False, normalizing input to unit norm over input dimension.
  """

  def __init__(self,
               depth=1,
               nonlin_fn=tf.tanh,
               weight_var=1.,
               bias_var=1.,
               n_gauss=101,
               n_var=151,
               n_corr=131,
               max_var=100,
               max_gauss=100,
               use_fixed_point_norm=False,
               grid_path=None,
               sess=None):
    self.depth = depth
    self.weight_var = weight_var
    self.bias_var = bias_var
    self.use_fixed_point_norm = use_fixed_point_norm
    self.sess = sess
    if FLAGS.use_precomputed_grid and (grid_path is None):
      raise ValueError("grid_path must be specified to use precomputed grid.")
    self.grid_path = grid_path

    self.nonlin_fn = nonlin_fn
    (self.var_aa_grid, self.corr_ab_grid, self.qaa_grid,
     self.qab_grid) = self.get_grid(n_gauss, n_var, n_corr, max_var, max_gauss)

    if self.use_fixed_point_norm:
      self.var_fixed_point_np, self.var_fixed_point = self.get_var_fixed_point()

  def get_grid(self, n_gauss, n_var, n_corr, max_var, max_gauss):
    """Get covariance grid by loading or computing a new one.
    """
    # File configuration for precomputed grid
    if FLAGS.use_precomputed_grid:
      grid_path = self.grid_path
      # TODO(jaehlee) np.save have broadcasting error when n_var==n_corr.
      if n_var == n_corr:
        n_var += 1
      grid_file_name = "grid_{0:s}_ng{1:d}_ns{2:d}_nc{3:d}".format(
          self.nonlin_fn.__name__, n_gauss, n_var, n_corr)
      grid_file_name += "_mv{0:d}_mg{1:d}".format(max_var, max_gauss)

    # Load grid file if it exists already
    if (FLAGS.use_precomputed_grid and
        tf.gfile.Exists(os.path.join(grid_path, grid_file_name))):
      with tf.gfile.Open(os.path.join(grid_path, grid_file_name), "rb") as f:
        grid_data_np = np.load(f)
        tf.logging.info("Loaded interpolation grid from %s"%
                        os.path.join(grid_path, grid_file_name))
        grid_data = (tf.convert_to_tensor(grid_data_np[0], dtype=tf.float64),
                     tf.convert_to_tensor(grid_data_np[1], dtype=tf.float64),
                     tf.convert_to_tensor(grid_data_np[2], dtype=tf.float64),
                     tf.convert_to_tensor(grid_data_np[3], dtype=tf.float64))

    else:
      tf.logging.info("Generating interpolation grid...")
      grid_data = _compute_qmap_grid(self.nonlin_fn, n_gauss, n_var, n_corr,
                                     max_var=max_var, max_gauss=max_gauss)
      if FLAGS.use_precomputed_grid:
        with tf.Session() as sess:
          grid_data_np = sess.run(grid_data)
        tf.gfile.MakeDirs(grid_path)
        with tf.gfile.Open(os.path.join(grid_path, grid_file_name), "wb") as f:
          np.save(f, grid_data_np)

        with tf.gfile.Open(os.path.join(grid_path, grid_file_name), "rb") as f:
          grid_data_np = np.load(f)
          tf.logging.info("Loaded interpolation grid from %s"%
                          os.path.join(grid_path, grid_file_name))
          grid_data = (tf.convert_to_tensor(grid_data_np[0], dtype=tf.float64),
                       tf.convert_to_tensor(grid_data_np[1], dtype=tf.float64),
                       tf.convert_to_tensor(grid_data_np[2], dtype=tf.float64),
                       tf.convert_to_tensor(grid_data_np[3], dtype=tf.float64))

    return grid_data

  def get_var_fixed_point(self):
    with tf.name_scope("get_var_fixed_point"):
      # If normalized input length starts at 1.
      current_qaa = self.weight_var * tf.constant(
          [1.], dtype=tf.float64) + self.bias_var

      diff = 1.
      prev_qaa_np = 1.
      it = 0
      while diff > 1e-6 and it < 300:
        samp_qaa = interp.interp_lin(
            self.var_aa_grid, self.qaa_grid, current_qaa)
        samp_qaa = self.weight_var * samp_qaa + self.bias_var
        current_qaa = samp_qaa

        with tf.Session() as sess:
          current_qaa_np = sess.run(current_qaa)
        diff = np.abs(current_qaa_np - prev_qaa_np)
        it += 1
        prev_qaa_np = current_qaa_np
      return current_qaa_np, current_qaa

  def k_diag(self, input_x, return_full=True):
    """Iteratively building the diagonal part (variance) of the NNGP kernel.

    Args:
      input_x: tensor of input of size [num_data, input_dim].
      return_full: boolean for output to be [num_data] sized or a scalar value
        for normalized inputs

    Sets self.layer_qaa_dict of {layer #: qaa at the layer}

    Returns:
      qaa: variance at the output.
    """
    with tf.name_scope("Kdiag"):
      # If normalized input length starts at 1.
      if self.use_fixed_point_norm:
        current_qaa = self.var_fixed_point
      else:
        current_qaa = self.weight_var * tf.convert_to_tensor(
            [1.], dtype=tf.float64) + self.bias_var
      self.layer_qaa_dict = {0: current_qaa}
      for l in xrange(self.depth):
        with tf.name_scope("layer_%d" % l):
          samp_qaa = interp.interp_lin(
              self.var_aa_grid, self.qaa_grid, current_qaa)
          samp_qaa = self.weight_var * samp_qaa + self.bias_var
          self.layer_qaa_dict[l + 1] = samp_qaa
          current_qaa = samp_qaa

      if return_full:
        qaa = tf.tile(current_qaa[:1], ([input_x.shape[0].value]))
      else:
        qaa = current_qaa[0]
      return qaa

  def k_full(self, input1, input2=None):
    """Iteratively building the full NNGP kernel.
    """
    input1 = self._input_layer_normalization(input1)
    if input2 is None:
      input2 = input1
    else:
      input2 = self._input_layer_normalization(input2)

    with tf.name_scope("k_full"):
      cov_init = tf.matmul(
          input1, input2, transpose_b=True) / input1.shape[1].value

      self.k_diag(input1)
      q_aa_init = self.layer_qaa_dict[0]

      q_ab = cov_init
      q_ab = self.weight_var * q_ab + self.bias_var
      corr = q_ab / q_aa_init[0]

      if FLAGS.fraction_of_int32 > 1:
        batch_size, batch_count = self._get_batch_size_and_count(input1, input2)
        with tf.name_scope("q_ab"):
          q_ab_all = []
          for b_x in range(batch_count):
            with tf.name_scope("batch_%d" % b_x):
              corr_flat_batch = corr[
                  batch_size * b_x : batch_size * (b_x + 1), :]
              corr_flat_batch = tf.reshape(corr_flat_batch, [-1])

              for l in xrange(self.depth):
                with tf.name_scope("layer_%d" % l):
                  q_aa = self.layer_qaa_dict[l]
                  q_ab = interp.interp_lin_2d(x=self.var_aa_grid,
                                              y=self.corr_ab_grid,
                                              z=self.qab_grid,
                                              xp=q_aa,
                                              yp=corr_flat_batch)

                  q_ab = self.weight_var * q_ab + self.bias_var
                  corr_flat_batch = q_ab / self.layer_qaa_dict[l + 1][0]

              q_ab_all.append(q_ab)

          q_ab_all = tf.parallel_stack(q_ab_all)
      else:
        with tf.name_scope("q_ab"):
          corr_flat = tf.reshape(corr, [-1])
          for l in xrange(self.depth):
            with tf.name_scope("layer_%d" % l):
              q_aa = self.layer_qaa_dict[l]
              q_ab = interp.interp_lin_2d(x=self.var_aa_grid,
                                          y=self.corr_ab_grid,
                                          z=self.qab_grid,
                                          xp=q_aa,
                                          yp=corr_flat)
              q_ab = self.weight_var * q_ab + self.bias_var
              corr_flat = q_ab / self.layer_qaa_dict[l+1][0]
            q_ab_all = q_ab

    return tf.reshape(q_ab_all, cov_init.shape, "qab")

  def _input_layer_normalization(self, x):
    """Input normalization to unit variance or fixed point variance.
    """
    with tf.name_scope("input_layer_normalization"):
      # Layer norm, fix to unit variance
      eps = 1e-15
      mean, var = tf.nn.moments(x, axes=[1], keep_dims=True)
      x_normalized = (x - mean) / tf.sqrt(var + eps)
      if self.use_fixed_point_norm:
        x_normalized *= tf.sqrt(
            (self.var_fixed_point[0] - self.bias_var) / self.weight_var)
      return x_normalized

  def _get_batch_size_and_count(self, input1, input2):
    """Compute batch size and number to split when input size is large.

    Args:
      input1: tensor, input tensor to covariance matrix
      input2: tensor, second input tensor to covariance matrix

    Returns:
      batch_size: int, size of each batch
      batch_count: int, number of batches
    """
    input1_size = input1.shape[0].value
    input2_size = input2.shape[0].value

    batch_size = min(np.iinfo(np.int32).max //
                     (FLAGS.fraction_of_int32 * input2_size), input1_size)
    while input1_size % batch_size != 0:
      batch_size -= 1

    batch_count = input1_size // batch_size
    return batch_size, batch_count


def _fill_qab_slice(idx, z1, z2, var_aa, corr_ab, nonlin_fn):
  """Helper method used for parallel computation for full qab."""
  log_weights_ab_unnorm = -(z1**2 + z2**2 - 2 * z1 * z2 * corr_ab) / (
      2 * var_aa[idx] * (1 - corr_ab**2))
  log_weights_ab = log_weights_ab_unnorm - tf.reduce_logsumexp(
      log_weights_ab_unnorm, axis=[0, 1], keep_dims=True)
  weights_ab = tf.exp(log_weights_ab)

  qab_slice = tf.reduce_sum(
      nonlin_fn(z1) * nonlin_fn(z2) * weights_ab, axis=[0, 1])
  qab_slice = tf.Print(qab_slice, [idx], "Generating slice: ")
  return qab_slice


def _compute_qmap_grid(nonlin_fn,
                       n_gauss,
                       n_var,
                       n_corr,
                       log_spacing=False,
                       min_var=1e-8,
                       max_var=100.,
                       max_corr=0.99999,
                       max_gauss=10.):
  """Construct graph for covariance grid to use for kernel computation.

  Given variance and correlation (or covariance) of pre-activation, perform
  Gaussian integration to get covariance of post-activation.

  Raises:
    ValueError: if n_gauss is even integer.

  Args:
    nonlin_fn: tf ops corresponding to point-wise non-linearity in
      corresponding NN. e.g.) tf.nn.relu, tf.nn.sigmoid,
      lambda x: x * tf.nn.sigmoid(x), ...
    n_gauss: int, number of Gaussian integration points with equal spacing
      between (-max_gauss, max_gauss). Choose odd integer, so that there is a
      gridpoint at 0.
    n_var: int, number of variance grid points.get_grid
    n_corr: int, number of correlation grid points.
    log_spacing: bool, whether to use log-linear instead of linear variance
      grid.
    min_var: float, smallest variance value to generate grid.
    max_var: float, largest varaince value to generate grid.
    max_corr: float, largest correlation value to generate grid. Should be
      slightly smaller than 1.
    max_gauss: float, range (-max_gauss, max_gauss) for Gaussian integration.

  Returns:
    var_grid_pts: tensor of size [n_var], grid points where variance are
      evaluated at.
    corr_grid_pts: tensor of size [n_corr], grid points where correlation are
      evalutated at.
    qaa: tensor of size [n_var], variance of post-activation at given
      pre-activation variance.
    qab: tensor of size [n_var, n_corr], covariance of post-activation at
      given pre-activation variance and correlation.
  """
  if n_gauss % 2 != 1:
    raise ValueError("n_gauss=%d should be an odd integer" % n_gauss)

  with tf.name_scope("compute_qmap_grid"):
    min_var = tf.convert_to_tensor(min_var, dtype=tf.float64)
    max_var = tf.convert_to_tensor(max_var, dtype=tf.float64)
    max_corr = tf.convert_to_tensor(max_corr, dtype=tf.float64)
    max_gauss = tf.convert_to_tensor(max_gauss, dtype=tf.float64)

    # Evaluation points for numerical integration over a Gaussian.
    z1 = tf.reshape(tf.linspace(-max_gauss, max_gauss, n_gauss), (-1, 1, 1))
    z2 = tf.transpose(z1, perm=[1, 0, 2])

    if log_spacing:
      var_aa = tf.exp(tf.linspace(tf.log(min_var), tf.log(max_var), n_var))
    else:
      # Evaluation points for pre-activations variance and correlation
      var_aa = tf.linspace(min_var, max_var, n_var)
    corr_ab = tf.reshape(tf.linspace(-max_corr, max_corr, n_corr), (1, 1, -1))

    # compute q_aa
    log_weights_aa_unnorm = -0.5 * (z1**2 / tf.reshape(var_aa, [1, 1, -1]))
    log_weights_aa = log_weights_aa_unnorm - tf.reduce_logsumexp(
        log_weights_aa_unnorm, axis=[0, 1], keep_dims=True)
    weights_aa = tf.exp(log_weights_aa)
    qaa = tf.reduce_sum(nonlin_fn(z1)**2 * weights_aa, axis=[0, 1])

    # compute q_ab
    # weights to reweight uniform samples by, for q_ab.
    # (weights are probability of z1, z2 under Gaussian
    #  w/ variance var_aa and covariance var_aa*corr_ab)
    # weights_ab will have shape [n_g, n_g, n_v, n_c]
    def fill_qab_slice(idx):
      return _fill_qab_slice(idx, z1, z2, var_aa, corr_ab, nonlin_fn)

    qab = tf.map_fn(
        fill_qab_slice,
        tf.range(n_var),
        dtype=tf.float64,
        parallel_iterations=multiprocessing.cpu_count())

    var_grid_pts = tf.reshape(var_aa, [-1])
    corr_grid_pts = tf.reshape(corr_ab, [-1])

    return var_grid_pts, corr_grid_pts, qaa, qab