import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

class ConstractiveThresholdHingeLoss(nn.Module):

    def __init__(self,hingethresh=0.0,margin=2.0):
        super(ConstractiveThresholdHingeLoss, self).__init__()
        self.threshold = hingethresh
        self.margin = margin

    def forward(self,out_vec_t0,out_vec_t1,label):

        distance = F.pairwise_distance(out_vec_t0,out_vec_t1,p=2)
        similar_pair = torch.clamp(distance - self.threshold,min=0.0)
        dissimilar_pair = torch.clamp(self.margin- distance,min=0.0)
        #dissimilar_pair = torch.clamp(self.margin-(distance-self.threshold),min=0.0)
        constractive_thresh_loss = torch.sum(
            (1-label)* torch.pow(similar_pair,2) + label * torch.pow(dissimilar_pair,2)
        )
        return constractive_thresh_loss

class ConstractiveLoss(nn.Module):

    def __init__(self,margin =2.0,dist_flag='l2'):
        super(ConstractiveLoss, self).__init__()
        self.margin = margin
        self.dist_flag = dist_flag

    def various_distance(self,out_vec_t0,out_vec_t1):

        if self.dist_flag == 'l2':
            distance = F.pairwise_distance(out_vec_t0,out_vec_t1,p=2)
        if self.dist_flag == 'l1':
            distance = F.pairwise_distance(out_vec_t0,out_vec_t1,p=1)
        if self.dist_flag == 'cos':
            similarity = F.cosine_similarity(out_vec_t0,out_vec_t1)
            distance = 1 - 2 * similarity/np.pi
        return distance

    def forward(self,out_vec_t0,out_vec_t1,label):

        #distance = F.pairwise_distance(out_vec_t0,out_vec_t1,p=2)
        distance = self.various_distance(out_vec_t0,out_vec_t1)
        #distance = 1 - F.cosine_similarity(out_vec_t0,out_vec_t1)
        constractive_loss = torch.sum((1-label)*torch.pow(distance,2 ) + \
                                       label * torch.pow(torch.clamp(self.margin - distance, min=0.0),2))
        return constractive_loss

class ConstractiveMaskLoss(nn.Module):

    def __init__(self,thresh_flag=False,hinge_thresh=0.0,dist_flag='l2'):
        super(ConstractiveMaskLoss, self).__init__()

        if thresh_flag:
            self.sample_constractive_loss = ConstractiveThresholdHingeLoss(margin=2.0,hingethresh=hinge_thresh)
        else:
            self.sample_constractive_loss = ConstractiveLoss(margin=2.0,dist_flag=dist_flag)

    def forward(self,out_t0,out_t1,ground_truth):

        #out_t0 = out_t0.permute(0,2,3,1)
        n,c,h,w = out_t0.data.shape
        out_t0_rz = torch.transpose(out_t0.view(c,h*w),1,0)
        out_t1_rz = torch.transpose(out_t1.view(c,h*w),1,0)
        gt_tensor = torch.from_numpy(np.array(ground_truth.data.cpu().numpy(),np.float32))
        gt_rz = Variable(torch.transpose(gt_tensor.view(1, h * w), 1, 0)).cuda()
        #gt_rz = Variable(torch.transpose(ground_truth.view(1,h*w),1,0))
        loss = self.sample_constractive_loss(out_t0_rz,out_t1_rz,gt_rz)
        return loss

class LogDetDivergence(nn.Module):

    def __init__(self,model,param_name,dim=512):
        super(LogDetDivergence, self).__init__()
        self.param_name = param_name
        self.param_dict = dict(model.named_parameters())
        self.dim = dim
        self.identity_matrix = Variable(torch.from_numpy(np.identity(self.dim)).float()).cuda()

    def select_param(self):

        for layer_name, layer_param in self.param_dict.items():
            if self.param_name in layer_name:
                if 'weight' in layer_name:
                    return layer_param

    def forward(self):

        constrainted_matrix = self.select_param()
        matrix_ = torch.squeeze(torch.squeeze(constrainted_matrix,dim=2),dim=2)
        matrix_t = torch.t(matrix_)
        matrixs = torch.mm(matrix_t,matrix_)
        trace_ = torch.trace(torch.mm(matrixs,torch.inverse(matrixs)))
        log_det = torch.logdet(matrixs)
        maha_loss = trace_ - log_det
        return maha_loss

class Mahalanobis_Constraint(nn.Module):
    def __init__(self,model,param_name,dim=512):
        super(Mahalanobis_Constraint, self).__init__()
        self.param_name = param_name
        self.param_dict = dict(model.named_parameters())
        self.dim = dim
        self.identity_matrix = Variable(torch.from_numpy(np.identity(self.dim)).float()).cuda()

    def select_param(self):

        for layer_name, layer_param in self.param_dict.items():
            if self.param_name in layer_name:
                if 'weight' in layer_name:
                    return layer_param

    def forward(self):

        constrainted_matrix = self.select_param()
        matrxi_ = torch.squeeze(torch.squeeze(constrainted_matrix,dim=2),dim=2)
        matrxi_t = torch.t(matrxi_)
        matrxi_contrainted = (torch.mm(matrxi_t,matrxi_) - self.identity_matrix).view(self.dim ** 2)
        regularizer = torch.pow(matrxi_contrainted, 2).sum(dim=0)
        return regularizer

class SampleHistogramLoss(nn.Module):
    def __init__(self, num_steps):
        super(SampleHistogramLoss, self).__init__()
        self.step = 1.0 / (num_steps - 1)
        self.t = torch.range(0, 1, self.step).view(-1, 1).cuda()
        self.tsize = self.t.size()[0]

    def forward(self,feat_t0,feat_t1, label):
        def histogram(inds, size):
            s_repeat_ = s_repeat.clone()
            indsa = (delta_repeat == (self.t - self.step)) & inds
            indsb = (delta_repeat == self.t) & inds
            s_repeat_[~(indsb | indsa)] = 0
            s_repeat_[indsa] = (s_repeat_ - Variable(self.t) + self.step)[indsa] / self.step
            s_repeat_[indsb] = (-s_repeat_ + Variable(self.t) + self.step)[indsb] / self.step
            return s_repeat_.sum(1) / size

class BhattacharyyaDistance(nn.Module):
    def __init__(self):
        super(BhattacharyyaDistance, self).__init__()

    def forward(self,hist1,hist2):

        bh_dist = (torch.sqrt(hist1 * hist2)).sum()
        return bh_dist

class KLCoefficient(nn.Module):
    def __init__(self):
        super(KLCoefficient, self).__init__()

    def forward(self,hist1,hist2):

        kl = F.kl_div(hist1,hist2)
        dist = 1. / 1 + kl
        return dist

class HistogramMaskLoss(nn.Module):
    def __init__(self,num_steps,dist_flag='l2'):
        super(HistogramMaskLoss, self).__init__()
        self.step = 1.0 / (num_steps - 1)
        self.t = torch.range(0, 1, self.step).view(-1, 1)
        self.dist_flag = dist_flag
        self.distance = KLCoefficient()

    def various_distance(self,out_vec_t0,out_vec_t1):
        if self.dist_flag == 'l2':
            distance = F.pairwise_distance(out_vec_t0,out_vec_t1,p=2)
        if self.dist_flag == 'cos':
            similarity = F.cosine_similarity(out_vec_t0,out_vec_t1)
            distance = 1 - 2 * similarity/np.pi
        return distance

    def histogram(self):
        pass

    def forward(self,feat_t0,feat_t1,ground_truth):

        n,c,h,w = feat_t0.data.shape
        out_t0_rz = torch.transpose(feat_t0.view(c,h*w),1,0)
        out_t1_rz = torch.transpose(feat_t1.view(c,h*w),1,0)
        gt_np = ground_truth.view(h * w).data.cpu().numpy()
        #### inspired by Source code from Histogram loss ###
        ### get all pos in positive pairs and negative pairs ###
        pos_inds_np,neg_inds_np = np.squeeze(np.where(gt_np == 0), 1),np.squeeze(np.where(gt_np !=0),1)
        pos_size,neg_size = pos_inds_np.shape[0],neg_inds_np.shape[0]
        pos_inds,neg_inds = torch.from_numpy(pos_inds_np).cuda(),torch.from_numpy(neg_inds_np).cuda()
        ### get similarities(l2 distance) for all position ###
        distance = torch.squeeze(self.various_distance(out_t0_rz,out_t1_rz),dim=1)
        ### build similarity histogram of positive pairs and negative pairs ###
        pos_dist_ls,neg_dist_ls = distance[pos_inds],distance[neg_inds]
        pos_dist_ls_t,neg_dist_ls_t = torch.from_numpy(pos_dist_ls.data.cpu().numpy()),torch.from_numpy(neg_dist_ls.data.cpu().numpy())
        hist_pos = Variable(torch.histc(pos_dist_ls_t,bins=100,min=0,max=1)/pos_size,requires_grad=True)
        hist_neg = Variable(torch.histc(neg_dist_ls_t,bins=100,min=0,max=1)/neg_size,requires_grad=True)
        loss = self.distance(hist_pos,hist_neg)
        return loss