#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import division
import six
import sys
import numpy as np
import argparse
import torch


def get_vocabs(dict_file):
    vocabs = torch.load(dict_file)

    enc_vocab, dec_vocab = None, None

    # the vocab object is a list of tuple (name, torchtext.Vocab)
    # we iterate over this list and associate vocabularies based on the name
    for vocab in vocabs:
        if vocab[0] == 'src':
            enc_vocab = vocab[1]
        if vocab[0] == 'tgt':
            dec_vocab = vocab[1]
    assert type(None) not in [type(enc_vocab), type(dec_vocab)]

    print("From: %s" % dict_file)
    print("\t* source vocab: %d words" % len(enc_vocab))
    print("\t* target vocab: %d words" % len(dec_vocab))

    return enc_vocab, dec_vocab


def get_embeddings(file_enc, opt, flag):
    embs = dict()
    if flag == 'enc':
        for (i, l) in enumerate(open(file_enc, 'rb')):
            if i < opt.skip_lines:
                continue
            if not l:
                break
            if len(l) == 0:
                continue

            l_split = l.decode('utf8').strip().split(' ')
            if len(l_split) == 2:
                continue
            embs[l_split[0]] = [float(em) for em in l_split[1:]]
        print("Got {} encryption embeddings from {}".format(len(embs),
                                                            file_enc))
    else:

        for (i, l) in enumerate(open(file_enc, 'rb')):
            if not l:
                break
            if len(l) == 0:
                continue

            l_split = l.decode('utf8').strip().split(' ')
            if len(l_split) == 2:
                continue
            embs[l_split[0]] = [float(em) for em in l_split[1:]]
        print("Got {} decryption embeddings from {}".format(len(embs),
                                                            file_enc))

    return embs


def match_embeddings(vocab, emb, opt):
    dim = len(six.next(six.itervalues(emb)))
    filtered_embeddings = np.zeros((len(vocab), dim))
    count = {"match": 0, "miss": 0}
    for w, w_id in vocab.stoi.items():
        if w in emb:
            filtered_embeddings[w_id] = emb[w]
            count['match'] += 1
        else:
            if opt.verbose:
                print(u"not found:\t{}".format(w), file=sys.stderr)
            count['miss'] += 1

    return torch.Tensor(filtered_embeddings), count


TYPES = ["GloVe", "word2vec"]


def main():

    parser = argparse.ArgumentParser(description='embeddings_to_torch.py')
    parser.add_argument('-emb_file_enc', required=True,
                        help="source Embeddings from this file")
    parser.add_argument('-emb_file_dec', required=True,
                        help="target Embeddings from this file")
    parser.add_argument('-output_file', required=True,
                        help="Output file for the prepared data")
    parser.add_argument('-dict_file', required=True,
                        help="Dictionary file")
    parser.add_argument('-verbose', action="store_true", default=False)
    parser.add_argument('-skip_lines', type=int, default=0,
                        help="Skip first lines of the embedding file")
    parser.add_argument('-type', choices=TYPES, default="GloVe")
    opt = parser.parse_args()

    enc_vocab, dec_vocab = get_vocabs(opt.dict_file)
    if opt.type == "word2vec":
        opt.skip_lines = 1

    embeddings_enc = get_embeddings(opt.emb_file_enc, opt, flag='enc')
    embeddings_dec = get_embeddings(opt.emb_file_dec, opt, flag='dec')

    filtered_enc_embeddings, enc_count = match_embeddings(enc_vocab,
                                                          embeddings_enc,
                                                          opt)
    filtered_dec_embeddings, dec_count = match_embeddings(dec_vocab,
                                                          embeddings_dec,
                                                          opt)
    print("\nMatching: ")
    match_percent = [_['match'] / (_['match'] + _['miss']) * 100
                     for _ in [enc_count, dec_count]]
    print("\t* enc: %d match, %d missing, (%.2f%%)" % (enc_count['match'],
                                                       enc_count['miss'],
                                                       match_percent[0]))
    print("\t* dec: %d match, %d missing, (%.2f%%)" % (dec_count['match'],
                                                       dec_count['miss'],
                                                       match_percent[1]))

    print("\nFiltered embeddings:")
    print("\t* enc: ", filtered_enc_embeddings.size())
    print("\t* dec: ", filtered_dec_embeddings.size())

    enc_output_file = opt.output_file + ".enc.pt"
    dec_output_file = opt.output_file + ".dec.pt"
    print("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s"
          % (enc_output_file, dec_output_file))
    torch.save(filtered_enc_embeddings, enc_output_file)
    torch.save(filtered_dec_embeddings, dec_output_file)
    print("\nDone.")


if __name__ == "__main__":
    main()