from __future__ import print_function

import argparse
import torch

import numpy as np

from torch.autograd.variable import Variable
from torch.optim import Adam

from multi_categorical_gans.datasets.dataset import Dataset
from multi_categorical_gans.datasets.formats import data_formats, loaders

from multi_categorical_gans.methods.general.autoencoder import AutoEncoder
from multi_categorical_gans.methods.general.generator import Generator
from multi_categorical_gans.methods.general.discriminator import Discriminator
from multi_categorical_gans.methods.general.wgan_gp import calculate_gradient_penalty

from multi_categorical_gans.utils.categorical import load_variable_sizes_from_metadata, categorical_variable_loss
from multi_categorical_gans.utils.commandline import DelayedKeyboardInterrupt, parse_int_list
from multi_categorical_gans.utils.cuda import to_cuda_if_available, to_cpu_if_available
from multi_categorical_gans.utils.initialization import load_or_initialize
from multi_categorical_gans.utils.logger import Logger


def add_noise_to_code(code, noise_radius):
    if noise_radius > 0:
        means = torch.zeros_like(code)
        gauss_noise = torch.normal(means, noise_radius)
        return code + to_cuda_if_available(Variable(gauss_noise))
    else:
        return code


def train(autoencoder,
          generator,
          discriminator,
          train_data,
          val_data,
          output_ae_path,
          output_gen_path,
          output_disc_path,
          output_loss_path,
          batch_size=1000,
          start_epoch=0,
          num_epochs=1000,
          num_ae_steps=1,
          num_disc_steps=2,
          num_gen_steps=1,
          noise_size=128,
          l2_regularization=0.001,
          learning_rate=0.001,
          ae_noise_radius=0.2,
          ae_noise_anneal=0.995,
          normalize_code=True,
          variable_sizes=None,
          temperature=None,
          penalty=0.1
          ):
    autoencoder, generator, discriminator = to_cuda_if_available(autoencoder, generator, discriminator)

    optim_ae = Adam(autoencoder.parameters(), weight_decay=l2_regularization, lr=learning_rate)
    optim_gen = Adam(generator.parameters(), weight_decay=l2_regularization, lr=learning_rate)
    optim_disc = Adam(discriminator.parameters(), weight_decay=l2_regularization, lr=learning_rate)

    logger = Logger(output_loss_path, append=start_epoch > 0)

    for epoch_index in range(start_epoch, num_epochs):
        logger.start_timer()

        # train
        autoencoder.train(mode=True)
        generator.train(mode=True)
        discriminator.train(mode=True)

        ae_losses = []
        disc_losses = []
        gen_losses = []

        more_batches = True
        train_data_iterator = train_data.batch_iterator(batch_size)

        while more_batches:
            # train autoencoder
            for _ in range(num_ae_steps):
                try:
                    batch = next(train_data_iterator)
                except StopIteration:
                    more_batches = False
                    break

                autoencoder.zero_grad()

                batch_original = Variable(torch.from_numpy(batch))
                batch_original = to_cuda_if_available(batch_original)
                batch_code = autoencoder.encode(batch_original, normalize_code=normalize_code)
                batch_code = add_noise_to_code(batch_code, ae_noise_radius)

                batch_reconstructed = autoencoder.decode(batch_code, training=True, temperature=temperature)

                ae_loss = categorical_variable_loss(batch_reconstructed, batch_original, variable_sizes)
                ae_loss.backward()

                optim_ae.step()

                ae_loss = to_cpu_if_available(ae_loss)
                ae_losses.append(ae_loss.data.numpy())

            # train discriminator
            for _ in range(num_disc_steps):
                try:
                    batch = next(train_data_iterator)
                except StopIteration:
                    more_batches = False
                    break

                discriminator.zero_grad()
                autoencoder.zero_grad()

                # first train the discriminator only with real data
                real_features = Variable(torch.from_numpy(batch))
                real_features = to_cuda_if_available(real_features)
                real_code = autoencoder.encode(real_features, normalize_code=normalize_code)
                real_code = add_noise_to_code(real_code, ae_noise_radius)
                real_pred = discriminator(real_code)
                real_loss = - real_pred.mean(0).view(1)
                real_loss.backward()

                # then train the discriminator only with fake data
                noise = Variable(torch.FloatTensor(len(batch), noise_size).normal_())
                noise = to_cuda_if_available(noise)
                fake_code = generator(noise)
                fake_code = fake_code.detach()  # do not propagate to the generator
                fake_pred = discriminator(fake_code)
                fake_loss = fake_pred.mean(0).view(1)
                fake_loss.backward()

                # this is the magic from WGAN-GP
                gradient_penalty = calculate_gradient_penalty(discriminator, penalty, real_code, fake_code)
                gradient_penalty.backward()

                optim_ae.step()
                optim_disc.step()

                disc_loss = real_loss + fake_loss + gradient_penalty
                disc_loss = to_cpu_if_available(disc_loss)
                disc_losses.append(disc_loss.data.numpy())

                del disc_loss
                del gradient_penalty
                del fake_loss
                del real_loss

            # train generator
            for _ in range(num_gen_steps):
                generator.zero_grad()

                noise = Variable(torch.FloatTensor(len(batch), noise_size).normal_())
                noise = to_cuda_if_available(noise)
                gen_code = generator(noise)
                fake_pred = discriminator(gen_code)
                fake_loss = - fake_pred.mean(0).view(1)
                fake_loss.backward()

                optim_gen.step()

                fake_loss = to_cpu_if_available(fake_loss)
                gen_losses.append(fake_loss.data.numpy()[0])

                del fake_loss

        # log epoch metrics for current class
        logger.log(epoch_index, num_epochs, "autoencoder", "train_mean_loss", np.mean(ae_losses))
        logger.log(epoch_index, num_epochs, "discriminator", "train_mean_loss", np.mean(disc_losses))
        logger.log(epoch_index, num_epochs, "generator", "train_mean_loss", np.mean(gen_losses))

        # save models for the epoch
        with DelayedKeyboardInterrupt():
            torch.save(autoencoder.state_dict(), output_ae_path)
            torch.save(generator.state_dict(), output_gen_path)
            torch.save(discriminator.state_dict(), output_disc_path)
            logger.flush()

        ae_noise_radius *= ae_noise_anneal

    logger.close()


def main():
    options_parser = argparse.ArgumentParser(description="Train ARAE or MC-ARAE. "
                                                         + "Define 'metadata' and 'temperature' to use MC-ARAE.")

    options_parser.add_argument("data", type=str, help="Training data. See 'data_format' parameter.")

    options_parser.add_argument("output_autoencoder", type=str, help="Autoencoder output file.")
    options_parser.add_argument("output_generator", type=str, help="Generator output file.")
    options_parser.add_argument("output_discriminator", type=str, help="Discriminator output file.")
    options_parser.add_argument("output_loss", type=str, help="Loss output file.")

    options_parser.add_argument("--input_autoencoder", type=str, help="Autoencoder input file.", default=None)
    options_parser.add_argument("--input_generator", type=str, help="Generator input file.", default=None)
    options_parser.add_argument("--input_discriminator", type=str, help="Discriminator input file.", default=None)

    options_parser.add_argument("--metadata", type=str,
                                help="Information about the categorical variables in json format.")

    options_parser.add_argument(
        "--validation_proportion", type=float,
        default=.1,
        help="Ratio of data for validation."
    )

    options_parser.add_argument(
        "--data_format",
        type=str,
        default="sparse",
        choices=data_formats,
        help="Either a dense numpy array or a sparse csr matrix."
    )

    options_parser.add_argument(
        "--code_size",
        type=int,
        default=128,
        help="Dimension of the autoencoder latent space."
    )

    options_parser.add_argument(
        "--noise_size",
        type=int,
        default=128,
        help="Dimension of the generator input noise."
    )

    options_parser.add_argument(
        "--encoder_hidden_sizes",
        type=str,
        default="",
        help="Size of each hidden layer in the encoder separated by commas (no spaces)."
    )

    options_parser.add_argument(
        "--decoder_hidden_sizes",
        type=str,
        default="",
        help="Size of each hidden layer in the decoder separated by commas (no spaces)."
    )

    options_parser.add_argument(
        "--batch_size",
        type=int,
        default=100,
        help="Amount of samples per batch."
    )

    options_parser.add_argument(
        "--start_epoch",
        type=int,
        default=0,
        help="Starting epoch."
    )

    options_parser.add_argument(
        "--num_epochs",
        type=int,
        default=5000,
        help="Number of epochs."
    )

    options_parser.add_argument(
        "--l2_regularization",
        type=float,
        default=0,
        help="L2 regularization weight for every parameter."
    )

    options_parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-5,
        help="Adam learning rate."
    )

    options_parser.add_argument(
        "--generator_hidden_sizes",
        type=str,
        default="100,100,100",
        help="Size of each hidden layer in the generator separated by commas (no spaces)."
    )

    options_parser.add_argument(
        "--bn_decay",
        type=float,
        default=0.9,
        help="Batch normalization decay for the generator and discriminator."
    )

    options_parser.add_argument(
        "--discriminator_hidden_sizes",
        type=str,
        default="100",
        help="Size of each hidden layer in the discriminator separated by commas (no spaces)."
    )

    options_parser.add_argument(
        "--num_autoencoder_steps",
        type=int,
        default=1,
        help="Number of successive training steps for the autoencoder."
    )

    options_parser.add_argument(
        "--num_discriminator_steps",
        type=int,
        default=1,
        help="Number of successive training steps for the discriminator."
    )

    options_parser.add_argument(
        "--num_generator_steps",
        type=int,
        default=1,
        help="Number of successive training steps for the generator."
    )

    options_parser.add_argument(
        "--autoencoder_noise_radius",
        type=float,
        default=0,
        help="Gaussian noise standard deviation for the latent code (autoencoder regularization)."
    )

    options_parser.add_argument(
        "--autoencoder_noise_anneal",
        type=float,
        default=0.995,
        help="Anneal the noise radius by this value after every epoch."
    )

    options_parser.add_argument(
        "--temperature",
        type=float,
        default=None,
        help="Gumbel-Softmax temperature."
    )

    options_parser.add_argument(
        "--penalty",
        type=float,
        default=0.1,
        help="WGAN-GP gradient penalty lambda."
    )

    options_parser.add_argument("--seed", type=int, help="Random number generator seed.", default=42)

    options = options_parser.parse_args()

    if options.seed is not None:
        np.random.seed(options.seed)
        torch.manual_seed(options.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(options.seed)

    features = loaders[options.data_format](options.data)
    data = Dataset(features)
    train_data, val_data = data.split(1.0 - options.validation_proportion)

    if options.metadata is not None and options.temperature is not None:
        variable_sizes = load_variable_sizes_from_metadata(options.metadata)
        temperature = options.temperature
    else:
        variable_sizes = None
        temperature = None

    autoencoder = AutoEncoder(
        features.shape[1],
        code_size=options.code_size,
        encoder_hidden_sizes=parse_int_list(options.encoder_hidden_sizes),
        decoder_hidden_sizes=parse_int_list(options.decoder_hidden_sizes),
        variable_sizes=variable_sizes
    )

    load_or_initialize(autoencoder, options.input_autoencoder)

    generator = Generator(
        options.noise_size,
        options.code_size,
        hidden_sizes=parse_int_list(options.generator_hidden_sizes),
        bn_decay=options.bn_decay
    )

    load_or_initialize(generator, options.input_generator)

    discriminator = Discriminator(
        options.code_size,
        hidden_sizes=parse_int_list(options.discriminator_hidden_sizes),
        bn_decay=0,  # no batch normalization for the critic
        critic=True
    )

    load_or_initialize(discriminator, options.input_discriminator)

    train(
        autoencoder,
        generator,
        discriminator,
        train_data,
        val_data,
        options.output_autoencoder,
        options.output_generator,
        options.output_discriminator,
        options.output_loss,
        batch_size=options.batch_size,
        start_epoch=options.start_epoch,
        num_epochs=options.num_epochs,
        num_ae_steps=options.num_autoencoder_steps,
        num_disc_steps=options.num_discriminator_steps,
        num_gen_steps=options.num_generator_steps,
        noise_size=options.noise_size,
        l2_regularization=options.l2_regularization,
        learning_rate=options.learning_rate,
        ae_noise_radius=options.autoencoder_noise_radius,
        ae_noise_anneal=options.autoencoder_noise_anneal,
        variable_sizes=variable_sizes,
        temperature=temperature,
        penalty=options.penalty
    )


if __name__ == "__main__":
    main()