import math
import numpy as np

from keras.layers import Input, LSTM, Dense
from keras.models import Model
from keras import backend as K
from keras.callbacks import EarlyStopping

Z_DIM = 32
ACTION_DIM = 3

HIDDEN_UNITS = 256
GAUSSIAN_MIXTURES = 5

BATCH_SIZE =32
EPOCHS = 20

def get_mixture_coef(y_pred):
    
    d = GAUSSIAN_MIXTURES * Z_DIM
    
    rollout_length = K.shape(y_pred)[1]
    
    pi = y_pred[:,:,:d]
    mu = y_pred[:,:,d:(2*d)]
    log_sigma = y_pred[:,:,(2*d):(3*d)]
    #discrete = y_pred[:,3*GAUSSIAN_MIXTURES:]
    
    pi = K.reshape(pi, [-1, rollout_length, GAUSSIAN_MIXTURES, Z_DIM])
    mu = K.reshape(mu, [-1, rollout_length, GAUSSIAN_MIXTURES, Z_DIM])
    log_sigma = K.reshape(log_sigma, [-1, rollout_length, GAUSSIAN_MIXTURES, Z_DIM])

    pi = K.exp(pi) / K.sum(K.exp(pi), axis=2, keepdims=True)
    sigma = K.exp(log_sigma)
    
    return pi, mu, sigma#, discrete


def tf_normal(y_true, mu, sigma, pi):

    rollout_length = K.shape(y_true)[1]
    y_true = K.tile(y_true,(1,1,GAUSSIAN_MIXTURES))
    y_true = K.reshape(y_true, [-1, rollout_length, GAUSSIAN_MIXTURES,Z_DIM])

    oneDivSqrtTwoPI = 1 / math.sqrt(2*math.pi)
    result = y_true - mu
#   result = K.permute_dimensions(result, [2,1,0])
    result = result * (1 / (sigma + 1e-8))
    result = -K.square(result)/2
    result = K.exp(result) * (1/(sigma + 1e-8))*oneDivSqrtTwoPI
    result = result * pi
    result = K.sum(result, axis=2) #### sum over gaussians
    #result = K.prod(result, axis=2) #### multiply over latent dims
    return result


class RNN():
    def __init__(self):
        self.models = self._build()
        self.model = self.models[0]
        self.forward = self.models[1]
        self.z_dim = Z_DIM
        self.action_dim = ACTION_DIM
        self.hidden_units = HIDDEN_UNITS
        self.gaussian_mixtures = GAUSSIAN_MIXTURES

    def _build(self):

        #### THE MODEL THAT WILL BE TRAINED
        rnn_x = Input(shape=(None, Z_DIM + ACTION_DIM))
        lstm = LSTM(HIDDEN_UNITS, return_sequences=True, return_state = True)

        lstm_output, _ , _ = lstm(rnn_x)
        mdn = Dense(GAUSSIAN_MIXTURES * (3*Z_DIM))(lstm_output) #+ discrete_dim

        rnn = Model(rnn_x, mdn)

        #### THE MODEL USED DURING PREDICTION
        state_input_h = Input(shape=(HIDDEN_UNITS,))
        state_input_c = Input(shape=(HIDDEN_UNITS,))
        state_inputs = [state_input_h, state_input_c]
        _ , state_h, state_c = lstm(rnn_x, initial_state = [state_input_h, state_input_c])

        forward = Model([rnn_x] + state_inputs, [state_h, state_c])

        #### LOSS FUNCTION

        def rnn_r_loss(y_true, y_pred):

            pi, mu, sigma = get_mixture_coef(y_pred)
        
            result = tf_normal(y_true, mu, sigma, pi)
            
            result = -K.log(result + 1e-8)
            result = K.mean(result, axis = (1,2)) # mean over rollout length and z dim

            return result
    
        def rnn_kl_loss(y_true, y_pred):
            pi, mu, sigma = get_mixture_coef(y_pred)
            kl_loss = - 0.5 * K.mean(1 + K.log(K.square(sigma)) - K.square(mu) - K.square(sigma), axis = [1,2,3])
            return kl_loss

        def rnn_loss(y_true, y_pred):
            return rnn_r_loss(y_true, y_pred) #+ rnn_kl_loss(y_true, y_pred)


        rnn.compile(loss=rnn_loss, optimizer='rmsprop', metrics = [rnn_r_loss, rnn_kl_loss])

        return (rnn,forward)

    def set_weights(self, filepath):
        self.model.load_weights(filepath)

    def train(self, rnn_input, rnn_output, validation_split = 0.2):

        earlystop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5, verbose=1, mode='auto')
        callbacks_list = [earlystop]

        self.model.fit(rnn_input, rnn_output,
            shuffle=True,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            validation_split=validation_split,
            callbacks=callbacks_list)

        self.model.save_weights('./rnn/weights.h5')

    def save_weights(self, filepath):
        self.model.save_weights(filepath)