import torch.nn as nn
import torch.nn.functional as F


class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weights=None):
        super(CrossEntropyLoss2d, self).__init__()

        self.loss = nn.NLLLoss2d(weight=weights)
        self.loss.cuda()

    def forward(self, outputs, targets):
        return self.loss(F.log_softmax(outputs), targets)