from __future__ import absolute_import, division, print_function, unicode_literals

import numpy
import keras
from keras import backend as K

from keras.layers.merge import Concatenate
from keras.layers import Lambda
from keras.layers.core import Reshape

def split(start, stop):
    return Lambda(lambda x: x[:, start:stop], output_shape=(None, stop-start))

def split_mixture_of_gaussians(x, n_components):
    pi = split(0, n_components)(x)
    mu = split(n_components, 2*n_components)(x)
    log_sig = split(2*n_components, 3*n_components)(x)
    return pi, mu, log_sig

def log_norm_pdf(x, mu, log_sig):
    z = (x - mu) / (K.exp(K.clip(log_sig, -40, 40))) #TODO: get rid of this clipping
    return -(0.5)*K.log(2*numpy.pi) - log_sig - 0.5*((z)**2)

def mix_gaussian_loss(x, mu, log_sig, w):
    '''
    Combine the mixture of gaussian distribution and the loss into a single function
    so that we can do the log sum exp trick for numerical stability...
    '''
    if K.backend() == "tensorflow":
        x.set_shape([None, 1])
    gauss = log_norm_pdf(K.repeat_elements(x=x, rep=mu.shape[1], axis=1), mu, log_sig)
    # TODO: get rid of clipping.
    gauss = K.clip(gauss, -40, 40)
    max_gauss = K.maximum((0.), K.max(gauss))
    # log sum exp trick...
    gauss = gauss - max_gauss
    out = K.sum(w * K.exp(gauss), axis=1)
    loss = K.mean(-K.log(out) + max_gauss)
    return loss

def mixture_of_gaussian_output(x, n_components):
    mu = keras.layers.Dense(n_components, activation='linear')(x)
    log_sig = keras.layers.Dense(n_components, activation='linear')(x)
    pi = keras.layers.Dense(n_components, activation='softmax')(x)
    return Concatenate(axis=1)([pi, mu, log_sig])

def mixture_of_gaussian_loss(y_true, y_pred, n_components):
    pi, mu, log_sig = split_mixture_of_gaussians(y_pred, n_components)
    return mix_gaussian_loss(y_true, mu, log_sig, pi)

def mixture_gaussian(n_components):
    '''
    Build a mixture of gaussian output and loss function that may be used in a keras graph.
    '''

    def output(x):
        return mixture_of_gaussian_output(x, n_components)

    def keras_loss(y, x):
        return mixture_of_gaussian_loss(y, x, n_components)
    return output, keras_loss