import torch
import torch.nn.functional as F

# Classic adversarial loss
def loss_KL_d(dis_fake, dis_real):
    L1 = torch.mean(F.softplus(-dis_real))
    L2 = torch.mean(F.softplus(dis_fake))
    return L1 + L2

def loss_KL_g(dis_fake):
    return torch.mean(F.softplus(-dis_fake))


# Hinge loss
def loss_hinge_d(dis_fake, dis_real):
    L1 = torch.mean(F.relu(1 - dis_real))
    L2 = torch.mean(F.relu(1 + dis_fake))
    return L1 + L2

def loss_hinge_g(dis_fake):
    return -torch.mean(dis_fake)


# NLL loss
def loss_nll(bin_output, bin_label, multi_output, multi_label, lam=0.5):
    L1 = F.binary_cross_entropy_with_logits(bin_output, bin_label)
    L2 = F.cross_entropy(multi_output, multi_label)
    return lam * L1 + (1.0 - lam) * L2

# NLL loss with another weighting scheme
def loss_nll_v2(bin_output, bin_label, multi_output, multi_label, lam):
    L1 = F.binary_cross_entropy_with_logits(bin_output, bin_label)
    L2 = F.cross_entropy(multi_output, multi_label)
    return L1 + lam * L2

# Binary loss
def loss_bin(bin_output, bin_label):
    return F.binary_cross_entropy_with_logits(bin_output, bin_label)