from typing import List, Optional, Dict, Tuple
import os
from collections import defaultdict
import pickle

from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

from qanta.guesser.abstract import AbstractGuesser
from qanta.datasets.abstract import QuestionText


class TfidfGuesser(AbstractGuesser):
    def __init__(self, config_num: Optional[int]):
        super().__init__(config_num)
        self.tfidf_vectorizer = None
        self.tfidf_matrix = None
        self.i_to_ans = None

    def train(self, training_data) -> None:
        questions = training_data[0]
        answers = training_data[1]
        answer_docs = defaultdict(str)
        for q, ans in zip(questions, answers):
            text = ' '.join(q)
            answer_docs[ans] += ' ' + text

        x_array = []
        y_array = []
        for ans, doc in answer_docs.items():
            x_array.append(doc)
            y_array.append(ans)

        self.i_to_ans = {i: ans for i, ans in enumerate(y_array)}
        self.tfidf_vectorizer = TfidfVectorizer(
            ngram_range=(1, 3), min_df=2, max_df=.9
        ).fit(x_array)
        self.tfidf_matrix = self.tfidf_vectorizer.transform(x_array)

    def guess(self, questions: List[QuestionText], max_n_guesses: Optional[int]) -> List[List[Tuple[str, float]]]:
        representations = self.tfidf_vectorizer.transform(questions)
        guess_matrix = self.tfidf_matrix.dot(representations.T).T
        guess_scores = guess_matrix.max(axis=1).toarray().reshape(-1)
        guess_indices = np.array(guess_matrix.argmax(axis=1)).reshape(-1)
        guesses = []
        for i in range(len(questions)):
            idx = guess_indices[i]
            score = guess_scores[i]
            guesses.append([(self.i_to_ans[idx], score)])

        return guesses

    def save(self, directory: str) -> None:
        with open(os.path.join(directory, 'params.pickle'), 'wb') as f:
            pickle.dump({
                'config_num': self.config_num,
                'i_to_ans': self.i_to_ans,
                'tfidf_vectorizer': self.tfidf_vectorizer,
                'tfidf_matrix': self.tfidf_matrix
            }, f)

    @classmethod
    def load(cls, directory: str):
        with open(os.path.join(directory, 'params.pickle'), 'rb') as f:
            params = pickle.load(f)
            guesser = TfidfGuesser(params['config_num'])
            guesser.tfidf_vectorizer = params['tfidf_vectorizer']
            guesser.tfidf_matrix = params['tfidf_matrix']
            guesser.i_to_ans = params['i_to_ans']
            return guesser

    @classmethod
    def targets(cls) -> List[str]:
        return ['params.pickle']