import src.data.utils as data_utils
import src.data.atomic as adata
import src.data.config as cfg

import torch
import random
from tqdm import tqdm


def map_name(name, opt):
    if name == "train":
        return "train{}k.txt".format(opt.trainsize)
    elif name == "test":
        return "test.txt"
    else:
        return "dev{}.txt".format(opt.devversion)


conceptnet_relations = [
    'AtLocation', 'CapableOf', 'Causes', 'CausesDesire',
    'CreatedBy', 'DefinedAs', 'DesireOf', 'Desires', 'HasA',
    'HasFirstSubevent', 'HasLastSubevent', 'HasPainCharacter',
    'HasPainIntensity', 'HasPrerequisite', 'HasProperty',
    'HasSubevent', 'InheritsFrom', 'InstanceOf', 'IsA',
    'LocatedNear', 'LocationOfAction', 'MadeOf', 'MotivatedByGoal',
    'NotCapableOf', 'NotDesires', 'NotHasA', 'NotHasProperty',
    'NotIsA', 'NotMadeOf', 'PartOf', 'ReceivesAction', 'RelatedTo',
    'SymbolOf', 'UsedFor'
]


split_into_words = {
    'AtLocation': "at location",
    'CapableOf': "capable of",
    'Causes': "causes",
    'CausesDesire': "causes desire",
    'CreatedBy': "created by",
    'DefinedAs': "defined as",
    'DesireOf': "desire of",
    'Desires': "desires",
    'HasA': "has a",
    'HasFirstSubevent': "has first subevent",
    'HasLastSubevent': "has last subevent",
    'HasPainCharacter': "has pain character",
    'HasPainIntensity': "has pain intensity",
    'HasPrerequisite': "has prequisite",
    'HasProperty': "has property",
    'HasSubevent': "has subevent",
    'InheritsFrom': "inherits from",
    'InstanceOf': 'instance of',
    'IsA': "is a",
    'LocatedNear': "located near",
    'LocationOfAction': "location of action",
    'MadeOf': "made of",
    'MotivatedByGoal': "motivated by goal",
    'NotCapableOf': "not capable of",
    'NotDesires': "not desires",
    'NotHasA': "not has a",
    'NotHasProperty': "not has property",
    'NotIsA': "not is a",
    'NotMadeOf': "not made of",
    'PartOf': "part of",
    'ReceivesAction': "receives action",
    'RelatedTo': "related to",
    'SymbolOf': "symbol of",
    'UsedFor': "used for"
}


class GenerationDataLoader(adata.DataLoader):
    def __init__(self, opt, categories=None):
        super(GenerationDataLoader, self).__init__(opt)
        self.opt = opt

        for split in self.data:
            self.data[split] = {"total": []}
            self.offsets[split] = {"total": 0}

        self.vocab_encoder = None
        self.vocab_decoder = None
        self.special_chars = None
        self.max_e1 = None
        self.max_e2 = None
        self.max_r = None

    def offset_summary(self, split):
        return sum(self.offsets[split].values())

    def load_data(self, path):
        if ".pickle" in path:
            print("Loading data from: {}".format(path))
            data_utils.load_existing_data_loader(self, path)
            return True

        for split in self.data:
            file_name = map_name(split, self.opt.data)

            if split != "dev" or self.opt.data.devversion != "12":
                string_tuples = open("{}/{}".format(
                    path, file_name), "r").read().split("\n")
                tuples = [x.split("\t") for x in string_tuples if x]
            else:
                string_tuples = open("{}/{}".format(
                    path, "dev1.txt"), "r").read().split("\n")
                tuples = [x.split("\t") for x in string_tuples if x]
                string_tuples = open("{}/{}".format(
                    path, "dev2.txt"), "r").read().split("\n")
                tuples += [x.split("\t") for x in string_tuples if x]

            if split in ["dev", "test"]:
                if self.opt.data.rel == "language":
                    self.data[split]["total"] = \
                        [(i[1].lower().strip(), split_into_words[i[0]],
                         i[2].lower().strip(), int(i[3])) for i in tuples]
                    self.data[split]["positive"] = \
                        [(i[1].lower().strip(), split_into_words[i[0]],
                         i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])]
                    self.data[split]["negative"] = \
                        [(i[1].lower().strip(), split_into_words[i[0]],
                         i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])]
                elif self.opt.data.rel == "relation":
                    self.data[split]["total"] = \
                        [(i[1].lower().strip(), "<{}>".format(i[0]),
                         i[2].lower().strip(), int(i[3])) for i in tuples]
                    self.data[split]["positive"] = \
                        [(i[1].lower().strip(), "<{}>".format(i[0]),
                         i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])]
                    self.data[split]["negative"] = \
                        [(i[1].lower().strip(), "<{}>".format(i[0]),
                         i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])]
            else:
                if self.opt.data.rel == "language":
                    self.data[split]["total"] = \
                        [(i[1].lower().strip(), split_into_words[i[0]],
                         i[2].lower().strip(), i[3]) for i in tuples]
                elif self.opt.data.rel == "relation":
                    self.data[split]["total"] = \
                        [(i[1].lower().strip(), "<{}>".format(i[0]),
                         i[2].lower().strip(), i[3]) for i in tuples]

        return False

    def make_tensors(self, text_encoder, special,
                     splits=["train", "dev", "test"], test=False):
        self.vocab_encoder = text_encoder.encoder
        self.vocab_decoder = text_encoder.decoder
        self.special_chars = special

        sequences = {}
        for split in splits:
            sequences[split], discarded = get_generation_sequences(
                self.data, split, text_encoder, test, self.opt.data.maxe1,
                self.opt.data.maxe2)

            if split == "train":
                self.data[split]["total"] = [j for i, j in enumerate(
                    self.data[split]["total"]) if i not in set(discarded)]
            self.masks[split]["total"] = [(len(i[0]), len(i[1]), len(i[2])) for
                                          i in sequences[split]]

        self.max_e1 = max([max([l[0] for l in self.masks[split]["total"]])
                           for split in self.masks])
        self.max_r = max([max([l[1] for l in self.masks[split]["total"]])
                          for split in self.masks])
        self.max_e2 = max([max([l[2] for l in self.masks[split]["total"]])
                           for split in self.masks])

        print(self.max_e1)
        print(self.max_r)
        print(self.max_e2)

        for split in splits:
            num_elements = len(sequences[split])
            self.sequences[split]["total"] = torch.LongTensor(
                num_elements, self.max_e1 + self.max_e2 + self.max_r).fill_(0)

            for i, seq in enumerate(sequences[split]):
                # print(self.sequences[split]["total"][i, :len(seq[0])].size())
                # print(torch.FloatTensor(seq[0]).size())
                self.sequences[split]["total"][i, :len(seq[0])] = \
                    torch.LongTensor(seq[0])
                start_r = self.max_e1
                end_r = self.max_e1 + len(seq[1])
                self.sequences[split]["total"][i, start_r:end_r] = \
                    torch.LongTensor(seq[1])
                start_e2 = self.max_e1 + self.max_r
                end_e2 = self.max_e1 + self.max_r + len(seq[2])
                self.sequences[split]["total"][i, start_e2:end_e2] = \
                    torch.LongTensor(seq[2])

            if split in ["test", "dev"]:
                print(split)
                self.sequences[split]["negative"] = \
                    self.sequences[split]["total"].index_select(
                        0, torch.LongTensor([i for i, j in enumerate(
                            self.data[split]['total']) if not j[3]]))
                            # self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if not j[3]]))
                self.sequences[split]["positive"] = \
                    self.sequences[split]["total"].index_select(
                        0, torch.LongTensor([i for i, j in enumerate(
                            self.data[split]['total']) if j[3]]))
                            # self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if j[3]]))

    def sample_batch(self, split, bs, cat="total", idxs=None):
        offset = self.offsets[split][cat]

        batch = {}

        # Decided not to reduce computation on here because it's all parallel
        # anyway and we don't want to run out of memory in cases where we
        # don't see the longest version quickly enough

        if idxs:
            seqs = self.sequences[split][cat].index_select(
                0, torch.LongTensor(idxs).to(
                    self.sequences[split][cat].device))
        else:
            seqs = self.sequences[split][cat][offset:offset + bs]
        batch["sequences"] = seqs.to(cfg.device)
        batch["attention_mask"] = make_attention_mask(seqs)
        batch["loss_mask"] = make_loss_mask(seqs, self.max_e1 + self.max_r)
        batch["key"] = (cat, offset, offset + bs)

        offset += seqs.size(0)

        self.offsets[split][cat] = offset

        if split == "train" and offset + bs > len(self.sequences[split][cat]):
            return batch, True
        elif offset >= len(self.sequences[split][cat]):
            return batch, True
        else:
            return batch, False

    def reset_offsets(self, splits=["train", "test", "dev"],
                      shuffle=True, keys=None):
        if isinstance(splits, str):
            splits = [splits]

        for split in splits:
            if keys is None:
                keys = ["total", "positive", "negative"]

            for key in keys:
                self.offsets[split][key] = 0

            if shuffle:
                self.shuffle_sequences(split, keys)

    def shuffle_sequences(self, split="train", keys=None):
        if keys is None:
            # print(type(self.data))
            # print(type(self.data.keys()))
            keys = self.data[split].keys()

        for key in keys:
            if key in ["positive", "negative"]:
                continue
            idxs = list(range(len(self.data[split][key])))

            random.shuffle(idxs)

            self.sequences[split][key] = \
                self.sequences[split][key].index_select(
                    0, torch.LongTensor(idxs))

            temp = [self.data[split][key][i] for i in idxs]
            self.data[split][key] = temp

            temp = [self.masks[split][key][i] for i in idxs]
            self.masks[split][key] = temp


def make_attention_mask(sequences):
    return (sequences != 0).float().to(cfg.device)


def make_loss_mask(sequences, max_event):
    # print(sequences.size())
    mask = (sequences != 0).float()
    mask[:, :max_event] = 0
    return mask[:, 1:].to(cfg.device)


def get_generation_sequences(data, split, text_encoder, test,
                             max_e1=10, max_e2=15):
    sequences = []
    count = 0

    final_event1 = None
    final_event2 = None
    final_relation = None

    discarded = []

    for event1, relation, event2, _ in tqdm(data[split]["total"]):
        e1, r, e2 = do_example(text_encoder, event1, relation, event2)

        if (split == "train" and len(e1) > max_e1 or
                len(e2) > max_e2):
            discarded.append(count)
            count += 1
            continue

        final = compile_final_sequence(
            e1, e2, r, text_encoder)

        sequences.append(final)

        count += 1

        if count > 10 and test:
            break

    return sequences, discarded


def do_example(text_encoder, event1, relation, event2):
    final_event1 = text_encoder.encode([event1], verbose=False)[0]
    if relation.lower() != relation:
        final_relation = [text_encoder.encoder[relation]]
    else:
        final_relation = text_encoder.encode(
            [relation], verbose=False)[0]
    if event2 is not None:
        final_event2 = text_encoder.encode([event2], verbose=False)[0]
    else:
        final_event2 = None

    return final_event1, final_relation, final_event2


def compile_final_sequence(final_event1, final_event2, final_relation, text_encoder):
    final = []

    final.append(final_event1)
    final.append(final_relation)
    final.append(final_event2)

    final[-1].append(text_encoder.encoder["<END>"])

    return final