from typing import Dict, Optional, List, Any
import logging

from overrides import overrides
import torch
from torch.nn.modules.linear import Linear

from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
from allennlp.modules import ConditionalRandomField
from allennlp.nn import InitializerApplicator, RegularizerApplicator
import allennlp.nn.util as util
from allennlp.training.metrics import CategoricalAccuracy, F1Measure

logger = logging.getLogger(__name__)

@Model.register("pico_crf_tagger")
class PicoCrfTagger(Model):
    """
    Exactly like the CrfTagger in AllenNLP:
    https://github.com/allenai/allennlp/blob/master/allennlp/models/crf_tagger.py

    But differences include:
    - No option for `constrain_crf_decoding` because only supports IO-encoding
      (because that's how EBMNLP dataset is annotated)
    - No option for `calculate_span_f1` because PICO is evaluated at token-level
    - No option for `verbose_metrics`.  Defaults to printing all because in PICO,
      we want to see F1 scores for each class.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 include_start_end_transitions: bool = True,
                 dropout: Optional[float] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.label_namespace = 'labels'
        self.num_tags = self.vocab.get_vocab_size(self.label_namespace)

        # encode text
        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.dropout = torch.nn.Dropout(dropout) if dropout else None

        # crf
        output_dim = self.encoder.get_output_dim()
        self.tag_projection_layer = TimeDistributed(Linear(output_dim, self.num_tags))
        self.crf = ConditionalRandomField(self.num_tags, constraints=None, include_start_end_transitions=include_start_end_transitions)

        self.metrics = {
            "accuracy": CategoricalAccuracy(),
            "accuracy3": CategoricalAccuracy(top_k=3)
        }
        for index, label in self.vocab.get_index_to_token_vocabulary(self.label_namespace).items():
            self.metrics['F1_' + label] = F1Measure(positive_label=index)

        initializer(self)

    @overrides
    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                tags: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                **kwargs) -> Dict[str, torch.Tensor]:

        # (batch, tokens, dim)
        embedded_text_input = self.text_field_embedder(tokens)
        mask = util.get_text_field_mask(tokens)
        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        # (batch, tokens, dim)
        encoded_text = self.encoder(embedded_text_input, mask)
        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        logits = self.tag_projection_layer(encoded_text)
        best_paths = self.crf.viterbi_tags(logits, mask)

        # Just get the tags and ignore the score.
        predicted_tags = [x for x, y in best_paths]

        output = {"logits": logits, "mask": mask, "tags": predicted_tags}

        if tags is not None:
            # Add negative log-likelihood as loss
            log_likelihood = self.crf(logits, tags, mask)
            output["loss"] = -log_likelihood

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the metrics
            class_probabilities = logits * 0.
            for i, instance_tags in enumerate(predicted_tags):
                for j, tag_id in enumerate(instance_tags):
                    class_probabilities[i, j, tag_id] = 1

            for metric in self.metrics.values():
                metric(class_probabilities, tags, mask.float())

        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]
        return output

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Converts the tag ids to the actual tags.
        ``output_dict["tags"]`` is a list of lists of tag_ids,
        so we use an ugly nested list comprehension.
        """
        output_dict["tags"] = [
            [
                self.vocab.get_token_from_index(tag, namespace=self.label_namespace)
                for tag in instance_tags
            ]
            for instance_tags in output_dict["tags"]
        ]

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics_to_return = {}

        total_f1, total_classes = 0, 0
        for metric_name, metric_obj in self.metrics.items():
            if metric_name.startswith('accuracy'):
                metrics_to_return[metric_name] = metric_obj.get_metric(reset)
            elif metric_name.startswith('F1_'):
                p, r, f1 = metric_obj.get_metric(reset)
                metrics_to_return[metric_name] = f1
                total_f1 += f1
                total_classes += 1
        metrics_to_return['avg_f1'] = total_f1 / total_classes

        return metrics_to_return