import torch.nn as nn
import torch
import pdb


class WeightedCrossEntropy(nn.Module):

    def __init__(self, weights=None, size_average=True):
        super(WeightedCrossEntropy, self).__init__()
        self.weights = weights
        self.size_average = size_average
        assert (self.size_average == True)  # Not implemented for the other case

    def forward(self, output, target):
        loss = nn.CrossEntropyLoss(self.weights, self.size_average)
        output_one = output.view(-1)
        output_zero = 1 - output_one
        output_converted = torch.stack([output_zero, output_one], 1)
        target_converted = target.view(-1).long()
        return loss(output_converted, target_converted)


class SeGANLoss(nn.Module):

    def __init__(self, weights=None, size_average=True):
        super(SeGANLoss, self).__init__()
        self.weights = weights
        self.size_average = size_average
        assert (self.size_average == True)  # Not implemented for the other case

    def forward(self, output, target):
        background = target == 0
        foreground = target == 1
        loss = nn.BCEWithLogitsLoss(size_average=self.size_average)
        background_loss = loss(output[background], target[background])
        foreground_loss = loss(output[foreground], target[foreground])
        return background_loss + foreground_loss