from typing import Dict, Optional, List, Any
import copy
import re

from overrides import overrides
import torch

from allennlp.common import Params
from allennlp.common.checks import check_dimensions_match
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import FeedForward, InputVariationalDropout, Seq2VecEncoder
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TextFieldEmbedder
from allennlp.modules.time_distributed import TimeDistributed
from allennlp.modules.similarity_functions.dot_product import DotProductSimilarity
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, masked_softmax, weighted_sum, replace_masked_values, get_mask_from_sequence_lengths
from allennlp.training.metrics import CategoricalAccuracy

from lib.modules import CoverageLoss
from lib.nn.util import unbind_tensor_dict
from lib.models.multee_esim import MulteeEsim


@Model.register("single_correct_mcq_multee_esim")
class SingleCorrectMcqMulteeEsim(MulteeEsim):

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 projection_feedforward: FeedForward,
                 inference_encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForward,
                 output_logit: FeedForward,
                 final_feedforward: FeedForward,
                 coverage_loss: CoverageLoss,
                 similarity_function: SimilarityFunction = DotProductSimilarity(),
                 dropout: float = 0.5,
                 contextualize_pair_comparators: bool = False,
                 pair_context_encoder: Seq2SeqEncoder = None,
                 pair_feedforward: FeedForward = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
    	# Need to send it verbatim because otherwise FromParams doesn't work appropriately.
        super().__init__(vocab=vocab,
                         text_field_embedder=text_field_embedder,
                         encoder=encoder,
                         similarity_function=similarity_function,
                         projection_feedforward=projection_feedforward,
                         inference_encoder=inference_encoder,
                         output_feedforward=output_feedforward,
                         output_logit=output_logit,
                         final_feedforward=final_feedforward,
                         contextualize_pair_comparators=contextualize_pair_comparators,
                         coverage_loss=coverage_loss,
                         pair_context_encoder=pair_context_encoder,
                         pair_feedforward=pair_feedforward,
                         dropout=dropout,
                         initializer=initializer,
                         regularizer=regularizer)
        self._answer_loss = torch.nn.CrossEntropyLoss()

        self._accuracy = CategoricalAccuracy()

    @overrides
    def forward(self,  # type: ignore
                premises: Dict[str, torch.LongTensor],
                hypotheses: Dict[str, torch.LongTensor],
                paragraph: Dict[str, torch.LongTensor],
                answer_index: torch.LongTensor = None,
                relevance_presence_mask: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        hypothesis_list = unbind_tensor_dict(hypotheses, dim=1)

        label_logits = []
        premises_attentions = []
        premises_aggregation_attentions = []
        coverage_losses = []
        for hypothesis in hypothesis_list:
            output_dict = super().forward(premises=premises, hypothesis=hypothesis, paragraph=paragraph)
            individual_logit = output_dict["label_logits"][:, self._label2idx["entailment"]] # only useful key
            label_logits.append(individual_logit)

            premises_attention = output_dict.get("premises_attention", None)
            premises_attentions.append(premises_attention)
            premises_aggregation_attention = output_dict.get("premises_aggregation_attention", None)
            premises_aggregation_attentions.append(premises_aggregation_attention)
            if relevance_presence_mask is not None:
                coverage_loss = output_dict["coverage_loss"]
                coverage_losses.append(coverage_loss)

        label_logits = torch.stack(label_logits, dim=-1)
        premises_attentions = torch.stack(premises_attentions, dim=1)
        premises_aggregation_attentions = torch.stack(premises_aggregation_attentions, dim=1)
        if relevance_presence_mask is not None:
            coverage_losses = torch.stack(coverage_losses, dim=0)

        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
        output_dict = {"label_logits": label_logits,
                       "label_probs": label_probs,
                       "premises_attentions": premises_attentions,
                       "premises_aggregation_attentions": premises_aggregation_attentions}

        if answer_index is not None:
            # answer_loss
            loss = self._answer_loss(label_logits, answer_index)
            # coverage loss
            if relevance_presence_mask is not None:
                loss += coverage_losses.mean()
            output_dict["loss"] = loss

            self._accuracy(label_logits, answer_index)

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        accuracy_metric = self._accuracy.get_metric(reset)
        return {'accuracy': accuracy_metric}