# Copyright 2019 Uizard Technologies
# Significantly inspired by:
# MixMatch - A Holistic Approach to Semi-Supervised Learning, Berthelot et al. (2019) - Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RealMix: Towards Realistic Deep Semi-Supervised Learning Algorithms"""


from absl import app, flags
from easydict import EasyDict
from libml import layers, utils, models
from libml.data_pair import DATASETS, stack_augment
from libml.data import DataSet, augment_cifar10, augment_color, augment_cutout, augment_stl10, augment_svhn, memoize, default_parse, dataset
from libml.layers import MixMode
from tqdm import trange

import functools
import itertools
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import glob
import numpy as np

FLAGS = flags.FLAGS


class RealMix(models.MultiModel):

    def augment(self, x, l, beta, **kwargs):
        assert 0, 'Do not call.'

    def guess_label(self, y, classifier, T, **kwargs):
        del kwargs
        logits_y = [classifier(yi, training=True) for yi in y]
        logits_y = tf.concat(logits_y, 0)
        # Compute predicted probability distribution py.
        p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])
        p_model_y = tf.reduce_mean(p_model_y, axis=0)
        # Compute the target distribution.
        p_target = tf.pow(p_model_y, 1. / T)
        p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
        return EasyDict(p_target=p_target, p_model=p_model_y)

    def get_tsa_threshold(self, schedule, global_step, num_train_steps, start, end):
        # Originally written in google-research/uda/image/main.py

        # Returns the current TSA (Training Signal Annealing) thresholds given the
        # schedule, current training step, total training steps, start threshold,
        # and end threshold.

        # Typical values are as follows:
        # schedule = "linear_schedule", "exp_schedule", "log_schedule"
        # global_step = self.step
        # num_train_step = FLAGS.train_kimg << 10, or FLAGS.train_kimg * FLAGS.epochs
        # start = 1. / FLAGS.nclass
        # end = 1

        step_ratio = tf.to_float(global_step) / tf.to_float(num_train_steps)
        if schedule == "linear_schedule":
            coeff = step_ratio
        elif schedule == "exp_schedule":
            scale = 5
            # [exp(-5), exp(0)] = [1e-2, 1]
            coeff = tf.exp((step_ratio - 1) * scale)
        elif schedule == "log_schedule":
            scale = 5
            # [1 - exp(0), 1 - exp(-5)] = [0, 0.99]
            coeff = 1 - tf.exp((-step_ratio) * scale)
        return coeff * (end - start) + start


    def anneal_sup_loss(self, sup_logits, sup_labels, sup_loss, global_step):
        # Adapted from google-research/uda/image/main.py

        # This is a version of TSA (Training Signal Annealing) that has been
        # adapted for use with RealMix. Specifically, it can deal with ground
        # truth values between 0 and 1 as created by MixUp.

        # The start value for TSA.
        tsa_start = 1. / FLAGS.nclass

        # Probability thresh above which loss for a sup image is not computed.
        eff_train_prob_threshold = self.get_tsa_threshold(
            FLAGS.tsa, global_step, FLAGS.train_kimg * FLAGS.epochs,
            tsa_start, end=1)

        # Calculate probabilities of each class for each image.
        sup_probs = tf.nn.softmax(sup_logits, axis=-1)

        # Mask the predicted probabilities to only ground truth classes.
        ground_truth_class_threshold = tf.greater(sup_labels, tf.constant(0.0, tf.float32))
        ground_truth_class_mask = tf.cast(ground_truth_class_threshold, tf.float32)
        correct_label_probs = sup_probs * ground_truth_class_mask

        # Calculate TSA threshold for each ground truth probability.
        # This is necessary since MixUp generates values between 0 and 1.
        eff_train_prob_threshold = sup_logits * eff_train_prob_threshold

        ones = tf.ones(tf.shape(correct_label_probs), tf.float32)
        pos_tensor = tf.multiply(ones, FLAGS.nclass + 1)
        neg_tensor = tf.multiply(ones, -1)

        # Loss for an image is kept if all its ground truth thresholds
        # are not met. A temporary mask is created to find which images
        # contain unmet thresholds.
        imgs_to_train_mask = tf.where(tf.less(correct_label_probs, eff_train_prob_threshold), \
                                pos_tensor, neg_tensor)
        imgs_to_train_mask = tf.reduce_mean(imgs_to_train_mask, axis=1)

        ones = tf.ones(tf.shape(imgs_to_train_mask), tf.float32)
        zeros = tf.zeros(tf.shape(imgs_to_train_mask), tf.float32)

        loss_mask = tf.where(tf.greater(imgs_to_train_mask, zeros), ones, zeros)
        loss_mask = tf.stop_gradient(loss_mask)

        # Mask the supervised loss and return the average.
        sup_loss = sup_loss * loss_mask
        avg_sup_loss = (tf.reduce_sum(sup_loss) /
                        tf.maximum(tf.reduce_sum(loss_mask), 1))

        return avg_sup_loss

    def confidence_mask_unsup(self, logits_y, labels_y, loss_l2u):
        # Adapted from google-research/uda/image/main.py

        # This function masks the unsupervised predictions that are below
        # a set confidence threshold. # Note the following will only work
        # using MSE loss and not KL-divergence.

        # Calculate largest predicted probability for each image.
        unsup_prob = tf.nn.softmax(logits_y, axis=-1)
        largest_prob = tf.reduce_max(unsup_prob, axis=-1)

        # Mask the loss for images that don't contain a predicted
        # probability above the threshold.
        loss_mask = tf.cast(tf.greater(largest_prob, FLAGS.percent_mask), tf.float32)
        tf.summary.scalar('losses/high_prob_ratio', tf.reduce_mean(loss_mask))
        loss_mask = tf.stop_gradient(loss_mask)
        loss_l2u = loss_l2u * tf.expand_dims(loss_mask, axis=-1)

        # Return the average unsupervised loss.
        avg_unsup_loss = (tf.reduce_sum(loss_l2u) /
                        tf.maximum(tf.reduce_sum(loss_mask) * FLAGS.nclass, 1))
        return avg_unsup_loss

    def percent_confidence_mask_unsup(self, logits_y, labels_y, loss_l2u):
        # Adapted from google-research/uda/image/main.py

        # This function masks the unsupervised predictions that are below
        # a set confidence threshold. # Note the following will only work
        # using MSE loss and not KL-divergence.

        # Calculate largest predicted probability for each image.
        unsup_prob = tf.nn.softmax(logits_y, axis=-1)
        largest_prob = tf.reduce_max(unsup_prob, axis=-1)

        # Get the indices of the bottom x% of probabilities and mask those out.
        # In other words, get the probability of the image with the x%*#numofsamples
        # lowest probability and use that as the mask.

        # Calculate the current confidence_mask value using the specified schedule:
        sorted_probs = tf.sort(largest_prob, axis=-1, direction='ASCENDING')
        sort_index = tf.math.multiply(tf.to_float(tf.shape(sorted_probs)[0]), FLAGS.percent_mask)
        curr_confidence_mask = tf.slice(sorted_probs, [tf.to_int64(sort_index)], [1])

        # Mask the loss for images that don't contain a predicted
        # probability above the threshold.
        loss_mask = tf.cast(tf.greater(largest_prob, curr_confidence_mask), tf.float32)
        tf.summary.scalar('losses/high_prob_ratio', tf.reduce_mean(loss_mask)) # The ratio of unl images above the thresh
        tf.summary.scalar('losses/percent_confidence_mask', tf.reshape(curr_confidence_mask,[]))
        loss_mask = tf.stop_gradient(loss_mask)
        loss_l2u = loss_l2u * tf.expand_dims(loss_mask, axis=-1)

        # Return the average unsupervised loss.
        avg_unsup_loss = (tf.reduce_sum(loss_l2u) /
                        tf.maximum(tf.reduce_sum(loss_mask) * FLAGS.nclass, 1))
        return avg_unsup_loss

    def model(self, batch, lr, wd, ema, beta, w_match, warmup_kimg=1024, nu=2, mixmode='xxy.yxy', **kwargs):
        hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]

        # Create placeholders for the labeled images, unlabeled images,
        # and the ground truth supervised labels respectively.
        x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
        y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
        l_in = tf.placeholder(tf.int32, [None], 'labels')
        wd *= lr
        w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
        augment = MixMode(mixmode)
        classifier = functools.partial(self.classifier, **kwargs)

        y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
        guess = self.guess_label(tf.split(y, nu), classifier, T=0.5, **kwargs)
        ly = tf.stop_gradient(guess.p_target)
        lx = tf.one_hot(l_in, self.nclass)

        # Create MixUp examples.
        xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
        x, y = xy[0], xy[1:]
        labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
        del xy, labels_xy

        # Create batches that represent both labeled and unlabeled batches.
        # For more, see google-research/mixmatch/issues/5.
        batches = layers.interleave([x] + y, batch)
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logits = [classifier(batches[0], training=True)]
        post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
        for batchi in batches[1:]:
            logits.append(classifier(batchi, training=True))
        logits = layers.interleave(logits, batch)
        logits_x = logits[0]
        logits_y = tf.concat(logits[1:], 0)

        # Calculate supervised and unsupervised losses.
        loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
        if FLAGS.tsa != "none":
            print("Using training signal annealing...")
            loss_xe = self.anneal_sup_loss(logits_x, labels_x, loss_xe, self.step)
        else:
            loss_xe = tf.reduce_mean(loss_xe)

        loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))


        if FLAGS.percent_mask > 0:
            print("Using percent-based confidence masking...")
            loss_l2u = self.percent_confidence_mask_unsup(logits_y, labels_y, loss_l2u)
        else:
            loss_l2u = tf.reduce_mean(loss_l2u)

        # Calculate largest predicted probability for each image.
        unsup_prob = tf.nn.softmax(logits_y, axis=-1)
        tf.summary.scalar('losses/min_unsup_prob', tf.reduce_min(tf.reduce_max(unsup_prob, axis=-1)))
        tf.summary.scalar('losses/mean_unsup_prob', tf.reduce_mean(tf.reduce_max(unsup_prob, axis=-1)))
        tf.summary.scalar('losses/max_unsup_prob', tf.reduce_max(tf.reduce_max(unsup_prob, axis=-1)))

        # Print losses to tensorboard.
        tf.summary.scalar('losses/xe', loss_xe)
        tf.summary.scalar('losses/l2u', loss_l2u)
        tf.summary.scalar('losses/overall', loss_xe + w_match * loss_l2u)

        # Applying EMA weights to model. Conceptualized by Tarvainen & Valpola, 2017
        # See https://arxiv.org/abs/1703.01780 for more.
        ema = tf.train.ExponentialMovingAverage(decay=ema)
        ema_op = ema.apply(utils.model_vars())
        ema_getter = functools.partial(utils.getter_ema, ema)
        post_ops.append(ema_op)
        post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])

        train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(*post_ops)

        # Tuning op: only retrain batch norm.
        skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        classifier(batches[0], training=True)
        train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                              if v not in skip_ops])

        return EasyDict(
            x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
            classify_raw=tf.nn.softmax(classifier(x_in, training=False)),  # No EMA, for debugging.
            classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)),
            eval_loss_op=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=classifier(x_in, getter=ema_getter, training=False), 
                labels=tf.one_hot(l_in, self.nclass))))


def get_dataset():
    assert FLAGS.dataset in DATASETS.keys() or FLAGS.custom_dataset, "Please enter a valid dataset name or use the --custom_dataset flag."

    # CIFAR10, CIFAR100, STL10, and SVHN are the pre-configured datasets
    # with each dataset's default augmentation.
    if FLAGS.dataset in DATASETS.keys():
        dataset = DATASETS[FLAGS.dataset]()

    # If the dataset has not been pre-configured, create it.
    else:
        label_size = [int(size) for size in FLAGS.label_size]
        valid_size = [int(size) for size in FLAGS.valid_size]

        augment_dict = {"cifar10": augment_cifar10, "cutout": augment_cutout, "svhn": augment_svhn, "stl10": augment_stl10, "color": augment_color}
        augmentation = augment_dict[FLAGS.augment]

        DATASETS.update([DataSet.creator(FLAGS.dataset.split(".")[0], seed, label, valid, [augmentation, stack_augment(augmentation)], \
            nclass=FLAGS.nclass, height=FLAGS.img_size, width=FLAGS.img_size, do_memoize=FLAGS.memoize)
                         for seed, label, valid in
                         itertools.product(range(2), label_size, valid_size)])
        dataset = DATASETS[FLAGS.dataset]()

    return dataset


def main(argv):
    del argv

    # Num of augmentations to perform on each image and measure consistency loss.
    # Performance does not significantly increase with more augmentations.
    assert FLAGS.nu == 2

    dataset = get_dataset()

    log_width = utils.ilog2(dataset.width)
    model = RealMix(
        os.path.join(FLAGS.train_dir, dataset.name),
        dataset,
        lr=FLAGS.lr,
        wd=FLAGS.wd,
        arch=FLAGS.arch,
        batch=FLAGS.batch,
        nclass=dataset.nclass,
        ema=FLAGS.ema,
        beta=FLAGS.beta,
        w_match=FLAGS.w_match,
        scales=FLAGS.scales or (log_width - 2),
        filters=FLAGS.filters,
        repeat=FLAGS.repeat,
        tsa=FLAGS.tsa,
        ood_mask=FLAGS.percent_mask,
        augmentation=FLAGS.augment)

    # if FLAGS.perform_inference:
    #     print("Performing inference...")
    #     assert FLAGS.inference_dir and FLAGS.inference_ckpt
    #     inference_dir = FLAGS.inference_dir
    #     inference_ckpt = FLAGS.inference_ckpt

    #     # images = model.session.run(memoize(default_parse(dataset([inference_dir]))).prefetch(10))

    #     if inference_dir[-1] != "/":
    #         inference_dir += "/"
    #     inference_img_paths = [path for path in glob.glob(inference_dir + "*.png")]
    #     images = np.asarray([plt.imread(img_path) for img_path in inference_img_paths])
    #     images = images * (2.0 / 255) - 1.0
    #     model.eval_mode(ckpt=inference_ckpt)
    #     # batch = FLAGS.batch
    #     feed_extra = None
    #     logits = [model.session.run(model.ops.classify_op, feed_dict={
    #         model.ops.x: images[0:10], **(feed_extra or {})})]

    #     print(np.asarray(logits).shape)
    #     print(logits)
    #     for i in range(10):
    #         print(np.amax(logits, axis=-1)[:, i], inference_img_paths[i])

    print("Preparing to train the %s dataset with %d classes, img_size of %d, %s augmentation, %s tsa schedule, %f weight decay, and learning rate of %f using RealMix" \
            % (FLAGS.dataset, FLAGS.nclass, FLAGS.img_size, FLAGS.augment, FLAGS.tsa, FLAGS.wd, FLAGS.lr))
    model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)


if __name__ == '__main__':
    utils.setup_tf()
    flags.DEFINE_float('wd', 0.02, 'Weight decay.')
    flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
    flags.DEFINE_float('beta', 0.75, 'Mixup beta distribution.')
    flags.DEFINE_float('w_match', 75, 'Weight for distribution matching loss.')
    flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
    flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
    flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
    flags.DEFINE_bool('custom_dataset', True, 'True if using a custom dataset.')
    flags.DEFINE_integer('nclass', 10, 'Number of classes present in custom dataset.')
    flags.DEFINE_integer('img_size', 32, 'Size of Images in custom dataset')
    flags.DEFINE_spaceseplist('label_size', ['250', '500', '1000', '2000', '4000'], 'List of different labeled data sizes.')
    flags.DEFINE_spaceseplist('valid_size', ['1', '500'], 'List of different validation sizes.')
    flags.DEFINE_enum('tsa', "none", enum_values=["none", "linear_schedule", "log_schedule", "exp_schedule"], help="anneal schedule of training signal annealing. ""tsa='' means not using TSA. See the paper for other schedules.")
    flags.DEFINE_float('percent_mask', -1, 'Confidence value above which the loss for an unsupervised example is masked.')
    flags.DEFINE_enum('augment', 'cifar10', enum_values=["cifar10", "color", "cutout", "svhn", "stl10"], help='Type of augmentation to use, as defined in libml.data.py')
    flags.DEFINE_bool('perform_inference', False, 'True if performing inference on a set of images.')
    flags.DEFINE_string('inference_dir', '', 'Directory of images to perform inference on.')
    flags.DEFINE_string('inference_ckpt', '', 'Checkpoint to perform inference with.')
    flags.DEFINE_bool('memoize', True, 'True if the dataset can be modified in memory.')
    FLAGS.set_default('dataset', 'cifar10.3@250-5000')
    FLAGS.set_default('batch', 64)
    FLAGS.set_default('lr', 0.002)
    FLAGS.set_default('train_kimg', 1 << 16)
    app.run(main)