from __future__ import absolute_import

import keras
import keras.backend as K

from .. import similarities


class PermanentDropout(keras.layers.Dropout):
    """Applies dropout to the input that isn't turned off during testing.

    This is one possible improvement for your generator models.

    Args:
        p: float between 0 and 1. Fraction of the input units to drop.
    """

    def call(self, x, mask=None):
        if 0. < self.p < 1.:
            noise_shape = self._get_noise_shape(x)
            x = K.dropout(x, self.p, noise_shape)
        return x


class BatchSimilarity(keras.layers.Layer):
    """Calculates intrabatch similarity, for minibatch discrimination.

    The minibatch similarities can be added as features for the existing
    layer by using a Merge layer. The layer only works for inputs with shape
    (batch_size, num_features). Inputs with more dimensions can be flattened.

    In order to make this layer linear time with respect to the batch size,
    instead of doing a pairwise comparison between each pair of samples in
    the batch, for each sample a random sample is uniformly selected with
    which to do pairwise comparison.

    Args:
        similarity: str, the similarity type. See gandlf.similarities for a
            possible types. Alternatively, it can be a function which takes
            two tensors as inputs and returns their similarity. A list or
            tuple of similarities will apply all the similarities.
        n: int or list of ints (one for each similarity), number of times to
            repeat each similarity, using a different sample to calculate the
            other similarity.

    Reference: "Improved Techniques for Training GANs"
        https://arxiv.org/abs/1606.03498
    """

    def __init__(self, similarity='exp_l1', n=1, **kwargs):
        if not isinstance(similarity, (list, tuple)):
            similarity = [similarity]
        if not isinstance(n, (list, tuple)):
            n = [n for _ in similarity]

        self.similarities = [similarities.get(s) for s in similarity]
        self.n = n

        super(BatchSimilarity, self).__init__(**kwargs)

    def build(self, input_shape):
        if len(input_shape) != 2:
            raise ValueError('The input to a BatchSimilarity layer must be '
                             '2D. Got %d dims.' % len(input_shape))

    def call(self, x, mask=None):
        sims = []
        for n, sim in zip(self.n, self.similarities):
            for _ in range(n):
                batch_size = K.shape(x)[0]
                idx = K.random_uniform((batch_size,), low=0, high=batch_size,
                                       dtype='int32')
                x_shuffled = K.gather(x, idx)
                pair_sim = sim(x, x_shuffled)
                for _ in range(K.ndim(x) - 1):
                    pair_sim = K.expand_dims(pair_sim, dim=1)
                sims.append(pair_sim)

        return K.concatenate(sims, axis=-1)

    def get_output_shape_for(self, input_shape):
        if len(input_shape) != 2:
            raise ValueError('The input to a BatchSimilarity layer must be '
                             '2D. Got %d dims.' % len(input_shape))
        output_shape = list(input_shape)
        output_shape[-1] = sum(self.n)
        return tuple(output_shape)

    def get_config(self):
        config = {'similarity': [s.__name__ for s in self.similarities]}
        base_config = super(BatchSimilarity, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))