#!/usr/bin/python
# -*- coding: utf-8 -*-

import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
import os
import time
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from estimators.utilities import Generator, Critic
from estimators.utilities import add_random_input, add_random_labels
from estimators.utilities import set_learning_rate, set_global_step, rescale
from estimators.utilities import gradient_step, sc_summary, save_generated_cells
from MulticoreTSNE import MulticoreTSNE as TSNE
tsne = TSNE(n_jobs=20)


class cscGAN:
    """
    Contains the class for the conditional scGAN (cscGAN).
    Methods include the creation of the graph, the training of the model,
     the validation and generation of the cells.
    """
    def __init__(self, train_files, valid_files, genes_no, clusters_no,
                 scaling, scale_value, max_steps, batch_size, latent_dim,
                 gen_layers, output_lsn, gene_cond_type, critic_layers,
                 optimizer, lambd, beta1, beta2, decay, alpha_0, alpha_final):
        """
        Constructor for the cscGAN.

        Parameters
        ----------
        train_files : list
            List of TFRecord files used for training.
        valid_files : list
            List of TFRecord files used for validation.
        genes_no : int
            Number of genes in the expression matrix.
        clusters_no : int
            Number of clusters.
        scaling : str
            Method used to scale the data, see the scaling method of the
            GeneMatrix class in preprocessing/process_raw.py for more details.
        scale_value : int, float
            Parameter of the scaling function.
        max_steps : int
            Number of steps in the (outer) training loop.
        batch_size : int
            Batch size used for the training.
        latent_dim : int
            Dimension of the latent space used from which the input noise
             of the generator is sampled.
        gen_layers : list
            List of integers corresponding to the number of neurons of each
            layer of the generator.
        output_lsn : int, None
            Parameter of the LSN layer at the output of the critic
             (i.e. total number of counts per generated cell).
            If set to None, the layer won't be added in the generator.
        gene_cond_type : str
            Conditional normalization layers used in the generator.
             Can be either "batchnorm" or "layernorm".
            If anything else, it won't be added in the model
            (there will be no conditioning in the generation).
        critic_layers : list
            List of integers corresponding to the number of neurons of each
            layer of the critic.
        optimizer : str
            Optimizer used in the training.
            Can be "AMSGrad" for AMSGrad.
            If anything else, Adam will be used.
        lambd : float
            Regularization hyper-parameter to be used with the gradient
             penalty in the WGAN loss.
        beta1 : float
            Exponential decay for the first-moment estimates.
        beta2 : float
            Exponential decay for the second-moment estimates.
        decay : str
            If True, uses an exponential decay of the learning rate.
        alpha_0 : float
            Initial learning rate value.
        alpha_final : float
            Final value of the learning rate if the decay is set to True.
        """

        # read the parameters
        self.clusters_no = clusters_no
        self.latent_dim = latent_dim
        self.lambd = lambd
        self.gen_cond_type = gene_cond_type
        self.critic_layers = critic_layers
        self.gen_layers = gen_layers
        self.output_lsn = output_lsn
        self.optimizer = optimizer
        self.beta1 = beta1
        self.beta2 = beta2
        self.lr_decay = decay
        self.alpha_0 = alpha_0
        self.alpha_final = alpha_final
        self.batch_size = batch_size
        self.max_steps = max_steps
        self.scaling = scaling
        self.scale_value = scale_value
        self.train_files = train_files
        self.valid_files = valid_files
        self.genes_no = genes_no

        # prepare input pipeline for training
        self.train_cells, self.train_cells_clusters = self.make_input_fn(
            self.train_files)

        # prepare input pipeline for validation
        self.test_cells, self.test_cells_clusters = self.make_input_fn(
            self.valid_files)

        # module parameters
        self.generator = None
        self.critic_real = None
        self.critic_fake = None
        self.gradient_penalty = None
        self.gen_loss = None
        self.critic_loss = None
        self.global_step = None
        self.incr_global_step = None
        self.learning_rate = None
        self.critic_train = None
        self.critic_grads_and_vars = None
        self.gen_train = None
        self.gen_grads_and_vars = None
        self.model_train = None
        self.output_tensor = None
        self.build_model()

        # add visualization
        self.visualization()

        # the total number of all trainable parameters
        self.parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    def make_input_fn(self, file_paths, epochs=None):
        """
        Function that loads the TFRecords files and creates the placeholders
        for the data inputs.

        Parameters
        ----------
        file_paths : list
            List of TFRecord files from which to read from.
        epochs : int
            Integer specifying the number of times to read through the dataset.
            If None, cycles through the dataset forever.
            NOTE - If specified, creates a variable that must be initialized,
            so call tf.local_variables_initializer() and run the op in a session.
            Default is None.

        Returns
        -------
        features : Tensor
            Tensor containing a batch of cells (vector of expression levels).
        cluster : Tensor
            Tensor containing (a batch of) the cluster indexes of the
            corresponding cells.
        """

        feature_map = {'scg': tf.SparseFeature(index_key='indices',
                                               value_key='values',
                                               dtype=tf.float32,
                                               size=self.genes_no),
                       'cluster_int': tf.FixedLenFeature(1, tf.int64)}

        options = tf.python_io.TFRecordOptions(
            tf.python_io.TFRecordCompressionType.GZIP)

        batched_features = tf.contrib.learn.read_batch_features(
            file_pattern=file_paths,
            batch_size=self.batch_size,
            features=feature_map,
            reader=lambda: tf.TFRecordReader(
                options=options),
            num_epochs=epochs)

        sgc = batched_features['scg']

        sparse = tf.sparse_reshape(sgc, (self.batch_size, self.genes_no))

        dense = tf.sparse_tensor_to_dense(sparse)

        cluster = tf.squeeze(tf.to_int32(batched_features['cluster_int']))

        features = tf.reshape(dense, (self.batch_size, self.genes_no))

        return features, cluster

    def build_model(self):
        """
        Method that initializes the cscGAN model, creates the graph and
        defines the loss and optimizer.

        Returns
        -------

        """
        # training or inference (used for the batch normalization)
        is_training = tf.placeholder(dtype=tf.bool, name='is_training')

        clusters_ratios = tf.placeholder(dtype=tf.float32,
                                         shape=(1, self.clusters_no),
                                         name='clusters_ratios')

        z_input = add_random_input(self.batch_size, self.latent_dim)
        input_clusters = add_random_labels(clusters_ratios, self.batch_size)

        # create generator
        self.generator = Generator.create_cond_generator(
            z_input=z_input,
            batch_size=self.batch_size,
            latent_dim=self.latent_dim,
            output_cells_dim=self.genes_no,
            var_scope='generator',
            gen_layers=self.gen_layers,
            output_lsn=self.output_lsn,
            gen_cond_type=self.gen_cond_type,
            clusters_no=self.clusters_no,
            input_clusters=input_clusters,
            is_training=is_training,
            clusters_ratios=clusters_ratios,
            reuse=None)

        # Critic with real cells as input
        with tf.name_scope('real_critic'):
            self.critic_real = Critic.create_cond_critic(
                xinput=self.train_cells,
                input_clusters=self.train_cells_clusters,
                var_scope="critic",
                critic_layers=self.critic_layers,
                clusters_no=self.clusters_no,
                reuse=None)

        # Critic with generated cells as input (shares weights with critic_real)
        with tf.name_scope('fake_critic'):
            self.critic_fake = \
                Critic.create_cond_critic(
                    xinput=self.generator.fake_outputs,
                    input_clusters=self.generator.input_clusters,
                    var_scope="critic",
                    critic_layers=self.critic_layers,
                    clusters_no=self.clusters_no, reuse=True)

        # Disc loss
        with tf.name_scope('critic_loss'):
            critic_loss_wgan = tf.reduce_mean(self.critic_fake.dist) \
                               - tf.reduce_mean(self.critic_real.dist)

            # The following lines implement the gradient penalty term
            alpha = tf.random_uniform(shape=[self.batch_size, 1],
                                      minval=0., maxval=1.)

            generator_interpolates = \
                Generator.create_cond_generator(
                    z_input=z_input,
                    batch_size=self.batch_size,
                    latent_dim=self.latent_dim,
                    output_cells_dim=self.genes_no,
                    var_scope='generator',
                    gen_layers=self.gen_layers,
                    output_lsn=self.output_lsn,
                    gen_cond_type=self.gen_cond_type,
                    clusters_no=self.clusters_no,
                    input_clusters=self.train_cells_clusters,
                    is_training=is_training,
                    clusters_ratios=clusters_ratios,
                    reuse=True)

            differences = generator_interpolates.fake_outputs - self.train_cells

            interpolates = self.train_cells + (alpha * differences)

            with tf.name_scope('help_critic'):
                critic_interpolates = \
                    Critic.create_cond_critic(
                        xinput=interpolates,
                        input_clusters=self.train_cells_clusters,
                        var_scope="critic",
                        critic_layers=self.critic_layers,
                        clusters_no=self.clusters_no,
                        reuse=True)

            gradients = tf.gradients(critic_interpolates.dist, [interpolates])[0]
            slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients),
                                           reduction_indices=[1]))
            self.gradient_penalty = tf.reduce_mean((slopes - 1) ** 2)
            critic_loss_wgan += self.lambd * self.gradient_penalty
            self.critic_loss = critic_loss_wgan

        # gen loss
        with tf.name_scope('generator_loss'):
            self.gen_loss = -tf.reduce_mean(self.critic_fake.dist)

        # add global step
        self.global_step, self.incr_global_step = set_global_step()

        # add decaying learning rate
        self.learning_rate = set_learning_rate(self.alpha_0,
                                               self.alpha_final,
                                               self.global_step,
                                               self.max_steps)

        # training the critic
        with tf.name_scope("critic_train"):
            critic_params = [var for var in tf.trainable_variables()
                             if var.name.startswith('critic')]
            self.critic_train, self.critic_grads_and_vars =\
                gradient_step(critic_params,
                              training_loss=self.critic_loss,
                              learning_rate=self.learning_rate,
                              beta1=self.beta1,
                              beta2=self.beta2,
                              optimizer=self.optimizer)

        # training the generator
        with tf.name_scope("generator_train"):
            with tf.control_dependencies(
                    [self.critic_train] +
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                gen_params = [var for var in tf.trainable_variables()
                              if var.name.startswith('gen')]
                self.gen_train, self.gen_grads_and_vars = gradient_step(
                    gen_params,
                    training_loss=self.gen_loss,
                    learning_rate=self.learning_rate,
                    beta1=self.beta1,
                    beta2=self.beta2,
                    optimizer=self.optimizer)

        self.model_train = tf.group(self.incr_global_step, self.gen_train)
        self.critic_train = tf.group(self.critic_train)

    def visualization(self):
        """
        Method creating the placeholders to log the values for monitoring
        with the Tensorboard.

        Returns
        -------

        """

        # histograms and mean of real and generated cells
        sc_summary("single_cell_fake", self.generator.fake_outputs)
        sc_summary("single_cell_real", self.train_cells)

        # loss functions visualization
        tf.summary.scalar("gen_loss", self.gen_loss)
        tf.summary.scalar("critic_loss", self.critic_loss)

        tf.summary.scalar("Penalty", self.gradient_penalty)
        tf.summary.scalar("learning_rate", self.learning_rate)

        tf.summary.histogram("Distance_fake", self.critic_fake.dist)
        tf.summary.histogram("Distance_real", self.critic_real.dist)

        # visualize  trainable variables
        for var in tf.trainable_variables():
            tf.summary.histogram(var.op.name + "/values", var)

        # visualize gradients
        for grad, var in self.gen_grads_and_vars + self.critic_grads_and_vars:
            if grad is not None:
                tf.summary.histogram(var.op.name + "/gradients", grad)

    def training(self, exp_folder, clusters_ratios, checkpoint=None,
                 progress_freq=1000, summary_freq=200, save_freq=5000,
                 validation_freq=500, critic_iter=5, valid_cells_no=500):
        """
        Method that trains the cscGAN.

        Parameters
        ----------
        exp_folder : str
            Path where TF will write the logs, save the model, the t-SNE plots etc.
        clusters_ratios : dict
            Dictionary containing the different cluster names and their
            ratio in the data.
        checkpoint : str, None
            Path to the checkpoint to start from, or None to start training
             from scratch.
            Default is None.
        progress_freq : int
            Period (in steps) between displays of the losses values on the
            standard output.
            Default is 1000.
        summary_freq : int
            Period (in steps) between logs for the Tensorboard.
            Default is 200.
        save_freq : int
            Period (in steps) between saves of the model.
            Default is 5000.
        validation_freq : int
            Period (in steps) between validation measures are computed
            (e.g. t-SNE plots).
            Default is 500.
        critic_iter : int
            Number of training iterations of the critic (inner loop) for each
            iteration on the generator (outer loop).
            Default is 5.
        valid_cells_no :
            Number of cells in the validation set.
            Default is 500.

        Returns
        -------

        """
        exp_name = exp_folder.split('/')[-1]

        # Transform the cluster ratios dictionary into an ordered list
        clusters_ratios = [value for (key, value) in sorted(clusters_ratios.items())]
        clusters_ratios.sort(reverse=True)
        clusters_ratios = np.reshape(clusters_ratios, (1, len(clusters_ratios)))

        # Number of different models to keep (each time a model is saved,
        #  it will overwrite the oldest)
        saver = tf.train.Saver(max_to_keep=1)

        train_supervisor = tf.train.Supervisor(logdir=exp_folder,
                                               save_summaries_secs=0,
                                               saver=None)

        start = time.time()

        # Start the TF session
        with train_supervisor.managed_session() as sess:

            train_feed_dict = {self.generator.is_training: True,
                               self.generator.clusters_ratios: clusters_ratios}

            print("Parameter Count is [ %d ]." % (sess.run(self.parameter_count)))

            # load checkpoint and instantiate the start_step accordingly
            if checkpoint is not None:
                print("Loading model from checkpoint....")
                checkpoint = tf.train.latest_checkpoint(checkpoint)
                saver.restore(sess, checkpoint)
                start_step = sess.run(train_supervisor.global_step)
            else:
                start_step = 0

            critic_fetch = {"train": self.critic_train}

            # Outer loop, one step for the generator and several for the critic
            for step in range(start_step, self.max_steps):

                # small utility function to perform tasks at defined intervals
                def should(freq):
                    return freq > 0 and \
                           ((step + 1) % freq == 0 or
                            step == self.max_steps - 1)

                # Inner loop, for each generator step, several critic steps
                if step > 0:
                    for i_critic in range(critic_iter):
                        sess.run(fetches=critic_fetch, feed_dict=train_feed_dict)

                model_fetches = {"train": self.model_train}

                # Add the corresponding summary tensors to the fetches
                if should(summary_freq):
                    model_fetches["summary"] = train_supervisor.summary_op

                if should(progress_freq):
                    model_fetches["gen_loss"] = self.gen_loss
                    model_fetches["critic_loss"] = self.critic_loss

                results = sess.run(model_fetches, feed_dict=train_feed_dict)

                # Update the summaries for Tensorboard
                if should(summary_freq):
                    print("Recording summary ...")
                    train_supervisor.summary_writer.add_summary(
                        results["summary"], step)

                # Launch the validation steps
                if should(validation_freq):
                    self.validation(sess, valid_cells_no,
                                    exp_folder, step, clusters_ratios)

                # Print out the progresses
                if should(progress_freq):
                    rate = (step + 1) / (time.time() - start)
                    remaining = (self.max_steps - (step + 1)) / rate
                    print("[ %s ] Step number %d ." % (exp_name, step))
                    print("[ %s ] Running rate  %0.3f steps/sec."
                          % (exp_name, rate))
                    print("[ %s ] Estimated remaining time  %d m"
                          % (exp_name, remaining // 60))
                    print("[ %s ] Critic batch loss %0.3f"
                          % (exp_name, results["critic_loss"]))
                    print("[ %s ] Generator batch loss %0.f"
                          % (exp_name, results["gen_loss"]))

                # Save the model
                if should(save_freq):
                    saver.save(sess, os.path.join(exp_folder, "model"),
                               global_step=step)

    def read_valid_cells(self, sess, cells_no):
        """
        Method that reads a given number of cells from the validation set.

        Parameters
        ----------
        sess : Session
            The TF Session in use.
        cells_no : int
            Number of validation cells to read.

        Returns
        -------
        real_cells : numpy array
            Matrix with the required amount of validation cells.
        real_clusters : list
            List containing the corresponding cluster indexes.
        """

        batches_no = int(np.ceil(cells_no // self.batch_size))

        real_cells = []
        real_clusters = []
        for i_batch in range(batches_no):
            test_inputs, test_clusters = sess.run(
                [self.test_cells, self.test_cells_clusters])
            real_cells.append(test_inputs)
            real_clusters.append(test_clusters)

        real_cells = np.array(real_cells, dtype=np.float32)
        real_cells = real_cells.reshape((-1, self.test_cells.shape[1]))

        real_cells = rescale(real_cells, scaling=self.scaling,
                             scale_value=self.scale_value)

        return real_cells, real_clusters

    def generate_cells(self, cells_no, clusters_ratios=None,
                       sess=None, save_path=None, checkpoint=None):
        """
        Method that generate cells from the current model.

        Parameters
        ----------
        cells_no : int or list
            Numbers of cells per cluster to be generated.
            If the clusters_ratios are provided, should be an int (total number of cells).
            If cluster_ratios is None, should be a list of number of cells per cluster.
        clusters_ratios : numpy array
            List containing the different cluster ratios to use for
            the conditional generation.
            Default is None.
        sess : Session
            The TF Session in use.
            If None, a Session is created.
            Default is None.
        save_path : str
            Path in which to write the generated cells.
            If None, the cells are only returned and not written.
            Default is None.
        checkpoint : str /None
            Path to the checkpoint from which to load the model.
            If None, uses the current model loaded in the session.
            Default is None.

        Returns
        -------
        fake_cells : Numpy array
            2-D Array with the gene expression matrix of the generated cells.
        fake_labels : Numpy array
            Array containing the cluster index of the generated cells.
        """

        if sess is None:
            sess = tf.Session()

        if checkpoint is not None:
            saver = tf.train.Saver()
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint))

        fake_cells = np.empty((0, self.genes_no), dtype=np.float32)
        fake_labels = np.empty([0, 1], dtype=np.int32)

        if clusters_ratios is None and len(cells_no) > 1:

            for cluster, cells_per_cluster in enumerate(cells_no):
                if int(cells_per_cluster) == 0:
                    continue
                clusters_ratios = np.zeros((1, self.clusters_no), dtype=np.float)
                clusters_ratios[0][cluster] = 1

                fc, fl = self.generate_cells(sess=sess, checkpoint=checkpoint,
                                             cells_no=int(cells_per_cluster),
                                             clusters_ratios=clusters_ratios)

                fake_cells = np.append(fake_cells, fc, axis=0)
                fake_labels = np.append(fake_labels, fl)

        else:

            batches_no = int(np.ceil(cells_no / self.batch_size))

            clusters_ratios_ph = self.generator.clusters_ratios
            fake_labels_tensor = self.generator.input_clusters
            is_training = self.generator.is_training
            fake_cells_tensor = self.generator.fake_outputs

            eval_feed_dict = {is_training: False,
                              clusters_ratios_ph: clusters_ratios}

            for i_batch in range(batches_no):
                fc, fl = sess.run([fake_cells_tensor, fake_labels_tensor],
                                  feed_dict=eval_feed_dict)
                fake_cells = np.append(fake_cells, fc, axis=0)
                fake_labels = np.append(fake_labels, fl)

            fake_labels = fake_labels[0:cells_no]
            fake_cells = fake_cells[0:cells_no]

        rescale(fake_cells, scaling=self.scaling, scale_value=self.scale_value)

        if save_path is not None:
            save_generated_cells(fake_cells, save_path, fake_labels)

        return fake_cells, fake_labels

    def validation(self, sess, cells_no, exp_folder,
                   train_step, clusters_ratios):
        """
        Method that initiates some validation steps of the current model.

        Parameters
        ----------
        sess : Session
            The TF Session in use.
        cells_no : int
            Number of cells to use for the validation step.
        exp_folder : str
            Path to the job folder in which the outputs will be saved.
        train_step : int
            Index of the current training step.
        clusters_ratios : list
            List containing the different cluster ratios to use for the
            conditional generation.

        Returns
        -------

        """

        print("Find tSNE embedding for the generated and the validation cells")
        self.generate_tSNE_image(sess, cells_no, exp_folder,
                                 train_step, clusters_ratios)

    def generate_tSNE_image(self, sess, cells_no, exp_folder,
                            train_step, clusters_ratios):
        """
        Generates and saves a t-SNE plot with real and simulated cells

        Parameters
        ----------
        sess : Session
            The TF Session in use.
        cells_no : int
            Number of cells to use for the real and simulated cells (each)
            used for the plot.
        exp_folder : str
            Path to the job folder in which the outputs will be saved.
        train_step : int
            Index of the current training step.
        clusters_ratios : list
            List containing the different cluster ratios to use for the
            conditional generation.

        Returns
        -------

        """

        tnse_logdir = os.path.join(exp_folder, 'TSNE')
        if not os.path.isdir(tnse_logdir):
            os.makedirs(tnse_logdir)

        # generate fake cells
        fake_cells, fake_clusters = self.generate_cells(
            checkpoint=None,
            cells_no=cells_no,
            clusters_ratios=clusters_ratios,
            sess=sess)

        valid_cells, valid_clusters = self.read_valid_cells(sess, cells_no)

        real_cells_clusters = np.array(valid_clusters, dtype=np.float32).flatten()
        fake_cells_clusters = np.array(fake_clusters, dtype=np.float32).flatten()

        embedded_cells = tsne.fit_transform(
            np.concatenate((valid_cells, fake_cells), axis=0))
        embedded_cells_real = embedded_cells[0:real_cells_clusters.shape[0], :]
        embedded_cells_fake = embedded_cells[real_cells_clusters.shape[0]:, :]

        colormap = cm.nipy_spectral
        colors = [colormap(i) for i in np.linspace(0, 1, self.clusters_no)]

        plt.clf()
        plt.figure(figsize=(16, 12))

        for i in range(self.clusters_no):
            mask = real_cells_clusters[:] == i

            plt.scatter(embedded_cells_real[mask, 0],
                        embedded_cells_real[mask, 1],
                        c=colors[i], marker='*',
                        label='real_' + str(i))

        for i in range(self.clusters_no):
            mask = fake_cells_clusters[:] == i
            plt.scatter(embedded_cells_fake[mask, 0],
                        embedded_cells_fake[mask, 1],
                        c=colors[i], marker='o',
                        label='fake_' + str(i))

        plt.grid(True)
        plt.legend(loc='lower left',
                   numpoints=1, ncol=3,
                   fontsize=8, bbox_to_anchor=(0, 0))
        plt.savefig(tnse_logdir + '/step_' + str(train_step) + '.jpg')
        plt.close()