"""Evaluate a language model."""

import argparse
import time, os
import math
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import _pickle as pickle

import data
import model

from utils import batchify, get_batch, repackage_hidden

parser = argparse.ArgumentParser(description='Evaluate Language Model')
parser.add_argument('--data', type=str, default='data/penn/',
                    help='location of the data corpus')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')
parser.add_argument('--cuda', action='store_false',
                    help='use CUDA')
parser.add_argument('--start_token', type=int, default=1000,
                    help='token where loss calculation ends')
parser.add_argument('--checkpoint', type=str, default='./model.pt',
                    help='model checkpoint to use')
parser.add_argument('--logfile', type=str, default='./',
                    help='location to write per token ppl')
parser.add_argument('--use_test', action='store_true', default=False,
                    help='Run on test set')
args = parser.parse_args()
print(args)

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    else:
        torch.cuda.manual_seed(args.seed)

corpus = data.Corpus(args.data)
print('Built corpus')

seq_len = 100
eval_batch_size = 1
if args.use_test:
    print('On test set!!!')
    data_ = batchify(corpus.test, eval_batch_size, args)
else:
    print('On validation set!!!')
    data_ = batchify(corpus.valid, eval_batch_size, args)

print('Load model from %s' % args.checkpoint)
start = time.time()
model = torch.load(args.checkpoint)
print('[%.1f s]' % (time.time() - start))

if args.cuda:
    model.cuda()
else:
    model.cpu()

#criterion = nn.CrossEntropyLoss()

def evaluate(data_source, batch_size, seq_len):
    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss = 0
    tokens = 0
    n = 0
    save_all_losses = []

    ntokens = len(corpus.dictionary)

    hidden = model.init_hidden(batch_size)

    for i in range(0, data_source.size(0) - 1, seq_len):
        tokens += seq_len
        data, targets = get_batch(data_source, i, args, evaluation=True, seq_len=seq_len)
        output, hidden = model(data, hidden)
        output = nn.functional.log_softmax(output.permute(2,1,0)).permute(2,1,0)
        targets = targets.view(data.data.shape[0], batch_size, -1)
        CELoss = torch.gather(output.data, dim=2, index=targets.data).squeeze()
        CELoss = -1*CELoss
        if tokens < args.start_token: continue # We are not ready to accumulate error yet
        elif tokens >= args.start_token and tokens-seq_len < args.start_token:
            data.data = data.data[-(tokens-args.start_token+1):]
            CELoss = CELoss[-(tokens-args.start_token+1):]
            print('First word: %s' % (corpus.dictionary.idx2word[data.data[-(tokens-args.start_token+1),0]]))
        total_loss += torch.sum(CELoss)
        n += data.size(0)
        save_all_losses += CELoss.tolist()
        hidden = repackage_hidden(hidden)
    print('total: %d' % n)
    print('Last word: %s' % (corpus.dictionary.idx2word[data.data[-1,0]]))
    return total_loss / float(n), save_all_losses

loss, all_losses = evaluate(data_, eval_batch_size, seq_len)

print('=' * 89)
print('| loss {:5.5f} | ppl {:8.5f}'.format(
    loss, math.exp(loss)))
print('=' * 89)

with open(args.logfile, 'wb') as f:
    pickle.dump(seq_len, f)
    pickle.dump(all_losses, f)