import json
import os
import time
from datetime import timedelta

import numpy as np
import torch

from import remove_duplicates
from import TensorDataset

from .rnn_model import SmilesRnn
from .smiles_char_dict import SmilesCharDictionary

def get_tensor_dataset(numpy_array):
    Gets a numpy array of indices, convert it into a Torch tensor,
    divided it into inputs and targets and wrap it
    into a TensorDataset

        numpy_array: to be converted

    Returns: a TensorDataset

    tensor = torch.from_numpy(numpy_array).long()

    inp = tensor[:, :-1]
    target = tensor[:, 1:]

    return TensorDataset(inp, target)

def get_tensor_dataset_on_device(numpy_array, device):
    Get tensor dataset and send it to a device
        numpy_array: to be converted
        device: cuda | cpu

        a TensorDataset on the required device

    dataset = get_tensor_dataset(numpy_array)
    dataset.tensors = tuple( for t in dataset.tensors)
    return dataset

def load_model(model_class, model_definition, model_weights, device, copy_to_cpu=True):

            model_class: what class of model?
            model_definition: path to model json
            model_weights: path to model weights
            device: cuda or cpu
            copy_to_cpu: bool

        Returns: an RNN model

    json_in = open(model_definition).read()
    raw_dict = json.loads(json_in)
    model = model_class(**raw_dict)
    map_location = lambda storage, loc: storage if copy_to_cpu else None
    model.load_state_dict(torch.load(model_weights, map_location))

def load_rnn_model(model_definition, model_weights, device, copy_to_cpu=True):
    return load_model(SmilesRnn, model_definition, model_weights, device, copy_to_cpu)

def save_model(model, base_dir, base_name):
    model_params = os.path.join(base_dir, base_name + '.pt'), model_params)

    model_config = os.path.join(base_dir, base_name + '.json')
    with open(model_config, 'w') as mc:

def load_smiles_from_file(smiles_path, rm_invalid=True, rm_duplicates=True, max_len=100):
    Given a list of SMILES strings, provides a zero padded NumPy array
    with their index representation. Sequences longer than `max_len` are
    discarded. The final array will have dimension (all_valid_smiles, max_len+2)
    as a beginning and end of sequence tokens are added to each string.

        smiles_path: a text file with one SMILES string per line
        max_len: dimension 1 of returned array, sequences will be padded

        sequences:list a numpy array of SMILES character indices
        valid_mask: list of len(smiles_list) - a boolean mask vector indicating if each index maps to a valid smiles
    smiles_list = open(smiles_path).readlines()
    return load_smiles_from_list(smiles_list, rm_invalid=rm_invalid, rm_duplicates=rm_duplicates, max_len=max_len)

def load_smiles_from_list(smiles_list, rm_invalid=True, rm_duplicates=True, max_len=100):
    Given a list of SMILES strings, provides a zero padded NumPy array
    with their index representation. Sequences longer than `max_len` are
    discarded. The final array will have dimension (all_valid_smiles, max_len+2)
    as a beginning and end of sequence tokens are added to each string.

        smiles_list: a list of SMILES strings
        rm_invalid: bool if True remove invalid smiles from final output. Note that if True the length of the output
          does not
          equal the size of the input  `smiles_list`. Default True
        rm_duplicates: bool if True return remove duplicates from final output. Note that if True the length of the
          output does not equal the size of the input  `smiles_list`. Default True
        max_len: dimension 1 of returned array, sequences will be padded

        sequences:list a numpy array of SMILES character indices
        valid_mask: list of len(smiles_list) - a boolean mask vector indicating if each index maps to a valid smiles
    sd = SmilesCharDictionary()

    # filter valid smiles strings
    valid_smiles = []
    valid_mask = [False] * len(smiles_list)
    for i, s in enumerate(smiles_list):
        s = s.strip()
        if sd.allowed(s) and len(s) <= max_len:
            valid_mask[i] = True
            if not rm_invalid:
                valid_smiles.append('C')  # default placeholder

    if rm_duplicates:
        unique_smiles = remove_duplicates(valid_smiles)
        unique_smiles = valid_smiles

    # max len + two chars for start token 'Q' and stop token '\n'
    max_seq_len = max_len + 2

    # allocate the zero matrix to be filled
    sequences = np.zeros((len(unique_smiles), max_seq_len), dtype=np.int32)

    for i, mol in enumerate(unique_smiles):
        enc_smi = sd.BEGIN + sd.encode(mol) + sd.END
        for c in range(len(enc_smi)):
            sequences[i, c] = sd.char_idx[enc_smi[c]]

    return sequences, valid_mask

def rnn_start_token_vector(batch_size, device='cpu'):
    Returns a vector of start tokens for SmilesRnn.
    This vector can be used to start sampling a batch of SMILES strings.

        batch_size: how many SMILES will be generated at the same time in SmilesRnn
        device: cpu | cuda

        a tensor (batch_size x 1) containing the start token
    sd = SmilesCharDictionary()
    return torch.LongTensor(batch_size, 1).fill_(sd.begin_idx).to(device)

def time_since(start_time):
    seconds = int(time.time() - start_time)
    return str(timedelta(seconds=seconds))

def set_random_seed(seed, device):
    Set the random seed for Numpy and PyTorch operations
        seed: seed for the random number generators
        device: "cpu" or "cuda"
    if device == 'cuda':