# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Training utilities.
"""

import argparse
import random
import pdb
import time
import itertools
import sys
import copy
import re
import logging
import torch
from torch import optim
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

from data import STOP_TOKENS
import vis

class Criterion(object):
    """Weighted CrossEntropyLoss."""
    def __init__(self, dictionary, device_id=None, bad_toks=[], reduction='elementwise_mean'):
        w = torch.Tensor(len(dictionary)).fill_(1)
        for tok in bad_toks:
            w[dictionary.get_idx(tok)] = 0.0
        if device_id is not None:
            w = w.cuda(device_id)
        # https://pytorch.org/docs/stable/nn.html 
        self.crit = nn.CrossEntropyLoss(w, reduction=reduction)

    def __call__(self, out, tgt):
        return self.crit(out, tgt)


class Engine(object):
    """The training engine.

    Performs training and evaluation.
    """
    def __init__(self, model, args, device_id=None, verbose=False):
        self.model = model
        self.args = args
        self.device_id = device_id
        self.verbose = verbose
        self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr,
            momentum=self.args.momentum,
            nesterov=(self.args.nesterov and self.args.momentum > 0))
        self.crit = Criterion(self.model.word_dict, device_id=device_id)
        self.sel_crit = Criterion(
            self.model.item_dict, device_id=device_id, bad_toks=['<disconnect>', '<disagree>'])
        if self.args.visual:
            self.model_plot = vis.ModulePlot(self.model, plot_weight=False, plot_grad=True)
            self.loss_plot = vis.Plot(['train', 'valid', 'valid_select'],
                'loss', 'loss', 'epoch', running_n=1)
            self.ppl_plot = vis.Plot(['train', 'valid', 'valid_select'],
                'perplexity', 'ppl', 'epoch', running_n=1)

    def forward(model, batch, requires_grad=False):
        """A helper function to perform a forward pass on a batch."""

        with torch.set_grad_enabled(requires_grad):
            # extract the batch into contxt, input, target and selection target
            ctx, inpt, tgt, sel_tgt = batch

            # create variables
            ctx = Variable(ctx) # (6, <=batch_size)
            inpt = Variable(inpt) # (max_len-1, <=batch_size)
            tgt = Variable(tgt) # (max_len-1*<=batch_size)
            sel_tgt = Variable(sel_tgt) # (6*<=batch_size, )

            # get context hidden state
            ctx_h = model.forward_context(ctx) # (1, <=batch_size, nhid_ctx)
            # create initial hidden state for the language rnn
            lang_h = model.zero_hid(ctx_h.size(1), model.args.nhid_lang) # (1, <=batch_size, nhid_lang)

            # perform forward for the language model
            out, lang_h = model.forward_lm(inpt, lang_h, ctx_h)
            # out: (max_len-1, <=batch_size, vocab_size)
            # lang_h: (max_len-1, <=batch_size, nhid_lang)

            # perform forward for the selection
            sel_out = model.forward_selection(inpt, lang_h, ctx_h) # (6*<=batch_size, len(item_dict))

            return out, lang_h, tgt, sel_out, sel_tgt

    def get_model(self):
        """Extracts the model."""
        return self.model

    def train_pass(self, N, trainset):
        """Training pass."""
        # make the model trainable
        self.model.train()

        total_loss = 0
        start_time = time.time()

        # training loop
        for batch in trainset:
            self.t += 1
            # forward pass
            out, hid, tgt, sel_out, sel_tgt = Engine.forward(self.model, batch, requires_grad=True)
            # out: (max_len-1, <=batch_size, vocab_size)
            # hid: (max_len-1, <=batch_size, nhid_lang)
            # tgt: (max_len-1*<=batch_size)
            # sel_out: (6*<=batch_size, len(item_dict))
            # sel_tgt: (6*<=batch_size)

            # compute LM loss and selection loss
            loss = self.crit(out.view(-1, N), tgt)
            loss += self.sel_crit(sel_out, sel_tgt) * self.model.args.sel_weight
            self.opt.zero_grad()
            # backward step with gradient clipping
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            self.opt.step()

            if self.args.visual and self.t % 100 == 0:
                self.model_plot.update(self.t)

            total_loss += loss.item()

        total_loss /= len(trainset)
        time_elapsed = time.time() - start_time
        return total_loss, time_elapsed

    def train_single(self, N, trainset):
        """A helper function to train on a random batch."""
        batch = random.choice(trainset)
        out, hid, tgt, sel_out, sel_tgt = Engine.forward(self.model, batch, requires_grad=True)
        loss = self.crit(out.view(-1, N), tgt) + \
            self.sel_crit(sel_out, sel_tgt) * self.model.args.sel_weight
        self.opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
        self.opt.step()
        return loss

    def valid_pass(self, N, validset, validset_stats):
        """Validation pass."""
        # put the model into the evaluation mode
        self.model.eval()

        valid_loss, select_loss = 0, 0
        for batch in validset:
            # compute forward pass
            out, hid, tgt, sel_out, sel_tgt = Engine.forward(self.model, batch, requires_grad=False)

            # evaluate LM and selection losses
            valid_loss += tgt.size(0) * self.crit(out.view(-1, N), tgt).item()
            select_loss += self.sel_crit(sel_out, sel_tgt).item()

        # dividing by the number of words in the input, not the tokens modeled,
        # because the latter includes padding
        return valid_loss / validset_stats['nonpadn'], select_loss / len(validset)

    def iter(self, N, epoch, lr, traindata, validdata):
        """Performs on iteration of the training.
        Runs one epoch on the training and validation datasets.
        """
        # trainset, _ = traindata
        trainset, trainset_stats = traindata
        validset, validset_stats = validdata

        train_loss, train_time = self.train_pass(N, trainset)
        check_train_loss, check_train_select_loss = self.valid_pass(N, trainset, trainset_stats)
        valid_loss, valid_select_loss = self.valid_pass(N, validset, validset_stats)

        if self.verbose:
            logging.info('| epoch %03d | train_loss %.3f | train_ppl %.3f | s/epoch %.2f | lr %0.8f' % (
                epoch, train_loss, np.exp(train_loss), train_time, lr))
            logging.info('| epoch %03d | valid_loss %.3f | valid_ppl %.3f' % (
                epoch, valid_loss, np.exp(valid_loss)))
            logging.info('| epoch %03d | train_select_loss %.3f | train_select_ppl %.3f (check)' % (
                epoch, check_train_select_loss, np.exp(check_train_select_loss)))
            logging.info('| epoch %03d | valid_select_loss %.3f | valid_select_ppl %.3f' % (
                epoch, valid_select_loss, np.exp(valid_select_loss)))

        if self.args.visual:
            self.loss_plot.update('train', epoch, train_loss)
            self.loss_plot.update('valid', epoch, valid_loss)
            self.loss_plot.update('valid_select', epoch, valid_select_loss)
            self.ppl_plot.update('train', epoch, np.exp(train_loss))
            self.ppl_plot.update('valid', epoch, np.exp(valid_loss))
            self.ppl_plot.update('valid_select', epoch, np.exp(valid_select_loss))

        return train_loss, valid_loss, valid_select_loss

    def train(self, corpus):
        """Entry point."""
        N = len(corpus.word_dict)
        best_model, best_valid_select_loss = None, 1e20
        lr = self.args.lr
        last_decay_epoch = 0
        self.t = 0

        validdata = corpus.valid_dataset(self.args.bsz, device_id=self.device_id)
        for epoch in range(1, self.args.max_epoch + 1):
            traindata = corpus.train_dataset(self.args.bsz, device_id=self.device_id)
            _, _, valid_select_loss = self.iter(N, epoch, lr, traindata, validdata)

            if valid_select_loss < best_valid_select_loss:
                best_valid_select_loss = valid_select_loss
                best_model = copy.deepcopy(self.model)

        if self.verbose:
            logging.info('| start annealing | best validselectloss %.3f | best validselectppl %.3f' % (
                best_valid_select_loss, np.exp(best_valid_select_loss)))

        self.model = best_model
        for epoch in range(self.args.max_epoch + 1, 100):
            if epoch - last_decay_epoch >= self.args.decay_every:
                last_decay_epoch = epoch
                lr /= self.args.decay_rate
                if lr < self.args.min_lr:
                    break
                self.opt = optim.SGD(self.model.parameters(), lr=lr)

            traindata = corpus.train_dataset(self.args.bsz, device_id=self.device_id)
            train_loss, valid_loss, valid_select_loss = self.iter(
                N, epoch, lr, traindata, validdata)

        return train_loss, valid_loss, valid_select_loss