import argparse

import os
from itertools import islice
from typing import Iterable, List, Optional

from keras import optimizers, losses
from keras.models import load_model
# noinspection PyPep8Naming
from keras import backend as K
from keras import callbacks
import numpy as np

from . import wikitext
from .bpe import BPEEncoder, ID_FOR_PADDING
from .utils import (
    load_optimizer_weights, contain_tf_gpu_mem_usage, CosineLRSchedule)
from .models import (
    universal_transformer_gpt_model, vanilla_transformer_gpt_model)


def pad_lm_samples(samples: Iterable[List[int]],
                   required_sequence_length: int):
    tail_padding = [ID_FOR_PADDING]
    for sample in samples:
        assert len(sample) > 0
        sample.extend(tail_padding * (required_sequence_length - len(sample)))


def training_data_to_samples(training_set_name: str,
                             encoder: BPEEncoder,
                             max_sequence_length: int) -> np.ndarray:
    """
    Reads WikiText dataset, interpreting each line as an independent sequence,
    then splits those lines with BPE tokenizer and turns them into word ids
    based on previously constructed BPE vocabulary (both the tokenizer
    and the vocabulary are parts of the BPEEncoder instance).

    Those word id's then packed into a matrix the size of
    (number of lines x max_sequence_length + 1), which can be later sliced
    to get X and Y matrices of sequences for training).
    """
    training_set = wikitext.read_wikitext_file(training_set_name)
    useful_sequences = []
    for line in training_set.splitlines():
        clean_line = line.strip()
        is_header = clean_line.startswith('=') and clean_line.endswith('=')
        if is_header or not clean_line:
            continue
        # the encoder is supposed to add <SEQ> and </SEQ>
        id_word_pairs = list(encoder(clean_line))
        useful_sequences.append(
            [word_id for word_id, _ in id_word_pairs[:max_sequence_length]])

    pad_lm_samples(useful_sequences, max_sequence_length + 1)
    result = np.empty(
        (len(useful_sequences), max_sequence_length + 1),
        dtype='int32')
    for i, sequence in enumerate(useful_sequences):
        result[i, :] = sequence
    return result


def training_data_to_dense_samples(training_set_name: str,
                                   encoder: BPEEncoder,
                                   max_sequence_length: int) -> np.ndarray:
    """
    Reads WikiText dataset, interpreting each line as an independent sequence,
    then splits those lines with BPE tokenizer and turns them into word ids
    based on previously constructed BPE vocabulary (both the tokenizer
    and the vocabulary are parts of the BPEEncoder instance).

    Those word id's then packed into a matrix the size of
    (number of lines x max_sequence_length + 1), which can be later sliced
    to get X and Y matrices of sequences for training).
    """
    training_set = wikitext.read_wikitext_file(training_set_name)
    useful_sequences = []

    def stream_bpe_tokens():
        for line in training_set.splitlines():
            clean_line = line.strip()
            if not clean_line:
                continue
            # the encoder is supposed to add <SEQ> and </SEQ>
            id_word_pairs = encoder(clean_line)
            yield from id_word_pairs

    id_word_stream = stream_bpe_tokens()
    while True:
        chunk = list(islice(id_word_stream, max_sequence_length))
        if len(chunk) == 0:
            break
        sample_sequence = [word_id for word_id, _ in chunk]
        useful_sequences.append(sample_sequence)

    pad_lm_samples(useful_sequences, max_sequence_length + 1)
    result = np.empty(
        (len(useful_sequences), max_sequence_length + 1),
        dtype='int32')
    for i, sequence in enumerate(useful_sequences):
        result[i, :] = sequence
    return result


def perplexity(y_true, y_pred):
    """
    Popular metric for evaluating language modelling architectures.
    More info: http://cs224d.stanford.edu/lecture_notes/LectureNotes4.pdf
    """
    cross_entropy = K.sparse_categorical_crossentropy(y_true, y_pred)
    return K.mean(K.exp(K.mean(cross_entropy, axis=-1)))


def main(model_save_path: str,
         model_name: str,
         tensorboard_log_path: Optional[str],
         num_epochs: int,
         learning_rate: float,
         batch_size: int,
         max_seq_length: int,
         word_embedding_size: int,
         load_weights_only: bool,
         show_model_summary: bool):
    contain_tf_gpu_mem_usage()
    encoder = wikitext.build_wikitext_bpe_encoder()

    def x_y_for_dataset(dataset_name):
        fat_sample = training_data_to_dense_samples(
            dataset_name, encoder, max_seq_length)
        _x = fat_sample[:, :max_seq_length]
        _y = np.expand_dims(fat_sample[:, 1:], axis=-1)
        return _x, _y

    x, y = x_y_for_dataset(wikitext.TRAINING_SET_NAME)

    def compile_new_model():
        if model_name == 'universal':
            optimizer = optimizers.Adam(
                lr=learning_rate, beta_1=0.6, beta_2=0.999)
            _model = universal_transformer_gpt_model(
                max_seq_length=max_seq_length,
                vocabulary_size=encoder.vocabulary_size(),
                word_embedding_size=word_embedding_size,
                transformer_depth=5,
                num_heads=8)
            _model.compile(
                optimizer,
                loss=losses.sparse_categorical_crossentropy,
                metrics=[perplexity])
        elif model_name == 'vanilla':
            optimizer = optimizers.Adam(
                lr=learning_rate, beta_1=0.9, beta_2=0.999, clipvalue=5.0)
            _model = vanilla_transformer_gpt_model(
                max_seq_length=max_seq_length,
                vocabulary_size=encoder.vocabulary_size(),
                word_embedding_size=word_embedding_size,
                transformer_depth=5,
                num_heads=8)
            _model.compile(
                optimizer,
                loss=losses.sparse_categorical_crossentropy,
                metrics=[perplexity])
        else:
            raise RuntimeError(f'Unknown model {model_name}')
        return _model

    if os.path.exists(model_save_path):
        if load_weights_only:
            print('Loading weights from', model_save_path)
            model = compile_new_model()
            model.load_weights(
                model_save_path, skip_mismatch=True, by_name=True)
            load_optimizer_weights(model, model_save_path)
        else:
            print('Loading the whole model from', model_save_path)
            model = load_model(
                model_save_path,
                custom_objects={
                    'perplexity': perplexity,
                })
    else:
        model = compile_new_model()

    if show_model_summary:
        model.summary(120)

    lr_scheduler = callbacks.LearningRateScheduler(
        CosineLRSchedule(lr_high=learning_rate,
                         lr_low=learning_rate / 32,
                         initial_period=num_epochs),
        verbose=1)
    model_callbacks = [
        callbacks.ModelCheckpoint(
            model_save_path,
            monitor='val_loss', save_best_only=True, verbose=True),
        lr_scheduler,
    ]
    if tensorboard_log_path:
        model_callbacks.append(callbacks.TensorBoard(tensorboard_log_path))
    model.fit(
        x, y,
        validation_data=x_y_for_dataset(wikitext.VALIDATION_SET_NAME),
        batch_size=batch_size, epochs=num_epochs,
        callbacks=model_callbacks)
    # Evaluation using test set
    print('-' * 80)
    test_x, test_y = x_y_for_dataset(wikitext.TEST_SET_NAME)
    test_metrics = model.evaluate(test_x, test_y, batch_size=batch_size)
    for metric_name, metric_value in zip(model.metrics_names, test_metrics):
        print(f'Test {metric_name}:', metric_value)


if __name__ == '__main__':
    _argparser = argparse.ArgumentParser(
        description='A simple example of the Transformer model in work',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    _argparser.add_argument(
        '--save', type=str, required=True, metavar='PATH',
        help='A path where the best model should be saved / restored from')
    _argparser.add_argument(
        '--tensorboard-log', type=str, metavar='PATH', default=None,
        help='Path to a directory for Tensorboard logs')
    _argparser.add_argument(
        '--epochs', type=int, default=200, metavar='INTEGER',
        help='The number of epochs to train')
    _argparser.add_argument(
        '--lr', type=float, default=2e-4, metavar='FLOAT',
        help='Learning rate')
    _argparser.add_argument(
        '--batch-size', type=int, default=32, metavar='INTEGER',
        help='Training batch size')
    _argparser.add_argument(
        '--seq-len', type=int, default=256, metavar='INTEGER',
        help='Max sequence length')
    _argparser.add_argument(
        '--we-size', type=int, default=512, metavar='INTEGER',
        help='Word embedding size')
    _argparser.add_argument(
        '--model', type=str, default='universal', metavar='NAME',
        choices=['universal', 'vanilla'],
        help='The type of the model to train: "vanilla" or "universal"')
    _argparser.add_argument(
        '--load-weights-only', action='store_true',
        help='Use the save file only to initialize weights '
             '(do not load the whole model)')
    _argparser.add_argument(
        '--model-summary', action='store_true',
        help='Display the summary of the model before the training begins')
    _args = _argparser.parse_args()

    main(model_save_path=_args.save,
         model_name=_args.model,
         tensorboard_log_path=_args.tensorboard_log,
         num_epochs=_args.epochs,
         learning_rate=_args.lr,
         batch_size=_args.batch_size,
         max_seq_length=_args.seq_len,
         word_embedding_size=_args.we_size,
         load_weights_only=_args.load_weights_only,
         show_model_summary=_args.model_summary)