import json
import logging
import numpy as np
import os
import pickle
import tensorflow as tf

from linguistic_style_transfer_model.config import global_config
from linguistic_style_transfer_model.config.model_config import mconf
from linguistic_style_transfer_model.evaluators import content_preservation, style_transfer
from linguistic_style_transfer_model.utils import data_processor, custom_decoder

logger = logging.getLogger(global_config.logger_name)


class AdversarialAutoencoder:

    def get_sentence_embedding(self, encoder_embedded_sequence):

        scope_name = "sentence_embedding"
        with tf.name_scope(scope_name):
            encoder_cell_fw = tf.nn.rnn_cell.DropoutWrapper(
                cell=tf.contrib.rnn.GRUCell(num_units=mconf.encoder_rnn_size),
                input_keep_prob=self.recurrent_state_keep_prob,
                output_keep_prob=self.recurrent_state_keep_prob,
                state_keep_prob=self.recurrent_state_keep_prob)
            encoder_cell_bw = tf.nn.rnn_cell.DropoutWrapper(
                cell=tf.contrib.rnn.GRUCell(num_units=mconf.encoder_rnn_size),
                input_keep_prob=self.recurrent_state_keep_prob,
                output_keep_prob=self.recurrent_state_keep_prob,
                state_keep_prob=self.recurrent_state_keep_prob)

            _, encoder_states = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=encoder_cell_fw, cell_bw=encoder_cell_bw,
                inputs=encoder_embedded_sequence, scope=scope_name,
                sequence_length=self.sequence_lengths, dtype=tf.float32)

            return tf.concat(values=encoder_states, axis=1, name="sentence_embedding")

    def get_style_embedding(self, sentence_embedding):

        with tf.name_scope("style_embedding"):
            style_embedding_mu = tf.nn.dropout(
                x=tf.layers.dense(
                    inputs=sentence_embedding,
                    units=mconf.style_embedding_size,
                    activation=tf.nn.leaky_relu, name="style_embedding_mu"),
                keep_prob=self.fully_connected_keep_prob)

            style_embedding_sigma = tf.nn.dropout(
                x=tf.layers.dense(
                    inputs=sentence_embedding,
                    units=mconf.style_embedding_size,
                    activation=tf.nn.leaky_relu, name="style_embedding_sigma"),
                keep_prob=self.fully_connected_keep_prob)

            return style_embedding_mu, style_embedding_sigma

    def get_content_embedding(self, sentence_embedding):

        with tf.name_scope("content_embedding"):
            content_embedding_mu = tf.nn.dropout(
                x=tf.layers.dense(
                    inputs=sentence_embedding,
                    units=mconf.content_embedding_size,
                    activation=tf.nn.leaky_relu, name="content_embedding_mu"),
                keep_prob=self.fully_connected_keep_prob)

            content_embedding_sigma = tf.nn.dropout(
                x=tf.layers.dense(
                    inputs=sentence_embedding,
                    units=mconf.content_embedding_size,
                    activation=tf.nn.leaky_relu, name="content_embedding_sigma"),
                keep_prob=self.fully_connected_keep_prob)

            return content_embedding_mu, content_embedding_sigma

    def get_content_adversary_prediction(self, style_embedding):

        content_adversary_mlp = tf.nn.dropout(
            x=tf.layers.dense(
                inputs=style_embedding, units=global_config.bow_size,
                activation=tf.nn.leaky_relu, name="content_adversary_mlp"),
            keep_prob=self.fully_connected_keep_prob)

        content_adversary_prediction = tf.layers.dense(
            inputs=content_adversary_mlp, units=global_config.bow_size,
            activation=tf.nn.softmax, name="content_adversary_prediction")

        return content_adversary_prediction

    def get_style_adversary_prediction(self, content_embedding, num_labels):

        style_adversary_mlp = tf.nn.dropout(
            x=tf.layers.dense(
                inputs=content_embedding, units=mconf.content_embedding_size,
                activation=tf.nn.leaky_relu, name="style_adversary_mlp"),
            keep_prob=self.fully_connected_keep_prob)

        style_adversary_prediction = tf.layers.dense(
            inputs=style_adversary_mlp, units=num_labels,
            activation=tf.nn.softmax, name="style_adversary_prediction")

        return style_adversary_prediction

    def generate_output_sequence(self, embedded_sequence, generative_embedding,
                                 decoder_embeddings, word_index, batch_size):

        decoder_cell = tf.nn.rnn_cell.DropoutWrapper(
            cell=tf.contrib.rnn.GRUCell(num_units=mconf.decoder_rnn_size),
            input_keep_prob=self.recurrent_state_keep_prob,
            output_keep_prob=self.recurrent_state_keep_prob,
            state_keep_prob=self.recurrent_state_keep_prob)

        projection_layer = tf.layers.Dense(units=global_config.vocab_size, use_bias=False)

        init_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32)

        training_decoder_scope_name = "training_decoder"
        with tf.name_scope(training_decoder_scope_name):
            training_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=embedded_sequence,
                sequence_length=self.sequence_lengths)

            training_decoder = custom_decoder.CustomBasicDecoder(
                cell=decoder_cell, helper=training_helper,
                initial_state=init_state,
                latent_vector=generative_embedding,
                output_layer=projection_layer)
            training_decoder.initialize(training_decoder_scope_name)

            training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=training_decoder, impute_finished=True,
                maximum_iterations=global_config.max_sequence_length,
                scope=training_decoder_scope_name)

        inference_decoder_scope_name = "inference_decoder"
        with tf.name_scope(inference_decoder_scope_name):
            greedy_embedding_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                embedding=decoder_embeddings,
                start_tokens=tf.fill(dims=[batch_size],
                                     value=word_index[global_config.sos_token]),
                end_token=word_index[global_config.eos_token])

            inference_decoder = custom_decoder.CustomBasicDecoder(
                cell=decoder_cell, helper=greedy_embedding_helper,
                initial_state=init_state,
                latent_vector=generative_embedding,
                output_layer=projection_layer)
            inference_decoder.initialize(inference_decoder_scope_name)

            inference_decoder_output, _, final_sequence_lengths = \
                tf.contrib.seq2seq.dynamic_decode(
                    decoder=inference_decoder, impute_finished=True,
                    maximum_iterations=global_config.max_sequence_length,
                    scope=inference_decoder_scope_name)

        return [training_decoder_output.rnn_output, inference_decoder_output.sample_id, final_sequence_lengths]

    def get_kl_loss(self, mu, log_sigma):
        return tf.reduce_mean(
            input_tensor=-0.5 * tf.reduce_sum(
                input_tensor=1 + log_sigma - tf.square(mu) - tf.exp(log_sigma),
                axis=1))

    def sample_prior(self, mu, log_sigma):
        epsilon = tf.random_normal(tf.shape(log_sigma), name="epsilon")
        return mu + epsilon * tf.exp(log_sigma)

    def compute_batch_entropy(self, x):
        return tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=-x * tf.log(x + mconf.epsilon), axis=1))

    def build_model(self, word_index, encoder_embedding_matrix, decoder_embedding_matrix, num_labels):

        # model inputs
        self.input_sequence = tf.placeholder(
            dtype=tf.int32, shape=[None, global_config.max_sequence_length],
            name="input_sequence")
        logger.debug("input_sequence: {}".format(self.input_sequence))

        batch_size = tf.shape(self.input_sequence)[0]
        logger.debug("batch_size: {}".format(batch_size))

        self.input_label = tf.placeholder(
            dtype=tf.float32, shape=[None, num_labels], name="input_label")
        logger.debug("input_label: {}".format(self.input_label))

        self.sequence_lengths = tf.placeholder(
            dtype=tf.int32, shape=[None], name="sequence_lengths")
        logger.debug("sequence_lengths: {}".format(self.sequence_lengths))

        self.input_bow_representations = tf.placeholder(
            dtype=tf.float32, shape=[None, global_config.bow_size],
            name="input_bow_representations")
        logger.debug("input_bow_representations: {}".format(self.input_bow_representations))

        self.inference_mode = tf.placeholder(dtype=tf.bool, name="inference_mode")
        logger.debug("inference_mode: {}".format(self.inference_mode))

        self.generation_mode = tf.placeholder(dtype=tf.bool, name="generation_mode")
        logger.debug("generation_mode: {}".format(self.generation_mode))

        self.recurrent_state_keep_prob = tf.cond(
            pred=tf.math.logical_or(self.inference_mode, self.generation_mode),
            true_fn=lambda: 1.0,
            false_fn=lambda: mconf.recurrent_state_keep_prob)

        self.fully_connected_keep_prob = tf.cond(
            pred=tf.math.logical_or(self.inference_mode, self.generation_mode),
            true_fn=lambda: 1.0,
            false_fn=lambda: mconf.fully_connected_keep_prob)

        self.sequence_word_keep_prob = tf.cond(
            pred=tf.math.logical_or(self.inference_mode, self.generation_mode),
            true_fn=lambda: 1.0,
            false_fn=lambda: mconf.sequence_word_keep_prob)

        self.conditioning_embedding = tf.placeholder(
            dtype=tf.float32, shape=[None, mconf.style_embedding_size],
            name="conditioning_embedding")
        logger.debug("conditioning_embedding: {}".format(self.conditioning_embedding))

        self.sampled_content_embedding = tf.placeholder(
            dtype=tf.float32, shape=[None, mconf.content_embedding_size],
            name="sampled_content_embedding")
        logger.debug("sampled_content_embedding: {}".format(self.sampled_content_embedding))

        self.epoch = tf.placeholder(dtype=tf.float32, shape=(), name="epoch")
        logger.debug("epoch: {}".format(self.epoch))

        self.style_kl_weight = tf.placeholder(dtype=tf.float32, shape=(), name="style_kl_weight")
        logger.debug("style_kl_weight: {}".format(self.style_kl_weight))

        self.content_kl_weight = tf.placeholder(dtype=tf.float32, shape=(), name="content_kl_weight")
        logger.debug("content_kl_weight: {}".format(self.content_kl_weight))

        decoder_input = tf.concat(
            values=[tf.fill(dims=[batch_size, 1], value=word_index[global_config.sos_token]),
                    self.input_sequence], axis=1, name="decoder_input")

        with tf.device('/cpu:0'):
            with tf.variable_scope("embeddings", reuse=tf.AUTO_REUSE):
                # word embeddings matrices
                encoder_embeddings = tf.get_variable(
                    initializer=encoder_embedding_matrix, dtype=tf.float32,
                    trainable=True, name="encoder_embeddings")
                logger.debug("encoder_embeddings: {}".format(encoder_embeddings))

                decoder_embeddings = tf.get_variable(
                    initializer=decoder_embedding_matrix, dtype=tf.float32,
                    trainable=True, name="decoder_embeddings")
                logger.debug("decoder_embeddings: {}".format(decoder_embeddings))

                # embedded sequences
                encoder_embedded_sequence = tf.nn.dropout(
                    x=tf.nn.embedding_lookup(params=encoder_embeddings, ids=self.input_sequence),
                    keep_prob=self.sequence_word_keep_prob,
                    name="encoder_embedded_sequence")
                logger.debug("encoder_embedded_sequence: {}".format(encoder_embedded_sequence))

                decoder_embedded_sequence = tf.nn.dropout(
                    x=tf.nn.embedding_lookup(params=decoder_embeddings, ids=decoder_input),
                    keep_prob=self.sequence_word_keep_prob,
                    name="decoder_embedded_sequence")
                logger.debug("decoder_embedded_sequence: {}".format(decoder_embedded_sequence))

        sentence_embedding = self.get_sentence_embedding(encoder_embedded_sequence)

        # style embedding
        style_embedding_mu, style_embedding_sigma = self.get_style_embedding(sentence_embedding)
        unweighted_style_kl_loss = self.get_kl_loss(style_embedding_mu, style_embedding_sigma)
        self.style_kl_loss = unweighted_style_kl_loss * self.style_kl_weight
        sampled_style_embedding = self.sample_prior(style_embedding_mu, style_embedding_sigma)

        self.style_embedding = tf.cond(
            pred=tf.math.logical_or(self.inference_mode, self.generation_mode),
            true_fn=lambda: self.conditioning_embedding,
            false_fn=lambda: sampled_style_embedding)
        logger.debug("style_embedding: {}".format(self.style_embedding))

        # content embedding
        content_embedding_mu, content_embedding_sigma = self.get_content_embedding(sentence_embedding)
        unweighted_content_kl_loss = self.get_kl_loss(content_embedding_mu, content_embedding_sigma)
        self.content_kl_loss = unweighted_content_kl_loss * self.content_kl_weight
        sampled_content_embedding = self.sample_prior(content_embedding_mu, content_embedding_sigma)

        pre_content_embedding = tf.cond(
            pred=self.inference_mode,
            true_fn=lambda: content_embedding_mu,
            false_fn=lambda: sampled_content_embedding)
        self.content_embedding = tf.cond(
            pred=self.generation_mode,
            true_fn=lambda: self.sampled_content_embedding,
            false_fn=lambda: pre_content_embedding
        )
        logger.debug("content_embedding: {}".format(self.content_embedding))

        # concatenated generative embedding
        generative_embedding = tf.layers.dense(
            inputs=tf.concat(values=[self.style_embedding, self.content_embedding], axis=1),
            units=mconf.decoder_rnn_size, activation=tf.nn.leaky_relu,
            name="generative_embedding")
        logger.debug("generative_embedding: {}".format(generative_embedding))

        # sequence predictions
        with tf.name_scope('sequence_prediction'):
            training_output, self.inference_output, self.final_sequence_lengths = \
                self.generate_output_sequence(
                    decoder_embedded_sequence, generative_embedding, decoder_embeddings,
                    word_index, batch_size)
            logger.debug("training_output: {}".format(training_output))
            logger.debug("inference_output: {}".format(self.inference_output))

        # adversarial loss
        with tf.name_scope('adversarial_objectives'):
            # style adversary
            style_adversary_prediction = self.get_style_adversary_prediction(content_embedding_mu, num_labels)
            logger.debug("style_adversary_prediction: {}".format(style_adversary_prediction))

            self.quantized_style_adversary_prediction = tf.contrib.seq2seq.hardmax(
                logits=style_adversary_prediction, name="quantized_style_adversary_prediction")

            self.style_adversary_entropy = self.compute_batch_entropy(style_adversary_prediction)
            logger.debug("style_adversary_entropy: {}".format(self.style_adversary_entropy))

            self.style_adversary_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=self.input_label, logits=style_adversary_prediction,
                label_smoothing=0.1)
            logger.debug("style_adversary_loss: {}".format(self.style_adversary_loss))

            # content adversary
            content_adversary_prediction = self.get_content_adversary_prediction(self.style_embedding)
            logger.debug("content_adversary_prediction: {}".format(content_adversary_prediction))

            self.content_adversary_entropy = self.compute_batch_entropy(content_adversary_prediction)
            logger.debug("content_adversary_entropy: {}".format(self.content_adversary_entropy))

            self.content_adversary_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=self.input_bow_representations, logits=content_adversary_prediction,
                label_smoothing=0.1)
            logger.debug("content_adversary_loss: {}".format(self.content_adversary_loss))

        # multi-task objectives
        with tf.name_scope('multitask_objectives'):
            # style multitask
            style_multitask_prediction = tf.nn.dropout(
                x=tf.layers.dense(
                    inputs=style_embedding_mu, units=num_labels,
                    activation=tf.nn.softmax, name="style_multitask_prediction"),
                keep_prob=self.fully_connected_keep_prob)
            logger.debug("style_multitask_prediction: {}".format(style_multitask_prediction))

            self.quantized_style_multitask_prediction = tf.contrib.seq2seq.hardmax(
                logits=style_multitask_prediction, name="quantized_style_multitask_prediction")

            self.style_multitask_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=self.input_label, logits=style_multitask_prediction, label_smoothing=0.1)
            logger.debug("style_multitask_loss: {}".format(self.style_multitask_loss))

            # bow multitask
            content_multitask_prediction = tf.nn.dropout(
                x=tf.layers.dense(
                    inputs=content_embedding_mu, units=global_config.bow_size,
                    activation=tf.nn.leaky_relu, name="content_multitask_prediction"),
                keep_prob=self.fully_connected_keep_prob)
            logger.debug("content_multitask_prediction: {}".format(content_multitask_prediction))

            self.content_multitask_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=self.input_bow_representations, logits=content_multitask_prediction,
                label_smoothing=0.1)
            logger.debug("content_multitask_loss: {}".format(self.content_multitask_loss))

        # overall latent space classifier
        # not required for style transfer
        # used to prove disentanglement
        style_overall_prediction = tf.nn.dropout(
            x=tf.layers.dense(
                inputs=tf.concat(values=[style_embedding_mu, content_embedding_mu], axis=1),
                units=num_labels, activation=tf.nn.softmax,
                name="style_overall_prediction"),
            keep_prob=self.fully_connected_keep_prob)
        logger.debug("style_overall_prediction: {}".format(style_overall_prediction))

        self.quantized_style_overall_prediction = tf.contrib.seq2seq.hardmax(
            logits=style_overall_prediction, name="quantized_style_overall_prediction")

        self.style_overall_prediction_loss = tf.losses.softmax_cross_entropy(
            onehot_labels=self.input_label, logits=style_overall_prediction, label_smoothing=0.1)
        logger.debug("style_overall_prediction_loss: {}".format(self.style_overall_prediction_loss))

        # reconstruction loss
        with tf.name_scope('reconstruction_loss'):
            batch_maxlen = tf.reduce_max(self.sequence_lengths)
            logger.debug("batch_maxlen: {}".format(batch_maxlen))

            # the training decoder only emits outputs equal in time-steps to the
            # max time-steps in the current batch
            target_sequence = tf.slice(
                input_=self.input_sequence,
                begin=[0, 0],
                size=[batch_size, batch_maxlen],
                name="target_sequence")
            logger.debug("target_sequence: {}".format(target_sequence))

            output_sequence_mask = tf.sequence_mask(
                lengths=tf.add(x=self.sequence_lengths, y=1),
                maxlen=batch_maxlen,
                dtype=tf.float32)

            self.reconstruction_loss = tf.contrib.seq2seq.sequence_loss(
                logits=training_output, targets=target_sequence,
                weights=output_sequence_mask)
            logger.debug("reconstruction_loss: {}".format(self.reconstruction_loss))

        # tensorboard logging variable summaries
        tf.summary.scalar(tensor=self.reconstruction_loss, name="reconstruction_loss_summary")
        tf.summary.scalar(tensor=self.style_multitask_loss, name="style_multitask_loss_summary")
        tf.summary.scalar(tensor=self.style_adversary_loss, name="style_adversary_loss_summary")
        tf.summary.scalar(tensor=self.content_adversary_loss, name="content_adversary_loss_summary")
        tf.summary.scalar(tensor=self.content_multitask_loss, name="content_multitask_loss_summary")
        tf.summary.scalar(tensor=unweighted_style_kl_loss, name="unweighted_style_kl_loss_summary")
        tf.summary.scalar(tensor=unweighted_content_kl_loss, name="unweighted_content_kl_loss_summary")
        tf.summary.scalar(tensor=self.style_kl_loss, name="style_kl_loss_summary")
        tf.summary.scalar(tensor=self.content_kl_loss, name="content_kl_loss_summary")

    def get_batch_indices(self, batch_number, data_limit):

        start_index = batch_number * mconf.batch_size
        end_index = min((batch_number + 1) * mconf.batch_size, data_limit)

        return start_index, end_index

    def run_batch(self, sess, start_index, end_index, fetches, padded_sequences,
                  one_hot_labels, text_sequence_lengths,
                  conditioning_embedding, inference_mode, generation_mode,
                  style_kl_weight, content_kl_weight, current_epoch):

        if not inference_mode and not generation_mode:
            conditioning_embedding = np.random.uniform(
                size=(end_index - start_index, mconf.style_embedding_size),
                low=-0.05, high=0.05).astype(dtype=np.float32)

        sampled_content_embedding = np.random.normal(
            size=(end_index - start_index, mconf.content_embedding_size)).astype(dtype=np.float32)

        bow_representations = data_processor.get_bow_representations(
            padded_sequences[start_index: end_index])

        ops = sess.run(
            fetches=fetches,
            feed_dict={
                self.input_sequence: padded_sequences[start_index: end_index],
                self.input_label: one_hot_labels[start_index: end_index],
                self.sequence_lengths: text_sequence_lengths[start_index: end_index],
                self.input_bow_representations: bow_representations,
                self.inference_mode: inference_mode,
                self.generation_mode: generation_mode,
                self.conditioning_embedding: conditioning_embedding,
                self.sampled_content_embedding: sampled_content_embedding,
                self.style_kl_weight: style_kl_weight,
                self.content_kl_weight: content_kl_weight,
                self.epoch: current_epoch
            })

        return ops

    def get_annealed_weight(self, iteration, lambda_weight):
        return (np.tanh(
            (iteration - mconf.kl_anneal_iterations * 1.5) /
            (mconf.kl_anneal_iterations / 3))
                + 1) * lambda_weight

    def train(self, sess, data_size, padded_sequences, text_sequence_lengths, one_hot_labels, num_labels,
              word_index, encoder_embedding_matrix, decoder_embedding_matrix, validation_sequences,
              validation_sequence_lengths, validation_labels, inverse_word_index, validation_actual_word_lists,
              options):

        writer = tf.summary.FileWriter(logdir=global_config.log_directory, graph=sess.graph)

        trainable_variables = tf.trainable_variables()
        logger.debug("trainable_variables: {}".format(trainable_variables))

        self.composite_loss = 0.0
        self.composite_loss += self.reconstruction_loss
        self.composite_loss += self.style_multitask_loss * mconf.style_multitask_loss_weight
        self.composite_loss += self.content_multitask_loss * mconf.content_multitask_loss_weight
        self.composite_loss -= self.style_adversary_entropy * mconf.style_adversary_loss_weight
        self.composite_loss -= self.content_adversary_entropy * mconf.content_adversary_loss_weight
        self.composite_loss += self.style_kl_loss
        self.composite_loss += self.content_kl_loss
        tf.summary.scalar(tensor=self.composite_loss, name="composite_loss_summary")
        self.all_summaries = tf.summary.merge_all()

        # optimize adversarial classification
        style_adversary_variable_labels = ["style_adversary"]
        content_adversary_variable_labels = ["content_adversary"]
        # style
        style_adversary_training_optimizer = tf.train.RMSPropOptimizer(
            learning_rate=mconf.style_adversary_learning_rate)
        style_adversary_training_variables = [
            x for x in trainable_variables if any(
                scope in x.name for scope in style_adversary_variable_labels)]
        logger.debug("style_adversary_training_optimizer.variables: {}".format(
            style_adversary_training_variables))
        style_adversary_training_operation = style_adversary_training_optimizer.minimize(
            loss=self.style_adversary_loss,
            var_list=style_adversary_training_variables)
        # content
        content_adversary_training_optimizer = tf.train.RMSPropOptimizer(
            learning_rate=mconf.content_adversary_learning_rate)
        content_adversary_training_variables = [
            x for x in trainable_variables if any(
                scope in x.name for scope in content_adversary_variable_labels)]
        logger.debug("content_adversary_training_optimizer.variables: {}".format(
            content_adversary_training_variables))
        content_adversary_training_operation = content_adversary_training_optimizer.minimize(
            loss=self.content_adversary_loss,
            var_list=content_adversary_training_variables)

        # optimize overall latent space classification
        style_overall_variable_labels = ["style_overall"]
        style_overall_optimizer = tf.train.RMSPropOptimizer(
            learning_rate=mconf.autoencoder_learning_rate)
        style_overall_training_variables = [
            x for x in trainable_variables if any(scope in x.name for scope in style_overall_variable_labels)]
        logger.debug("style_overall_training_variables: {}".format(style_overall_training_variables))
        style_overall_training_operation = style_overall_optimizer.minimize(
            loss=self.style_overall_prediction_loss,
            var_list=style_overall_training_variables)

        # optimize reconstruction
        reconstruction_training_optimizer = tf.train.AdamOptimizer(
            learning_rate=mconf.autoencoder_learning_rate)
        reconstruction_training_variables = [
            x for x in trainable_variables if all(
                scope not in x.name for scope in
                (style_adversary_variable_labels +
                 content_adversary_variable_labels +
                 style_overall_variable_labels))]
        logger.debug("reconstruction_training_optimizer.variables: {}".format(reconstruction_training_variables))
        reconstruction_training_operation = reconstruction_training_optimizer.minimize(
            loss=self.composite_loss, var_list=reconstruction_training_variables)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        num_batches = data_size // mconf.batch_size
        if data_size % mconf.batch_size:
            num_batches += 1
        logger.debug("Training - texts shape: {}; labels shape {}"
                     .format(padded_sequences.shape, one_hot_labels.shape))

        iteration = 0
        style_kl_weight, content_kl_weight = 0, 0
        for current_epoch in range(1, options.training_epochs + 1):

            all_style_embeddings = list()
            all_content_embeddings = list()

            shuffle_indices = np.random.permutation(np.arange(data_size))

            shuffled_padded_sequences = padded_sequences[shuffle_indices]
            shuffled_one_hot_labels = one_hot_labels[shuffle_indices]
            shuffled_text_sequence_lengths = text_sequence_lengths[shuffle_indices]

            for batch_number in range(num_batches):
                (start_index, end_index) = self.get_batch_indices(
                    batch_number=batch_number, data_limit=data_size)

                if iteration < mconf.kl_anneal_iterations:
                    style_kl_weight = self.get_annealed_weight(iteration, mconf.style_kl_lambda)
                    content_kl_weight = self.get_annealed_weight(iteration, mconf.content_kl_lambda)

                fetches = \
                    [reconstruction_training_operation,
                     style_adversary_training_operation,
                     content_adversary_training_operation,
                     style_overall_training_operation,
                     self.reconstruction_loss,
                     self.style_multitask_loss,
                     self.content_multitask_loss,
                     self.style_adversary_loss,
                     self.style_adversary_entropy,
                     self.content_adversary_loss,
                     self.content_adversary_entropy,
                     self.style_kl_loss,
                     self.content_kl_loss,
                     self.composite_loss,
                     self.style_embedding,
                     self.content_embedding,
                     self.all_summaries]

                [_, _, _, _,
                 reconstruction_loss,
                 style_multitask_loss, content_multitask_loss,
                 style_adversary_crossentropy, style_adversary_entropy,
                 content_adversary_crossentropy, content_adversary_entropy,
                 style_kl_loss, content_kl_loss,
                 composite_loss,
                 style_embeddings, content_embedding,
                 all_summaries] = \
                    self.run_batch(
                        sess, start_index, end_index, fetches,
                        shuffled_padded_sequences, shuffled_one_hot_labels,
                        shuffled_text_sequence_lengths, None, False, False,
                        style_kl_weight, content_kl_weight, current_epoch)

                log_msg = "[R: {:.2f}, " \
                          "SMT: {:.2f}, CMT: {:.2f}, " \
                          "SCE: {:.2f}, SE: {:.2f}, " \
                          "CCE: {:.2f}, CE: {:.2f}, " \
                          "SKL: {:.2f}, CKL: {:.2f}] " \
                          "Epoch {}-{}: {:.4f}"
                logger.info(log_msg.format(
                    reconstruction_loss,
                    style_multitask_loss, content_multitask_loss,
                    style_adversary_crossentropy, style_adversary_entropy,
                    content_adversary_crossentropy, content_adversary_entropy,
                    style_kl_loss, content_kl_loss,
                    current_epoch, batch_number, composite_loss))

                all_style_embeddings.extend(style_embeddings)
                all_content_embeddings.extend(content_embedding)

                iteration += 1

                writer.add_summary(all_summaries, iteration)
                writer.flush()

            saver.save(sess=sess, save_path=global_config.model_save_path)

            np.save(file=global_config.all_style_embeddings_path, arr=np.asarray(all_style_embeddings))
            np.save(file=global_config.all_content_embeddings_path, arr=all_content_embeddings)
            with open(global_config.all_shuffled_labels_path, 'wb') as pickle_file:
                pickle.dump(shuffled_one_hot_labels, pickle_file)

            average_label_embeddings = data_processor.get_average_label_embeddings(
                data_size, options.dump_embeddings, current_epoch)
            with open(global_config.average_label_embeddings_path, 'wb') as pickle_file:
                pickle.dump(average_label_embeddings, pickle_file)

            if not current_epoch % global_config.validation_interval:
                self.run_validation(options, num_labels, validation_sequences, validation_sequence_lengths,
                                    validation_labels, validation_actual_word_lists, all_style_embeddings,
                                    shuffled_one_hot_labels, inverse_word_index, current_epoch, sess)

        writer.close()

    def run_validation(self, options, num_labels, validation_sequences, validation_sequence_lengths,
                       validation_labels, validation_actual_word_lists, all_style_embeddings,
                       shuffled_one_hot_labels, inverse_word_index, current_epoch, sess):

        logger.info("Running Validation {}:".format(current_epoch // global_config.validation_interval))

        glove_model = content_preservation.load_glove_model(options.validation_embeddings_file_path)

        validation_style_transfer_scores = list()
        validation_content_preservation_scores = list()
        validation_word_overlap_scores = list()
        for i in range(num_labels):

            logger.info("validating label {}".format(i))

            label_embeddings = list()
            validation_sequences_to_transfer = list()
            validation_labels_to_transfer = list()
            validation_sequence_lengths_to_transfer = list()

            for k in range(len(all_style_embeddings)):
                if shuffled_one_hot_labels[k].tolist().index(1) == i:
                    label_embeddings.append(all_style_embeddings[k])

            for k in range(len(validation_sequences)):
                if validation_labels[k].tolist().index(1) != i:
                    validation_sequences_to_transfer.append(validation_sequences[k])
                    validation_labels_to_transfer.append(validation_labels[k])
                    validation_sequence_lengths_to_transfer.append(validation_sequence_lengths[k])

            style_embedding = np.mean(np.asarray(label_embeddings), axis=0)

            validation_batches = len(validation_sequences_to_transfer) // mconf.batch_size
            if len(validation_sequences_to_transfer) % mconf.batch_size:
                validation_batches += 1

            validation_generated_sequences = list()
            validation_generated_sequence_lengths = list()
            for val_batch_number in range(validation_batches):
                (start_index, end_index) = self.get_batch_indices(
                    batch_number=val_batch_number,
                    data_limit=len(validation_sequences_to_transfer))

                conditioning_embedding = np.tile(
                    A=style_embedding, reps=(end_index - start_index, 1))

                [validation_generated_sequences_batch, validation_sequence_lengths_batch] = \
                    self.run_batch(
                        sess, start_index, end_index,
                        [self.inference_output, self.final_sequence_lengths],
                        validation_sequences_to_transfer, validation_labels_to_transfer,
                        validation_sequence_lengths_to_transfer,
                        conditioning_embedding, True, False, 0, 0, current_epoch)
                validation_generated_sequences.extend(validation_generated_sequences_batch)
                validation_generated_sequence_lengths.extend(validation_sequence_lengths_batch)

            trimmed_generated_sequences = \
                [[index for index in sequence
                  if index != global_config.predefined_word_index[global_config.eos_token]]
                 for sequence in [x[:(y - 1)] for (x, y) in zip(
                    validation_generated_sequences, validation_generated_sequence_lengths)]]

            generated_word_lists = \
                [data_processor.generate_words_from_indices(x, inverse_word_index)
                 for x in trimmed_generated_sequences]

            generated_sentences = [" ".join(x) for x in generated_word_lists]

            output_file_path = "output/{}-training/validation_sentences_{}.txt".format(
                global_config.experiment_timestamp, i)
            os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
            with open(output_file_path, 'w') as output_file:
                for sentence in generated_sentences:
                    output_file.write(sentence + "\n")

            [style_transfer_score, confusion_matrix] = style_transfer.get_style_transfer_score(
                options.classifier_saved_model_path, output_file_path, str(i), None)
            logger.debug("style_transfer_score: {}".format(style_transfer_score))
            logger.debug("confusion_matrix:\n{}".format(confusion_matrix))

            content_preservation_score = content_preservation.get_content_preservation_score(
                validation_actual_word_lists, generated_word_lists, glove_model)
            logger.debug("content_preservation_score: {}".format(content_preservation_score))

            word_overlap_score = content_preservation.get_word_overlap_score(
                validation_actual_word_lists, generated_word_lists)
            logger.debug("word_overlap_score: {}".format(word_overlap_score))

            validation_style_transfer_scores.append(style_transfer_score)
            validation_content_preservation_scores.append(content_preservation_score)
            validation_word_overlap_scores.append(word_overlap_score)

        aggregate_style_transfer = np.mean(np.asarray(validation_style_transfer_scores))
        logger.info("Aggregate Style Transfer: {}".format(aggregate_style_transfer))

        aggregate_content_preservation = np.mean(np.asarray(validation_content_preservation_scores))
        logger.info("Aggregate Content Preservation: {}".format(aggregate_content_preservation))

        aggregate_word_overlap = np.mean(np.asarray(validation_word_overlap_scores))
        logger.info("Aggregate Word Overlap: {}".format(aggregate_word_overlap))

        with open(global_config.validation_scores_path, 'a+') as validation_scores_file:
            validation_record = {
                "epoch": current_epoch,
                "style-transfer": aggregate_style_transfer,
                "content-preservation": aggregate_content_preservation,
                "word-overlap": aggregate_word_overlap
            }
            validation_scores_file.write(json.dumps(validation_record) + "\n")

    def transform_sentences(self, sess, padded_sequences, text_sequence_lengths, style_embedding,
                            num_labels, model_save_path):

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=sess, save_path=model_save_path)

        data_size = len(padded_sequences)
        generated_sequences = list()
        final_sequence_lengths = list()
        overall_label_predictions = list()
        style_label_predictions = list()
        adversarial_label_predictions = list()
        cross_entropy_scores = list()
        num_batches = data_size // mconf.batch_size
        if data_size % mconf.batch_size:
            num_batches += 1

        # these won't be needed to generate new sentences, so just use random numbers
        one_hot_labels_placeholder = np.random.randint(
            low=0, high=1, size=(data_size, num_labels)).astype(dtype=np.int32)

        end_index = None
        style_kl_weight = 0
        content_kl_weight = 0
        current_epoch = 0
        for batch_number in range(num_batches):
            (start_index, end_index) = self.get_batch_indices(
                batch_number=batch_number, data_limit=data_size)

            conditioning_embedding = np.tile(A=style_embedding, reps=(end_index - start_index, 1))

            generated_sequences_batch, final_sequence_lengths_batch, \
            overall_label_predictions_batch, style_label_predictions_batch, \
            adversarial_label_predictions_batch, cross_entropy_score = \
                self.run_batch(
                    sess, start_index, end_index,
                    [self.inference_output, self.final_sequence_lengths,
                     self.quantized_style_overall_prediction,
                     self.quantized_style_multitask_prediction,
                     self.quantized_style_adversary_prediction,
                     self.reconstruction_loss],
                    padded_sequences, one_hot_labels_placeholder, text_sequence_lengths,
                    conditioning_embedding, True, False, style_kl_weight, content_kl_weight, current_epoch)

            generated_sequences.extend(generated_sequences_batch)
            final_sequence_lengths.extend(final_sequence_lengths_batch)
            overall_label_predictions.extend(overall_label_predictions_batch)
            style_label_predictions.extend(style_label_predictions_batch)
            adversarial_label_predictions.extend(adversarial_label_predictions_batch)
            cross_entropy_scores.append(cross_entropy_score)

        return generated_sequences, final_sequence_lengths, overall_label_predictions, \
               style_label_predictions, adversarial_label_predictions, cross_entropy_scores

    def generate_novel_sentences(self, sess, style_embedding, data_size, num_labels, model_save_path):
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=sess, save_path=model_save_path)

        generated_sequences = list()
        final_sequence_lengths = list()

        num_batches = data_size // mconf.batch_size
        if data_size % mconf.batch_size:
            num_batches += 1

        end_index = None
        style_kl_weight = 0
        content_kl_weight = 0
        current_epoch = 0

        dummy_sequences = np.zeros(shape=(data_size, global_config.max_sequence_length))
        dummy_oh_labels = np.zeros(shape=(data_size, num_labels))  # oh = one hot
        dummy_ts_lengths = np.zeros(shape=(data_size))  # ts = text sequence

        for batch_number in range(num_batches):
            (start_index, end_index) = self.get_batch_indices(
                batch_number=batch_number, data_limit=data_size)

            conditioning_embedding = np.tile(A=style_embedding, reps=(end_index - start_index, 1))

            generated_sequences_batch, final_sequence_lengths_batch = \
                self.run_batch(
                    sess, start_index, end_index,
                    [self.inference_output, self.final_sequence_lengths],
                    dummy_sequences, dummy_oh_labels, dummy_ts_lengths,
                    conditioning_embedding, False, True, style_kl_weight, content_kl_weight, current_epoch)

            generated_sequences.extend(generated_sequences_batch)
            final_sequence_lengths.extend(final_sequence_lengths_batch)

        return generated_sequences, final_sequence_lengths