"""Memory module for Kanerva Machines.

  Functions of the module always take inputs with shape:
  [seq_length, batch_size, ...]

  Examples:

    # Initialisation
    memory = KanervaMemory(code_size=100, memory_size=32)
    prior_memory = memory.get_prior_state(batch_size)

    # Update memory posterior
    posterior_memory, _, _, _ = memory.update_state(z_episode, prior_memory)

    # Read from the memory using cues z_q
    read_z, dkl_w = memory.read_with_z(z_q, posterior_memory)

    # Compute the KL-divergence between posterior and prior memory
    dkl_M = memory.get_dkl_total(posterior_memory)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import sonnet as snt
import tensorflow as tf
import tensorflow_probability as tfp

MemoryState = collections.namedtuple(
    'MemoryState',
    # Mean of memory slots, [batch_size, memory_size, word_size]
    # Covariance of memory slots, [batch_size, memory_size, memory_size]
    ('M_mean', 'M_cov'))

EPSILON = 1e-6


# disable lint warnings for cleaner algebraic expressions
# pylint: disable=invalid-name
class KanervaMemory(snt.AbstractModule):
  """A memory-based generative model."""

  def __init__(self,
               code_size,
               memory_size,
               num_opt_iters=1,
               w_prior_stddev=1.0,
               obs_noise_stddev=1.0,
               sample_w=False,
               sample_M=False,
               name='KanervaMemory'):
    """Initialise the memory module.

    Args:
      code_size: Integer specifying the size of each encoded input.
      memory_size: Integer specifying the total number of rows in the memory.
      num_opt_iters: Integer specifying the number of optimisation iterations.
      w_prior_stddev: Float specifying the standard deviation of w's prior.
      obs_noise_stddev: Float specifying the standard deviation of the
        observational noise.
      sample_w: Boolean specifying whether to sample w or simply take its mean.
      sample_M: Boolean specifying whether to sample M or simply take its mean.
      name: String specfying the name of this module.
    """
    super(KanervaMemory, self).__init__(name=name)
    self._memory_size = memory_size
    self._code_size = code_size
    self._num_opt_iters = num_opt_iters
    self._sample_w = sample_w
    self._sample_M = sample_M
    self._w_prior_stddev = tf.constant(w_prior_stddev)

    with self._enter_variable_scope():
      log_w_stddev = snt.TrainableVariable(
          [], name='w_stddev',
          initializers={'w': tf.constant_initializer(np.log(0.3))})()
      if obs_noise_stddev > 0.0:
        self._obs_noise_stddev = tf.constant(obs_noise_stddev)
      else:
        log_obs_stddev = snt.TrainableVariable(
            [], name='obs_stdddev',
            initializers={'w': tf.constant_initializer(np.log(1.0))})()
        self._obs_noise_stddev = tf.exp(log_obs_stddev)
    self._w_stddev = tf.exp(log_w_stddev)
    self._w_prior_dist = tfp.distributions.MultivariateNormalDiag(
        loc=tf.zeros([self._memory_size]),
        scale_identity_multiplier=self._w_prior_stddev)

  def _build(self):
    raise ValueError('`_build()` should not be called for this module since'
                     'it takes no inputs and all of its variables are'
                     'constructed in `__init__`')

  def _get_w_dist(self, mu_w):
    return tfp.distributions.MultivariateNormalDiag(
        loc=mu_w, scale_identity_multiplier=self._w_stddev)

  def sample_prior_w(self, seq_length, batch_size):
    """Sample w from its prior.

    Args:
      seq_length: length of sequence
      batch_size: batch size of samples
    Returns:
      w: [batch_size, memory_size]
    """
    return self._w_prior_dist.sample([seq_length, batch_size])

  def read_with_z(self, z, memory_state):
    """Query from memory (specified by memory_state) using embedding z.

    Args:
      z: Tensor with dimensions [episode_length, batch_size, code_size]
        containing an embedded input.
      memory_state: Instance of `MemoryState`.

    Returns:
      A tuple of tensors containing the mean of read embedding and the
        KL-divergence between the w used in reading and its prior.
    """
    M = self.sample_M(memory_state)
    w_mean = self._solve_w_mean(z, M)
    w_samples = self.sample_w(w_mean)
    dkl_w = self.get_dkl_w(w_mean)
    z_mean = self.get_w_to_z_mean(w_samples, M)
    return z_mean, dkl_w

  def wrap_z_dist(self, z_mean):
    """Wrap the mean of z as an observation (Gaussian) distribution."""
    return tfp.distributions.MultivariateNormalDiag(
        loc=z_mean, scale_identity_multiplier=self._obs_noise_stddev)

  def sample_w(self, w_mean):
    """Sample w from its posterior distribution."""
    if self._sample_w:
      return self._get_w_dist(w_mean).sample()
    else:
      return w_mean

  def sample_M(self, memory_state):
    """Sample the memory from its distribution specified by memory_state."""
    if self._sample_M:
      noise_dist = tfp.distributions.MultivariateNormalFullCovariance(
          covariance_matrix=memory_state.M_cov)
      # C, B, M
      noise = tf.transpose(noise_dist.sample(self._code_size),
                           [1, 2, 0])
      return memory_state.M_mean + noise
    else:
      return memory_state.M_mean

  def get_w_to_z_mean(self, w_p, R):
    """Return the mean of z by reading from memory using weights w_p."""
    return tf.einsum('sbm,bmc->sbc', w_p, R)  # Rw

  def _read_cov(self, w_samples, memory_state):
    episode_size, batch_size = w_samples.get_shape().as_list()[:2]
    _, U = memory_state  # cov: [B, M, M]
    wU = tf.einsum('sbm,bmn->sbn', w_samples, U)
    wUw = tf.einsum('sbm,sbm->sb', wU, w_samples)
    wUw.get_shape().assert_is_compatible_with([episode_size, batch_size])
    return wU, wUw

  def get_dkl_total(self, memory_state):
    """Compute the KL-divergence between a memory distribution and its prior."""
    R, U = memory_state
    B, K, _ = R.get_shape().as_list()
    U.get_shape().assert_is_compatible_with([B, K, K])
    R_prior, U_prior = self.get_prior_state(B)
    p_diag = tf.matrix_diag_part(U_prior)
    q_diag = tf.matrix_diag_part(U)  # B, K
    t1 = self._code_size * tf.reduce_sum(q_diag / p_diag, -1)
    t2 = tf.reduce_sum((R - R_prior)**2 / tf.expand_dims(
        p_diag, -1), [-2, -1])
    t3 = -self._code_size * self._memory_size
    t4 = self._code_size * tf.reduce_sum(tf.log(p_diag) - tf.log(q_diag), -1)
    return t1 + t2 + t3 + t4

  def _get_dkl_update(self, memory_state, w_samples, new_z_mean, new_z_var):
    """Compute memory_kl after updating prior_state."""
    B, K, C = memory_state.M_mean.get_shape().as_list()
    S = w_samples.get_shape().as_list()[0]

    # check shapes
    w_samples.get_shape().assert_is_compatible_with([S, B, K])
    new_z_mean.get_shape().assert_is_compatible_with([S, B, C])

    delta = new_z_mean - self.get_w_to_z_mean(w_samples, memory_state.M_mean)
    _, wUw = self._read_cov(w_samples, memory_state)
    var_z = wUw + new_z_var + self._obs_noise_stddev**2
    beta = wUw / var_z

    dkl_M = -0.5 * (self._code_size * beta
                    - tf.reduce_sum(tf.expand_dims(beta / var_z, -1)
                                    * delta**2, -1)
                    + self._code_size * tf.log(1 - beta))
    dkl_M.get_shape().assert_is_compatible_with([S, B])
    return dkl_M

  @snt.reuse_variables
  def _get_prior_params(self):
    log_var = snt.TrainableVariable(
        [], name='prior_var_scale',
        initializers={'w': tf.constant_initializer(
            np.log(1.0))})()
    self._prior_var = tf.ones([self._memory_size]) * tf.exp(log_var) + EPSILON
    prior_cov = tf.matrix_diag(self._prior_var)
    prior_mean = snt.TrainableVariable(
        [self._memory_size, self._code_size],
        name='prior_mean',
        initializers={'w': tf.truncated_normal_initializer(
            mean=0.0, stddev=1.0)})()
    return prior_mean, prior_cov

  @property
  def prior_avg_var(self):
    """return the average of prior memory variance."""
    return tf.reduce_mean(self._prior_var)

  def _solve_w_mean(self, new_z_mean, M):
    """Minimise the conditional KL-divergence between z wrt w."""
    w_matrix = tf.matmul(M, M, transpose_b=True)
    w_rhs = tf.einsum('bmc,sbc->bms', M, new_z_mean)
    w_mean = tf.matrix_solve_ls(
        matrix=w_matrix, rhs=w_rhs,
        l2_regularizer=self._obs_noise_stddev**2 / self._w_prior_stddev**2)
    w_mean = tf.einsum('bms->sbm', w_mean)
    return w_mean

  def get_prior_state(self, batch_size):
    """Return the prior distribution of memory as a MemoryState."""
    prior_mean, prior_cov = self._get_prior_params()
    batch_prior_mean = tf.stack([prior_mean] * batch_size)
    batch_prior_cov = tf.stack([prior_cov] * batch_size)
    return MemoryState(M_mean=batch_prior_mean,
                       M_cov=batch_prior_cov)

  def update_state(self, z, memory_state):
    """Update the memory state using Bayes' rule.

    Args:
      z: A tensor with dimensions [episode_length, batch_size, code_size]
        containing a sequence of embeddings to write into memory.
      memory_state: A `MemoryState` namedtuple containing the memory state to
        be written to.

    Returns:
      A tuple containing the following elements:
      final_memory: A `MemoryState` namedtuple containing the new memory state
        after the update.
      w_mean_episode: The mean of w for the written episode.
      dkl_w_episode: The KL-divergence of w for the written episode.
      dkl_M_episode: The KL-divergence between the memory states before and
        after the update.
    """

    episode_size, batch_size = z.get_shape().as_list()[:2]
    w_array = tf.TensorArray(dtype=tf.float32, size=episode_size,
                             element_shape=[1, batch_size, self._memory_size])
    dkl_w_array = tf.TensorArray(dtype=tf.float32, size=episode_size,
                                 element_shape=[1, batch_size])
    dkl_M_array = tf.TensorArray(dtype=tf.float32, size=episode_size,
                                 element_shape=[1, batch_size])
    init_var = (0, memory_state, w_array, dkl_w_array, dkl_M_array)
    cond = lambda i, m, d_2, d_3, d_4: i < episode_size
    def loop_body(i, old_memory, w_array, dkl_w_array, dkl_M_array):
      """Update memory step-by-step."""
      z_step = tf.expand_dims(z[i], 0)
      new_memory = old_memory
      for _ in xrange(self._num_opt_iters):
        w_step_mean = self._solve_w_mean(z_step, self.sample_M(new_memory))
        w_step_sample = self.sample_w(w_step_mean)
        new_memory = self._update_memory(old_memory,
                                         w_step_mean,
                                         z_step, 0)
      dkl_w_step = self.get_dkl_w(w_step_mean)
      dkl_M_step = self._get_dkl_update(old_memory,
                                        w_step_sample,
                                        z_step, 0)
      return (i+1,
              new_memory,
              w_array.write(i, w_step_sample),
              dkl_w_array.write(i, dkl_w_step),
              dkl_M_array.write(i, dkl_M_step))

    _, final_memory, w_mean, dkl_w, dkl_M = tf.while_loop(
        cond, loop_body, init_var)
    w_mean_episode = w_mean.concat()
    dkl_w_episode = dkl_w.concat()
    dkl_M_episode = dkl_M.concat()
    dkl_M_episode.get_shape().assert_is_compatible_with(
        [episode_size, batch_size])

    return final_memory, w_mean_episode, dkl_w_episode, dkl_M_episode

  def _update_memory(self, old_memory, w_samples, new_z_mean, new_z_var):
    """Setting new_z_var=0 for sample based update."""
    old_mean, old_cov = old_memory
    wR = self.get_w_to_z_mean(w_samples, old_memory.M_mean)
    wU, wUw = self._read_cov(w_samples, old_memory)
    sigma_z = wUw + new_z_var + self._obs_noise_stddev**2  # [S, B]
    delta = new_z_mean - wR  # [S, B, C]
    c_z = wU / tf.expand_dims(sigma_z, -1)  # [S, B, M]
    posterior_mean = old_mean + tf.einsum('sbm,sbc->bmc', c_z, delta)
    posterior_cov = old_cov - tf.einsum('sbm,sbn->bmn', c_z, wU)
    # Clip diagonal elements for numerical stability
    posterior_cov = tf.matrix_set_diag(
        posterior_cov,
        tf.clip_by_value(tf.matrix_diag_part(posterior_cov), EPSILON, 1e10))
    new_memory = MemoryState(M_mean=posterior_mean, M_cov=posterior_cov)
    return new_memory

  def get_dkl_w(self, w_mean):
    """Return the KL-divergence between posterior and prior weights w."""
    posterior_dist = self._get_w_dist(w_mean)
    dkl_w = posterior_dist.kl_divergence(self._w_prior_dist)
    dkl_w.get_shape().assert_is_compatible_with(
        w_mean.get_shape().as_list()[:-1])
    return dkl_w