# coding=utf-8

"""
Implementation of the RNN model
"""

import torch
import torch.nn as tnn
import torch.nn.utils.rnn as tnnur

import models.vocabulary as mv


class RNN(tnn.Module):
    """
    Implements a N layer GRU(M) cell including an embedding layer
    and an output linear layer back to the size of the vocabulary
    """

    def __init__(self, vocabulary_size, num_dimensions, num_layers, embedding_layer_size, dropout):
        """
        Implements a N layer GRU|LSTM cell including an embedding layer and an output linear layer
        back to the size of the vocabulary.
        :param voc_size: Size of the vocabulary.
        :param num_dimensions: Size of each of the RNN layers.
        :param num_layers: Number of RNN layers.
        :param cell_type: Cell type to use (GRU or LSTM).
        :param embedding_layer_size: Size of the embedding layer.
        :param dropout: Dropout to add between cell layers.
        :return:
        """
        super(RNN, self).__init__()

        self.num_dimensions = num_dimensions
        self.embedding_layer_size = embedding_layer_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.vocabulary_size = vocabulary_size

        self._embedding = tnn.Sequential(
            tnn.Embedding(self.vocabulary_size, self.embedding_layer_size),
            tnn.Dropout(self.dropout)
        )
        self._rnn = tnn.LSTM(self.embedding_layer_size, self.num_dimensions,
                             num_layers=self.num_layers, dropout=self.dropout, batch_first=True)
        self._linear = tnn.Linear(self.num_dimensions, self.vocabulary_size)

    def forward(self, padded_seqs, seq_lengths, hidden_state=None):  # pylint: disable=W0221
        """
        Performs a forward pass on the model. Note: you pass the **whole** sequence.
        :param padded_seqs: Padded input tensor (batch_size, seq_size).
        :param seq_lengths: Length of each sequence in the batch.
        :param hidden_state: Hidden state tensor.
        :return: A tuple with the output state and the output hidden state.
        """
        batch_size = padded_seqs.size(0)
        if hidden_state is None:
            size = (self.num_layers, batch_size, self.num_dimensions)
            hidden_state = [torch.zeros(*size).cuda(), torch.zeros(*size).cuda()]

        padded_encoded_seqs = self._embedding(padded_seqs)  # (batch,seq,embedding)
        packed_encoded_seqs = tnnur.pack_padded_sequence(
            padded_encoded_seqs, seq_lengths, batch_first=True, enforce_sorted=False)
        packed_encoded_seqs, hidden_state = self._rnn(packed_encoded_seqs, hidden_state)
        padded_encoded_seqs, _ = tnnur.pad_packed_sequence(packed_encoded_seqs, batch_first=True)

        mask = (padded_encoded_seqs[:, :, 0] != 0).unsqueeze(dim=-1).type(torch.float)
        logits = self._linear(padded_encoded_seqs)*mask
        return (logits, hidden_state)

    def get_params(self):
        """
        Returns the configuration parameters of the model.
        :return: A dict with the params of the model.
        """
        return {
            'dropout': self.dropout,
            'num_dimensions': self.num_dimensions,
            'num_layers': self.num_layers,
            'embedding_layer_size': self.embedding_layer_size,
            'vocabulary_size': self.vocabulary_size
        }


class Model:
    """
    Implements an RNN model using SMILES.
    """

    def __init__(self, vocabulary, tokenizer, network_params=None, max_sequence_length=256, no_cuda=False,
                 mode="train"):
        """
        Implements an RNN.
        :param vocabulary: Vocabulary to use.
        :param tokenizer: Tokenizer to use.
        :param network_params: Network params to initialize the RNN.
        :param max_sequence_length: Sequences longer than this value will not be processed.
        :param no_cuda: The model is explicitly initialized as not using cuda, even if cuda is available.
        :param mode: Training or eval mode.
        """
        self.vocabulary = vocabulary
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length

        if not isinstance(network_params, dict):
            network_params = {}

        self.network = RNN(**network_params)
        if torch.cuda.is_available() and not no_cuda:
            self.network.cuda()

        self.nll_loss = tnn.NLLLoss(reduction="none", ignore_index=0)

        self.set_mode(mode)

    @classmethod
    def load_from_file(cls, file_path, mode="train"):
        """
        Loads a model from a single file
        :param file_path: Path of the file where the model data was previously stored.
        :param mode: Mode to load the model as (training or eval).
        :return: A new instance of the Model or an exception if it was not possible to load it.
        """
        if torch.cuda.is_available():
            save_dict = torch.load(file_path)
        else:
            save_dict = torch.load(file_path, map_location=lambda storage, loc: storage)

        network_params = save_dict.get("network_params", {})
        model = Model(
            vocabulary=save_dict['vocabulary'],
            tokenizer=save_dict.get('tokenizer', mv.SMILESTokenizer()),
            network_params=network_params,
            max_sequence_length=save_dict['max_sequence_length'],
            mode=mode
        )
        model.network.load_state_dict(save_dict["network"])
        return model

    def set_mode(self, mode):
        """
        Changes the mode of the RNN to training or eval.
        :param mode: Mode to change to (training, eval)
        :return: The model instance.
        """
        if mode == "sampling" or mode == "eval":
            self.network.eval()
        else:
            self.network.train()
        return self

    def save(self, path):
        """
        Saves the model to a file.
        :param path: Path to save the model to.
        """
        save_dict = {
            'vocabulary': self.vocabulary,
            'tokenizer': self.tokenizer,
            'max_sequence_length': self.max_sequence_length,
            'network': self.network.state_dict(),
            'network_params': self.network.get_params()
        }
        torch.save(save_dict, path)

    def likelihood(self, padded_seqs, seq_lengths):
        """
        Retrieves the likelihood of a given sequence. Used in training.
        :param padded_seqs: (batch_size, sequence_length) A batch of padded sequences.
        :param seq_lengths: Length of each sequence in a tensor.
        :return:  (batch_size) Log likelihood for each example.
        """
        logits, _ = self.network(padded_seqs, seq_lengths - 1)
        log_probs = logits.log_softmax(dim=2).transpose(1, 2)
        return self.nll_loss(log_probs, padded_seqs[:, 1:]).sum(dim=1)

    @torch.no_grad()
    def sample_smiles(self, num):
        """
        Samples n SMILES from the model.
        :param num: Number of SMILES to sample.
        :return: An iterator with (smiles, likelihood) pairs
        """
        input_vector = torch.full((num, 1), self.vocabulary["^"], dtype=torch.long).cuda()  # (batch, 1)
        seq_lengths = torch.ones(num).cuda()  # (batch)
        sequences = []
        hidden_state = None
        nlls = torch.zeros(num).cuda()
        not_finished = torch.ones(num, 1, dtype=torch.long).cuda()
        for _ in range(self.max_sequence_length - 1):
            logits, hidden_state = self.network(input_vector, seq_lengths, hidden_state)  # (batch, 1, voc)
            probs = logits.softmax(dim=2).squeeze()  # (batch, voc)
            log_probs = logits.log_softmax(dim=2).squeeze()
            input_vector = torch.multinomial(probs, 1)*not_finished  # (batch, 1)
            sequences.append(input_vector)
            nlls += self.nll_loss(log_probs, input_vector.squeeze())
            not_finished = (input_vector > 1).type(torch.long)
            if not_finished.sum() == 0:
                break

        smiles = [self.tokenizer.untokenize(self.vocabulary.decode(seq))
                  for seq in torch.cat(sequences, 1).data.cpu().numpy()]
        nlls = nlls.data.cpu().numpy().tolist()
        return zip(smiles, nlls)