# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom neural network layers and primitives.
"""
import tensorflow as tf

from libml.data import DataSet


def smart_shape(x):
    s, t = x.shape, tf.shape(x)
    return [t[i] if s[i].value is None else s[i] for i in range(len(s))]


def entropy_from_logits(logits):
    """Computes entropy from classifier logits.

    Args:
        logits: a tensor of shape (batch_size, class_count) representing the
        logits of a classifier.

    Returns:
        A tensor of shape (batch_size,) of floats giving the entropies
        batchwise.
    """
    distribution = tf.contrib.distributions.Categorical(logits=logits)
    return distribution.entropy()


def entropy_penalty(logits, entropy_penalty_multiplier, mask):
    """Computes an entropy penalty using the classifier logits.

    Args:
        logits: a tensor of shape (batch_size, class_count) representing the
            logits of a classifier.
        entropy_penalty_multiplier: A float by which the entropy is multiplied.
        mask: A tensor that optionally masks out some of the costs.

    Returns:
        The mean entropy penalty
    """
    entropy = entropy_from_logits(logits)
    losses = entropy * entropy_penalty_multiplier
    losses *= tf.cast(mask, tf.float32)
    return tf.reduce_mean(losses)


def kl_divergence_from_logits(logits_a, logits_b):
    """Gets KL divergence from logits parameterizing categorical distributions.

    Args:
        logits_a: A tensor of logits parameterizing the first distribution.
        logits_b: A tensor of logits parameterizing the second distribution.

    Returns:
        The (batch_size,) shaped tensor of KL divergences.
    """
    distribution1 = tf.contrib.distributions.Categorical(logits=logits_a)
    distribution2 = tf.contrib.distributions.Categorical(logits=logits_b)
    return tf.contrib.distributions.kl_divergence(distribution1, distribution2)


def mse_from_logits(output_logits, target_logits):
    """Computes MSE between predictions associated with logits.

    Args:
        output_logits: A tensor of logits from the primary model.
        target_logits: A tensor of logits from the secondary model.

    Returns:
        The mean MSE
    """
    diffs = tf.nn.softmax(output_logits) - tf.nn.softmax(target_logits)
    squared_diffs = tf.square(diffs)
    return tf.reduce_mean(squared_diffs, -1)


def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [tf.concat(v, axis=0) for v in xy]


def renorm(v):
    return v / tf.reduce_sum(v, axis=-1, keepdims=True)


def shakeshake(a, b, training):
    if not training:
        return 0.5 * (a + b)
    mu = tf.random_uniform([tf.shape(a)[0]] + [1] * (len(a.shape) - 1), 0, 1)
    mixf = a + mu * (b - a)
    mixb = a + mu[::1] * (b - a)
    return tf.stop_gradient(mixf - mixb) + mixb


class PMovingAverage:
    def __init__(self, name, nclass, buf_size):
        self.ma = tf.Variable(tf.ones([buf_size, nclass]) / nclass, trainable=False, name=name)

    def __call__(self):
        v = tf.reduce_mean(self.ma, axis=0)
        return v / tf.reduce_sum(v)

    def update(self, entry):
        entry = tf.reduce_mean(entry, axis=0)
        return tf.assign(self.ma, tf.concat([self.ma[1:], [entry]], axis=0))


class PData:
    def __init__(self, dataset: DataSet):
        self.has_update = False
        if dataset.p_unlabeled is not None:
            self.p_data = tf.constant(dataset.p_unlabeled, name='p_data')
        elif dataset.p_labeled is not None:
            self.p_data = tf.constant(dataset.p_labeled, name='p_data')
        else:
            self.p_data = tf.Variable(renorm(tf.ones([dataset.nclass])), trainable=False, name='p_data')
            self.has_update = True

    def __call__(self):
        return self.p_data / tf.reduce_sum(self.p_data)

    def update(self, entry, decay=0.999):
        entry = tf.reduce_mean(entry, axis=0)
        return tf.assign(self.p_data, self.p_data * decay + entry * (1 - decay))


class MixMode:
    # A class for mixing data for various combination of labeled and unlabeled.
    # x = labeled example
    # y = unlabeled example
    # For example "xx.yxy" means: mix x with x, mix y with both x and y.
    MODES = 'xx.yy xxy.yxy xx.yxy xx.yx xx. .yy xxy. .yxy .'.split()

    def __init__(self, mode):
        assert mode in self.MODES
        self.mode = mode

    @staticmethod
    def augment_pair(x0, l0, x1, l1, beta, **kwargs):
        del kwargs
        mix = tf.distributions.Beta(beta, beta).sample([tf.shape(x0)[0], 1, 1, 1])
        mix = tf.maximum(mix, 1 - mix)
        index = tf.random_shuffle(tf.range(tf.shape(x0)[0]))
        xs = tf.gather(x1, index)
        ls = tf.gather(l1, index)
        xmix = x0 * mix + xs * (1 - mix)
        lmix = l0 * mix[:, :, 0, 0] + ls * (1 - mix[:, :, 0, 0])
        return xmix, lmix

    @staticmethod
    def augment(x, l, beta, **kwargs):
        return MixMode.augment_pair(x, l, x, l, beta, **kwargs)

    def __call__(self, xl: list, ll: list, betal: list):
        assert len(xl) == len(ll) >= 2
        assert len(betal) == 2
        if self.mode == '.':
            return xl, ll
        elif self.mode == 'xx.':
            mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
            return [mx0] + xl[1:], [ml0] + ll[1:]
        elif self.mode == '.yy':
            mx1, ml1 = self.augment(
                tf.concat(xl[1:], 0), tf.concat(ll[1:], 0), betal[1])
            return (xl[:1] + tf.split(mx1, len(xl) - 1),
                    ll[:1] + tf.split(ml1, len(ll) - 1))
        elif self.mode == 'xx.yy':
            mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
            mx1, ml1 = self.augment(
                tf.concat(xl[1:], 0), tf.concat(ll[1:], 0), betal[1])
            return ([mx0] + tf.split(mx1, len(xl) - 1),
                    [ml0] + tf.split(ml1, len(ll) - 1))
        elif self.mode == 'xxy.':
            mx, ml = self.augment(
                tf.concat(xl, 0), tf.concat(ll, 0),
                sum(betal) / len(betal))
            return (tf.split(mx, len(xl))[:1] + xl[1:],
                    tf.split(ml, len(ll))[:1] + ll[1:])
        elif self.mode == '.yxy':
            mx, ml = self.augment(
                tf.concat(xl, 0), tf.concat(ll, 0),
                sum(betal) / len(betal))
            return (xl[:1] + tf.split(mx, len(xl))[1:],
                    ll[:1] + tf.split(ml, len(ll))[1:])
        elif self.mode == 'xxy.yxy':
            mx, ml = self.augment(
                tf.concat(xl, 0), tf.concat(ll, 0),
                sum(betal) / len(betal))
            return tf.split(mx, len(xl)), tf.split(ml, len(ll))
        elif self.mode == 'xx.yxy':
            mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
            mx1, ml1 = self.augment(tf.concat(xl, 0), tf.concat(ll, 0), betal[1])
            mx1, ml1 = [tf.split(m, len(xl))[1:] for m in (mx1, ml1)]
            return [mx0] + mx1, [ml0] + ml1
        elif self.mode == 'xx.yx':
            mx0, ml0 = self.augment(xl[0], ll[0], betal[0])
            mx1, ml1 = zip(*[
                self.augment_pair(xl[i], ll[i], xl[0], ll[0], betal[1])
                for i in range(1, len(xl))
            ])
            return [mx0] + list(mx1), [ml0] + list(ml1)
        raise NotImplementedError(self.mode)