# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description: 
"""

import time

from gensim import corpora, models, similarities

from dialogbot.reader.data_helper import load_corpus_file
from dialogbot.utils.logger import logger


class TfidfModel:
    def __init__(self, corpus_file, word2id):
        time_s = time.time()
        self.contexts, self.responses = load_corpus_file(corpus_file, word2id, size=50000)

        self._train_model()
        self.corpus_mm = self.tfidf_model[self.corpus]
        self.index = similarities.MatrixSimilarity(self.corpus_mm)
        logger.debug("Time to build tfidf model by %s: %2.f seconds." % (corpus_file, time.time() - time_s))

    def _train_model(self, min_freq=1):
        # Create tfidf model.
        self.dct = corpora.Dictionary(self.contexts)
        # Filter low frequency words from dictionary.
        low_freq_ids = [id_ for id_, freq in
                        self.dct.dfs.items() if freq <= min_freq]
        self.dct.filter_tokens(low_freq_ids)
        self.dct.compactify()
        # Build tfidf model.
        self.corpus = [self.dct.doc2bow(s) for s in self.contexts]
        self.tfidf_model = models.TfidfModel(self.corpus)

    def _text2vec(self, text):
        bow = self.dct.doc2bow(text)
        return self.tfidf_model[bow]

    def similarity(self, query, size=10):
        vec = self._text2vec(query)
        sims = self.index[vec]
        sim_sort = sorted(list(enumerate(sims)),
                          key=lambda item: item[1], reverse=True)
        return sim_sort[:size]

    def get_docs(self, sim_items):
        docs = [self.contexts[id_] for id_, score in sim_items]
        answers = [self.responses[id_] for id_, score in sim_items]
        return docs, answers