import json
import os
import random
import itertools
import math
from collections import defaultdict

import numpy as np

from cnn_dm.prepro import para2sents
from my.tensorflow import grouper
from my.utils import index


class Data(object):
    def get_size(self):
        raise NotImplementedError()

    def get_by_idxs(self, idxs):
        """
        Efficient way to obtain a batch of items from filesystem
        :param idxs:
        :return dict: {'X': [,], 'Y', }
        """
        data = defaultdict(list)
        for idx in idxs:
            each_data = self.get_one(idx)
            for key, val in each_data.items():
                data[key].append(val)
        return data

    def get_one(self, idx):
        raise NotImplementedError()

    def get_empty(self):
        raise NotImplementedError()

    def __add__(self, other):
        raise NotImplementedError()

class MyData(Data):
    def __init__(self, config, root_dir, file_names):
        self.root_dir = root_dir
        self.file_names = file_names
        self.config = config

    def get_one(self, idx):
        file_name = self.file_names[idx]
        with open(os.path.join(self.root_dir, file_name), 'r') as fh:
            url = fh.readline().strip()
            _ = fh.readline()
            para = fh.readline().strip()
            _ = fh.readline()
            ques = fh.readline().strip()
            _ = fh.readline()
            answer = fh.readline().strip()
            _ = fh.readline()
            cands = list(line.strip() for line in fh)
            cand_ents = list(cand.split(":")[0] for cand in cands)
            wordss = para2sents(para, self.config.width)
            ques_words = ques.split(" ")

            x = wordss
            cx = [[list(word) for word in words] for words in wordss]
            q = ques_words
            cq = [list(word) for word in ques_words]
            y = answer
            c = cand_ents

            data = {'x': x, 'cx': cx, 'q': q, 'cq': cq, 'y': y, 'c': c, 'ids': file_name}
            return data

    def get_empty(self):
        return MyData(self.config, self.root_dir, [])

    def __add__(self, other):
        file_names = self.file_names + other.file_names
        return MyData(self.config, self.root_dir, file_names)

    def get_size(self):
        return len(self.file_names)


class DataSet(object):
    def __init__(self, data, data_type, shared=None, valid_idxs=None):
        self.data = data  # e.g. {'X': [0, 1, 2], 'Y': [2, 3, 4]}
        self.data_type = data_type
        self.shared = shared
        total_num_examples = self.get_data_size()
        self.valid_idxs = range(total_num_examples) if valid_idxs is None else valid_idxs
        self.num_examples = total_num_examples

    def _sort_key(self, idx):
        rx = self.data['*x'][idx]
        x = self.shared['x'][rx[0]][rx[1]]
        return max(map(len, x))

    def get_data_size(self):
        if isinstance(self.data, dict):
            return len(next(iter(self.data.values())))
        elif isinstance(self.data, Data):
            return self.data.get_size()
        raise Exception()

    def get_by_idxs(self, idxs):
        if isinstance(self.data, dict):
            out = defaultdict(list)
            for key, val in self.data.items():
                out[key].extend(val[idx] for idx in idxs)
            return out
        elif isinstance(self.data, Data):
            return self.data.get_by_idxs(idxs)
        raise Exception()

    def get_one(self, idx):
        if isinstance(self.data, dict):
            out = {key: [val[idx]] for key, val in self.data.items()}
            return out
        elif isinstance(self.data, Data):
            return self.data.get_one(idx)

    def get_batches(self, batch_size, num_batches=None, shuffle=False, cluster=False):
        """

        :param batch_size:
        :param num_batches:
        :param shuffle:
        :param cluster: cluster examples by their lengths; this might give performance boost (i.e. faster training).
        :return:
        """
        num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size))
        if num_batches is None:
            num_batches = num_batches_per_epoch
        num_epochs = int(math.ceil(num_batches / num_batches_per_epoch))

        if shuffle:
            random_idxs = random.sample(self.valid_idxs, len(self.valid_idxs))
            if cluster:
                sorted_idxs = sorted(random_idxs, key=self._sort_key)
                sorted_grouped = lambda: list(grouper(sorted_idxs, batch_size))
                grouped = lambda: random.sample(sorted_grouped(), num_batches_per_epoch)
            else:
                random_grouped = lambda: list(grouper(random_idxs, batch_size))
                grouped = random_grouped
        else:
            raw_grouped = lambda: list(grouper(self.valid_idxs, batch_size))
            grouped = raw_grouped

        batch_idx_tuples = itertools.chain.from_iterable(grouped() for _ in range(num_epochs))
        for _ in range(num_batches):
            batch_idxs = tuple(i for i in next(batch_idx_tuples) if i is not None)
            batch_data = self.get_by_idxs(batch_idxs)
            shared_batch_data = {}
            for key, val in batch_data.items():
                if key.startswith('*'):
                    assert self.shared is not None
                    shared_key = key[1:]
                    shared_batch_data[shared_key] = [index(self.shared[shared_key], each) for each in val]
            batch_data.update(shared_batch_data)

            batch_ds = DataSet(batch_data, self.data_type, shared=self.shared)
            yield batch_idxs, batch_ds

    def get_multi_batches(self, batch_size, num_batches_per_step, num_steps=None, shuffle=False, cluster=False):
        batch_size_per_step = batch_size * num_batches_per_step
        batches = self.get_batches(batch_size_per_step, num_batches=num_steps, shuffle=shuffle, cluster=cluster)
        multi_batches = (tuple(zip(grouper(idxs, batch_size, shorten=True, num_groups=num_batches_per_step),
                         data_set.divide(num_batches_per_step))) for idxs, data_set in batches)
        return multi_batches

    def get_empty(self):
        if isinstance(self.data, dict):
            data = {key: [] for key in self.data}
        elif isinstance(self.data, Data):
            data = self.data.get_empty()
        else:
            raise Exception()
        return DataSet(data, self.data_type, shared=self.shared)

    def __add__(self, other):
        if isinstance(self.data, dict):
            data = {key: val + other.data[key] for key, val in self.data.items()}
        elif isinstance(self.data, Data):
            data = self.data + other.data
        else:
            raise Exception()

        valid_idxs = list(self.valid_idxs) + [valid_idx + self.num_examples for valid_idx in other.valid_idxs]
        return DataSet(data, self.data_type, shared=self.shared, valid_idxs=valid_idxs)

    def divide(self, integer):
        batch_size = int(math.ceil(self.num_examples / integer))
        idxs_gen = grouper(self.valid_idxs, batch_size, shorten=True, num_groups=integer)
        data_gen = (self.get_by_idxs(idxs) for idxs in idxs_gen)
        ds_tuple = tuple(DataSet(data, self.data_type, shared=self.shared) for data in data_gen)
        return ds_tuple


class MyDataSet(DataSet):
    def __init__(self, data, data_type, shared=None, valid_idxs=None):
        super(MyDataSet, self).__init__(data, data_type, shared=shared, valid_idxs=valid_idxs)
        shared['max_num_sents'] = len(self.get_one(self.num_examples-1)['x'])

    def _sort_key(self, idx):
        return idx


def read_data(config, data_type, ref, data_filter=None):
    shared_path = os.path.join(config.data_dir, "shared_{}.json".format(data_type))
    with open(shared_path, 'r') as fh:
        shared = json.load(fh)

    paths = shared['sorted']
    if config.filter_ratio < 1.0:
        stop = int(round(len(paths) * config.filter_ratio))
        paths = paths[:stop]
    num_examples = len(paths)
    valid_idxs = range(num_examples)

    print("Loaded {}/{} examples from {}".format(len(valid_idxs), num_examples, data_type))

    shared_path = config.shared_path or os.path.join(config.out_dir, "shared.json")
    if not ref:
        word2vec_dict = shared['lower_word2vec'] if config.lower_word else shared['word2vec']
        word_counter = shared['lower_word_counter'] if config.lower_word else shared['word_counter']
        char_counter = shared['char_counter']
        if config.finetune:
            shared['word2idx'] = {word: idx + 3 for idx, word in
                                  enumerate(word for word, count in word_counter.items()
                                            if count > config.word_count_th or (config.known_if_glove and word in word2vec_dict))}
        else:
            assert config.known_if_glove
            assert config.use_glove_for_unk
            shared['word2idx'] = {word: idx + 3 for idx, word in
                                  enumerate(word for word, count in word_counter.items()
                                            if count > config.word_count_th and word not in word2vec_dict)}
        shared['char2idx'] = {char: idx + 2 for idx, char in
                              enumerate(char for char, count in char_counter.items()
                                        if count > config.char_count_th)}
        NULL = "-NULL-"
        UNK = "-UNK-"
        ENT = "-ENT-"
        shared['word2idx'][NULL] = 0
        shared['word2idx'][UNK] = 1
        shared['word2idx'][ENT] = 2
        shared['char2idx'][NULL] = 0
        shared['char2idx'][UNK] = 1

        json.dump({'word2idx': shared['word2idx'], 'char2idx': shared['char2idx']}, open(shared_path, 'w'))
    else:
        new_shared = json.load(open(shared_path, 'r'))
        for key, val in new_shared.items():
            shared[key] = val

    if config.use_glove_for_unk:
        # create new word2idx and word2vec
        word2vec_dict = shared['lower_word2vec'] if config.lower_word else shared['word2vec']
        new_word2idx_dict = {word: idx for idx, word in enumerate(word for word in word2vec_dict.keys() if word not in shared['word2idx'])}
        shared['new_word2idx'] = new_word2idx_dict
        offset = len(shared['word2idx'])
        word2vec_dict = shared['lower_word2vec'] if config.lower_word else shared['word2vec']
        new_word2idx_dict = shared['new_word2idx']
        idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()}
        # print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict)))
        new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
        shared['new_emb_mat'] = new_emb_mat

    data = MyData(config, os.path.join(config.root_dir, data_type), paths)
    data_set = MyDataSet(data, data_type, shared=shared, valid_idxs=valid_idxs)
    return data_set


def get_cnn_data_filter(config):
    return True


def update_config(config, data_sets):
    config.max_num_sents = 0
    config.max_sent_size = 0
    config.max_ques_size = 0
    config.max_word_size = 0
    for data_set in data_sets:
        shared = data_set.shared
        config.max_sent_size = max(config.max_sent_size, shared['max_sent_size'])
        config.max_ques_size = max(config.max_ques_size, shared['max_ques_size'])
        config.max_word_size = max(config.max_word_size, shared['max_word_size'])
        config.max_num_sents = max(config.max_num_sents, shared['max_num_sents'])

    config.max_word_size = min(config.max_word_size, config.word_size_th)

    config.char_vocab_size = len(data_sets[0].shared['char2idx'])
    config.word_emb_size = len(next(iter(data_sets[0].shared['word2vec'].values())))
    config.word_vocab_size = len(data_sets[0].shared['word2idx'])