##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: chenyuru
## This source code is licensed under the MIT-style license
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from math import pi, sqrt

safe_log = lambda x: torch.log(torch.clamp(x, 1e-8, 1e8))

class _BaseKLDivergence(nn.Module):
    def __init__(self):
        super(_BaseKLDivergence, self).__init__()

    def forward(self, Q, P):
        """
        Parameters
        ----------
        P: ground truth probability distribution [batch_size, n, n]
        Q: predicted probability distribution [batch_size, n, n]

        Description
        -----------
        compute the KL divergence of attention maps. Here P and Q denote 
        the pixel-level attention map with n spatial positions.
        """
        kl_loss = P * safe_log(P / Q)
        pixel_loss = torch.sum(kl_loss, dim=-1)
        total_loss = torch.mean(pixel_loss)
        return total_loss

class AttentionLoss2d(_BaseKLDivergence):
    def __init__(self, scale=1):
        super(AttentionLoss2d, self).__init__()
        self.scale = scale

    def get_similarity(self, depth):
        b, _, h, w = depth.shape
        M = depth.reshape((b, h * w, 1))
        N = depth.reshape((b, 1, h * w))
        W = F.softmax(-torch.abs(torch.log(M) - torch.log(N)), -1)
        W[torch.isnan(W)] = 0
        return W

    def get_gt_sim_map(self, label):
        b, _, h, w = label.shape
        res_label = F.interpolate(label, size=(h//8//self.scale, w//8//self.scale), mode='nearest')
        gt_sim_map = self.get_similarity(res_label)
        return gt_sim_map
    
    def forward(self, sim_map, label):
        """
        Parameters
        ----------
        sim_map: [batch_size, n, n]
        label: [batch_size, 1, h, w]
        """
        gt_sim_map = self.get_gt_sim_map(label)
        return super(AttentionLoss2d, self).forward(sim_map, gt_sim_map)

class _BaseEntropyLoss2d(nn.Module):
    def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None):
        """
        Parameters
        ----------
        ignore_index : Specifies a target value that is ignored
                       and does not contribute to the input gradient
        reduction : Specifies the reduction to apply to the output: 
                    'mean' | 'sum'. 'mean': elemenwise mean, 
                    'sum': class dim will be summed and batch dim will be averaged.
        use_weight : whether to use weights of classes.
        weight : Tensor, optional
                a manual rescaling weight given to each class.
                If given, has to be a Tensor of size "nclasses"
        """
        super(_BaseEntropyLoss2d, self).__init__()
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.use_weights = use_weights
        if use_weights:
            print("w/ class balance")
            print(weight)
            self.weight = torch.FloatTensor(weight).cuda()
        else:
            print("w/o class balance")
            self.weight = None

    def get_entropy(self, pred, label):
        """
        Return
        ------
        entropy : shape [batch_size, h, w, c]
        Description
        -----------
        Information Entropy based loss need to get the entropy according to your implementation, 
        each element denotes the loss of a certain position and class.
        """
        raise NotImplementedError

    def forward(self, pred, label):
        """
        Parameters
        ----------
        pred: [batch_size, num_classes, h, w]
        label: [batch_size, h, w]
        """
        assert not label.requires_grad
        assert pred.dim() == 4
        assert label.dim() == 3
        assert pred.size(0) == label.size(0), "{0} vs {1} ".format(pred.size(0), label.size(0))
        assert pred.size(2) == label.size(1), "{0} vs {1} ".format(pred.size(2), label.size(1))
        assert pred.size(3) == label.size(2), "{0} vs {1} ".format(pred.size(3), label.size(3))

        n, c, h, w = pred.size()
        if self.use_weights:
            if self.weight is None:
                print('label size {}'.format(label.shape))
                freq = np.zeros(c)
                for k in range(c):
                    mask = (label[:, :, :] == k)
                    freq[k] = torch.sum(mask)
                    print('{}th frequency {}'.format(k, freq[k]))
                weight = freq / np.sum(freq) * c
                weight = np.median(weight) / weight
                self.weight = torch.FloatTensor(weight).cuda()
                print('Online class weight: {}'.format(self.weight))
        else:
            self.weight = 1
        if self.ignore_index is None:
            self.ignore_index = c + 1

        entropy = self.get_entropy(pred, label)

        mask = label != self.ignore_index
        weighted_entropy = entropy * self.weight

        if self.reduction == 'sum':
            loss = torch.sum(weighted_entropy, -1)[mask].mean()
        elif self.reduction == 'mean':
            loss = torch.mean(weighted_entropy, -1)[mask].mean()
        return loss


class OrdinalRegression2d(_BaseEntropyLoss2d):
    def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None):
        super(OrdinalRegression2d, self).__init__(ignore_index, reduction, use_weights, weight)

    def get_entropy(self, pred, label):
        n, c, h, w = pred.size()
        label = label.unsqueeze(3).long()
        pred = pred.permute(0, 2, 3, 1)
        mask10 = ((torch.arange(c)).cuda() <  label).float()
        mask01 = ((torch.arange(c)).cuda() >= label).float()
        entropy = safe_log(pred) * mask10 + safe_log(1 - pred) * mask01
        return -entropy

def NormalDist(x, sigma):
    f = torch.exp(-x**2/(2*sigma**2)) / sqrt(2*pi*sigma**2)
    return f

class CrossEntropy2d(_BaseEntropyLoss2d):
    def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None,
                 eps=0.0, priorType='uniform'):
        """
        Parameters
        ----------
        eps : label smoothing factor
        prior : prior distribution, if uniform it is equivalent to the 
                label smoothing trick (https://arxiv.org/abs/1512.00567).
                However, gaussian distribution is more friendly with the depth estimation I think.
        """
        super(CrossEntropy2d, self).__init__(ignore_index, reduction, use_weights, weight)
        self.eps = eps
        self.priorType = priorType

    def get_entropy(self, pred, label):
        n, c, h, w = pred.size()
        label = label.unsqueeze(3).long()
        pred = F.softmax(pred, 1).permute(0, 2, 3, 1)
        one_hot_label = ((torch.arange(c)).cuda() == label).float()

        if self.eps == 0:
            prior = 0
        else:
            if self.priorType == 'gaussian':
                tensor = (torch.arange(c).cuda() - label).float()
                prior = NormalDist(tensor, c / 10)
            elif self.priorType == 'uniform':
                prior = 1 / (c-1)

        smoothed_label = (1 - self.eps) * one_hot_label + self.eps * prior * (1 - one_hot_label)
        entropy = smoothed_label * safe_log(pred) + (1 - smoothed_label) * safe_log(1 - pred)
        return -entropy 

class OhemCrossEntropy2d(CrossEntropy2d):
    def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None,
                 eps=0.0, priorType='uniform', thresh=0.6, min_kept=0, 
                 ):
        """
        Parameters
        ----------
        thresh : OHEM (online hard example mining) threshold of correct probability
        min_kept : OHEM of minimal kept pixels

        Description
        -----------
        modified from https://github.com/PkuRainBow/OCNet.pytorch/blob/master/utils/loss.py#L68
        """
        super(OhemCrossEntropy2d, self).__init__(ignore_index, reduction, use_weights, weight,
                                                 eps, priorType)
        self.thresh = float(thresh)
        self.min_kept = int(min_kept)

    def get_ohem_label(self, pred, label):
        n, c, h, w = pred.size()
        if self.ignore_index is None:
            self.ignore_index = c + 1

        input_label = label.data.cpu().numpy().ravel().astype(np.int32)
        x = np.rollaxis(pred.data.cpu().numpy(), 1).reshape((c, -1))
        input_prob = np.exp(x - x.max(axis=0, keepdims=True))
        input_prob /= input_prob.sum(axis=0, keepdims=True)

        valid_flag = input_label != self.ignore_index
        valid_inds = np.where(valid_flag)[0]
        valid_label = input_label[valid_flag]
        num_valid = valid_flag.sum()
        if self.min_kept >= num_valid:
            print('Labels: {}'.format(num_valid))
        elif num_valid > 0:
            valid_prob = input_prob[:,valid_flag]
            valid_prob = valid_prob[valid_label, np.arange(len(valid_label), dtype=np.int32)]
            threshold = self.thresh
            if self.min_kept > 0:
                index = valid_prob.argsort()
                threshold_index = index[ min(len(index), self.min_kept) - 1 ]
                if valid_prob[threshold_index] > self.thresh:
                    threshold = valid_prob[threshold_index]
            kept_flag = valid_prob <= threshold
            valid_kept_inds = valid_inds[kept_flag]
            valid_inds = valid_kept_inds

        self.ohem_ratio = len(valid_inds) / num_valid
        #print('Max prob: {:.4f}, hard ratio: {:.4f} = {} / {} '.format(input_prob.max(), self.ohem_ratio, len(valid_inds), num_valid))
        valid_kept_label = input_label[valid_inds].copy()
        input_label.fill(self.ignore_index)
        input_label[valid_inds] = valid_kept_label
        #valid_flag_new = input_label != self.ignore_index
        # print(np.sum(valid_flag_new))
        label = torch.from_numpy(input_label.reshape(label.size())).long().cuda()
        return label
        
    def get_entropy(self, pred, label):
        label = self.get_ohem_label(pred, label)
        entropy = super(OhemCrossEntropy2d, self).get_entropy(pred, label)
        return entropy