import os
import time
import sys
import argparse

sys.path.append(os.getcwd())
import torch

import src.train.atomic_train as train
import src.models.models as models
import src.data.data as data
import utils.utils as utils
import src.train.utils as train_utils
import src.data.config as cfg

from src.data.utils import TextEncoder
from src.train.opt import OpenAIAdam

import src.models.utils as model_utils
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np

import random

parser = argparse.ArgumentParser()
parser.add_argument("--generation_set_size", type=str, default='full', choices=["full", "human"])
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--split", type=str, default="dev")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--experiment_num", type=str, default="0")
parser.add_argument("--model_name", type=str, default="models/atomic-generation/iteration-500-50000/transformer/categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant-maxe1_17-maxe2_35-maxr_1/model_transformer-nL_12-nH_12-hSize_768-edpt_0.1-adpt_0.1-rdpt_0.1-odpt_0.1-pt_gpt-afn_gelu-init_pt-vSize_40542/exp_generation-seed_123-l2_0.01-vl2_T-lrsched_warmup_linear-lrwarm_0.002-clip_1-loss_nll-b2_0.999-b1_0.9-e_1e-08/bs_1-smax_40-sample_greedy-numseq_1-gs_1000-es_1000-categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant/6.25e-05_adam_64_22000.pickle")
parser.add_argument("--gen_len", type=int, default=100)

args = parser.parse_args()
split = args.split

# Generate configuration files depending on experiment being run
utils.generate_config_files("atomic", args.experiment_num, eval_mode=True)

# Loads the correct configuration file
config_file = "config/atomic/config_{}.json".format(args.experiment_num)

# Read config file to option
config = cfg.read_config(cfg.load_config(config_file))
cfg.device = args.device

eval_opt = cfg.get_eval_parameters(config)

model_stuff = data.load_checkpoint(args.model_name)

opt = model_stuff["opt"]
opt.eval.update(eval_opt)

# Set the random seeds
torch.manual_seed(opt.train.static.seed)
random.seed(opt.train.static.seed)
if config.gpu_mode:
    torch.cuda.manual_seed_all(opt.train.static.seed)

opt.train.dynamic.epoch = 0

print("Loading Data")

categories = opt.data.categories

path = "data/atomic/processed/generation/{}.pickle".format(
    utils.make_name_string(opt.data).replace(
        "kr_{}".format(opt.data.get("kr", 1)), "kr_1"))
data_loader = data.make_data_loader(opt, categories)
loaded = data_loader.load_data(path)

data_loader.batch_size = opt.train.dynamic.bs

print("Done.")

text_encoder = TextEncoder(config.encoder_path, config.bpe_path)

special = [data.start_token, data.end_token]
special += ["<{}>".format(cat) for cat in categories]

special += [data.blank_token]

text_encoder.encoder = data_loader.vocab_encoder
text_encoder.decoder = data_loader.vocab_decoder


context_size_event = data_loader.max_event
context_size_effect = data_loader.max_effect

n_special = len(special)
n_ctx = context_size_event + context_size_effect
n_vocab = len(text_encoder.encoder) + n_ctx

print(data_loader.__dict__.keys())
opt.net.vSize = n_vocab

print("Building Model")

print(opt.exp)

model = models.make_model(
    opt, n_vocab, n_ctx, 0, load=False, return_acts=False, return_probs=True)

models.load_state_dict(model, model_stuff["state_dict"])

if config.gpu_mode:
    print("Pushing to GPU: {}".format(config.gpu_index))
    cfg.device = config.gpu_index
    cfg.do_gpu = True
    torch.cuda.set_device(cfg.device)
    model.cuda(cfg.device)
    print("Done.")

model.eval()

device = cfg.device
model.to(device)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

lm_model = model

def make_batch(X):
    X = np.array(X)
    assert X.ndim in [1, 2]
    if X.ndim == 1:
        X = np.expand_dims(X, axis=0)
    pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
    pos_enc = np.expand_dims(pos_enc, axis=0)
    batch = np.stack([X, pos_enc], axis=-1)
    batch = torch.tensor(batch, dtype=torch.long).to(device)
    return batch


def append_batch(X, next_idx, mask):
    next_pos = X[:, -1:, 1] + 1
    next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)
    next_mask = torch.cat([mask, torch.ones(X.size(0), 1, device=mask.device)], 1)
    return torch.cat((X, next_x), 1), next_mask


data_loader.reset_offsets(splits=split, shuffle=False)

if args.generation_set_size == "full":
    b = [tuple(j) for j in data_loader.sequences[split]['total'][:, :data_loader.max_event + 1].tolist()]
    total = []
    set_total = set()
    for i, sequence in enumerate(b):
        if sequence not in set_total:
            total.append(i)
            set_total.add(sequence)
elif args.generation_set_size == "human":
    human_events = open("data/atomic/{}-human-eval-events.txt".format(split), "r").read().split("\n")
    found = []
    total = []
    for i, j in enumerate(data_loader.data[split]["total"]):
        if j[0] in human_events and (j[0], j[1]) not in found:
            found.append((j[0], j[1]))
            total.append(i)
else:
    total = list(range(int(args.generation_set_size)))

args.decoding_strategy = "greedy"

final_sequences = []

end_token = text_encoder.encoder[data.end_token]

eval_file_name = args.model_name.replace("gs_1000", "gs_{}".format(
    args.generation_set_size))
eval_file_name = eval_file_name[:-7] + "/{}.gens".format(split)
eval_file_name = eval_file_name.replace("models/", "results/gens/")

print("Saving generations to: {}".format(eval_file_name))

with torch.no_grad():
    for idx in tqdm(total):
        sequence_all = {}

        batch, reset = data_loader.sample_batch(split=split, bs=1, idxs=[idx])

        XMB = batch["sequences"][:, :context_size_event + 1]
        Ref = batch["sequences"][:, context_size_event + 1:]
        MMB = batch["attention_mask"][:, :context_size_event + 1]

        init = "".join([text_encoder.decoder[i].replace('</w>', ' ').replace(
                "<blank>", "___ ") for i in XMB[:, :-1].squeeze().tolist() if i])
        attr = text_encoder.decoder[XMB[:, -1].item()].strip("<>")

        XMB = model_utils.prepare_position_embeddings(
            opt, text_encoder.encoder, XMB.unsqueeze(-1))

        sequence_all["event"] = init
        sequence_all["effect_type"] = attr

        lm_probs = lm_model(XMB.unsqueeze(1), sequence_mask=MMB)
        dist = lm_probs[:, -1, :].squeeze()

        values, indices = lm_probs[:, -1, :].max(dim=-1)
        seqs = indices.clone().unsqueeze(0)

        next_pos = XMB[:, -1:, 1] + 1
        next_x = torch.cat((indices.view(1, -1), next_pos), -1).unsqueeze(1)
        XMB = torch.cat((XMB, next_x), 1)
        MMB = torch.cat([MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1)

        # Sample from top k

        for _ in range(args.gen_len):
            lm_probs = lm_model(XMB.unsqueeze(1), sequence_mask=MMB)
            dist = lm_probs[:, -1, :].squeeze()

            # Sample from top k
            values, next_idx = lm_probs[:, -1, :].max(dim=-1)

            next_idx = next_idx.unsqueeze(1)

            seqs = torch.cat([seqs, next_idx], 1)

            if (next_idx.item() == end_token) or _ == context_size_effect - 1:
                break

            XMB, MMB = append_batch(XMB, next_idx, MMB)

        beams = []

        for beam in seqs:
            beams.append(" ".join("".join(
                [text_encoder.decoder[tok.item()].replace(
                    '</w>', ' ').replace('\n', '')
                 for tok in beam if tok != end_token]).split()))

        # print(beams[0])

        sequence_all['beams'] = beams
        final_sequences.append(sequence_all)

import pickle

utils.mkpath("/".join(eval_file_name.split("/")[:-1]))

with open(eval_file_name, "wb") as f:
    pickle.dump(final_sequences, f)