# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# IMPORTS
import os
import torch
import time
import matplotlib.pyplot as plt
import numpy as np
import itertools
import glob

from torch.autograd import Variable
from torch.optim import lr_scheduler
from torchvision import utils
from skimage import color
from models.losses import CombinedLoss


##
# Helper functions
##
def create_exp_directory(exp_dir_name):
    """
    Function to create a directory if it does not exist yet.
    :param str exp_dir_name: name of directory to create.
    :return:
    """
    if not os.path.exists(exp_dir_name):
        try:
            os.makedirs(exp_dir_name)
            print("Successfully Created Directory @ {}".format(exp_dir_name))
        except:
            print("Directory Creation Failed - Check Path")
    else:
        print("Directory {} Exists ".format(exp_dir_name))


def dice_confusion_matrix(batch_output, labels_batch, num_classes):
    """
    Function to compute the dice confusion matrix.
    :param batch_output:
    :param labels_batch:
    :param num_classes:
    :return:
    """
    dice_cm = torch.zeros(num_classes, num_classes)

    for i in range(num_classes):
        gt = (labels_batch == i).float()

        for j in range(num_classes):
            pred = (batch_output == j).float()
            inter = torch.sum(torch.mul(gt, pred)) + 0.0001
            union = torch.sum(gt) + torch.sum(pred) + 0.0001
            dice_cm[i, j] = 2 * torch.div(inter, union)

    avg_dice = torch.mean(torch.diagflat(dice_cm))

    return avg_dice, dice_cm


def iou_score(pred_cls, true_cls, nclass=79):
    """
    compute the intersection-over-union score
    both inputs should be categorical (as opposed to one-hot)
    """
    intersect_ = []
    union_ = []

    for i in range(1, nclass):
        intersect = ((pred_cls == i).float() + (true_cls == i).float()).eq(2).sum().item()
        union = ((pred_cls == i).float() + (true_cls == i).float()).ge(1).sum().item()
        intersect_.append(intersect)
        union_.append(union)

    return np.array(intersect_), np.array(union_)


def accuracy(pred_cls, true_cls, nclass=79):
    """
    Function to calculate accuracy (TP/(TP + FP + TN + FN)
    :param pytorch.Tensor pred_cls: network prediction (categorical)
    :param pytorch.Tensor true_cls: ground truth (categorical)
    :param int nclass: number of classes
    :return:
    """
    positive = torch.histc(true_cls.cpu().float(), bins=nclass, min=0, max=nclass, out=None)
    per_cls_counts = []
    tpos = []

    for i in range(1, nclass):
        true_positive = ((pred_cls == i).float() + (true_cls == i).float()).eq(2).sum().item()
        tpos.append(true_positive)
        per_cls_counts.append(positive[i])

    return np.array(tpos), np.array(per_cls_counts)


##
# Plotting functions
##
def plot_predictions(images_batch, labels_batch, batch_output, plt_title, file_save_name):
    """
    Function to plot predictions from validation set.
    :param images_batch:
    :param labels_batch:
    :param batch_output:
    :param plt_title:
    :param file_save_name:
    :return:
    """

    f = plt.figure(figsize=(20, 20))
    n, c, h, w = images_batch.shape
    mid_slice = c // 2
    images_batch = torch.unsqueeze(images_batch[:, mid_slice, :, :], 1)
    grid = utils.make_grid(images_batch.cpu(), nrow=4)

    plt.subplot(131)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title('Slices')

    grid = utils.make_grid(labels_batch.unsqueeze_(1).cpu(), nrow=4)[0]
    color_grid = color.label2rgb(grid.numpy(), bg_label=0)
    plt.subplot(132)
    plt.imshow(color_grid)
    plt.title('Ground Truth')

    grid = utils.make_grid(batch_output.unsqueeze_(1).cpu(), nrow=4)[0]
    color_grid = color.label2rgb(grid.numpy(), bg_label=0)
    plt.subplot(133)
    plt.imshow(color_grid)
    plt.title('Prediction')

    plt.suptitle(plt_title)
    plt.tight_layout()

    f.savefig(file_save_name, bbox_inches='tight')
    plt.close(f)
    plt.gcf().clear()


def plot_confusion_matrix(cm, classes, title='Confusion matrix', cmap=plt.cm.Blues, file_save_name="temp.pdf"):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    :param cm:
    :param classes:
    :param title:
    :param cmap:
    :param file_save_name:
    :return:
    """
    f = plt.figure(figsize=(35, 35))

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))

    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm.max() / 2.

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

    f.savefig(file_save_name, bbox_inches='tight')
    plt.close(f)
    plt.gcf().clear()


##
# Training routine
##
class Solver(object):
    """
    Class for training neural networks
    """

    # gamma is the factor for lowering the lr and step_size is when it gets lowered
    default_lr_scheduler_args = {"gamma": 0.05,
                                 "step_size": 5}

    def __init__(self, num_classes, optimizer=torch.optim.Adam, optimizer_args={}, loss_func=CombinedLoss(),
                 lr_scheduler_args={}):

        # Merge and update the default arguments - optimizer
        self.optimizer_args = optimizer_args

        lr_scheduler_args_merged = Solver.default_lr_scheduler_args.copy()
        lr_scheduler_args_merged.update(lr_scheduler_args)

        # Merge and update the default arguments - lr scheduler
        self.lr_scheduler_args = lr_scheduler_args_merged

        self.optimizer = optimizer
        self.loss_func = loss_func
        self.num_classes = num_classes
        self.classes = list(range(self.num_classes))

    def train(self, model, train_loader, train_loader_test, validation_loader, class_names, num_epochs,
              log_params, expdir, scheduler_type, torch_v11, resume=True):
        """
        Train Model with provided parameters for optimization
        Inputs:
        -- model - model to be trained
        -- train_loader - training DataLoader Object
        -- validation_loader - validation DataLoader Object
        -- num_epochs = total number of epochs
        -- log_params - parameters for logging the progress
        -- expdir --directory to save check points

        """
        create_exp_directory(expdir)  # Experimental directory
        create_exp_directory(log_params["logdir"])  # Logging Directory

        # Instantiate the optimizer class
        optimizer = self.optimizer(model.parameters(), **self.optimizer_args)

        # Instantiate the scheduler class
        if scheduler_type == "StepLR":
            scheduler = lr_scheduler.StepLR(optimizer, step_size=self.lr_scheduler_args["step_size"],
                                            gamma=self.lr_scheduler_args["gamma"])
        else:
            scheduler = None

        # Set up logger format
        a = "{}\t" * (self.num_classes - 2) + "{}"

        epoch = -1  # To allow for restoration
        print('-------> Starting to train')

        # Code for restoring model
        if resume:

            try:
                prior_model_paths = sorted(glob.glob(os.path.join(expdir, 'Epoch_*')), key=os.path.getmtime)

                if prior_model_paths:
                    current_model = prior_model_paths.pop()

                state = torch.load(current_model)

                # Restore model dictionary
                model.load_state_dict(state["model_state_dict"])
                optimizer.load_state_dict(state["optimizer_state_dict"])
                scheduler.load_state_dict(state["scheduler_state_dict"])
                epoch = state["epoch"]

                print("Successfully restored the model state. Resuming training from Epoch {}".format(epoch + 1))

            except Exception as e:
                print("No model to restore. Resuming training from Epoch 0. {}".format(e))

        log_params["logger"].info("{} parameters in total".format(sum(x.numel() for x in model.parameters())))

        while epoch < num_epochs:

            epoch = epoch + 1
            epoch_start = time.time()

            # Update learning rate based on epoch number (only for pytorch version <1.2)
            if torch_v11 and scheduler is not None:
                scheduler.step()

            loss_batch = np.zeros(1)

            for batch_idx, sample_batch in enumerate(train_loader):

                # Assign data
                images_batch, labels_batch, weights_batch = sample_batch['image'], sample_batch['label'], \
                                                            sample_batch['weight']

                # Map to variables
                images_batch = Variable(images_batch)
                labels_batch = Variable(labels_batch)
                weights_batch = Variable(weights_batch)

                if torch.cuda.is_available():
                    images_batch, labels_batch, weights_batch = images_batch.cuda(), labels_batch.cuda(), \
                                                                weights_batch.type(torch.FloatTensor).cuda()

                model.train()  # Set to training mode!
                optimizer.zero_grad()

                predictions = model(images_batch)

                loss_total, loss_dice, loss_ce = self.loss_func(predictions, labels_batch, weights_batch)

                loss_total.backward()

                optimizer.step()

                loss_batch += loss_total.item()

                if batch_idx % (len(train_loader) // 2) == 0 or batch_idx == len(train_loader) - 1:
                    log_params["logger"].info("Train Epoch: {} [{}/{}] ({:.0f}%)] "
                                              "with loss: {}".format(epoch, batch_idx,
                                                                     len(train_loader),
                                                                     100. * batch_idx / len(train_loader),
                                                                     loss_batch / (batch_idx + 1)))

                del images_batch, labels_batch, weights_batch, predictions, loss_total, loss_dice, loss_ce

            # Update learning rate at the end based on epoch number (only for pytorch version > 1.1)
            if not torch_v11 and scheduler is not None:
                scheduler.step()

            epoch_finish = time.time() - epoch_start

            log_params["logger"].info("Train Epoch {} finished in {:.04f} seconds.".format(epoch, epoch_finish))
            # End of Training, time to accumulate results

            # Testing Loop on Training Data
            # Set evaluation mode on the model
            model.eval()

            val_loss_total = 0
            val_loss_dice = 0
            val_loss_ce = 0

            ints_ = np.zeros(self.num_classes - 1)
            unis_ = np.zeros(self.num_classes - 1)
            per_cls_counts = np.zeros(self.num_classes - 1)
            accs = np.zeros(self.num_classes - 1)  # -1 to exclude background (still included in val loss)

            with torch.no_grad():

                if train_loader_test is not None:
                    cnf_matrix_train = torch.zeros(self.num_classes, self.num_classes)

                    val_start = time.time()

                    for batch_idx, sample_batch in enumerate(train_loader_test):

                        images_batch, labels_batch, weights_batch = sample_batch['image'], sample_batch['label'], \
                                                                    sample_batch['weight']

                        # Map to variables
                        images_batch = Variable(images_batch)
                        labels_batch = Variable(labels_batch)
                        weights_batch = Variable(weights_batch)

                        if torch.cuda.is_available():
                            images_batch, labels_batch, weights_batch = images_batch.cuda(), labels_batch.cuda(), \
                                                                        weights_batch.type(torch.FloatTensor).cuda()

                        predictions = model(images_batch)

                        _, batch_output = torch.max(predictions, dim=1)

                        _, cm_batch = dice_confusion_matrix(batch_output, labels_batch, self.num_classes)

                        cnf_matrix_train += cm_batch.cpu()

                        # Plot sample predictions
                        if batch_idx == 0:
                            plt_title = 'Train Results Epoch ' + str(epoch)

                            file_save_name = os.path.join(log_params["logdir"],
                                                          'Epoch_' + str(epoch) + '_Train_Predictions.pdf')

                            plot_predictions(images_batch, labels_batch, batch_output, plt_title, file_save_name)

                        if batch_idx % 5 == 0:
                            print("Test on Train Data --Epoch: {}. Iter: {} / {}.".format(epoch, batch_idx,
                                                                                          len(train_loader_test)))

                        del images_batch, labels_batch, weights_batch, predictions, batch_output, cm_batch

                    cnf_matrix_train = cnf_matrix_train / (batch_idx + 1)
                    train_end = time.time() - val_start
                    print("Completed Testing on Training Dataset in {:0.4f} s".format(train_end))

                    # print(cnf_matrix_train)
                    save_name = os.path.join(log_params["logdir"], 'Epoch_' + str(epoch) + '_Train_Dice_CM.pdf')
                    plot_confusion_matrix(cnf_matrix_train.cpu().numpy(), self.classes, file_save_name=save_name)

                if validation_loader is not None:

                    val_start = time.time()
                    cnf_matrix_validation = torch.zeros(self.num_classes, self.num_classes)

                    for batch_idx, sample_batch in enumerate(validation_loader):

                        images_batch, labels_batch, weights_batch = sample_batch['image'], sample_batch['label'], \
                                                                    sample_batch['weight']

                        # Map to variables (no longer necessary after pytorch 0.40)
                        images_batch = Variable(images_batch)
                        labels_batch = Variable(labels_batch)
                        weights_batch = Variable(weights_batch)

                        if torch.cuda.is_available():
                            images_batch, labels_batch, weights_batch = images_batch.cuda(), labels_batch.cuda(), \
                                                                        weights_batch.type(torch.FloatTensor).cuda()

                        # Get logits, sum up batch loss and get final predictions (argmax)
                        predictions = model(images_batch)
                        loss_total, loss_dice, loss_ce = self.loss_func(predictions, labels_batch, weights_batch)
                        val_loss_total += loss_total.item()
                        val_loss_dice += loss_dice.item()
                        val_loss_ce += loss_ce.item()

                        _, batch_output = torch.max(predictions, dim=1)

                        # Calculate iou_scores, accuracy and dice confusion matrix + sum over previous batches
                        int_, uni_ = iou_score(batch_output, labels_batch, self.num_classes)
                        ints_ += int_
                        unis_ += uni_

                        tpos, pcc = accuracy(batch_output, labels_batch, self.num_classes)
                        accs += tpos
                        per_cls_counts += pcc

                        _, cm_batch = dice_confusion_matrix(batch_output, labels_batch, self.num_classes)
                        cnf_matrix_validation += cm_batch.cpu()

                        # Plot sample predictions
                        if batch_idx == 0:
                            plt_title = 'Validation Results Epoch ' + str(epoch)

                            file_save_name = os.path.join(log_params["logdir"],
                                                          'Epoch_' + str(epoch) + '_Validations_Predictions.pdf')

                            plot_predictions(images_batch, labels_batch, batch_output, plt_title, file_save_name)

                        del images_batch, labels_batch, weights_batch, predictions, batch_output, \
                            int_, uni_, tpos, pcc, loss_total, loss_dice, loss_ce  # cm_batch,

                    # Get final measures and log them
                    ious = ints_ / unis_
                    accs /= per_cls_counts
                    val_loss_total /= (batch_idx + 1)
                    val_loss_dice /= (batch_idx + 1)
                    val_loss_ce /= (batch_idx + 1)
                    cnf_matrix_validation = cnf_matrix_validation / (batch_idx + 1)
                    val_end = time.time() - val_start

                    print("Completed Validation Dataset in {:0.4f} s".format(val_end))

                    save_name = os.path.join(log_params["logdir"], 'Epoch_' + str(epoch) + '_Validation_Dice_CM.pdf')
                    plot_confusion_matrix(cnf_matrix_validation.cpu().numpy(), self.classes, file_save_name=save_name)

                    # Log metrics
                    log_params["logger"].info("[Epoch {} stats]: MIoU: {:.4f}; "
                                              "Mean Accuracy: {:.4f}; "
                                              "Avg loss total: {:.4f}; "
                                              "Avg loss dice: {:.4f}; "
                                              "Avg loss ce: {:.4f}".format(epoch, np.mean(ious), np.mean(accs),
                                                                           val_loss_total, val_loss_dice, val_loss_ce))

                    log_params["logger"].info(a.format(*class_names))
                    log_params["logger"].info(a.format(*ious))

            # Saving Models

            if epoch % log_params["log_iter"] == 0:
                save_name = os.path.join(expdir, 'Epoch_' + str(epoch).zfill(2) + '_training_state.pkl')
                checkpoint = {"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(),
                              "epoch": epoch}
                if scheduler is not None:
                    checkpoint["scheduler_state_dict"] = scheduler.state_dict()

                torch.save(checkpoint, save_name)

            model.train()