"""
This file handles the details of the loss function during training.

This includes: LossComputeBase and the standard NMTLossCompute, and
               sharded loss compute stuff.
"""
from __future__ import division
import ipdb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import onmt
import onmt.io


class LossComputeBase(nn.Module):
    """
    Class for managing efficient loss computation. Handles
    sharding next step predictions and accumulating mutiple
    loss computations


    Users can implement their own loss computation strategy by making
    subclass of this one.  Users need to implement the _compute_loss()
    and make_shard_state() methods.

    Args:
        generator (:obj:`nn.Module`) :
             module that maps the output of the decoder to a
             distribution over the target vocabulary.
        tgt_vocab (:obj:`Vocab`) :
             torchtext vocab object representing the target output
    """
    def __init__(self, generator, tgt_vocab):
        super(LossComputeBase, self).__init__()
        self.generator = generator
        self.tgt_vocab = tgt_vocab
        self.padding_idx = tgt_vocab.stoi[onmt.io.PAD_WORD]

    def _make_shard_state(self, batch, output, range_, attns=None):
        """
        Make shard state dictionary for shards() to return iterable
        shards for efficient loss computation. Subclass must define
        this method to match its own _compute_loss() interface.
        Args:
            batch: the current batch.
            output: the predict output from the model.
            range_: the range of examples for computing, the whole
                    batch or a trunc of it?
            attns: the attns dictionary returned from the model.
        """
        return NotImplementedError

    def _compute_loss(self, batch, output, target, **kwargs):
        """
        Compute the loss. Subclass must define this method.

        Args:

            batch: the current batch.
            output: the predict output from the model.
            target: the validate target to compare output with.
            **kwargs(optional): additional info for computing loss.
        """
        return NotImplementedError

    # def monolithic_compute_loss(self, batch, output, attns):
    #     """
    #     Compute the forward loss for the batch.

    #     Args:
    #       batch (batch): batch of labeled examples
    #       output (:obj:`FloatTensor`):
    #           output of decoder model `[tgt_len x batch x hidden]`
    #       attns (dict of :obj:`FloatTensor`) :
    #           dictionary of attention distributions
    #           `[tgt_len x batch x src_len]`
    #     Returns:
    #         :obj:`onmt.Statistics`: loss statistics
    #     """
    #     # range_ = (0, batch.tgt.size(0))
    #     # shard_state = self._make_shard_state(batch, output, range_, attns)
    #     # loss, batch_stats = self._compute_loss(batch, **shard_state)
    #     target = batch.tgt[1:batch.tgt.size(0)]
    #     loss, batch_stats = self._compute_loss(batch, output, target)
    #     return loss, batch_stats

    def monolithic_compute_loss(self, batch, output, attns):
        target = batch.tgt[1:batch.tgt.size(0)]

        scores = self.generator(output.view(-1, output.size(2)))  # Just doing _bottle w/o calling function
        gtruth = target.view(-1)

        weight = torch.ones(len(self.tgt_vocab)).cuda()
        weight[self.padding_idx] = 0
        loss = F.nll_loss(scores, gtruth, weight=weight, size_average=False) # , ignore_index=-100, reduce=None, reduction='elementwise_mean')

        loss_data = loss.data.clone()
        batch_stats = self._stats(loss_data, scores.data, target.view(-1).data)

        return loss, batch_stats

    def sharded_compute_loss(self, batch, output, attns,
                             cur_trunc, trunc_size, shard_size):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note harding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard

        Returns:
            :obj:`onmt.Statistics`: validation loss statistics

        """

        batch_stats = onmt.Statistics()
        range_ = (cur_trunc, cur_trunc + trunc_size)
        shard_state = self._make_shard_state(batch, output, range_, attns)

        for shard in shards(shard_state, shard_size):
            ipdb.set_trace()
            loss, stats = self._compute_loss(batch, **shard)
            loss.div(batch.batch_size).backward()
            batch_stats.update(stats)

        return batch_stats

    def _stats(self, loss, scores, target):
        """
        Args:
            loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
            scores (:obj:`FloatTensor`): a score for each possible output
            target (:obj:`FloatTensor`): true targets

        Returns:
            :obj:`Statistics` : statistics for this batch.
        """
        pred = scores.max(1)[1]
        non_padding = target.ne(self.padding_idx)
        num_correct = pred.eq(target) \
                          .masked_select(non_padding) \
                          .sum()
        return onmt.Statistics(loss[0], non_padding.sum(), num_correct)

    def _bottle(self, v):
        return v.view(-1, v.size(2))

    def _unbottle(self, v, batch_size):
        return v.view(-1, batch_size, v.size(1))


class NMTLossCompute(LossComputeBase):
    """
    Standard NMT Loss Computation.
    """
    def __init__(self, generator, tgt_vocab, label_smoothing=0.0):
        super(NMTLossCompute, self).__init__(generator, tgt_vocab)
        assert (label_smoothing >= 0.0 and label_smoothing <= 1.0)

        self.tgt_vocab_len = len(tgt_vocab)

        if label_smoothing > 0:
            # When label smoothing is turned on,
            # KL-divergence between q_{smoothed ground truth prob.}(w)
            # and p_{prob. computed by model}(w) is minimized.
            # If label smoothing value is set to zero, the loss
            # is equivalent to NLLLoss or CrossEntropyLoss.
            # All non-true labels are uniformly set to low-confidence.
            self.criterion = nn.KLDivLoss(size_average=False)
            one_hot = torch.randn(1, len(tgt_vocab))
            one_hot.fill_(label_smoothing / (len(tgt_vocab) - 2))
            one_hot[0][self.padding_idx] = 0
            self.register_buffer('one_hot', one_hot)
        else:
            weight = torch.ones(len(tgt_vocab))
            weight[self.padding_idx] = 0
            self.criterion = nn.NLLLoss(weight, size_average=False)  # IMPORTANT: NLLLoss is what we use. Interesting that size_average=False
            # ipdb.set_trace()
        self.confidence = 1.0 - label_smoothing

    def _make_shard_state(self, batch, output, range_, attns=None):
        return {
            "output": output,
            "target": batch.tgt[range_[0] + 1: range_[1]],
        }

    def _compute_loss(self, batch, output, target):
        # ipdb.set_trace()
        # scores = self.generator(self._bottle(output))
        scores = self.generator(output.view(-1, output.size(2)))  # Just doing _bottle w/o calling function
        gtruth = target.view(-1)

        # if self.confidence < 1:
        #     tdata = gtruth.data
        #     mask = torch.nonzero(tdata.eq(self.padding_idx)).squeeze()
        #     likelihood = torch.gather(scores.data, 1, tdata.unsqueeze(1))
        #     tmp_ = self.one_hot.repeat(gtruth.size(0), 1)
        #     tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence)
        #     if mask.dim() > 0:
        #         likelihood.index_fill_(0, mask, 0)
        #         tmp_.index_fill_(0, mask, 0)
        #     gtruth = Variable(tmp_, requires_grad=False)

        # loss = self.criterion(scores, gtruth)

        weight = torch.ones(self.tgt_vocab_len).cuda()
        weight[self.padding_idx] = 0
        loss = F.nll_loss(scores, gtruth, weight=weight, size_average=False) # , ignore_index=-100, reduce=None, reduction='elementwise_mean')

        # if self.confidence < 1:
        #     loss_data = - likelihood.sum(0)
        # else:
        #     loss_data = loss.data.clone()

        loss_data = loss.data.clone()
        stats = self._stats(loss_data, scores.data, target.view(-1).data)

        return loss, stats


def filter_shard_state(state):
    for k, v in state.items():
        if v is not None:
            if isinstance(v, Variable) and v.requires_grad:
                v = Variable(v.data, requires_grad=True, volatile=False)
            yield k, v


def shards(state, shard_size, eval=False):
    """
    Args:
        state: A dictionary which corresponds to the output of
               *LossCompute._make_shard_state(). The values for
               those keys are Tensor-like or None.
        shard_size: The maximum size of the shards yielded by the model.
        eval: If True, only yield the state, nothing else.
              Otherwise, yield shards.

    Yields:
        Each yielded shard is a dict.

    Side effect:
        After the last shard, this function does back-propagation.
    """
    if eval:
        yield state
    else:
        # non_none: the subdict of the state dictionary where the values
        # are not None.
        non_none = dict(filter_shard_state(state))

        # Now, the iteration:
        # state is a dictionary of sequences of tensor-like but we
        # want a sequence of dictionaries of tensors.
        # First, unzip the dictionary into a sequence of keys and a
        # sequence of tensor-like sequences.
        keys, values = zip(*((k, torch.split(v, shard_size))
                             for k, v in non_none.items()))

        # Now, yield a dictionary for each shard. The keys are always
        # the same. values is a sequence of length #keys where each
        # element is a sequence of length #shards. We want to iterate
        # over the shards, not over the keys: therefore, the values need
        # to be re-zipped by shard and then each shard can be paired
        # with the keys.
        for shard_tensors in zip(*values):
            yield dict(zip(keys, shard_tensors))

        # Assumed backprop'd
        variables = ((state[k], v.grad.data) for k, v in non_none.items()
                     if isinstance(v, Variable) and v.grad is not None)
        inputs, grads = zip(*variables)
        torch.autograd.backward(inputs, grads)