from typing import Dict, Optional, List, Any

import torch
from allennlp.data import Vocabulary
from allennlp.models import BidirectionalAttentionFlow
from allennlp.models.model import Model
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn import util
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy, SquadEmAndF1
from torch.nn.functional import nll_loss

from modules.pair_encoder.pair_encoder import AttentionEncoder
from modules.pointer_network.pointer_network import QAOutputLayer

@Model.register("r_net")
class RNet(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 question_encoder: Seq2SeqEncoder,
                 passage_encoder: Seq2SeqEncoder,
                 pair_encoder: AttentionEncoder,
                 self_encoder: AttentionEncoder,
                 output_layer: QAOutputLayer,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 share_encoder: bool = False):

        super().__init__(vocab, regularizer)
        self.text_field_embedder = text_field_embedder
        self.question_encoder = question_encoder
        self.passage_encoder = passage_encoder
        self.pair_encoder = pair_encoder
        self.self_encoder = self_encoder
        self.output_layer = output_layer

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self.share_encoder = share_encoder
        self.loss = torch.nn.CrossEntropyLoss()
        initializer(self)

    def forward(self,
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        question_embeded = self.text_field_embedder(question)
        passage_embeded = self.text_field_embedder(passage)

        question_mask = get_text_field_mask(question).byte()
        passage_mask = get_text_field_mask(passage).byte()

        quetion_encoded = self.question_encoder(
            question_embeded, question_mask)
        
        if self.share_encoder:
            passage_encoded = self.question_encoder(passage_embeded, passage_mask)
        else:
            passage_encoded = self.passage_encoder(passage_embeded, passage_mask)

        passage_encoded = self.pair_encoder(
            passage_encoded, passage_mask, quetion_encoded, question_mask)
        passage_encoded = self.self_encoder(
            passage_encoded, passage_mask, passage_encoded, passage_mask)

        span_start_logits, span_end_logits = self.output_layer(
            quetion_encoded, question_mask, passage_encoded, passage_mask)

        # Calculating loss and making prediction
        # Following code is copied from allennlp.models.BidirectionalAttentionFlow
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        span_start_logits = util.replace_masked_values(
            span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(
            span_end_logits, passage_mask, -1e7)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }


        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(
                span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(
                span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack(
                [span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            batch_size = question_embeded.size(0)
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,
                }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span