from __future__ import print_function
import datetime
import time
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import codecs
import pickle
import math

from model_word_ada.LM import LM
from model_word_ada.basic import BasicRNN
from model_word_ada.ddnet import DDRNN
from model_word_ada.radam import RAdam
from model_word_ada.ldnet import LDRNN
from model_word_ada.densenet import DenseRNN
from model_word_ada.dataset import LargeDataset, EvalDataset
from model_word_ada.adaptive import AdaptiveSoftmax
import model_word_ada.utils as utils

# from tensorboardX import SummaryWriter
# writer = SummaryWriter(logdir='./cps/gadam/log_1bw_full/')

import argparse
import json
import os
import sys
import itertools
import functools

def evaluate(data_loader, lm_model, criterion, limited = 76800):
    print('evaluating')
    lm_model.eval()

    iterator = data_loader.get_tqdm()

    lm_model.init_hidden()
    total_loss = 0
    total_len = 0
    for word_t, label_t in iterator:
        label_t = label_t.view(-1)
        tmp_len = label_t.size(0)
        output = lm_model.log_prob(word_t)
        total_loss += tmp_len * utils.to_scalar(criterion(autograd.Variable(output), label_t))
        total_len += tmp_len

        if limited >=0 and total_len > limited:
            break

    ppl = math.exp(total_loss / total_len)
    print('PPL: ' + str(ppl))

    return ppl

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_folder', default='/data/billionwords/one_billion/')
    parser.add_argument('--load_checkpoint', default='')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--sequence_length', type=int, default=20)
    parser.add_argument('--hid_dim', type=int, default=2048)
    parser.add_argument('--word_dim', type=int, default=300)
    parser.add_argument('--label_dim', type=int, default=-1)
    parser.add_argument('--layer_num', type=int, default=2)
    parser.add_argument('--droprate', type=float, default=0.1)
    parser.add_argument('--add_relu', action='store_true')
    parser.add_argument('--layer_drop', type=float, default=0.5)
    parser.add_argument('--gpu', type=int, default=1)
    parser.add_argument('--epoch', type=int, default=14)
    parser.add_argument('--clip', type=float, default=5)
    parser.add_argument('--update', choices=['Adam', 'Adagrad', 'Adadelta', 'RAdam', 'SGD'], default='Adam')
    parser.add_argument('--rnn_layer', choices=['Basic', 'DDNet', 'DenseNet', 'LDNet'], default='Basic')
    parser.add_argument('--rnn_unit', choices=['gru', 'lstm', 'rnn', 'bnlstm'], default='lstm')
    parser.add_argument('--lr', type=float, default=-1)
    parser.add_argument('--lr_decay', type=lambda t: [int(tup) for tup in t.split(',')], default=[8])
    parser.add_argument('--cut_off', nargs='+', default=[4000,40000,200000])
    parser.add_argument('--interval', type=int, default=100)
    parser.add_argument('--check_interval', type=int, default=4000)
    parser.add_argument('--checkpath', default='./cps/gadam/')
    parser.add_argument('--model_name', default='adam')
    args = parser.parse_args()

    if args.gpu >= 0:
        torch.cuda.set_device(args.gpu)

    print('loading dataset')
    dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb'))
    w_map, test_data, range_idx = dataset['w_map'], dataset['test_data'], dataset['range']

    cut_off = args.cut_off + [len(w_map) + 1]

    train_loader = LargeDataset(args.dataset_folder, range_idx, args.batch_size, args.sequence_length)
    test_loader = EvalDataset(test_data, args.batch_size)

    print('building model')

    rnn_map = {'Basic': BasicRNN, 'DDNet': DDRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = args.layer_drop)}
    rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate)

    if args.label_dim > 0:
        soft_max = AdaptiveSoftmax(args.label_dim, cut_off)
    else:
        soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off)

    lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu)
    lm_model.rand_ini()
    # lm_model.cuda()

    optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta, 'RAdam': RAdam, 'SGD': functools.partial(optim.SGD, momentum=0.9)}
    if args.lr > 0:
        optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr)
    else:
        optimizer=optim_map[args.update](lm_model.parameters())

    if args.load_checkpoint:
        if os.path.isfile(args.load_checkpoint):
            print("loading checkpoint: '{}'".format(args.load_checkpoint))
            checkpoint_file = torch.load(args.load_checkpoint, map_location=lambda storage, loc: storage)
            lm_model.load_state_dict(checkpoint_file['lm_model'], False)
            optimizer.load_state_dict(checkpoint_file['opt'], False)
        else:
            print("no checkpoint found at: '{}'".format(args.load_checkpoint))

    test_lm = nn.NLLLoss()
    
    test_lm.cuda()
    lm_model.cuda()
    
    batch_index = 0
    epoch_loss = 0
    full_epoch_loss = 0
    best_train_ppl = float('inf')
    cur_lr = args.lr

    try:
        for indexs in range(args.epoch):

            print('#' * 89)
            print('Start: {}'.format(indexs))
            
            iterator = train_loader.get_tqdm()
            full_epoch_loss = 0

            lm_model.train()

            for word_t, label_t in iterator:

                if 1 == train_loader.cur_idx:
                    lm_model.init_hidden()

                label_t = label_t.view(-1)

                lm_model.zero_grad()
                loss = lm_model(word_t, label_t)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm(lm_model.parameters(), args.clip)
                optimizer.step()

                batch_index += 1 

                if 0 == batch_index % args.interval:
                    s_loss = utils.to_scalar(loss)
                    # writer.add_scalars('loss_tracking/train_loss', {args.model_name:s_loss}, batch_index)
                
                epoch_loss += utils.to_scalar(loss)
                full_epoch_loss += utils.to_scalar(loss)
                if 0 == batch_index % args.check_interval:
                    epoch_ppl = math.exp(epoch_loss / args.check_interval)
                    # writer.add_scalars('loss_tracking/train_ppl', {args.model_name: epoch_ppl}, batch_index)
                    print('epoch_ppl: {} lr: {} @ batch_index: {}'.format(epoch_ppl, cur_lr, batch_index))
                    epoch_loss = 0

            if indexs in args.lr_decay and cur_lr > 0:
                cur_lr *= 0.1
                print('adjust_learning_rate...')
                utils.adjust_learning_rate(optimizer, cur_lr)
    
            test_ppl = evaluate(test_loader, lm_model, test_lm, -1)

            # writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, indexs)
            print('test_ppl: {} @ index: {}'.format(test_ppl, indexs))
        
            torch.save({'lm_model': lm_model.state_dict(), 'opt':optimizer.state_dict()}, args.checkpath+args.model_name+'.model')

    except KeyboardInterrupt:

        print('Exiting from training early')
        test_ppl = evaluate(test_loader, lm_model, test_lm, -1)
        # writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, args.epoch)

        torch.save({'lm_model': lm_model.state_dict(), 'opt':optimizer.state_dict()}, args.checkpath+args.model_name+'.model')

    # writer.close()