# Copyright 2017 The Nader Akoury. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

""" Module containing samplers of the latent space """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.contrib.distributions as distributions
import tensorflow.contrib.framework as framework
import tensorflow.contrib.layers as layers

import glas.model.attention as attentions
import glas.model.rnn as rnn
import glas.utils.graph as graph_utils
from glas.utils.ops import hellinger


SAMPLE_TYPES = ['basic', 'hellinger', 'chisq', 'estimated', 'uniform']


class BasicSampler(rnn.RNN):
    """ The basic latent sampler which uses the reparamaterization trick with a normal distribution
    and calculates latent loss as KL divergence from the standard normal distribution. """
    def __init__(self, config, attention, latent_space, scope='BasicSampler'):
        """ Initialize the sampler """
        super(BasicSampler, self).__init__(scope=scope)

        self.posteriors = []
        self.samples = config.samples
        self.sample_size = config.sample_size

        self.attention = attention
        self.latent_space = latent_space

        shape = (config.batch_size, config.sample_size)
        self.prior = distributions.Normal(tf.zeros(shape), tf.ones(shape), name='prior')

    def compute_moments(self, distribution_or_tensor):
        """ Update the moving averages of the moments based on the passed in tensor """
        if isinstance(distribution_or_tensor, tf.Tensor):
            axes = list(range(distribution_or_tensor.get_shape().ndims - 1))
            return tf.nn.moments(distribution_or_tensor, axes)
        elif isinstance(distribution_or_tensor, distributions.Distribution):
            return distribution_or_tensor.mean(), distribution_or_tensor.variance()
        else:
            raise ValueError('Can only sample a tf.Tensor or distributions.Distribution')

    def approximate_posterior(self, tensor, scope='posterior'):
        """ Calculate the approximate posterior given the tensor """
        # Generate mu and sigma of the Gaussian for the approximate posterior
        with tf.variable_scope(scope, 'posterior', [tensor]):
            mean = layers.linear(tensor, self.sample_size, scope='mean')

            # Use the log of sigma for numerical stability
            log_sigma = layers.linear(tensor, self.sample_size, scope='log_sigma')

            # Create the Gaussian distribution
            sigma = tf.exp(log_sigma)
            posterior = distributions.Normal(mean, sigma, name='posterior')

            self.collect_named_outputs(posterior.loc)
            self.collect_named_outputs(posterior.scale)
            self.posteriors.append(posterior)

            return posterior

    def calculate_latent_loss(self, latent_weights):
        """ Calculate the latent loss in the form of KL divergence """
        for posterior in self.posteriors:
            # NOTE: set allow_nan=True to prevent a CPU-only Assert operation
            kl_divergence = distributions.kl(posterior, self.prior)
            kl_divergence = tf.reduce_sum(latent_weights * kl_divergence, 1, name='kl_divergence')
            tf.losses.add_loss(tf.reduce_mean(kl_divergence, 0, name='kl_divergence/avg'))

    @framework.add_arg_scope
    @rnn.RNN.step_fn
    def random_sample(self, outputs_collections=None):  # pylint: disable=unused-argument
        """ Sample the prior """
        return self.sample(self.prior), None

    def sample(self, distribution_or_tensor, reuse=None):
        """ Sample the passed in distribution or tensor """
        reuse = True if reuse or self.step > 0 else None
        with tf.variable_scope(self.variable_scope, reuse=reuse):
            if isinstance(distribution_or_tensor, tf.Tensor):
                return distribution_or_tensor
            elif isinstance(distribution_or_tensor, distributions.Distribution):
                return tf.reduce_mean(distribution_or_tensor.sample(self.samples), 0)
            else:
                raise ValueError('Can only sample a tf.Tensor or distributions.Distribution')

    def attend(self, tensor):
        """ Use attention over the latent space """
        if self.attention is not None and not isinstance(self.attention, attentions.NoAttention):
            focus = self.attention.read(self.latent_space, tensor)
            tf.add_to_collection(graph_utils.GraphKeys.RNN_OUTPUTS, focus)

            return focus

        return tensor

    @framework.add_arg_scope
    @rnn.RNN.step_fn
    def __call__(self, tensor, outputs_collections=None):
        """ Execute the next time step of the cell """
        focus = self.attend(tensor)
        posterior = self.approximate_posterior(focus)

        sample = self.sample(posterior)
        return sample, None

    next = __call__


class HellingerSampler(BasicSampler):
    """ Latent sampler that uses the Hellinger distance squared rather than the KL divergence. """
    def __init__(self, config, attention, latent_space, scope='HellingerSampler'):
        """ Initialize the sampler """
        super(HellingerSampler, self).__init__(
            config, attention, latent_space, scope=scope)

    def calculate_latent_loss(self, latent_weights):
        """ Calculate the latent loss in the form of KL divergence """
        for posterior in self.posteriors:
            hellinger_distance = latent_weights * hellinger(posterior, self.prior)
            hellinger_distance = tf.reduce_sum(hellinger_distance, 1, name='hellinger')
            tf.losses.add_loss(tf.reduce_mean(hellinger_distance, 0, name='hellinger/avg'))


class MinChiSquaredDistribution(distributions.Distribution):
    """ The minimum chi squared distribution given a mean.

    The prior is assumed to be a standard normal distribution. The probability density function 'f'
    of the minimum chi squared distribution from a prior 'g' given an arithmetic mean is:
        f(x) = g(x) * ((m2_g - m1_f * m1_g) + x * (m1_f - m1_g)) / sigma_g^2

    For more info see section 3.1 from http://web.unbc.ca/~kumarp/d4.pdf """
    def __init__(self,
                 mean,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='MinChiSquaredDistribution'):
        """Construct the mininimum chi squared distribution from the mean. """
        parameters = locals()
        parameters.pop('self')

        with tf.name_scope(name, values=[mean]) as name_scope:
            self._avg = tf.identity(mean, name='mean')
            super(MinChiSquaredDistribution, self).__init__(
                dtype=self._avg.dtype,
                reparameterization_type=distributions.FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=[self._avg],
                name=name_scope)

    @staticmethod
    def _param_shapes(sample_shape):
        return dict(zip(('mean'), ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)])))

    def _batch_shape_tensor(self):
        return tf.shape(self._avg)

    def _batch_shape(self):
        return self._avg.get_shape()

    def _event_shape_tensor(self):
        return tf.constant([], dtype=tf.int32)

    def _event_shape(self):
        return tf.TensorShape([])

    def _sample_n(self, n, seed=None):
        """ Sample the minimum chi squared distribution using reparameterization """
        shape = tf.concat(([n], tf.shape(self.mean())), 0)
        sampled = tf.random_normal(shape=shape, mean=0, stddev=1, dtype=self._avg.dtype, seed=seed)
        return sampled * (1.0 + sampled * self._avg)

    def _log_prob(self, x):
        raise NotImplementedError('Not implemented yet.')

    def _log_cdf(self, x):
        raise NotImplementedError('Not implemented yet.')

    def _cdf(self, x):
        raise NotImplementedError('Not implemented yet.')

    def _log_survival_function(self, x):
        raise NotImplementedError('Not implemented yet.')

    def _survival_function(self, x):
        raise NotImplementedError('Not implemented yet.')

    def _entropy(self):
        raise NotImplementedError('Not implemented yet.')

    def _mean(self):
        return self._avg

    def _variance(self):
        # The variance of the minimum chi squared divergence probability distribution is given
        # by the following (NOTE: mt_f is the t-th moment of the distribution f):
        #   sigma_f^2 = ((m2_g-m1_f*m1_g)*m2_g+(m1_f-m1_g)*m3_g-mu_f^2*sigma_g^2)/sigma_g^2
        #
        # When using the standard normal distribution as the prior 'g' note that:
        #   m1_g = 0, m2_g = 1, m3_g = 0
        #
        # So this becomes:
        #    sigma_f^2 = 1 - mu_f^2
        return 1.0 - tf.square(self._avg)

    def _std(self):
        return tf.sqrt(self._variance())

    def _mode(self):
        return self._avg


class ChiSquaredSampler(BasicSampler):
    """ Latent sampler that uses attention.

    Minimize the chi squared divergence using the first moment of the probability distribution. """
    def __init__(self, config, attention, latent_space, scope='ChiSquaredSampler'):
        """ Initialize the sampler """
        super(ChiSquaredSampler, self).__init__(
            config, attention, latent_space, scope=scope)

        shape = (config.batch_size, self.sample_size)
        self.prior = distributions.Normal(tf.zeros(shape), tf.ones(shape), name='prior')

    def approximate_posterior(self, tensor, scope='posterior'):
        """ Calculate the approximate posterior given the tensor """
        # Generate the minimum chi squared divergence distribution 'f' from the prior 'g'
        with tf.variable_scope(scope, 'posterior', [tensor]):
            mean = layers.linear(tensor, self.sample_size, scope='mean')

            # Create the Gaussian distribution
            posterior = MinChiSquaredDistribution(mean, name='posterior')

            self.collect_named_outputs(posterior.mean())
            self.collect_named_outputs(posterior.variance())
            self.posteriors.append(posterior)

            return posterior

    def calculate_latent_loss(self, latent_weights):
        """ Calculate the latent loss in the form of KL divergence """
        for posterior in self.posteriors:
            # Minimize the chi squared divergence of the posterior 'f' from the prior 'g' (a
            # standard normal distribution), this amounts to minimizing the square of the difference
            # of the first moment of f from the first moment of g divided by the squared variance of
            # g (NOTE: mt_f is the t-th moment of the distribution f):
            #    min(chisq) = (m1_f - m1_g)^2 / sigma_g^2
            #
            # The idea behind using the chi squared divergence is that it is an upper bound for the
            # Kullback-Leibler divergence. The following inequality holds:
            #    KL(f||g) <= log(1 + Chi^2(f||g))
            #
            # So minimize this bound rather than the chi squared divergence directly
            mean, _ = self.compute_moments(posterior)

            axes = tf.range(1, tf.rank(mean))
            chisq = tf.log1p(tf.square(mean - self.prior.mean()) / self.prior.variance())
            chisq = tf.reduce_sum(latent_weights * chisq, axes)
            tf.losses.add_loss(tf.reduce_mean(chisq, name='chisq'))


class EstimatedSampler(ChiSquaredSampler):
    """ Latent sampler that uses attention.

    Estimates the first two moments of the input tensor then uses of the probability integral
    transform assuming the incoming tensor is drawn from a normal distribution F(x), it then
    transforms to a uniform distribution G(x). """
    def __init__(self, config, attention, latent_space, scope='EstimatedSampler'):
        """ Initialize the sampler """
        super(EstimatedSampler, self).__init__(
            config, attention, latent_space, scope=scope)

        shape = (config.batch_size,) + attention.read_size(latent_space)
        self.prior = distributions.Uniform(tf.zeros(shape), tf.ones(shape), name='prior')

    def approximate_posterior(self, tensor, scope='posterior'):
        """ Calculate the approximate posterior given the tensor """
        # Assume the incoming random variable 'X' is drawn from a normal distribution and use the
        # probability integral transform to transform 'X' into 'Y' which is drawn from a standard
        # uniform distribution.
        mean, variance = self.compute_moments(tensor)
        normal = distributions.Normal(mean, tf.sqrt(variance))
        posterior = normal.cdf(tensor)

        self.collect_named_outputs(posterior)
        self.posteriors.append(posterior)

        return posterior


class UniformSampler(EstimatedSampler):
    """ Latent sampler that uses attention.

    Reparameterize the incoming distribution as a uniform distribution specified by with mean and
    variance. """
    def __init__(self, config, attention, latent_space, scope='UniformSampler'):
        """ Initialize the sampler """
        super(UniformSampler, self).__init__(
            config, attention, latent_space, scope=scope)

        shape = (config.batch_size, self.sample_size)
        self.prior = distributions.Uniform(tf.zeros(shape), tf.ones(shape), name='prior')

    def approximate_posterior(self, tensor, scope='posterior'):
        """ Calculate the approximate posterior given the tensor """
        # Generate mu and sigma of the Gaussian for the approximate posterior
        sample_size = self.prior.batch_shape.as_list()[-1]
        with tf.variable_scope(scope, 'posterior', [tensor]):
            mean = layers.linear(tensor, sample_size, scope='mean')

            # Use the log of sigma for numerical stability
            log_variance = layers.linear(tensor, sample_size, scope='log_variance')

            # Create the Uniform distribution
            variance = tf.exp(log_variance)
            delta = tf.sqrt(3.0 * variance)
            posterior = distributions.Uniform(mean - delta, mean + delta, name='posterior')

            self.collect_named_outputs(posterior.low)
            self.collect_named_outputs(posterior.high)
            self.posteriors.append(posterior)

            return posterior


def create_latent_space(batch_size, shape, steps=None):
    """ Create the latent space """
    # Setup the latent space. The latent space is a 2-D tensor used for each element in the batch
    # with dimensions [batch_size, latent_size, latent_size]. If steps are provided then there is a
    # latent space per step with dimensions [step, batch_size, latent_size, latent_size].
    latent_shape = shape
    if steps is not None:
        latent_shape = (steps,) + latent_shape

    latent_space = framework.model_variable(
        'LatentSpace', shape=latent_shape, trainable=True,
        initializer=tf.random_uniform_initializer(0.0, 1e-3))
    latent_space = tf.tile(latent_space, (batch_size,) + (1,) * (len(latent_shape) - 1))
    latent_space = tf.reshape(latent_space, (batch_size,) + latent_shape)

    if steps is not None:
        permutation = (1, 0) + tuple(x + 2 for x in range(len(shape)))
        latent_space = tf.transpose(latent_space, permutation)

    return latent_space


def create_sampler(config):
    """ Create the appropriate sampler based on the passed in config """
    if config.sample_type == 'basic':
        sample_type = BasicSampler
    elif config.sample_type == 'hellinger':
        sample_type = HellingerSampler
    elif config.sample_type == 'chisq':
        sample_type = ChiSquaredSampler
    elif config.sample_type == 'estimated':
        sample_type = EstimatedSampler
    elif config.sample_type == 'uniform':
        sample_type = UniformSampler

    latent_size = (config.latent_size, config.latent_size)
    latent_space = create_latent_space(config.batch_size, latent_size)
    attention = attentions.create_attention(
        config.sample_attention_type, latent_size, read_size=config.latent_read_size)

    return sample_type(config, attention, latent_space)