#!/usr/bin/env python3

import numpy as np

from keras import backend as K
from keras.models import Sequential, Model
from keras.layers import Input, Conv2D, BatchNormalization, Dense, Conv2DTranspose, Flatten, Reshape, \
    Lambda, LeakyReLU, Activation
from keras.regularizers import l2

from .losses import mean_gaussian_negative_log_likelihood


def create_models(n_channels=3, recon_depth=9, wdecay=1e-5, bn_mom=0.9, bn_eps=1e-6):

    image_shape = (64, 64, n_channels)
    n_encoder = 1024
    n_discriminator = 512
    latent_dim = 128
    decode_from_shape = (8, 8, 256)
    n_decoder = np.prod(decode_from_shape)
    leaky_relu_alpha = 0.2

    def conv_block(x, filters, leaky=True, transpose=False, name=''):
        conv = Conv2DTranspose if transpose else Conv2D
        activation = LeakyReLU(leaky_relu_alpha) if leaky else Activation('relu')
        layers = [
            conv(filters, 5, strides=2, padding='same', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name=name + 'conv'),
            BatchNormalization(momentum=bn_mom, epsilon=bn_eps, name=name + 'bn'),
            activation
        ]
        if x is None:
            return layers
        for layer in layers:
            x = layer(x)
        return x

    # Encoder
    def create_encoder():
        x = Input(shape=image_shape, name='enc_input')

        y = conv_block(x, 64, name='enc_blk_1_')
        y = conv_block(y, 128, name='enc_blk_2_')
        y = conv_block(y, 256, name='enc_blk_3_')
        y = Flatten()(y)
        y = Dense(n_encoder, kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='enc_h_dense')(y)
        y = BatchNormalization(name='enc_h_bn')(y)
        y = LeakyReLU(leaky_relu_alpha)(y)

        z_mean = Dense(latent_dim, name='z_mean', kernel_initializer='he_uniform')(y)
        z_log_var = Dense(latent_dim, name='z_log_var', kernel_initializer='he_uniform')(y)

        return Model(x, [z_mean, z_log_var], name='encoder')

    # Decoder
    decoder = Sequential([
        Dense(n_decoder, kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', input_shape=(latent_dim,), name='dec_h_dense'),
        BatchNormalization(name='dec_h_bn'),
        LeakyReLU(leaky_relu_alpha),
        Reshape(decode_from_shape),
        *conv_block(None, 256, transpose=True, name='dec_blk_1_'),
        *conv_block(None, 128, transpose=True, name='dec_blk_2_'),
        *conv_block(None, 32, transpose=True, name='dec_blk_3_'),
        Conv2D(n_channels, 5, activation='tanh', padding='same', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='dec_output')
    ], name='decoder')

    # Discriminator
    def create_discriminator():
        x = Input(shape=image_shape, name='dis_input')

        layers = [
            Conv2D(32, 5, padding='same', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='dis_blk_1_conv'),
            LeakyReLU(leaky_relu_alpha),
            *conv_block(None, 128, leaky=True, name='dis_blk_2_'),
            *conv_block(None, 256, leaky=True, name='dis_blk_3_'),
            *conv_block(None, 256, leaky=True, name='dis_blk_4_'),
            Flatten(),
            Dense(n_discriminator, kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='dis_dense'),
            BatchNormalization(name='dis_bn'),
            LeakyReLU(leaky_relu_alpha),
            Dense(1, activation='sigmoid', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='dis_output')
        ]

        y = x
        y_feat = None
        for i, layer in enumerate(layers, 1):
            y = layer(y)
            # Output the features at the specified depth
            if i == recon_depth:
                y_feat = y

        return Model(x, [y, y_feat], name='discriminator')

    encoder = create_encoder()
    discriminator = create_discriminator()

    return encoder, decoder, discriminator


def _sampling(args):
    """Reparameterization trick by sampling fr an isotropic unit Gaussian.
       Instead of sampling from Q(z|X), sample eps = N(0,I)

    # Arguments:
        args (tensor): mean and log of variance of Q(z|X)
    # Returns:
        z (tensor): sampled latent vector
    """
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon


def build_graph(encoder, decoder, discriminator, recon_vs_gan_weight=1e-6):
    image_shape = K.int_shape(encoder.input)[1:]
    latent_shape = K.int_shape(decoder.input)[1:]

    sampler = Lambda(_sampling, output_shape=latent_shape, name='sampler')

    # Inputs
    x = Input(shape=image_shape, name='input_image')
    # z_p is sampled directly from isotropic gaussian
    z_p = Input(shape=latent_shape, name='z_p')

    # Build computational graph

    z_mean, z_log_var = encoder(x)
    z = sampler([z_mean, z_log_var])

    x_tilde = decoder(z)
    x_p = decoder(z_p)

    dis_x, dis_feat = discriminator(x)
    dis_x_tilde, dis_feat_tilde = discriminator(x_tilde)
    dis_x_p = discriminator(x_p)[0]

    # Compute losses

    # Learned similarity metric
    dis_nll_loss = mean_gaussian_negative_log_likelihood(dis_feat, dis_feat_tilde)

    # KL divergence loss
    kl_loss = K.mean(-0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1))

    # Create models for training
    encoder_train = Model(x, dis_feat_tilde, name='e')
    encoder_train.add_loss(kl_loss)
    encoder_train.add_loss(dis_nll_loss)

    decoder_train = Model([x, z_p], [dis_x_tilde, dis_x_p], name='de')
    normalized_weight = recon_vs_gan_weight / (1. - recon_vs_gan_weight)
    decoder_train.add_loss(normalized_weight * dis_nll_loss)

    discriminator_train = Model([x, z_p], [dis_x, dis_x_tilde, dis_x_p], name='di')

    # Additional models for testing
    vae = Model(x, x_tilde, name='vae')
    vaegan = Model(x, dis_x_tilde, name='vaegan')

    return encoder_train, decoder_train, discriminator_train, vae, vaegan