"""
tensorflow/keras utilities for the neuron project

If you use this code, please cite 
Dalca AV, Guttag J, Sabuncu MR
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 
CVPR 2018

Contact: adalca [at] csail [dot] mit [dot] edu
License: GPLv3
"""

import sys

# third party
import numpy as np
import keras.backend as K
from keras import losses
import tensorflow as tf

# local
from . import utils

class CategoricalCrossentropy(object):
    """
    Categorical crossentropy with optional categorical weights and spatial prior

    Adapted from weighted categorical crossentropy via wassname:
    https://gist.github.com/wassname/ce364fddfc8a025bfab4348cf5de852d

    Variables:
        weights: numpy array of shape (C,) where C is the number of classes

    Usage:
        loss = CategoricalCrossentropy().loss # or
        loss = CategoricalCrossentropy(weights=weights).loss # or
        loss = CategoricalCrossentropy(..., prior=prior).loss
        model.compile(loss=loss, optimizer='adam')
    """

    def __init__(self, weights=None, use_float16=False, vox_weights=None, crop_indices=None):
        """
        Parameters:
            vox_weights is either a numpy array the same size as y_true,
                or a string: 'y_true' or 'expy_true'
            crop_indices: indices to crop each element of the batch
                if each element is N-D (so y_true is N+1 dimensional)
                then crop_indices is a Tensor of crop ranges (indices)
                of size <= N-D. If it's < N-D, then it acts as a slice
                for the last few dimensions.
                See Also: tf.gather_nd
        """

        self.weights = weights if (weights is not None) else None
        self.use_float16 = use_float16
        self.vox_weights = vox_weights
        self.crop_indices = crop_indices

        if self.crop_indices is not None and vox_weights is not None:
            self.vox_weights = utils.batch_gather(self.vox_weights, self.crop_indices)

    def loss(self, y_true, y_pred):
        """ categorical crossentropy loss """

        if self.crop_indices is not None:
            y_true = utils.batch_gather(y_true, self.crop_indices)
            y_pred = utils.batch_gather(y_pred, self.crop_indices)

        if self.use_float16:
            y_true = K.cast(y_true, 'float16')
            y_pred = K.cast(y_pred, 'float16')

        # scale and clip probabilities
        # this should not be necessary for softmax output.
        y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
        y_pred = K.clip(y_pred, K.epsilon(), 1)

        # compute log probability
        log_post = K.log(y_pred)  # likelihood

        # loss
        loss = - y_true * log_post

        # weighted loss
        if self.weights is not None:
            loss *= self.weights

        if self.vox_weights is not None:
            loss *= self.vox_weights

        # take the total loss
        # loss = K.batch_flatten(loss)
        mloss = K.mean(K.sum(K.cast(loss, 'float32'), -1))
        tf.verify_tensor_all_finite(mloss, 'Loss not finite')
        return mloss


class Dice(object):
    """
    Dice of two Tensors.

    Tensors should either be:
    - probabilitic for each label
        i.e. [batch_size, *vol_size, nb_labels], where vol_size is the size of the volume (n-dims)
        e.g. for a 2D vol, y has 4 dimensions, where each entry is a prob for that voxel
    - max_label
        i.e. [batch_size, *vol_size], where vol_size is the size of the volume (n-dims).
        e.g. for a 2D vol, y has 3 dimensions, where each entry is the max label of that voxel

    Variables:
        nb_labels: optional numpy array of shape (L,) where L is the number of labels
            if not provided, all non-background (0) labels are computed and averaged
        weights: optional numpy array of shape (L,) giving relative weights of each label
        input_type is 'prob', or 'max_label'
        dice_type is hard or soft

    Usage:
        diceloss = metrics.dice(weights=[1, 2, 3])
        model.compile(diceloss, ...)

    Test:
        import keras.utils as nd_utils
        reload(nrn_metrics)
        weights = [0.1, 0.2, 0.3, 0.4, 0.5]
        nb_labels = len(weights)
        vol_size = [10, 20]
        batch_size = 7

        dice_loss = metrics.Dice(nb_labels=nb_labels).loss
        dice = metrics.Dice(nb_labels=nb_labels).dice
        dice_wloss = metrics.Dice(nb_labels=nb_labels, weights=weights).loss

        # vectors
        lab_size = [batch_size, *vol_size]
        r = nd_utils.to_categorical(np.random.randint(0, nb_labels, lab_size), nb_labels)
        vec_1 = np.reshape(r, [*lab_size, nb_labels])
        r = nd_utils.to_categorical(np.random.randint(0, nb_labels, lab_size), nb_labels)
        vec_2 = np.reshape(r, [*lab_size, nb_labels])

        # get some standard vectors
        tf_vec_1 = tf.constant(vec_1, dtype=tf.float32)
        tf_vec_2 = tf.constant(vec_2, dtype=tf.float32)

        # compute some metrics
        res = [f(tf_vec_1, tf_vec_2) for f in [dice, dice_loss, dice_wloss]]
        res_same = [f(tf_vec_1, tf_vec_1) for f in [dice, dice_loss, dice_wloss]]

        # tf run
        init_op = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init_op)
            sess.run(res)
            sess.run(res_same)
            print(res[2].eval())
            print(res_same[2].eval())
    """

    def __init__(self, nb_labels,
                 weights=None,
                 input_type='prob',
                 dice_type='soft',
                 approx_hard_max=True,
                 vox_weights=None,
                 crop_indices=None,
                 area_reg=0.1):  # regularization for bottom of Dice coeff
        """
        input_type is 'prob', or 'max_label'
        dice_type is hard or soft
        approx_hard_max - see note below

        Note: for hard dice, we grab the most likely label and then compute a
        one-hot encoding for each voxel with respect to possible labels. To grab the most
        likely labels, argmax() can be used, but only when Dice is used as a metric
        For a Dice *loss*, argmax is not differentiable, and so we can't use it
        Instead, we approximate the prob->one_hot translation when approx_hard_max is True.
        """

        self.nb_labels = nb_labels
        self.weights = None if weights is None else K.variable(weights)
        self.vox_weights = None if vox_weights is None else K.variable(vox_weights)
        self.input_type = input_type
        self.dice_type = dice_type
        self.approx_hard_max = approx_hard_max
        self.area_reg = area_reg
        self.crop_indices = crop_indices

        if self.crop_indices is not None and vox_weights is not None:
            self.vox_weights = utils.batch_gather(self.vox_weights, self.crop_indices)

    def dice(self, y_true, y_pred):
        """
        compute dice for given Tensors

        """
        if self.crop_indices is not None:
            y_true = utils.batch_gather(y_true, self.crop_indices)
            y_pred = utils.batch_gather(y_pred, self.crop_indices)

        if self.input_type == 'prob':
            # We assume that y_true is probabilistic, but just in case:
            y_true /= K.sum(y_true, axis=-1, keepdims=True)
            y_true = K.clip(y_true, K.epsilon(), 1)

            # make sure pred is a probability
            y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
            y_pred = K.clip(y_pred, K.epsilon(), 1)

        # Prepare the volumes to operate on
        # If we're doing 'hard' Dice, then we will prepare one-hot-based matrices of size
        # [batch_size, nb_voxels, nb_labels], where for each voxel in each batch entry,
        # the entries are either 0 or 1
        if self.dice_type == 'hard':

            # if given predicted probability, transform to "hard max""
            if self.input_type == 'prob':
                if self.approx_hard_max:
                    y_pred_op = _hard_max(y_pred, axis=-1)
                    y_true_op = _hard_max(y_true, axis=-1)
                else:
                    y_pred_op = _label_to_one_hot(K.argmax(y_pred, axis=-1), self.nb_labels)
                    y_true_op = _label_to_one_hot(K.argmax(y_true, axis=-1), self.nb_labels)

            # if given predicted label, transform to one hot notation
            else:
                assert self.input_type == 'max_label'
                y_pred_op = _label_to_one_hot(y_pred, self.nb_labels)
                y_true_op = _label_to_one_hot(y_true, self.nb_labels)

        # If we're doing soft Dice, require prob output, and the data already is as we need it
        # [batch_size, nb_voxels, nb_labels]
        else:
            assert self.input_type == 'prob', "cannot do soft dice with max_label input"
            y_pred_op = y_pred
            y_true_op = y_true

        # compute dice for each entry in batch.
        # dice will now be [batch_size, nb_labels]
        sum_dim = 1
        top = 2 * K.sum(y_true_op * y_pred_op, sum_dim)
        bottom = K.sum(K.square(y_true_op), sum_dim) + K.sum(K.square(y_pred_op), sum_dim)
        # make sure we have no 0s on the bottom. K.epsilon()
        bottom = K.maximum(bottom, self.area_reg)
        return top / bottom

    def mean_dice(self, y_true, y_pred):
        """ weighted mean dice across all patches and labels """

        # compute dice, which will now be [batch_size, nb_labels]
        dice_metric = self.dice(y_true, y_pred)

        # weigh the entries in the dice matrix:
        if self.weights is not None:
            dice_metric *= self.weights
        if self.vox_weights is not None:
            dice_metric *= self.vox_weights

        # return one minus mean dice as loss
        mean_dice_metric = K.mean(dice_metric)
        tf.verify_tensor_all_finite(mean_dice_metric, 'metric not finite')
        return mean_dice_metric


    def loss(self, y_true, y_pred):
        """ the loss. Assumes y_pred is prob (in [0,1] and sum_row = 1) """

        # compute dice, which will now be [batch_size, nb_labels]
        dice_metric = self.dice(y_true, y_pred)

        # loss
        dice_loss = 1 - dice_metric

        # weigh the entries in the dice matrix:
        if self.weights is not None:
            dice_loss *= self.weights

        # return one minus mean dice as loss
        mean_dice_loss = K.mean(dice_loss)
        tf.verify_tensor_all_finite(mean_dice_loss, 'Loss not finite')
        return mean_dice_loss


class MeanSquaredError():
    """
    MSE with several weighting options
    """


    def __init__(self, weights=None, vox_weights=None, crop_indices=None):
        """
        Parameters:
            vox_weights is either a numpy array the same size as y_true,
                or a string: 'y_true' or 'expy_true'
            crop_indices: indices to crop each element of the batch
                if each element is N-D (so y_true is N+1 dimensional)
                then crop_indices is a Tensor of crop ranges (indices)
                of size <= N-D. If it's < N-D, then it acts as a slice
                for the last few dimensions.
                See Also: tf.gather_nd
        """
        self.weights = weights
        self.vox_weights = vox_weights
        self.crop_indices = crop_indices

        if self.crop_indices is not None and vox_weights is not None:
            self.vox_weights = utils.batch_gather(self.vox_weights, self.crop_indices)
        
    def loss(self, y_true, y_pred):

        if self.crop_indices is not None:
            y_true = utils.batch_gather(y_true, self.crop_indices)
            y_pred = utils.batch_gather(y_pred, self.crop_indices)

        ksq = K.square(y_pred - y_true)

        if self.vox_weights is not None:
            if self.vox_weights == 'y_true':
                ksq *= y_true
            elif self.vox_weights == 'expy_true':
                ksq *= tf.exp(y_true)
            else:
                ksq *= self.vox_weights

        if self.weights is not None:
            ksq *= self.weights

        return K.mean(ksq)


class Mix():
    """ a mix of several losses """

    def __init__(self, losses, loss_wts=None):
        self.losses = losses
        self.loss_wts = loss_wts
        if loss_wts is None:
            self.loss_wts = np.ones(len(loss_wts))

    def loss(self, y_true, y_pred):
        total_loss = K.variable(0)
        for idx, loss in enumerate(self.losses):
            total_loss += self.loss_wts[idx] * loss(y_true, y_pred)
        return total_loss


class WGAN_GP(object):
    """
    based on https://github.com/rarilurelo/keras_improved_wgan/blob/master/wgan_gp.py
    """

    def __init__(self, disc, batch_size=1, lambda_gp=10):
        self.disc = disc
        self.lambda_gp = lambda_gp
        self.batch_size = batch_size

    def loss(self, y_true, y_pred):

        # get the value for the true and fake images
        disc_true = self.disc(y_true)
        disc_pred = self.disc(y_pred)

        # sample a x_hat by sampling along the line between true and pred
        # z = tf.placeholder(tf.float32, shape=[None, 1])
        # shp = y_true.get_shape()[0]
        # WARNING: SHOULD REALLY BE shape=[batch_size, 1] !!!
        # self.batch_size does not work, since it's not None!!!
        alpha = K.random_uniform(shape=[K.shape(y_pred)[0], 1, 1, 1])
        diff = y_pred - y_true
        interp = y_true + alpha * diff

        # take gradient of D(x_hat)
        gradients = K.gradients(self.disc(interp), [interp])[0]
        grad_pen = K.mean(K.square(K.sqrt(K.sum(K.square(gradients), axis=1))-1))

        # compute loss
        return (K.mean(disc_pred) - K.mean(disc_true)) + self.lambda_gp * grad_pen


class Nonbg(object):
    """ UNTESTED
    class to modify output on operating only on the non-bg class

    All data is aggregated and the (passed) metric is called on flattened true and
    predicted outputs in all (true) non-bg regions

    Usage:
        loss = metrics.dice
        nonbgloss = nonbg(loss).loss
    """

    def __init__(self, metric):
        self.metric = metric

    def loss(self, y_true, y_pred):
        """ prepare a loss of the given metric/loss operating on non-bg data """
        yt = y_true #.eval()
        ytbg = np.where(yt == 0)
        y_true_fix = K.variable(yt.flat(ytbg))
        y_pred_fix = K.variable(y_pred.flat(ytbg))
        return self.metric(y_true_fix, y_pred_fix)


def l1(y_true, y_pred):
    """ L1 metric (MAE) """
    return losses.mean_absolute_error(y_true, y_pred)


def l2(y_true, y_pred):
    """ L2 metric (MSE) """
    return losses.mean_squared_error(y_true, y_pred)


###############################################################################
# Helper Functions
###############################################################################

def _label_to_one_hot(tens, nb_labels):
    """
    Transform a label nD Tensor to a one-hot 3D Tensor. The input tensor is first
    batch-flattened, and then each batch and each voxel gets a one-hot representation
    """
    y = K.batch_flatten(tens)
    return K.one_hot(y, nb_labels)


def _hard_max(tens, axis):
    """
    we can't use the argmax function in a loss, as it's not differentiable
    We can use it in a metric, but not in a loss function
    therefore, we replace the 'hard max' operation (i.e. argmax + onehot)
    with this approximation
    """
    tensmax = K.max(tens, axis=axis, keepdims=True)
    eps_hot = K.maximum(tens - tensmax + K.epsilon(), 0)
    one_hot = eps_hot / K.epsilon()
    return one_hot