import logging
import torch
import numpy as np
from torch.nn.functional import cross_entropy, nll_loss
from typing import Optional, Dict, List, Any

from allennlp.models.model import Model
from allennlp.data import Vocabulary
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder
from allennlp.nn import RegularizerApplicator
from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy, SquadEmAndF1, Average
from allennlp.nn import InitializerApplicator, util
from allennlp.modules.input_variational_dropout import InputVariationalDropout

from allennlp.tools import squad_eval

from models.layers import FusionLayer, BilinearSeqAtt

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


@Model.register("slqa_h")
class MultiGranularityHierarchicalAttentionFusionNetworks(Model):

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 projected_layer: Seq2SeqEncoder,
                 flow_layer: Seq2SeqEncoder,
                 contextual_passage: Seq2SeqEncoder,
                 contextual_question: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 regularizer: Optional[RegularizerApplicator] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 ):

        super(MultiGranularityHierarchicalAttentionFusionNetworks, self).__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._encoding_dim = self._phrase_layer.get_output_dim()
        self.projected_layer = torch.nn.Linear(self._encoding_dim + 1024, self._encoding_dim)
        self.fuse = FusionLayer(self._encoding_dim)
        self.projected_lstm = projected_layer
        self.flow = flow_layer
        self.contextual_layer_p = contextual_passage
        self.contextual_layer_q = contextual_question
        self.linear_self_align = torch.nn.Linear(self._encoding_dim, 1)
        self.bilinear_layer_s = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.bilinear_layer_e = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.yesno_predictor = torch.nn.Linear(self._encoding_dim, 3)
        self.relu = torch.nn.ReLU()

        self._max_span_length = 30

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)

    def forward(self, question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(yesno_list, 0).view(total_qa_count)

        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        # total_qa_count * max_q_len * encoding_dim
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim())
        embedded_passage = self._text_field_embedder(passage)

        word_emb_ques, elmo_ques, ques_feat = torch.split(embedded_question, [200, 1024, 40], dim=2)
        word_emb_pass, elmo_pass, pass_feat = torch.split(embedded_passage, [200, 1024, 40], dim=2)
        embedded_question = torch.cat([word_emb_ques, elmo_ques], dim=2)
        embedded_passage = torch.cat([word_emb_pass, elmo_pass], dim=2)

        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(embedded_passage)
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        encode_passage = self._phrase_layer(embedded_passage, passage_mask)
        projected_passage = self.relu(self.projected_layer(torch.cat([encode_passage, elmo_pass], dim=2)))

        encode_question = self._phrase_layer(embedded_question, question_mask)
        projected_question = self.relu(self.projected_layer(torch.cat([encode_question, elmo_ques], dim=2)))

        encoded_passage = self._variational_dropout(projected_passage)
        repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
        repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim)
        repeated_pass_feat = (pass_feat.unsqueeze(1).repeat(1, max_qa_count, 1, 1)).view(total_qa_count, passage_length, 40)
        encoded_question = self._variational_dropout(projected_question)

        # total_qa_count * max_q_len * passage_length
        # cnt * m * n
        s = torch.bmm(encoded_question, repeated_encoded_passage.transpose(2, 1))
        alpha = util.masked_softmax(s, question_mask.unsqueeze(2).expand(s.size()), dim=1)
        # cnt * n * h
        aligned_p = torch.bmm(alpha.transpose(2, 1), encoded_question)

        # cnt * m * n
        beta = util.masked_softmax(s, repeated_passage_mask.unsqueeze(1).expand(s.size()), dim=2)
        # cnt * m * h
        un_aligned_q = torch.bmm(beta, repeated_encoded_passage)

        # flow
        # (b * qa) * m * h -> (b * m) * qa * h
        un_aligned_q = un_aligned_q.reshape(batch_size, max_qa_count, max_q_len, -1).transpose(2, 1).reshape(batch_size * max_q_len, max_qa_count, -1)
        tmp_q_mask = question_mask.reshape(batch_size, max_qa_count, max_q_len).transpose(2, 1).reshape(batch_size * max_q_len, max_qa_count)
        aligned_q = self.flow(un_aligned_q, tmp_q_mask).reshape(batch_size, max_q_len, max_qa_count, -1).transpose(2, 1).reshape(total_qa_count,
                                                                                                                                 max_q_len,
                                                                                                                                 self._encoding_dim)

        fused_p = self.fuse(repeated_encoded_passage, aligned_p)
        fused_q = self.fuse(encoded_question, aligned_q)

        # add manual features here
        q_aware_p = self.projected_lstm(torch.cat([fused_p, repeated_pass_feat], dim=2), repeated_passage_mask)

        # cnt * n * n
        self_p = torch.bmm(q_aware_p, q_aware_p.transpose(2, 1))
        for i in range(passage_length):
            self_p[:, i, i] = 0
        lamb = util.masked_softmax(self_p, repeated_passage_mask.unsqueeze(1).expand(self_p.size()), dim=2)
        # cnt * n * h
        self_aligned_p = torch.bmm(lamb, q_aware_p)

        # cnt * n * h
        fused_self_p = self.fuse(q_aware_p, self_aligned_p)
        contextual_p = self.contextual_layer_p(fused_self_p, repeated_passage_mask)

        contextual_q = self.contextual_layer_q(fused_q, question_mask)
        # cnt * m
        gamma = util.masked_softmax(self.linear_self_align(contextual_q).squeeze(2), question_mask, dim=1)
        # cnt * h
        weighted_q = torch.bmm(gamma.unsqueeze(1), contextual_q).squeeze(1)

        span_start_logits = self.bilinear_layer_s(weighted_q, contextual_p)
        span_end_logits = self.bilinear_layer_e(weighted_q, contextual_p)

        # cnt * n * 1  cnt * 1 * h
        span_yesno_logits = self.yesno_predictor(torch.bmm(span_end_logits.unsqueeze(2), weighted_q.unsqueeze(1)))

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss for training

        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(torch.nn.functional.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                yesno_pred = predicted_span[2]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_query_id_list.append(iid)
                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
        return output_dict

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] for yn_list in output_dict.pop("yesno")]
        output_dict['yesno'] = yesno_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'yesno': self._span_yesno_accuracy.get_metric(reset),
                'f1': self._official_f1.get_metric(reset), }

    @staticmethod
    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 3), dtype=torch.long)
        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()

        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
        return best_word_span