# -*- coding: utf-8 -*-

import codecs
from collections import Counter, defaultdict
from itertools import chain, count

import torch
import torchtext.data
import torchtext.vocab

PAD_WORD = '<blank>'
UNK = 0
BOS_WORD = '<s>'
EOS_WORD = '</s>'


def __getstate__(self):
    return dict(self.__dict__, stoi=dict(self.stoi))


def __setstate__(self, state):
    self.__dict__.update(state)
    self.stoi = defaultdict(lambda: 0, self.stoi)


torchtext.vocab.Vocab.__getstate__ = __getstate__
torchtext.vocab.Vocab.__setstate__ = __setstate__


def extract_features(tokens):
    "Given a list of token separate out words and features (if any)."
    if not tokens:
        return [], [], -1
    split_tokens = [token.split(u"│") for token in tokens]
    split_tokens = [token for token in split_tokens if token[0]]
    token_size = len(split_tokens[0])
    assert all(len(token) == token_size for token in split_tokens), \
        "all words must have the same number of features"
    words_and_features = list(zip(*split_tokens))
    words = words_and_features[0]
    features = words_and_features[1:]
    return words, features, token_size - 1


def merge_vocabs(vocabs, vocab_size=None):
    """
    Merge individual vocabularies (assumed to be generated from disjoint
    documents) into a larger vocabulary.

    Args:
        vocabs: `torchtext.vocab.Vocab` vocabularies to be merged
        vocab_size: `int` the final vocabulary size. `None` for no limit.
    Return:
        `torchtext.vocab.Vocab`
    """
    merged = Counter(chain(*[vocab.freqs for vocab in vocabs]))
    return torchtext.vocab.Vocab(merged,
                                 specials=[PAD_WORD, BOS_WORD, EOS_WORD],
                                 max_size=vocab_size)


def make_features(batch, side):
    """
    Args:
        batch (Variable): a batch of source or target data.
        side (str): for source or for target.
    Returns:
        A sequence of src/tgt tensors with optional feature tensors
        of size (len x batch).
    """
    assert side in ['src', 'tgt']
    if isinstance(batch.__dict__[side], tuple):
        data = batch.__dict__[side][0]
    else:
        data = batch.__dict__[side]
    feat_start = side + "_feat_"
    features = sorted(batch.__dict__[k]
                      for k in batch.__dict__ if feat_start in k)
    levels = [data] + features
    return torch.cat([level.unsqueeze(2) for level in levels], 2)


def join_dicts(*args):
    """
    args: dictionaries with disjoint keys
    returns: a single dictionary that has the union of these keys
    """
    return dict(chain(*[d.items() for d in args]))


class OrderedIterator(torchtext.data.Iterator):
    def create_batches(self):
        if self.train:
            self.batches = torchtext.data.pool(
                self.data(), self.batch_size,
                self.sort_key, self.batch_size_fn,
                random_shuffler=self.random_shuffler)
        else:
            self.batches = []
            for b in torchtext.data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))


class ONMTDataset(torchtext.data.Dataset):
    """Defines a dataset for machine translation."""

    @staticmethod
    def sort_key(ex):
        "Sort in reverse size order"
        return -len(ex.src)

    def __init__(self, src_path, tgt_path, fields, opt, src_img_dir=None, **kwargs):
        """
        Create a TranslationDataset given paths and fields.

        src_path: location of source-side data
        tgt_path: location of target-side data or None. If it exists, it
                  source and target data must be the same length.
        fields:
        src_img_dir: if not None, uses images instead of text for the
                     source. TODO: finish
        """
        if src_img_dir:
            self.type_ = "img"
        else:
            self.type_ = "text"

        if self.type_ == "text":
            self.src_vocabs = []
            src_truncate = 0 if opt is None else opt.src_seq_length_trunc
            src_point = next(self._read_corpus_file(src_path, src_truncate))
            self.nfeatures = src_point[2]
            src_data = self._read_corpus_file(src_path, src_truncate)
            src_examples = self._construct_examples(src_data, "src")
        else:
            # TODO finish this.
            if not transforms:
                load_image_libs()

        if tgt_path is not None:
            tgt_truncate = 0 if opt is None else opt.tgt_seq_length_trunc
            tgt_data = self._read_corpus_file(tgt_path, tgt_truncate)
            # assert len(src_data) == len(tgt_data), \
            #     "Len src and tgt do not match"
            tgt_examples = self._construct_examples(tgt_data, "tgt")
        else:
            tgt_examples = None

        # examples: one for each src line or (src, tgt) line pair.
        # Each element is a dictionary whose keys represent at minimum
        # the src tokens and their indices and potentially also the
        # src and tgt features and alignment information.
        if tgt_examples is not None:
            examples = (join_dicts(src, tgt) for src, tgt in zip(src_examples, tgt_examples))
        else:
            examples = src_examples

        def dynamic_dict(examples):
            for example in examples:
                src = example["src"]
                src_vocab = torchtext.vocab.Vocab(Counter(src))
                self.src_vocabs.append(src_vocab)
                # mapping source tokens to indices in the dynamic dict
                src_map = torch.LongTensor([src_vocab.stoi[w] for w in src])
                example["src_map"] = src_map

                if "tgt" in example:
                    tgt = example["tgt"]
                    mask = torch.LongTensor([0] + [src_vocab.stoi[w] for w in tgt] + [0])
                    example["alignment"] = mask
                yield example

        if opt is None or opt.dynamic_dict:
            examples = dynamic_dict(examples)

        # Peek at the first to see which fields are used.
        ex = next(examples)
        keys = ex.keys()
        fields = [(k, fields[k]) for k in (list(keys) + ["indices"])]

        def construct_final(examples):
            for i, ex in enumerate(examples):
                yield torchtext.data.Example.fromlist([ex[k] for k in keys] + [i], fields)

        def filter_pred(example):
            if tgt_examples is None:
                return 0 < len(example.src) <= opt.src_seq_length
            else:
                return 0 < len(example.src) <= opt.src_seq_length and 0 < len(example.tgt) <= opt.tgt_seq_length

        super(ONMTDataset, self).__init__(
            list(construct_final(chain([ex], examples))),
            fields,
            filter_pred if opt is not None else None)

    def _read_corpus_file(self, path, truncate):
        """
        path: location of a src or tgt file
        truncate: maximum sequence length (0 for unlimited)

        returns: (word, features, nfeat) triples for each line
        """
        with codecs.open(path, "r", "utf-8") as corpus_file:
            lines = (line.split() for line in corpus_file)
            if truncate:
                lines = (line[:truncate] for line in lines)
            for line in lines:
                yield extract_features(line)

    def _construct_examples(self, lines, side):
        assert side in ["src", "tgt"]
        for line in lines:
            words, feats, _ = line
            example_dict = {side: words}
            if feats:
                prefix = side + "_feat_"
                example_dict.update((prefix + str(j), f) for j, f in enumerate(feats))
            yield example_dict

    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, d):
        self.__dict__.update(d)

    def __reduce_ex__(self, proto):
        "This is a hack. Something is broken with torch pickle."
        return super(ONMTDataset, self).__reduce_ex__()

    def collapse_copy_scores(self, scores, batch, tgt_vocab):
        """Given scores from an expanded dictionary
        corresponeding to a batch, sums together copies,
        with a dictionary word when it is ambigious.
        """
        offset = len(tgt_vocab)
        for b in range(batch.batch_size):
            index = batch.indices.data[b]
            src_vocab = self.src_vocabs[index]
            for i in range(1, len(src_vocab)):
                sw = src_vocab.itos[i]
                ti = tgt_vocab.stoi[sw]
                if ti != 0:
                    scores[:, b, ti] += scores[:, b, offset + i]
                    scores[:, b, offset + i].fill_(1e-20)
        return scores

    @staticmethod
    def load_fields(vocab):
        vocab = dict(vocab)
        fields = ONMTDataset.get_fields(len(ONMTDataset.collect_features(vocab)))
        for k, v in vocab.items():
            # Hack. Can't pickle defaultdict :(
            v.stoi = defaultdict(lambda: 0, v.stoi)
            fields[k].vocab = v
        return fields

    @staticmethod
    def save_vocab(fields):
        vocab = []
        for k, f in fields.items():
            if 'vocab' in f.__dict__:
                f.vocab.stoi = dict(f.vocab.stoi)
                vocab.append((k, f.vocab))
        return vocab

    @staticmethod
    def collect_features(fields, side="src"):
        assert side in ["src", "tgt"]
        feats = []
        for j in count():
            key = side + "_feat_" + str(j)
            if key not in fields:
                break
            feats.append(key)
        return feats

    @staticmethod
    def collect_feature_dicts(fields):
        feature_dicts = []
        for j in count():
            key = "src_feat_" + str(j)
            if key not in fields:
                break
            feature_dicts.append(fields[key].vocab)
        return feature_dicts

    @staticmethod
    def get_fields(nFeatures=0):
        fields = {}
        fields["src"] = torchtext.data.Field(
            pad_token=PAD_WORD,
            include_lengths=True)

        # fields = [("src_img", torchtext.data.Field(
        #     include_lengths=True))]

        for j in range(nFeatures):
            fields["src_feat_"+str(j)] = \
                torchtext.data.Field(pad_token=PAD_WORD)

        fields["tgt"] = torchtext.data.Field(
            init_token=BOS_WORD, eos_token=EOS_WORD,
            pad_token=PAD_WORD)

        def make_src(data, _):
            src_size = max([t.size(0) for t in data])
            src_vocab_size = max([t.max() for t in data]) + 1
            alignment = torch.zeros(src_size, len(data), src_vocab_size)
            for i, sent in enumerate(data):
                for j, t in enumerate(sent):
                    alignment[j, i, t] = 1
            return alignment

        fields["src_map"] = torchtext.data.Field(
            use_vocab=False, tensor_type=torch.FloatTensor,
            postprocessing=make_src, sequential=False)

        def make_tgt(data, _):
            tgt_size = max([t.size(0) for t in data])
            alignment = torch.zeros(tgt_size, len(data)).long()
            for i, sent in enumerate(data):
                alignment[:sent.size(0), i] = sent
            return alignment

        fields["alignment"] = torchtext.data.Field(
            use_vocab=False, tensor_type=torch.LongTensor,
            postprocessing=make_tgt, sequential=False)

        fields["indices"] = torchtext.data.Field(
            use_vocab=False, tensor_type=torch.LongTensor,
            sequential=False)

        return fields

    @staticmethod
    def build_vocab(train, opt):
        fields = train.fields
        fields["src"].build_vocab(train, max_size=opt.src_vocab_size,
                                  min_freq=opt.src_words_min_frequency)
        for j in range(train.nfeatures):
            fields["src_feat_" + str(j)].build_vocab(train)
        fields["tgt"].build_vocab(train, max_size=opt.tgt_vocab_size,
                                  min_freq=opt.tgt_words_min_frequency)

        # Merge the input and output vocabularies.
        if opt.share_vocab:
            # `tgt_vocab_size` is ignored when sharing vocabularies
            merged_vocab = merge_vocabs(
                [fields["src"].vocab, fields["tgt"].vocab],
                vocab_size=opt.src_vocab_size)
            fields["src"].vocab = merged_vocab
            fields["tgt"].vocab = merged_vocab


def load_image_libs():
    "Conditional import of torch image libs."
    global Image, transforms
    from PIL import Image
    from torchvision import transforms