"""
BERT stands for Bidirectional Encoder Representations from Transformers.

It's a way of pre-training Transformer to model a language, described in
paper [BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding](https://arxiv.org/abs/1810.04805). A quote from it:

> BERT is designed to pre-train deep bidirectional representations
> by jointly conditioning on both left and right context in all layers.
> As a result, the pre-trained BERT representations can be fine-tuned
> with just one additional output layer to create state-of-the art
> models for a wide range of tasks, such as question answering
> and language inference, without substantial task-specific architecture
> modifications.
"""

import random
from itertools import islice, chain
from typing import List, Callable

import numpy as np
# noinspection PyPep8Naming
from keras import backend as K
from keras.utils import get_custom_objects


class BatchGeneratorForBERT:
    """
    This class generates batches for a BERT-based language model
    in an abstract way, by using an external function sampling
    sequences of token IDs of a given length.
    """

    reserved_positions = 3

    def __init__(self, sampler: Callable[[int], List[int]],
                 dataset_size: int,
                 sep_token_id: int,
                 cls_token_id: int,
                 mask_token_id: int,
                 first_normal_token_id: int,
                 last_normal_token_id: int,
                 sequence_length: int,
                 batch_size: int,
                 sentence_min_span: float = 0.25):
        """
        :param sampler: A callable object responsible for uniformly sampling
            pieces of the dataset (already turned into token IDs).
            It should take one int argument - the sample length, and return
            a list of token IDs of the requested size.
        :param dataset_size: How big the whole dataset is, measured in number
            of token IDs.
        :param sep_token_id: ID of a token used as a separator between
            the sentences (called "[SEP]" in the paper).
        :param cls_token_id: ID of a token marking the node/position
            responsible for classification (always the first node).
            The token is called "[CLS]" in the original paper.
        :param mask_token_id: ID of a token masking the original words
            of the sentence, which the network should learn to "restore" using
            the context.
        :param first_normal_token_id: ID of the first token representing
            a normal word/token, not a specialized token, like "[SEP]".
        :param last_normal_token_id: ID of the last token representing
            a normal word, not a specialized token.
        :param sequence_length: a sequence length that can be accepted
            by the model being trained / validate.
        :param batch_size: how many samples each batch should include.
        :param sentence_min_span: A floating number ranging from 0 to 1,
            indicating the percentage of words (of the `sequence_length`)
            a shortest sentence should occupy. For example,
            if the value is 0.25, each sentence will vary in length from 25%
            to 75% of the whole `sequence_length` (minus 3 reserved positions
            for [CLS] and [SEP] tokens).
        """
        self.sampler = sampler
        self.steps_per_epoch = (
            # We sample the dataset randomly. So we can make only a crude
            # estimation of how many steps it should take to cover most of it.
            dataset_size // (sequence_length * batch_size))
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.sep_token_id = sep_token_id
        self.cls_token_id = cls_token_id
        self.mask_token_id = mask_token_id
        self.first_token_id = first_normal_token_id
        self.last_token_id = last_normal_token_id
        assert 0.0 < sentence_min_span <= 1.0
        self.sentence_min_length = max(
            int(sentence_min_span *
                (self.sequence_length - self.reserved_positions)),
            1)
        self.sentence_max_length = (
            self.sequence_length
            - self.reserved_positions
            - self.sentence_min_length)

    def generate_batches(self):
        """
        Keras-compatible generator of batches for BERT (can be used with
        `keras.models.Model.fit_generator`).

        Generates tuples of (inputs, targets).
        `inputs` is a list of two values:
            1. masked_sequence: an integer tensor shaped as
               (batch_size, sequence_length), containing token ids of
               the input sequence, with some words masked by the [MASK] token.
            2. segment id: an integer tensor shaped as
               (batch_size, sequence_length),
               and containing 0 or 1 depending on which segment (A or B)
               each position is related to.

        `targets` is also a list of two values:
            1. combined_label: an integer tensor of a shape
               (batch_size, sequence_length, 2), containing both
               - the original token ids
               - and the mask (0s and 1s, indicating places where
                 a word has been replaced).
               both stacked along the last dimension.
               So combined_label[:, :, 0] would slice only the token ids,
               and combined_label[:, :, 1] would slice only the mask.
            2. has_next: a float32 tensor (batch_size, 1) containing
               1s for all samples where "sentence B" is directly following
               the "sentence A", and 0s otherwise.
        """
        samples = self.generate_samples()
        while True:
            next_bunch_of_samples = islice(samples, self.batch_size)
            has_next, mask, sequence, segment, masked_sequence = zip(
                *list(next_bunch_of_samples))
            combined_label = np.stack([sequence, mask], axis=-1)
            yield (
                [np.array(masked_sequence), np.array(segment)],
                [combined_label,
                 np.expand_dims(np.array(has_next, dtype=np.float32), axis=-1)]
            )

    def generate_samples(self):
        """
        Generates samples, one by one, for later concatenation into batches
        by `generate_batches()`.
        """
        while True:
            # Sentence A has length between 25% and 75% of the whole sequence
            a_length = random.randint(
                self.sentence_min_length,
                self.sentence_max_length)
            b_length = (
               self.sequence_length - self.reserved_positions - a_length)

            # Sampling sentences A and B,
            # making sure they follow each other 50% of the time
            has_next = random.random() < 0.5
            if has_next:
                # sentence B is a continuation of A
                full_sample = self.sampler(a_length + b_length)
                sentence_a = full_sample[:a_length]
                sentence_b = full_sample[a_length:]
            else:
                # sentence B is not a continuation of A
                # note that in theory the same or overlapping sentence
                # can be selected as B, but it's highly improbable
                # and shouldn't affect the performance
                sentence_a = self.sampler(a_length)
                sentence_b = self.sampler(b_length)
            assert len(sentence_a) == a_length
            assert len(sentence_b) == b_length
            sequence = (
                [self.cls_token_id] +
                sentence_a + [self.sep_token_id] +
                sentence_b + [self.sep_token_id])
            masked_sequence = sequence.copy()
            output_mask = np.zeros((len(sequence),), dtype=int)
            segment_id = np.full((len(sequence),), 1, dtype=int)
            segment_id[:a_length + 2] = 0
            for word_pos in chain(
                    range(1, a_length + 1),
                    range(a_length + 2, a_length + 2 + b_length)):
                if random.random() < 0.15:
                    dice = random.random()
                    if dice < 0.8:
                        masked_sequence[word_pos] = self.mask_token_id
                    elif dice < 0.9:
                        masked_sequence[word_pos] = random.randint(
                            self.first_token_id, self.last_token_id)
                    # else: 10% of the time we just leave the word as is
                    output_mask[word_pos] = 1
            yield (int(has_next), output_mask, sequence,
                   segment_id, masked_sequence)


def masked_perplexity(y_true, y_pred):
    """
    Masked version of popular metric for evaluating performance of
    language modelling architectures. It assumes that y_pred has shape
    (batch_size, sequence_length, 2), containing both
      - the original token ids
      - and the mask (0s and 1s, indicating places where
        a word has been replaced).
    both stacked along the last dimension.
    Masked perplexity ignores all but masked words.

    More info: http://cs224d.stanford.edu/lecture_notes/LectureNotes4.pdf
    """
    y_true_value = y_true[:, :, 0]
    mask = y_true[:, :, 1]
    cross_entropy = K.sparse_categorical_crossentropy(y_true_value, y_pred)
    batch_perplexities = K.exp(
        K.sum(mask * cross_entropy, axis=-1) / (K.sum(mask, axis=-1) + 1e-6))
    return K.mean(batch_perplexities)


class MaskedPenalizedSparseCategoricalCrossentropy:
    """
    Masked cross-entropy (see `masked_perplexity` for more details)
    loss function with penalized confidence.
    Combines two loss functions: cross-entropy and negative entropy
    (weighted by `penalty_weight` parameter), following paper
    "Regularizing Neural Networks by Penalizing Confident Output Distributions"
    (https://arxiv.org/abs/1701.06548)

    how to use:
    >>> model.compile(
    >>>     optimizer,
    >>>     loss=MaskedPenalizedSparseCategoricalCrossentropy(0.1))
    """
    def __init__(self, penalty_weight: float):
        self.penalty_weight = penalty_weight

    def __call__(self, y_true, y_pred):
        y_true_val = y_true[:, :, 0]
        mask = y_true[:, :, 1]

        # masked per-sample means of each loss
        num_items_masked = K.sum(mask, axis=-1) + 1e-6
        masked_cross_entropy = (
            K.sum(mask * K.sparse_categorical_crossentropy(y_true_val, y_pred),
                  axis=-1)
            / num_items_masked)
        masked_entropy = (
            K.sum(mask * -K.sum(y_pred * K.log(y_pred), axis=-1), axis=-1)
            / num_items_masked)
        return masked_cross_entropy - self.penalty_weight * masked_entropy

    def get_config(self):
        return {
            'penalty_weight': self.penalty_weight
        }


get_custom_objects().update({
    'MaskedPenalizedSparseCategoricalCrossentropy':
        MaskedPenalizedSparseCategoricalCrossentropy,
    'masked_perplexity': masked_perplexity,
})