from __future__ import absolute_import

import torch
from torch import nn
from torch.autograd import Variable
import numpy as np


class TripletClusteringLoss(nn.Module):
    def __init__(self, clusters, margin=0,):
        super(TripletClusteringLoss, self).__init__()
        assert isinstance(clusters, torch.autograd.Variable)
        self.clusters = clusters
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)
        self.num_classes = clusters.size(0)
        self.num_features = clusters.size(1)
        self.dist = torch.pow(self.clusters, 2).sum(dim=1, keepdim=True)

    def forward(self, inputs, targets):
        assert self.num_features == input.size(1)
        n = inputs.size(0)
        dist = self.dist.expand(self.num_classes, n)
        dist += torch.pow(inputs, 2).sum(dim=1).t()
        dist.addmm_(1, -2, self.clusters, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        dist = dist.t()
        # For each anchor, find the hardest positive and negative
        mask = torch.zeros(n,self.num_classes,out=torch.ByteTensor())
        target_ids = targets.data.numpy().astype(int)
        mask[np.arange(n),target_ids] = 1
        dist_ap = dist[mask == 1]
        dist_an = dist[mask == 0].view(n, -1).min(dim=1)
        # Compute ranking hinge loss
        y = dist_an.data.new()
        y.resize_as_(dist_an.data)
        y.fill_(1)
        y = Variable(y)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
        return loss, prec

    def update_clusters(self,clusters):
        self.clusters = clusters