#!/usr/bin/python
# -*- coding: utf-8 -*-

"""A Variational Autoencoders for MNIST.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from keras.layers import Input, Dense, Lambda, Conv2D, Conv2DTranspose, \
    Flatten, Reshape
from keras.models import Model
from keras import backend as K
from keras.datasets import mnist
from keras import metrics
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

EPOCH = 5
INPUT_DIM = 784
BATCH_SIZE = 64
HIDDEN_VAR_DIM = 7 * 7 * 32
LATENT_VAR_DIM = 2

# input image dimensions

(img_rows, img_cols, img_chns) = (28, 28, 1)

if K.image_data_format() == 'channels_first':
    original_img_size = (img_chns, img_rows, img_cols)
    output_shape = (BATCH_SIZE, 32, 7, 7)
else:
    original_img_size = (img_rows, img_cols, img_chns)
    output_shape = (BATCH_SIZE, 7, 7, 32)


def sampling(args):
    (z_mean, z_var) = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0],
                              LATENT_VAR_DIM), mean=0., stddev=1.)
    return z_mean + z_var * epsilon


def encode(x):
    input_reshape = Reshape(original_img_size)(x)
    conv1 = Conv2D(16, 5, strides=(2, 2), padding='same',
                   activation='relu')(input_reshape)
    conv2 = Conv2D(32, 5, strides=(2, 2), padding='same',
                   activation='relu')(conv1)
    hidden = Flatten()(conv2)
    z_mean = Dense(LATENT_VAR_DIM, activation='relu')(hidden)
    z_var = Dense(LATENT_VAR_DIM, activation='relu')(hidden)
    return (z_mean, z_var)


def decode(z):
    hidden = Dense(HIDDEN_VAR_DIM, activation='relu')(z)
    hidden_reshape = Reshape(output_shape[1:])(hidden)
    deconv1 = Conv2DTranspose(16, 5, strides=(2, 2), padding='same',
                              activation='relu')(hidden_reshape)
    deconv2 = Conv2DTranspose(1, 5, strides=(2, 2), padding='same',
                              activation='sigmoid')(deconv1)
    return Flatten()(deconv2)


def main(_):
    x = Input(shape=(INPUT_DIM, ))
    (z_mean, z_var) = encode(x)
    z = Lambda(sampling)([z_mean, z_var])
    x_decoded = decode(z)
    model = Model(inputs=x, outputs=x_decoded)

    def vae_loss(y_true, y_pred):
        generation_loss = img_rows * img_cols \
            * metrics.binary_crossentropy(x, x_decoded)
        kl_loss = 0.5 * tf.reduce_sum(K.square(z_mean)
                + K.square(z_var) - K.log(K.square(z_var + 1e-8)) - 1,
                axis=1)
        return tf.reduce_mean(generation_loss + kl_loss)

    model.compile(optimizer='rmsprop', loss=vae_loss)

    # train the VAE on MNIST digits

    ((x_train, y_train), (x_test, y_test)) = mnist.load_data()

    x_train = x_train.astype('float32') / 255.
    x_test = x_test.astype('float32') / 255.
    x_train = x_train.reshape((len(x_train),
                              np.prod(x_train.shape[1:])))
    x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

    print(model.summary())

    model.fit(
        x_train,
        y_train,
        shuffle=True,
        epochs=EPOCH,
        batch_size=BATCH_SIZE,
        validation_data=(x_test, y_test),
        )

    generator = K.function([model.layers[8].input],
                           [model.layers[12].output])

    # display a 2D manifold of the digits

    n = 15  # figure with 15x15 digits
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))

    # linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian
    # to produce values of the latent variables z, since the prior of the latent space is Gaussian

    grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
    grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

    for (i, yi) in enumerate(grid_x):
        for (j, xi) in enumerate(grid_y):
            z_sample = np.array([[xi, yi]])
            z_sample = np.tile(z_sample,
                               BATCH_SIZE).reshape(BATCH_SIZE, 2)
            x_decoded = generator([z_sample])[0]
            digit = x_decoded[0].reshape(digit_size, digit_size)

            figure[i * digit_size:(i + 1) * digit_size, j * digit_size:
                   (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    plt.imshow(figure, cmap='Greys_r')
    plt.show()


if __name__ == '__main__':
    tf.app.run(main=main)