"""
A script to embed every phrase in a dataset as a dense vector, then
to find the top-k neighbors of each phrase according to cosine
similarity.

1. Install missing dependencies.

    # More details: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md
    conda install faiss-cpu -c pytorch

2. Prepare data. For example, the chunking dataset from CoNLL 2000.

    wget https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz
    gunzip train.txt.gz
    python diora/misc/convert_conll_to_jsonl.py --path train.txt > conll-train.jsonl

3. Run this script.

    python diora/scripts/phrase_embed.py \
        --batch_size 10 \
        --emb w2v \
        --embeddings_path ~/data/glove.6B/glove.6B.50d.txt \
        --hidden_dim 50 \
        --log_every_batch 100 \
        --save_after 1000 \
        --data_type conll_jsonl \
        --validation_path ./conll-train.jsonl \
        --validation_filter_length 10

Can control the number of neighbors to show with the `--k_top` flag.

Can control the number of candidates to consider with `--k_candidates` flag.

"""


import json
import types
import itertools

import torch

import numpy as np

from train import argument_parser, parse_args, configure
from train import get_validation_dataset, get_validation_iterator
from train import build_net

from diora.logging.configuration import get_logger

try:
    import faiss
    from faiss import normalize_L2
except:
    print('Could not import `faiss`, which is used to find nearest neighbors.')


def get_cell_index(entity_labels, i_label=0, i_pos=1, i_size=2):
    def helper():
        for i, lst in enumerate(entity_labels):
            for el in lst:
                if el is None:
                    continue
                pos = el[i_pos]
                size = el[i_size]
                label = el[i_label]
                yield (i, pos, size, label)
    lst = list(helper())
    if len(lst) == 0:
        return None, []
    batch_index = [x[0] for x in lst]
    positions = [x[1] for x in lst]
    sizes = [x[2] for x in lst]
    labels = [x[3] for x in lst]

    return batch_index, positions, sizes, labels


def get_many_cells(diora, chart, batch_index, positions, sizes):
    cells = []
    length = diora.length

    idx = []
    for bi, pos, size in zip(batch_index, positions, sizes):
        level = size - 1
        offset = diora.index.get_offset(length)[level]
        absolute_pos = offset + pos
        idx.append(absolute_pos)

    cells = chart[batch_index, idx]

    return cells


def get_many_phrases(batch, batch_index, positions, sizes):
    batch = batch.tolist()
    lst = []
    for bi, pos, size in zip(batch_index, positions, sizes):
        phrase = tuple(batch[bi][pos:pos+size])
        lst.append(phrase)
    return lst


class BatchRecorder(object):
    def __init__(self, dtype={}):
        super(BatchRecorder, self).__init__()
        self.cache = {}
        self.dtype = dtype
        self.dtype2flatten = {
            'list': self._flatten_list,
            'np': self._flatten_np,
            'torch': self._flatten_torch,
        }

    def _flatten_list(self, v):
        return list(itertools.chain(*v))

    def _flatten_np(self, v):
        return np.concatenate(v, axis=0)

    def _flatten_torch(self, v):
        return torch.cat(v, 0).cpu().data.numpy()

    def get_flattened_result(self):
        def helper():
            for k, v in self.cache.items():
                flatten = self.dtype2flatten[self.dtype.get(k, 'list')]
                yield k, flatten(v)
        return {k: v for k, v in helper()}
            
    def record(self, **kwargs):
        for k, v in kwargs.items():
            self.cache.setdefault(k, []).append(v)


class Index(object):
    def __init__(self, dim=None):
        super(Index, self).__init__()
        self.D, self.I = None, None
        self.index = faiss.IndexFlatIP(dim)

    def add(self, vecs):
        self.index.add(vecs)

    def cache(self, vecs, k):
        self.D, self.I = self.index.search(vecs, k)

    def topk(self, q, k):
        for j in range(k):
            idx = self.I[q][j]
            dist = self.D[q][j]
            yield idx, dist


class NearestNeighborsLookup(object):
    def __init__(self):
        super(NearestNeighborsLookup, self).__init__()


def run(options):
    logger = get_logger()

    validation_dataset = get_validation_dataset(options)
    validation_iterator = get_validation_iterator(options, validation_dataset)
    word2idx = validation_dataset['word2idx']
    embeddings = validation_dataset['embeddings']

    idx2word = {v: k for k, v in word2idx.items()}

    logger.info('Initializing model.')
    trainer = build_net(options, embeddings, validation_iterator)
    diora = trainer.net.diora

    # 1. Get all relevant phrase vectors.

    dtype = {
        'example_ids': 'list',
        'labels': 'list',
        'positions': 'list',
        'sizes': 'list',
        'phrases': 'list',
        'inside': 'torch',
        'outside': 'torch',
    }
    batch_recorder = BatchRecorder(dtype=dtype)

    ## Eval mode.
    trainer.net.eval()

    batches = validation_iterator.get_iterator(random_seed=options.seed)

    logger.info('Beginning to embed phrases.')

    with torch.no_grad():
        for i, batch_map in enumerate(batches):
            sentences = batch_map['sentences']
            batch_size = sentences.shape[0]
            length = sentences.shape[1]

            # Skips very short examples.
            if length <= 2:
                continue

            _ = trainer.step(batch_map, train=False, compute_loss=False)

            entity_labels = batch_map['entity_labels']
            batch_index, positions, sizes, labels = get_cell_index(entity_labels)

            # Skip short phrases.
            batch_index = [x for x, y in zip(batch_index, sizes) if y >= 2]
            positions = [x for x, y in zip(positions, sizes) if y >= 2]
            labels = [x for x, y in zip(labels, sizes) if y >= 2]
            sizes = [y for y in sizes if y >= 2]

            cell_index = (batch_index, positions, sizes)

            batch_result = {}
            batch_result['example_ids'] = [batch_map['example_ids'][idx] for idx in cell_index[0]]
            batch_result['labels'] = labels
            batch_result['positions'] = cell_index[1]
            batch_result['sizes'] = cell_index[2]
            batch_result['phrases'] = get_many_phrases(sentences, *cell_index)
            batch_result['inside'] = get_many_cells(diora, diora.inside_h, *cell_index)
            batch_result['outside'] = get_many_cells(diora, diora.outside_h, *cell_index)

            batch_recorder.record(**batch_result)

    result = batch_recorder.get_flattened_result()

    # 2. Build an index of nearest neighbors.

    vectors = np.concatenate([result['inside'], result['outside']], axis=1)
    normalize_L2(vectors)

    index = Index(dim=vectors.shape[1])
    index.add(vectors)
    index.cache(vectors, options.k_candidates)

    # 3. Print a summary.

    example_ids = result['example_ids']
    phrases = result['phrases']

    assert len(example_ids) == len(phrases)
    assert len(example_ids) == vectors.shape[0]

    def stringify(phrase):
        return ' '.join([idx2word[idx] for idx in phrase])

    for i in range(vectors.shape[0]):
        topk = []

        for j, score in index.topk(i, options.k_candidates):
            # Skip same example.
            if example_ids[i] == example_ids[j]:
                continue
            # Skip string match.
            if phrases[i] == phrases[j]:
                continue
            topk.append((j, score))
            if len(topk) == options.k_top:
                break
        assert len(topk) == options.k_top, 'Did not find enough valid candidates.'

        # Print.
        print('[query] example_id={} phrase={}'.format(
            example_ids[i], stringify(phrases[i])))
        for rank, (j, score) in enumerate(topk):
            print('rank={} score={:.3f} example_id={} phrase={}'.format(
                rank, score, example_ids[j], stringify(phrases[j])))


if __name__ == '__main__':
    parser = argument_parser()
    parser.add_argument('--k_candidates', default=100, type=int)
    parser.add_argument('--k_top', default=3, type=int)
    options = parse_args(parser)
    configure(options)

    run(options)