from typing import Dict, Optional, List
from torch.nn import Linear
import torch
from torch.autograd import Variable
from torch.nn.functional import normalize
from allennlp.common import Params
from allennlp.common.checks import check_dimensions_match
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import FeedForward
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TimeDistributed, TextFieldEmbedder
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, last_dim_softmax, weighted_sum, replace_masked_values
from allennlp.training.metrics import CategoricalAccuracy
from endtasks import util
from endtasks.modules import VariationalDropout


@Model.register("esim-pair2vec")
class ESIMPair2Vec(Model):
    """
    This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference"
    <https://www.semanticscholar.org/paper/Enhanced-LSTM-for-Natural-Language-Inference-Chen-Zhu/83e7654d545fbbaaf2328df365a781fb67b841b4>`_
    by Chen et al., 2017.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
        model.
    attend_feedforward : ``FeedForward``
        This feedforward network is applied to the encoded sentence representations before the
        similarity matrix is computed between words in the premise and words in the hypothesis.
    similarity_function : ``SimilarityFunction``
        This is the similarity function used when computing the similarity matrix between words in
        the premise and words in the hypothesis.
    compare_feedforward : ``FeedForward``
        This feedforward network is applied to the aligned premise and hypothesis representations,
        individually.
    aggregate_feedforward : ``FeedForward``
        This final feedforward network is applied to the concatenated, summed result of the
        ``compare_feedforward`` network, and its output is used as the entailment class logits.
    premise_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the premise, we can optionally apply an encoder.  If this is ``None``, we
        will do nothing.
    hypothesis_encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        After embedding the hypothesis, we can optionally apply an encoder.  If this is ``None``,
        we will use the ``premise_encoder`` for the encoding (doing nothing if ``premise_encoder``
        is also ``None``).
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self, vocab: Vocabulary,
                 encoder_keys: List[str],
                 mask_key: str,
                 pair2vec_config_file: str,
                 pair2vec_model_file: str,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 projection_feedforward: FeedForward,
                 inference_encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForward,
                 output_logit: FeedForward,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 dropout: float = 0.5,
                 pair2vec_dropout: float = 0.0,
                 bidirectional_pair2vec: bool = True,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)
        self._vocab = vocab
        self.pair2vec = util.get_pair2vec(pair2vec_config_file, pair2vec_model_file)
        self._encoder_keys = encoder_keys
        self._mask_key = mask_key
        self._text_field_embedder = text_field_embedder
        self._projection_feedforward = projection_feedforward
        self._encoder = encoder
        from allennlp.modules.matrix_attention import DotProductMatrixAttention

        self._matrix_attention = DotProductMatrixAttention()


        self._inference_encoder = inference_encoder
        self._pair2vec_dropout = torch.nn.Dropout(pair2vec_dropout)
        self._bidirectional_pair2vec = bidirectional_pair2vec

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.rnn_input_dropout = VariationalDropout(dropout)
        else:
            self.dropout = None
            self.rnn_input_dropout = None

        self._output_feedforward = output_feedforward
        self._output_logit = output_logit

        self._num_labels = vocab.get_vocab_size(namespace="labels")


        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self)

    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: Dict = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = util.get_encoder_input(self._text_field_embedder, premise, self._encoder_keys)
        embedded_hypothesis = util.get_encoder_input(self._text_field_embedder, hypothesis, self._encoder_keys)
        premise_as_args = util.get_pair2vec_word_embeddings(self.pair2vec, premise['pair2vec_tokens'])
        hypothesis_as_args = util.get_pair2vec_word_embeddings(self.pair2vec, hypothesis['pair2vec_tokens'])

        premise_mask = util.get_mask(premise, self._mask_key).float()
        hypothesis_mask = util.get_mask(hypothesis, self._mask_key).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)


        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # cross sequence embeddings
        ph_pair_embeddings = normalize(util.get_pair_embeddings(self.pair2vec, premise_as_args, hypothesis_as_args), dim=-1)
        hp_pair_embeddings = normalize(util.get_pair_embeddings(self.pair2vec, hypothesis_as_args, premise_as_args), dim=-1)
        if self._bidirectional_pair2vec:
            temp = torch.cat((ph_pair_embeddings, hp_pair_embeddings.transpose(1,2)), dim=-1)
            hp_pair_embeddings = torch.cat((hp_pair_embeddings, ph_pair_embeddings.transpose(1,2)), dim=-1)
            ph_pair_embeddings = temp
        # pair_embeddings = torch.cat((ph_pair_embeddings, hp_pair_embeddings.transpose(1,2)), dim=-1)
        # pair2vec masks
        pair2vec_premise_mask = 1 - (torch.eq(premise['pair2vec_tokens'], 0).long() + torch.eq(premise['pair2vec_tokens'], 1).long())
        pair2vec_hypothesis_mask = 1 - (torch.eq(hypothesis['pair2vec_tokens'], 0).long() + torch.eq(hypothesis['pair2vec_tokens'], 1).long())
        # re-normalize attention using pair2vec masks
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), pair2vec_premise_mask)
        p2h_attention = last_dim_softmax(similarity_matrix, pair2vec_hypothesis_mask)

        attended_hypothesis_pairs = self._pair2vec_dropout(weighted_sum(ph_pair_embeddings, p2h_attention)) * pair2vec_premise_mask.float().unsqueeze(-1)
        attended_premise_pairs = self._pair2vec_dropout(weighted_sum(hp_pair_embeddings, h2p_attention)) * pair2vec_hypothesis_mask.float().unsqueeze(-1)


        # the "enhancement" layer
        premise_enhanced = torch.cat(
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis,
                 attended_hypothesis_pairs],
                dim=-1
        )
        hypothesis_enhanced = torch.cat(
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise,
                 attended_premise_pairs],
                dim=-1
        )

        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        ).max(dim=1)
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7
        ).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(premise_mask, 1, keepdim=True)
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(hypothesis_mask, 1, keepdim=True)

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v = self.dropout(v)

        output_hidden = self._output_feedforward(v)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
                'accuracy': self._accuracy.get_metric(reset),
                }