import argparse
import numpy as np
import random
from tqdm import tqdm
import torch
import torch.nn as nn
import os, sys
parentPath = os.path.abspath("..")
sys.path.insert(0, parentPath)# add parent folder to path so as to import common modules
from helper import indexes2sent
import models, data, configs
from metrics import Metrics
from data_loader import APIDataset, load_dict, load_vecs


def evaluate(model, metrics, test_loader, vocab_desc, vocab_api, repeat, decode_mode, f_eval):
    ivocab_api = {v: k for k, v in vocab_api.items()}
    ivocab_desc = {v: k for k, v in vocab_desc.items()}
    device = next(model.parameters()).device
    
    recall_bleus, prec_bleus = [], []
    local_t = 0
    for descs, apiseqs, desc_lens, api_lens in tqdm(test_loader):
        
        if local_t>1000:
            break        
        
        desc_str = indexes2sent(descs[0].numpy(), vocab_desc)
        
        descs, desc_lens = [tensor.to(device) for tensor in [descs, desc_lens]]
        sample_words, sample_lens = model.sample(descs, desc_lens, repeat, decode_mode)
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, vocab_api)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
        ref_str, _ =indexes2sent(apiseqs[0].numpy(), vocab_api)
        ref_tokens = ref_str.split(' ')
        
        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)
        
        local_t += 1 
        f_eval.write("Batch %d \n" % (local_t))# print the context        
        f_eval.write(f"Query: {desc_str} \n")
        f_eval.write("Target >> %s\n" % (ref_str.replace(" ' ", "'")))# print the true outputs 
        for r_id, pred_sent in enumerate(pred_sents):
            f_eval.write("Sample %d >> %s\n" % (r_id, pred_sent.replace(" ' ", "'")))
        f_eval.write("\n")

    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2*(prec_bleu*recall_bleu) / (prec_bleu+recall_bleu+10e-12)
    
    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f"% (recall_bleu, prec_bleu, f1)
    print(report)
    f_eval.write(report + "\n")
    print("Done testing")
    
    return recall_bleu, prec_bleu

def main(args):
    conf = getattr(configs, 'config_'+args.model)()
    # Set the random seed manually for reproducibility.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    else:
        print("Note that our pre-trained models require CUDA to evaluate.")
    
    # Load data
    test_set=APIDataset(args.data_path+'test.desc.h5', args.data_path+'test.apiseq.h5', conf['max_sent_len'])
    test_loader=torch.utils.data.DataLoader(dataset=test_set, batch_size=1, shuffle=False, num_workers=1)
    vocab_api = load_dict(args.data_path+'vocab.apiseq.json')
    vocab_desc = load_dict(args.data_path+'vocab.desc.json')
    metrics=Metrics()
    
    # Load model checkpoints   
    model = getattr(models, args.model)(conf)
    ckpt=f'./output/{args.model}/{args.expname}/{args.timestamp}/models/model_epo{args.reload_from}.pkl'
    model.load_state_dict(torch.load(ckpt))
    
    f_eval = open(f"./output/{args.model}/{args.expname}/results.txt".format(args.model, args.expname), "w")
    
    evaluate(model, metrics, test_loader, vocab_desc, vocab_api, args.n_samples, args.decode_mode , f_eval)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PyTorch DeepAPI for Eval')
    parser.add_argument('--data_path', type=str, default='./data/', help='location of the data corpus')
    parser.add_argument('--model', type=str, default='RNNSeq2Seq', help='model name')
    parser.add_argument('--expname', type=str, default='basic', help='experiment name, disinguishing different parameter settings')
    parser.add_argument('--timestamp', type=str, default='201909270147', help='time stamp')
    parser.add_argument('--reload_from', type=int, default=10000, help='directory to load models from')
    
    parser.add_argument('--n_samples', type=int, default=10, help='Number of responses to sampling')
    parser.add_argument('--decode_mode', type=str, default='sample',
                        help='decoding mode for generation: beamsearch, greedy or sample')
    parser.add_argument('--seed', type=int, default=1111, help='random seed')
    args = parser.parse_args()
    print(vars(args))
    main(args)