import torch
import random
import numpy as np
from config import global_config as cfg
from reader import CamRest676Reader, get_glove_matrix
from reader import KvretReader
from reader import UbuntuDialogueReader
from reader import JDCorpusReader
from unsup_net import UnsupervisedSEDST, cuda_
from torch.optim import Adam, RMSprop
from torch.autograd import Variable
from reader import pad_sequences
import argparse, time
from metric import CamRestEvaluator, KvretEvaluator, GenericEvaluator
import logging


class Model:
    def __init__(self, dataset, inference_only=False):
        reader_dict = {
            'camrest': CamRest676Reader,
            'kvret': KvretReader,
            'ubuntu': UbuntuDialogueReader,
            'jd': JDCorpusReader
        }
        model_dict = {
            'SEDST': UnsupervisedSEDST,
        }
        evaluator_dict = {
            'camrest': CamRestEvaluator,
            'kvret': KvretEvaluator,
            'ubuntu': GenericEvaluator,
            'jd': GenericEvaluator
        }
        self.reader = reader_dict[dataset]()
        self.m = model_dict[cfg.m](embed_size=cfg.embedding_size,
                                   hidden_size=cfg.hidden_size,
                                   q_hidden_size=cfg.q_hidden_size,
                                   vocab_size=cfg.vocab_size,
                                   layer_num=cfg.layer_num,
                                   dropout_rate=cfg.dropout_rate,
                                   z_length=cfg.z_length,
                                   alpha=cfg.alpha,
                                   max_ts=cfg.max_ts,
                                   beam_search=cfg.beam_search,
                                   beam_size=cfg.beam_size,
                                   eos_token_idx=self.reader.vocab.encode('EOS_M'),
                                   vocab=self.reader.vocab,
                                   teacher_force=cfg.teacher_force,
                                   degree_size=cfg.degree_size)
        self.EV = evaluator_dict[dataset]  # evaluator class
        if cfg.cuda: self.m = self.m.cuda()
        self.base_epoch = -1

    def _convert_batch(self, py_batch, prev_z_py=None):
        u_input_py = py_batch['user']
        u_len_py = py_batch['u_len']
        kw_ret = {}
        if cfg.prev_z_method == 'concat' and prev_z_py is not None:
            for i in range(len(u_input_py)):
                eob = self.reader.vocab.encode('EOS_Z2')
                if eob in prev_z_py[i] and prev_z_py[i].index(eob) != len(prev_z_py[i]) - 1:
                    idx = prev_z_py[i].index(eob)
                    u_input_py[i] = prev_z_py[i][:idx + 1] + u_input_py[i]
                else:
                    u_input_py[i] = prev_z_py[i] + u_input_py[i]
                u_len_py[i] = len(u_input_py[i])
                for j, word in enumerate(prev_z_py[i]):
                    if word >= cfg.vocab_size:
                        prev_z_py[i][j] = 2  # unk
        elif cfg.prev_z_method == 'separate' and prev_z_py is not None:
            for i in range(len(prev_z_py)):
                eob = self.reader.vocab.encode('EOS_Z2')
                if eob in prev_z_py[i] and prev_z_py[i].index(eob) != len(prev_z_py[i]) - 1:
                    idx = prev_z_py[i].index(eob)
                    prev_z_py[i] = prev_z_py[i][:idx + 1]
                for j, word in enumerate(prev_z_py[i]):
                    if word >= cfg.vocab_size:
                        prev_z_py[i][j] = 2  # unk
            prev_z_input_np = pad_sequences(prev_z_py, cfg.max_ts, padding='post', truncating='pre').transpose((1, 0))
            prev_z_len = np.array([len(_) for _ in prev_z_py])
            prev_z_input = cuda_(Variable(torch.from_numpy(prev_z_input_np).long()))
            kw_ret['prev_z_len'] = prev_z_len
            kw_ret['prev_z_input'] = prev_z_input
            kw_ret['prev_z_input_np'] = prev_z_input_np

        degree_input_np = np.array(py_batch['degree'])
        u_input_np = pad_sequences(u_input_py, cfg.max_ts, padding='post', truncating='pre').transpose((1, 0))
        z_input_np = pad_sequences(py_batch['latent'], padding='post').transpose((1, 0))
        if cfg.pretrain:
            m_input_np = pad_sequences(py_batch['response'], cfg.max_ts, padding='post', truncating='post').transpose(
                (1, 0))
        else:
            m_input_np = pad_sequences(py_batch['response'], cfg.max_ts, padding='post', truncating='pre').transpose(
                (1, 0))
        p_input_np = pad_sequences(py_batch['post'], cfg.max_ts, padding='post', truncating='pre').transpose((1, 0))
        u_len = np.array(u_len_py)
        m_len = np.array(py_batch['m_len'])
        p_len = np.array(py_batch['p_len'])
        degree_input = cuda_(Variable(torch.from_numpy(degree_input_np).float()))
        u_input = cuda_(Variable(torch.from_numpy(u_input_np).long()))
        z_input = cuda_(Variable(torch.from_numpy(z_input_np).long()))
        m_input = cuda_(Variable(torch.from_numpy(m_input_np).long()))
        p_input = cuda_(Variable(torch.from_numpy(p_input_np).long()))
        supervised = py_batch['supervised'][0]
        kw_ret['z_input_np'] = z_input_np
        return u_input, u_input_np, z_input, m_input, m_input_np, p_input, p_input_np, u_len, m_len, p_len, \
               degree_input, supervised, kw_ret

    def train(self):
        lr = cfg.lr
        prev_min_loss, early_stop_count = 1 << 30, cfg.early_stop_count
        train_time = 0
        for epoch in range(cfg.epoch_num):
            sw = time.time()
            if epoch < cfg.base_epoch:
                continue
            sup_loss, unsup_loss = 0, 0
            sup_cnt, unsup_cnt = 0, 0
            data_iterator = self.reader.mini_batch_iterator('train')
            optim = Adam(lr=lr, params=filter(lambda x: x.requires_grad, self.m.parameters()), weight_decay=1e-6)
            for iter_num, dial_batch in enumerate(data_iterator):
                if epoch == cfg.base_epoch and iter_num < cfg.base_iter:
                    continue
                turn_states = {}
                turn_states_q = {}
                prev_z = None
                trunc_cnt = 1
                for turn_num, turn_batch in enumerate(dial_batch):
                    if cfg.truncated:
                        logging.debug('iter %d turn %d' % (iter_num, turn_num))
                    optim.zero_grad()
                    u_input, u_input_np, z_input, m_input, m_input_np, p_input, p_input_np, u_len, \
                    m_len, p_len, degree_input, supervised, kw_ret \
                        = self._convert_batch(turn_batch, prev_z)

                    loss, m_loss, p_loss, kl_div_loss, turn_states, turn_states_q = self.m(u_input=u_input,
                                                                                           z_input=None,
                                                                                           m_input=m_input,
                                                                                           p_len=p_len,
                                                                                           degree_input=degree_input,
                                                                                           u_input_np=u_input_np,
                                                                                           m_input_np=m_input_np,
                                                                                           z_supervised=False,
                                                                                           turn_states=turn_states,
                                                                                           p_input=p_input,
                                                                                           p_input_np=p_input_np,
                                                                                           u_len=u_len, m_len=m_len,
                                                                                           mode='train',
                                                                                           turn_states_q=turn_states_q,
                                                                                           **kw_ret)
                    if turn_num == len(dial_batch) - 1 or (trunc_cnt and trunc_cnt % cfg.trunc_turn == 0):
                        for k in turn_states:
                            turn_states[k] = cuda_(Variable(turn_states[k].data))
                        loss.backward(retain_graph=False)
                    else:
                        loss.backward(retain_graph=True)
                    trunc_cnt += 1
                    grad = torch.nn.utils.clip_grad_norm(self.m.parameters(), 4.0)
                    optim.step()
                    unsup_loss += loss.item()
                    if cfg.truncated and not np.isnan(loss.data.cpu().numpy()) and not np.isnan(
                            grad) and iter_num % 10 == 0 and iter_num != 0:
                        self.save_model(epoch)
                    unsup_cnt += 1
                    logging.debug(
                        'unsupervised loss:{} m_loss:{} p_loss:{} kl_div_loss:{} grad:{}'.format(loss.item(),
                                                                                                 m_loss.item(),
                                                                                                 p_loss.item(),
                                                                                                 kl_div_loss.data[
                                                                                                     0], grad))
                    prev_z = turn_batch['latent']

            epoch_sup_loss, epoch_unsup_loss = sup_loss / (sup_cnt + 1e-8), unsup_loss / (unsup_cnt + 1e-8)
            train_time += time.time() - sw

            logging.info('Traning time: {}'.format(train_time))
            logging.info('avg training loss in epoch %d sup:%6f unsup:%6f' % (epoch, epoch_sup_loss, epoch_unsup_loss))
            # do validation
            valid_sup_loss, valid_unsup_loss = self.validate()
            logging.info('validation loss in epoch %d sup:%6f unsup:%6f' % (epoch, valid_sup_loss, valid_unsup_loss))
            logging.info('time for epoch %d: %6f' % (epoch, time.time() - sw))
            valid_loss = valid_sup_loss + valid_unsup_loss
            self.save_model(epoch)
            if valid_loss <= prev_min_loss:
                prev_min_loss = valid_loss
            else:
                early_stop_count -= 1
                lr *= cfg.lr_decay
                if not early_stop_count:
                    break
                logging.info('early stop countdown %d, learning rate %6f' % (early_stop_count, lr))

    def eval(self, data='test'):
        self.m.eval()
        self.reader.result_file = None
        with torch.no_grad():
            data_iterator = self.reader.mini_batch_iterator(data)
            mode = 'test'  # if not cfg.pretrain else 'pretrain_test'
            for batch_num, dial_batch in enumerate(data_iterator):
                turn_states = {}
                turn_states_q = {}
                prev_z = None
                for turn_num, turn_batch in enumerate(dial_batch):
                    u_input, u_input_np, z_input, m_input, m_input_np, p_input, p_input_np, u_len, \
                    m_len, p_len, degree_input, supervised, kw_ret \
                        = self._convert_batch(turn_batch, prev_z)
                    m_idx, z_idx, turn_states = self.m(mode=mode, u_input=u_input, u_len=u_len, z_input=z_input,
                                                       m_input=m_input,
                                                       degree_input=degree_input, u_input_np=u_input_np,
                                                       m_input_np=m_input_np,
                                                       p_input=p_input, p_input_np=p_input_np, p_len=p_len,
                                                       m_len=m_len, z_supervised=None, turn_states=turn_states,
                                                       **kw_ret)
                    if not cfg.last_turn_only or turn_num == len(dial_batch) - 1:
                        self.reader.wrap_result(turn_batch, m_idx, z_idx)
                    prev_z = z_idx
                # print('{}\r'.format(batch_num))
            ev = self.EV(result_path=cfg.result_path)
            res = ev.run_metrics()
        self.m.train()
        return res

    def validate(self, data='dev'):
        self.m.eval()
        with torch.no_grad():
            data_iterator = self.reader.mini_batch_iterator(data)
            sup_loss, unsup_loss = 0, 0
            sup_cnt, unsup_cnt = 0, 0
            for d, dial_batch in enumerate(data_iterator):
                turn_states = {}
                for turn_num, turn_batch in enumerate(dial_batch):
                    if turn_num <= 0 or turn_num < len(dial_batch) - cfg.max_turn:
                        continue
                    u_input, u_input_np, z_input, m_input, m_input_np, p_input, p_input_np, u_len, \
                    m_len, p_len, degree_input, supervised, kw_ret \
                        = self._convert_batch(turn_batch)

                    loss, m_loss, p_loss, kl_div_loss, turn_states, _ = self.m(u_input=u_input, z_input=None,
                                                                               m_input=m_input,
                                                                               z_supervised=False,
                                                                               turn_states=turn_states,
                                                                               u_input_np=u_input_np,
                                                                               m_input_np=m_input_np,
                                                                               p_input=p_input, p_input_np=p_input_np,
                                                                               p_len=p_len,
                                                                               u_len=u_len, m_len=m_len, mode='train',
                                                                               degree_input=degree_input,
                                                                               turn_states_q={}, **kw_ret)
                    if not cfg.last_turn_only or turn_num == len(dial_batch) - 1:
                        unsup_loss += m_loss.item()
                        unsup_cnt += 1
                    logging.debug(
                        'unsupervised loss:{} m_loss:{} p_loss:{} kl_div_loss:{}'.format(loss.item(), m_loss.item(),
                                                                                         p_loss.item(),
                                                                                         kl_div_loss.item()))
                    for k in turn_states:
                        turn_states[k] = turn_states[k].detach()

            sup_loss /= (sup_cnt + 1e-8)
            unsup_loss /= (unsup_cnt + 1e-8)
        self.m.train()
        res = self.eval()
        return sup_loss, unsup_loss

    def save_model(self, epoch, path=None):
        if not path:
            path = cfg.model_path
        all_state = {'sedst': self.m.state_dict(),
                     'config': cfg.__dict__,
                     'epoch': epoch}
        with open(path, 'wb') as f:
            torch.save(all_state, path)

    def load_model(self, path=None):
        if not path:
            path = cfg.model_path
        with open(path, 'rb') as f:
            all_state = torch.load(path)
        self.m.load_state_dict(all_state['sedst'], strict=False)
        self.base_epoch = all_state.get('epoch', 0)


    def freeze_module(self, module):
        for param in module.parameters():
            param.requires_grad = False

    def unfreeze_module(self, module):
        for param in module.parameters():
            param.requires_grad = True

    def load_glove_embedding(self, freeze=False):
        initial_arr = self.m.u_encoder.embedding.weight.data.cpu().numpy()
        mat = get_glove_matrix(self.reader.vocab, initial_arr)
        # np.save('./data/embedding.npy',mat)
        # mat = np.load('./data/embedding.npy')
        embedding_arr = torch.from_numpy(mat)

        self.m.u_encoder.embedding.weight.data.copy_(embedding_arr)
        self.m.p_encoder.embedding.weight.data.copy_(embedding_arr)
        self.m.m_decoder.emb.weight.data.copy_(embedding_arr)
        self.m.p_decoder.emb.weight.data.copy_(embedding_arr)
        self.m.qz_decoder.mu.weight.data.copy_(embedding_arr.transpose(1, 0))
        self.m.pz_decoder.mu.weight.data.copy_(embedding_arr.transpose(1, 0))
        if freeze:
            self.freeze_module(self.m.u_encoder.embedding)
            self.freeze_module(self.m.m_e.embedding)
            self.freeze_module(self.m.m_decoder.emb)

    def count_params(self):

        module_parameters = filter(lambda p: p.requires_grad, self.m.parameters())
        param_cnt = sum([np.prod(p.size()) for p in module_parameters])

        print('total trainable params: %d' % param_cnt)


def main():
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    random.seed(1)
    np.random.seed(1)

    parser = argparse.ArgumentParser()
    parser.add_argument('-mode')
    parser.add_argument('-dataset')
    parser.add_argument('-cfg', nargs='*')
    args = parser.parse_args()

    cfg.init_handler(args.dataset)

    if args.cfg:
        for pair in args.cfg:
            k, v = tuple(pair.split('='))
            dtype = type(getattr(cfg, k))
            if dtype == type(None):
                raise ValueError()
            if dtype is bool:
                v = False if v == 'False' else True
            else:
                v = dtype(v)
            setattr(cfg, k, v)

    logging.debug(str(cfg))
    if cfg.cuda:
        torch.cuda.set_device(cfg.cuda_device)
        logging.debug('Device: {}'.format(torch.cuda.current_device()))
    cfg.mode = args.mode
    m = Model(args.dataset.split('-')[-1])
    m.count_params()
    if args.mode == 'train':
        m.load_glove_embedding()
        m.train()
    elif args.mode == 'adjust':
        m.load_model()
        m.train()
    elif args.mode == 'test':
        m.load_model()
        m.eval()


if __name__ == '__main__':
    main()