from typing import Tuple, List, Optional

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import Model
from tensorflow.keras.layers import Embedding, Dropout, Dense, Activation
from tensorflow.keras.layers import LSTM, Bidirectional, Layer
from sacred import Ingredient
import numpy as np

import rinokeras as rk
from rinokeras.layers import Stack

from tape.data_utils.vocabs import PFAM_VOCAB, UNIPROT_BEPLER
from .AbstractTapeModel import AbstractTapeModel

bepler_hparams = Ingredient('bepler')


@bepler_hparams.config
def configure_bepler():
    dropout = 0.1  # noqa: F841
    use_pfam_alphabet: bool = True  # noqa: F841


class RandomReplaceMask(Layer):
    """ Copied from rinokeras because we're going to potentially have
    different  replace masks.

    Replaces some percentage of the input with a mask token. Used for
    implementing  style models. This is actually slightly more complex - it
    does one of three things

    Based on https://arxiv.org/abs/1810.04805.

    Args:
        percentage (float): Percentage of input tokens to mask
        mask_token (int): Token to replace masked input with
    """

    def __init__(self,
                 percentage: float,
                 n_symbols: Optional[int] = None,
                 *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        if not 0 <= percentage < 1:
            raise ValueError("Masking percentage must be in [0, 1).\
                Received {}".format(percentage))
        self.percentage = percentage
        self.n_symbols = n_symbols

    def _generate_bert_mask(self, inputs):
        mask_shape = K.shape(inputs)
        bert_mask = K.random_uniform(mask_shape) < self.percentage
        return bert_mask

    def call(self,
             inputs: tf.Tensor,
             mask: Optional[tf.Tensor] = None):
        """
        Args:
            inputs (tf.Tensor[ndims=2, int]): Tensor of values to mask
            mask (Optional[tf.Tensor[bool]]): Locations in the inputs to that are valid
                                                     (i.e. not padding, start tokens, etc.)
        Returns:
            masked_inputs (tf.Tensor[ndims=2, int]): Tensor of masked values
            bert_mask: Locations in the input that were masked
        """

        random_mask = self._generate_bert_mask(inputs)

        if mask is not None:
            random_mask &= mask

        masked_inputs = inputs * tf.cast(~random_mask, inputs.dtype)

        random_mask = tf.cast(random_mask, inputs.dtype)

        masked_inputs += K.random_uniform(
            K.shape(random_mask), 0, self.n_symbols, dtype=inputs.dtype) * random_mask

        return masked_inputs


class BiLM(Model):

    @bepler_hparams.capture
    def __init__(self, n_symbols: int, dropout: float = 0, use_pfam_alphabet: bool = True):
        super().__init__()

        self._use_pfam_alphabet = use_pfam_alphabet

        if use_pfam_alphabet:
            self.embed = Embedding(n_symbols, n_symbols)
        else:
            n_symbols = 21
            self.embed = Embedding(n_symbols + 1, n_symbols)

        self.dropout = Dropout(dropout)
        self.rnn = Stack([
            LSTM(1024, return_sequences=True, use_bias=True,
                 implementation=2, recurrent_activation='sigmoid'),
            LSTM(1024, return_sequences=True, use_bias=True,
                 implementation=2, recurrent_activation='sigmoid')])

        self.compute_logits = Dense(n_symbols, use_bias=True, activation='linear')

    def transform(self, z_fwd, z_rvs, mask_fwd, mask_rvs, sequence_lengths):
        h_fwd = []
        h = z_fwd

        for layer in self.rnn.layers:
            h = layer(h, mask=mask_fwd)
            h = self.dropout(h)
            h_fwd.append(h)

        h_rvs = []
        h = z_rvs
        for layer in self.rnn.layers:
            h = layer(h, mask=mask_rvs)
            h = self.dropout(h)
            h_rvs.append(
                tf.reverse_sequence(h, sequence_lengths - 1, seq_axis=1))

        return h_fwd, h_rvs

    def embed_and_split(self, x, sequence_lengths, pad=False):
        if pad:
            # Add one to each sequence element
            if not self._use_pfam_alphabet:
                x = x + 1
                mask = rk.utils.convert_sequence_length_to_sequence_mask(x, sequence_lengths)
                x = x * tf.cast(mask, x.dtype)

            x = tf.pad(x, [[0, 0], [1, 1]])  # pad x
            sequence_lengths += 2

        mask = rk.utils.convert_sequence_length_to_sequence_mask(x, sequence_lengths)

        z = self.embed(x)
        z_fwd = z[:, :-1]
        mask_fwd = mask[:, :-1]

        z_rvs = tf.reverse_sequence(z, sequence_lengths, seq_axis=1)[:, :-1]
        mask_rvs = tf.reverse_sequence(mask, sequence_lengths, seq_axis=1)[:, :-1]

        return z_fwd, z_rvs, mask_fwd, mask_rvs, sequence_lengths

    def call(self, inputs, encode=False):
        inputs, sequence_lengths = inputs
        z_fwd, z_rvs, mask_fwd, mask_rvs, sequence_lengths = self.embed_and_split(
            inputs, sequence_lengths, pad=encode)
        h_fwd_list, h_rvs_list = self.transform(
            z_fwd, z_rvs, mask_fwd, mask_rvs, sequence_lengths)

        h_fwd = h_fwd_list[-1]
        h_rvs = h_rvs_list[-1]

        lm_outputs = tf.concat((h_fwd[:, 1:], h_rvs[:, :-1]), -1)
        logp_fwd = self.compute_logits(h_fwd)
        logp_rvs = self.compute_logits(h_rvs)

        # prepend forward logp with zero
        # postpend reverse logp with zero
        logp_fwd = tf.pad(logp_fwd, [[0, 0], [1, 0], [0, 0]])
        logp_rvs = tf.pad(logp_rvs, [[0, 0], [0, 1], [0, 0]])

        logp = tf.nn.log_softmax(logp_fwd + logp_rvs)

        concat = []
        for h_fwd, h_rvs in zip(h_fwd_list, h_rvs_list):
            h_fwd = h_fwd[:, :-1]
            h_rvs = h_rvs[:, 1:]

            concat.extend([h_fwd, h_rvs])

        h = tf.concat(concat, -1)

        return {'logp': logp, 'h': h, 'lm_outputs': lm_outputs}


class LMEmbed(Model):

    @bepler_hparams.capture
    def __init__(self, n_symbols: int, dropout: float = 0, use_pfam_alphabet: bool = True):
        super().__init__()

        if not use_pfam_alphabet:
            n_symbols = 21

        self.embed = Embedding(n_symbols, 512)
        self.lm = BiLM(n_symbols, dropout)
        self.proj = Dense(512, use_bias=True, activation='linear')
        self.transform = Activation('relu')

    def call(self, inputs):
        inputs, sequence_lengths = inputs
        h_lm = self.lm((inputs, sequence_lengths), encode=True)
        lm_outputs = h_lm['lm_outputs']
        h_lm = h_lm['h']

        h = self.embed(inputs)
        h_lm = self.proj(h_lm)
        h = self.transform(h + h_lm)
        return h, lm_outputs


class BeplerModel(AbstractTapeModel):

    @bepler_hparams.capture
    def __init__(self,
                 n_symbols: int,
                 dropout: float = 0,
                 use_pfam_alphabet: bool = True):
        if not use_pfam_alphabet:
            n_symbols = 21

        super().__init__(n_symbols)
        self._use_pfam_alphabet = use_pfam_alphabet

        self.embed = LMEmbed(n_symbols, dropout)
        self.dropout = Dropout(dropout)
        lstm = Stack([
            Bidirectional(
                LSTM(512, return_sequences=True, use_bias=True,
                     recurrent_activation='sigmoid', implementation=2))
            for _ in range(3)])
        self.rnn = lstm
        self.proj = Dense(100, use_bias=True, activation='linear')
        self.random_replace = RandomReplaceMask(0.05, n_symbols)

    def convert_sequence_vocab(self, sequence):
        PFAM_TO_BEPLER_ENCODED = {encoding: UNIPROT_BEPLER.get(aa, 20) for aa, encoding in PFAM_VOCAB.items()}
        PFAM_TO_BEPLER_ENCODED[PFAM_VOCAB['<PAD>']] = 0

        def to_uniprot_bepler(seq):
            new_seq = np.zeros_like(seq)

            for pfam_encoding, uniprot_encoding in PFAM_TO_BEPLER_ENCODED.items():
                new_seq[seq == pfam_encoding] = uniprot_encoding

            return new_seq

        new_sequence = tf.py_func(to_uniprot_bepler, [sequence], sequence.dtype)
        new_sequence.set_shape(sequence.shape)

        return new_sequence

    def call(self, inputs):
        sequence = inputs['primary']
        protein_length = inputs['protein_length']

        if not self._use_pfam_alphabet:
            sequence = self.convert_sequence_vocab(sequence)
        sequence = K.in_train_phase(self.random_replace(sequence), sequence)

        mask = rk.utils.convert_sequence_length_to_sequence_mask(sequence, protein_length)
        embed, lm_outputs = self.embed((sequence, protein_length))
        tf.add_to_collection('checkpoints', embed)
        rnn_out = self.rnn(embed, mask=mask)
        tf.add_to_collection('checkpoints', rnn_out)
        rnn_out = self.dropout(rnn_out)
        proj = self.proj(rnn_out)
        tf.add_to_collection('checkpoints', proj)

        inputs['encoder_output'] = proj
        inputs['lm_outputs'] = lm_outputs

        return inputs

    def get_optimal_batch_sizes(self) -> Tuple[List[int], List[int]]:
        bucket_sizes = np.array([100, 200, 300, 400, 600, 900, 1000, 1200, 1300, 2000, 3000])
        batch_sizes = np.array([5, 5, 4, 3, 2, 1, 0.75, 0.5, 0.5, 0.25, 0, 0])

        batch_sizes = np.asarray(batch_sizes * self._get_gpu_memory(), np.int32)
        return bucket_sizes, batch_sizes