import torch as th
import torch.nn as nn

class LabelSmoothing(nn.Module):
    """
    Computer loss at one time step.
    """
    def __init__(self, size, smoothing=0.0):
        """Label Smoothing module
        args:
            size: vocab_size
            smoothing: smoothing ratio
        """
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction='sum')
        self.size = size
        self.smoothing = smoothing

    def forward(self, x, target):
        # x: (*, n_classes)
        # target: (*)
        assert x.size(1) == self.size
        with th.no_grad():
            tgt_dist = th.zeros_like(x, dtype=th.float)
            tgt_dist.fill_(self.smoothing / (self.size - 1))
            tgt_dist.scatter_(1, target.unsqueeze(1), 1 - self.smoothing)

        return self.criterion(x, tgt_dist)