import keras.backend as K
from keras.engine.topology import Layer
from keras.layers.convolutional import Conv1D
from keras import initializers
from keras import regularizers
from keras import constraints
import tensorflow as tf
import numpy as np


################################################################################
# Quadratic-time MMD with Gaussian RBF

def _mix_rbf_kernel(X, Y, sigmas=[1.], wts=None):
    if wts is None:
        wts = [1] * len(sigmas)

    XX = tf.matmul(X, X, transpose_b=True)
    XY = tf.matmul(X, Y, transpose_b=True)
    YY = tf.matmul(Y, Y, transpose_b=True)

    X_sqnorms = tf.diag_part(XX)
    Y_sqnorms = tf.diag_part(YY)

    r = lambda x: tf.expand_dims(x, 0)
    c = lambda x: tf.expand_dims(x, 1)

    K_XX, K_XY, K_YY = 0, 0, 0
    for sigma, wt in zip(sigmas, wts):
        gamma = 1 / (2 * sigma**2)
        K_XX += wt * tf.exp(-gamma * (-2 * XX + c(X_sqnorms) + r(X_sqnorms)))
        K_XY += wt * tf.exp(-gamma * (-2 * XY + c(X_sqnorms) + r(Y_sqnorms)))
        K_YY += wt * tf.exp(-gamma * (-2 * YY + c(Y_sqnorms) + r(Y_sqnorms)))

    return K_XX, K_XY, K_YY, tf.reduce_sum(wts)


def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
    m = tf.cast(tf.shape(K_XX)[0], tf.float32)
    n = tf.cast(tf.shape(K_YY)[0], tf.float32)

    
    if biased:
        mmd2 = (tf.reduce_sum(K_XX, keep_dims=True) / (m * m)
              + tf.reduce_sum(K_YY, keep_dims=True) / (n * n)
              - 2 * tf.reduce_sum(K_XY, keep_dims=True) / (m * n))
    else:
        if const_diagonal is not False:
            trace_X = m * const_diagonal
            trace_Y = n * const_diagonal
        else:
            trace_X = tf.trace(K_XX)
            trace_Y = tf.trace(K_YY)

        mmd2 = ((tf.reduce_sum(K_XX) - trace_X) / (m * (m - 1))
              + (tf.reduce_sum(K_YY) - trace_Y) / (n * (n - 1))
              - 2 * tf.reduce_sum(K_XY) / (m * n))

    return mmd2


def mix_rbf_mmd2(X, Y, sigmas=[1.], wts=None, biased=True):
    K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigmas, wts)
    return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)


def rbf_mmd2(X, Y, sigma=1., biased=True):
    return mix_rbf_mmd2(X, Y, sigmas=[sigma], biased=biased)


################################################################################


################################################################################
# Customized layers

class Max_over_time(Layer):
    def __init__(self, **kwargs):
        self.supports_masking = True
        super(Max_over_time, self).__init__(**kwargs)

    def call(self, x, mask=None):
        if mask is not None:
            mask = K.cast(mask, K.floatx())
            mask = K.expand_dims(mask)
            x = x * mask
        return K.max(x, axis=1)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[2])
    
    def compute_mask(self, x, mask):
        return None

class KL_loss(Layer):
    def __init__(self, batch_size, **kwargs):
        super(KL_loss, self).__init__(**kwargs)
        self.batch_size = batch_size


    def call(self, x, mask=None):
        a = x[0]
        b = x[1]

        a = K.mean(a, axis=0, keepdims=True)
        b = K.mean(b, axis=0, keepdims=True)
        a /= K.sum(a, keepdims=True)
        b /= K.sum(b, keepdims=True)

        a = K.clip(a, K.epsilon(), 1)
        b = K.clip(b, K.epsilon(), 1)

        loss = K.sum(a*K.log(a/b), axis=-1, keepdims=True) \
            + K.sum(b*K.log(b/a), axis=-1, keepdims=True)

        loss = K.repeat_elements(loss, self.batch_size, axis=0)
        
        return loss

    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], 1)

    def compute_mask(self, x, mask):
        return None


class mmd_loss(Layer):
    def __init__(self, batch_size, **kwargs):
        super(mmd_loss, self).__init__(**kwargs)
        self.batch_size = batch_size


    def call(self, x, mask=None):
        a = x[0]
        b = x[1]

        mmd = rbf_mmd2(a, b)
        mmd = K.repeat_elements(mmd, self.batch_size, axis=0)

        return mmd

    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], 1)

    def compute_mask(self, x, mask):
        return None



class Ensemble_pred_loss(Layer):
    def __init__(self, **kwargs):
        super(Ensemble_pred_loss, self).__init__(**kwargs)

    def call(self, x, mask=None):
        pred = x[0]
        target = x[1]
        weight = x[2]

        error = K.categorical_crossentropy(target, pred)
        loss = error * weight

        return loss

    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], 1)

    def compute_mask(self, x, mask):
        return None



class Conv1DWithMasking(Conv1D):
    def __init__(self, **kwargs):
        self.supports_masking = True
        super(Conv1DWithMasking, self).__init__(**kwargs)
    
    def compute_mask(self, x, mask):
        return mask