from __future__ import absolute_import

import torch
from torch import nn
from torch.autograd import Variable
from ..evaluation_metrics import accuracy


class DeepLoss(nn.Module):
    def __init__(self, margin=0):
        super(DeepLoss, self).__init__()
        self.triplet_criterion = nn.MarginRankingLoss(margin=margin)
        self.soft_criterion = nn.CrossEntropyLoss()

    def forward(self, inputs, targets, epoch, add_soft=0):
        cnn, rnn, main = inputs
        if epoch < add_soft:
            loss, prec = self.tri_loss(main, targets)
        else:
            loss_main, prec_main = self.tri_loss(main, targets)
            loss_cnn, prec_cnn = self.softmax(cnn, targets)
            loss_rnn, prec_rnn = self.softmax(rnn, targets)
            loss = loss_main + loss_cnn + loss_rnn
            prec = max(prec_main, prec_cnn, prec_rnn)
        return loss, prec

    def tri_loss(self, inputs, targets):
        n = inputs.size(0)
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        # For each anchor, find the hardest positive and negative

        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max())
            dist_an.append(dist[i][mask[i] == 0].min())
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)
        # Compute ranking hinge loss
        y = dist_an.data.new()
        y.resize_as_(dist_an.data)
        y.fill_(1)
        y = Variable(y)
        loss = self.triplet_criterion(dist_an, dist_ap, y)
        prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
        return loss, prec

    def softmax(self, inputs, targets):
        loss = self.soft_criterion(inputs, targets)
        prec, = accuracy(inputs.data, targets.data)
        prec = prec[0]
        return loss, prec

    def normalize(self, inputs, p=2):
        outputs = inputs.pow(p) / inputs.pow(p).sum(dim=1, keepdim=True).expand_as(inputs)
        return outputs

    def fusion(self, dist, targets):
        pass