import logging import os import keras import numpy from keras import backend as K, Model from keras.callbacks import CSVLogger from keras.layers import Input, Dense, BatchNormalization, LeakyReLU, Dropout, Lambda from keras.models import load_model from scipy import sparse import scgen from .util import balancer, extractor, shuffle_adata log = logging.getLogger(__file__) class VAEArithKeras: """ VAE with Arithmetic vector Network class. This class contains the implementation of Variational Auto-encoder network with Vector Arithmetics. Parameters ---------- kwargs: :key `validation_data` : AnnData must be fed if `use_validation` is true. :key dropout_rate: float dropout rate :key learning_rate: float learning rate of optimization algorithm :key model_path: basestring path to save the model after training x_dimension: integer number of gene expression space dimensions. z_dimension: integer number of latent space dimensions. See also -------- CVAE from scgen.models._cvae : Conditional VAE implementation. """ def __init__(self, x_dimension, z_dimension=100, **kwargs): self.x_dim = x_dimension self.z_dim = z_dimension self.learning_rate = kwargs.get("learning_rate", 0.001) self.dropout_rate = kwargs.get("dropout_rate", 0.2) self.model_to_use = kwargs.get("model_path", "./models/") self.alpha = kwargs.get("alpha", 0.00005) self.x = Input(shape=(x_dimension,), name="input") self.z = Input(shape=(z_dimension,), name="latent") self.init_w = keras.initializers.glorot_normal() self._create_network() self._loss_function() self.vae_model.summary() def _encoder(self): """ Constructs the encoder sub-network of VAE. This function implements the encoder part of Variational Auto-encoder. It will transform primary data in the `n_vars` dimension-space to the `z_dimension` latent space. Parameters ---------- No parameters are needed. Returns ------- mean: Tensor A dense layer consists of means of gaussian distributions of latent space dimensions. log_var: Tensor A dense layer consists of log transformed variances of gaussian distributions of latent space dimensions. """ h = Dense(800, kernel_initializer=self.init_w, use_bias=False)(self.x) h = BatchNormalization(axis=1)(h) h = LeakyReLU()(h) h = Dropout(self.dropout_rate)(h) h = Dense(800, kernel_initializer=self.init_w, use_bias=False)(h) h = BatchNormalization(axis=1)(h) h = LeakyReLU()(h) h = Dropout(self.dropout_rate)(h) # h = Dense(512, kernel_initializer=self.init_w, use_bias=False)(h) # h = BatchNormalization()(h) # h = LeakyReLU()(h) # h = Dropout(self.dropout_rate)(h) # h = Dense(256, kernel_initializer=self.init_w, use_bias=False)(h) # h = BatchNormalization()(h) # h = LeakyReLU()(h) # h = Dropout(self.dropout_rate)(h) mean = Dense(self.z_dim, kernel_initializer=self.init_w)(h) log_var = Dense(self.z_dim, kernel_initializer=self.init_w)(h) z = Lambda(self._sample_z, output_shape=(self.z_dim,), name="Z")([mean, log_var]) self.encoder_model = Model(inputs=self.x, outputs=z, name="encoder") return mean, log_var def _decoder(self): """ Constructs the decoder sub-network of VAE. This function implements the decoder part of Variational Auto-encoder. It will transform constructed latent space to the previous space of data with n_dimensions = n_vars. Parameters ---------- No parameters are needed. Returns ------- h: Tensor A Tensor for last dense layer with the shape of [n_vars, ] to reconstruct data. """ h = Dense(800, kernel_initializer=self.init_w, use_bias=False)(self.z) h = BatchNormalization(axis=1)(h) h = LeakyReLU()(h) h = Dropout(self.dropout_rate)(h) h = Dense(800, kernel_initializer=self.init_w, use_bias=False)(h) h = BatchNormalization(axis=1)(h) h = LeakyReLU()(h) h = Dropout(self.dropout_rate)(h) # h = Dense(768, kernel_initializer=self.init_w, use_bias=False)(h) # h = BatchNormalization()(h) # h = LeakyReLU()(h) # h = Dropout(self.dropout_rate)(h) # h = Dense(1024, kernel_initializer=self.init_w, use_bias=False)(h) # h = BatchNormalization()(h) # h = LeakyReLU()(h) # h = Dropout(self.dropout_rate)(h) h = Dense(self.x_dim, kernel_initializer=self.init_w, use_bias=True)(h) self.decoder_model = Model(inputs=self.z, outputs=h, name="decoder") return h @staticmethod def _sample_z(args): """ Samples from standard Normal distribution with shape [size, z_dim] and applies re-parametrization trick. It is actually sampling from latent space distributions with N(mu, var) computed in `_encoder` function. Parameters ---------- No parameters are needed. Returns ------- The computed Tensor of samples with shape [size, z_dim]. """ mu, log_var = args batch_size = K.shape(mu)[0] z_dim = K.shape(mu)[1] eps = K.random_normal(shape=[batch_size, z_dim]) return mu + K.exp(log_var / 2) * eps def _create_network(self): """ Constructs the whole VAE network. It is step-by-step constructing the VAE network. First, It will construct the encoder part and get mu, log_var of latent space. Second, It will sample from the latent space to feed the decoder part in next step. Finally, It will reconstruct the data by constructing decoder part of VAE. Parameters ---------- No parameters are needed. Returns ------- Nothing will be returned. """ self.mu, self.log_var = self._encoder() self.x_hat = self._decoder() self.vae_model = Model(inputs=self.x, outputs=self.decoder_model(self.encoder_model(self.x)), name="VAE") def _loss_function(self): """ Defines the loss function of VAE network after constructing the whole network. This will define the KL Divergence and Reconstruction loss for VAE and also defines the Optimization algorithm for network. The VAE Loss will be weighted sum of reconstruction loss and KL Divergence loss. Parameters ---------- No parameters are needed. Returns ------- Nothing will be returned. """ def vae_loss(y_true, y_pred): return K.mean(recon_loss(y_true, y_pred) + self.alpha * kl_loss(y_true, y_pred)) def kl_loss(y_true, y_pred): return 0.5 * K.sum(K.exp(self.log_var) + K.square(self.mu) - 1. - self.log_var, axis=1) def recon_loss(y_true, y_pred): return 0.5 * K.sum(K.square((y_true - y_pred)), axis=1) self.vae_optimizer = keras.optimizers.Adam(lr=self.learning_rate) self.vae_model.compile(optimizer=self.vae_optimizer, loss=vae_loss, metrics=[kl_loss, recon_loss]) def to_latent(self, data): """ Map `data` in to the latent space. This function will feed data in encoder part of VAE and compute the latent space coordinates for each sample in data. Parameters ---------- data: numpy nd-array Numpy nd-array to be mapped to latent space. `data.X` has to be in shape [n_obs, n_vars]. Returns ------- latent: numpy nd-array Returns array containing latent space encoding of 'data' """ latent = self.encoder_model.predict(data) return latent def _avg_vector(self, data): """ Computes the average of points which computed from mapping `data` to encoder part of VAE. Parameters ---------- data: numpy nd-array Numpy nd-array matrix to be mapped to latent space. Note that `data.X` has to be in shape [n_obs, n_vars]. Returns ------- The average of latent space mapping in numpy nd-array. """ latent = self.to_latent(data) latent_avg = numpy.average(latent, axis=0) return latent_avg def reconstruct(self, data): """ Map back the latent space encoding via the decoder. Parameters ---------- data: `~anndata.AnnData` Annotated data matrix whether in latent space or gene expression space. use_data: bool This flag determines whether the `data` is already in latent space or not. if `True`: The `data` is in latent space (`data.X` is in shape [n_obs, z_dim]). if `False`: The `data` is not in latent space (`data.X` is in shape [n_obs, n_vars]). Returns ------- rec_data: 'numpy nd-array' Returns 'numpy nd-array` containing reconstructed 'data' in shape [n_obs, n_vars]. """ rec_data = self.decoder_model.predict(x=data) return rec_data def linear_interpolation(self, source_adata, dest_adata, n_steps): """ Maps `source_adata` and `dest_adata` into latent space and linearly interpolate `n_steps` points between them. Parameters ---------- source_adata: `~anndata.AnnData` Annotated data matrix of source cells in gene expression space (`x.X` must be in shape [n_obs, n_vars]) dest_adata: `~anndata.AnnData` Annotated data matrix of destinations cells in gene expression space (`y.X` must be in shape [n_obs, n_vars]) n_steps: int Number of steps to interpolate points between `source_adata`, `dest_adata`. Returns ------- interpolation: numpy nd-array Returns the `numpy nd-array` of interpolated points in gene expression space. Example -------- >>> import anndata >>> import scgen >>> train_data = anndata.read("./data/train.h5ad") >>> validation_data = anndata.read("./data/validation.h5ad") >>> network = scgen.VAEArith(x_dimension= train_data.shape[1], model_path="./models/test" ) >>> network.train(train_data=train_data, use_validation=True, validation_data=validation_data, shuffle=True, n_epochs=2) >>> souece = train_data[((train_data.obs["cell_type"] == "CD8T") & (train_data.obs["condition"] == "control"))] >>> destination = train_data[((train_data.obs["cell_type"] == "CD8T") & (train_data.obs["condition"] == "stimulated"))] >>> interpolation = network.linear_interpolation(souece, destination, n_steps=25) """ if sparse.issparse(source_adata.X): source_average = source_adata.X.A.mean(axis=0).reshape((1, source_adata.shape[1])) else: source_average = source_adata.X.A.mean(axis=0).reshape((1, source_adata.shape[1])) if sparse.issparse(dest_adata.X): dest_average = dest_adata.X.A.mean(axis=0).reshape((1, dest_adata.shape[1])) else: dest_average = dest_adata.X.A.mean(axis=0).reshape((1, dest_adata.shape[1])) start = self.to_latent(source_average) end = self.to_latent(dest_average) vectors = numpy.zeros((n_steps, start.shape[1])) alpha_values = numpy.linspace(0, 1, n_steps) for i, alpha in enumerate(alpha_values): vector = start * (1 - alpha) + end * alpha vectors[i, :] = vector vectors = numpy.array(vectors) interpolation = self.reconstruct(vectors) return interpolation def predict(self, adata, conditions, cell_type_key, condition_key, adata_to_predict=None, celltype_to_predict=None, obs_key="all"): """ Predicts the cell type provided by the user in stimulated condition. Parameters ---------- celltype_to_predict: basestring The cell type you want to be predicted. obs_key: basestring or dict Dictionary of celltypes you want to be observed for prediction. adata_to_predict: `~anndata.AnnData` Adata for unpertubed cells you want to be predicted. Returns ------- predicted_cells: numpy nd-array `numpy nd-array` of predicted cells in primary space. delta: float Difference between stimulated and control cells in latent space Example -------- >>> import anndata >>> import scgen >>> train_data = anndata.read("./data/train.h5ad" >>> validation_data = anndata.read("./data/validation.h5ad") >>> network = scgen.VAEArith(x_dimension= train_data.shape[1], model_path="./models/test" ) >>> network.train(train_data=train_data, use_validation=True, validation_data=validation_data, shuffle=True, n_epochs=2) >>> prediction, delta = pred, delta = scg.predict(adata= train_new,conditions={"ctrl": "control", "stim":"stimulated"}, cell_type_key="cell_type",condition_key="condition",adata_to_predict=unperturbed_cd4t) """ if obs_key == "all": ctrl_x = adata[adata.obs["condition"] == conditions["ctrl"], :] stim_x = adata[adata.obs["condition"] == conditions["stim"], :] ctrl_x = balancer(ctrl_x, cell_type_key=cell_type_key, condition_key=condition_key) stim_x = balancer(stim_x, cell_type_key=cell_type_key, condition_key=condition_key) else: key = list(obs_key.keys())[0] values = obs_key[key] subset = adata[adata.obs[key].isin(values)] ctrl_x = subset[subset.obs["condition"] == conditions["ctrl"], :] stim_x = subset[subset.obs["condition"] == conditions["stim"], :] if len(values) > 1: ctrl_x = balancer(ctrl_x, cell_type_key=cell_type_key, condition_key=condition_key) stim_x = balancer(stim_x, cell_type_key=cell_type_key, condition_key=condition_key) if celltype_to_predict is not None and adata_to_predict is not None: raise Exception("Please provide either a cell type or adata not both!") if celltype_to_predict is None and adata_to_predict is None: raise Exception("Please provide a cell type name or adata for your unperturbed cells") if celltype_to_predict is not None: ctrl_pred = extractor(adata, celltype_to_predict, conditions, cell_type_key, condition_key)[1] else: ctrl_pred = adata_to_predict eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0]) cd_ind = numpy.random.choice(range(ctrl_x.shape[0]), size=eq, replace=False) stim_ind = numpy.random.choice(range(stim_x.shape[0]), size=eq, replace=False) if sparse.issparse(ctrl_x.X) and sparse.issparse(stim_x.X): latent_ctrl = self._avg_vector(ctrl_x.X.A[cd_ind, :]) latent_sim = self._avg_vector(stim_x.X.A[stim_ind, :]) else: latent_ctrl = self._avg_vector(ctrl_x.X[cd_ind, :]) latent_sim = self._avg_vector(stim_x.X[stim_ind, :]) delta = latent_sim - latent_ctrl if sparse.issparse(ctrl_pred.X): latent_cd = self.to_latent(ctrl_pred.X.A) else: latent_cd = self.to_latent(ctrl_pred.X) stim_pred = delta + latent_cd predicted_cells = self.reconstruct(stim_pred) return predicted_cells, delta def restore_model(self): """ restores model weights from `model_to_use`. Parameters ---------- No parameters are needed. Returns ------- Nothing will be returned. Example -------- >>> import anndata >>> import scgen >>> train_data = anndata.read("./data/train.h5ad") >>> validation_data = anndata.read("./data/validation.h5ad") >>> network = scgen.VAEArith(x_dimension= train_data.shape[1], model_path="./models/test" ) >>> network.restore_model() """ self.vae_model = load_model(os.path.join(self.model_to_use, 'vae.h5'), compile=False) self.encoder_model = load_model(os.path.join(self.model_to_use, 'encoder.h5'), compile=False) self.decoder_model = load_model(os.path.join(self.model_to_use, 'decoder.h5'), compile=False) self._loss_function() def train(self, train_data, validation_data=None, n_epochs=25, batch_size=32, early_stop_limit=20, threshold=0.0025, initial_run=True, shuffle=True, verbose=1, save=True, checkpoint=50, **kwargs): """ Trains the network `n_epochs` times with given `train_data` and validates the model using validation_data if it was given in the constructor function. This function is using `early stopping` technique to prevent over-fitting. Parameters ---------- train_data: scanpy AnnData Annotated Data Matrix for training VAE network. validation_data: scanpy AnnData Annotated Data Matrix for validating VAE network after each epoch. n_epochs: int Number of epochs to iterate and optimize network weights batch_size: integer size of each batch of training dataset to be fed to network while training. early_stop_limit: int Number of consecutive epochs in which network loss is not going lower. After this limit, the network will stop training. threshold: float Threshold for difference between consecutive validation loss values if the difference is upper than this `threshold`, this epoch will not considered as an epoch in early stopping. initial_run: bool if `True`: The network will initiate training and log some useful initial messages. if `False`: Network will resume the training using `restore_model` function in order to restore last model which has been trained with some training dataset. shuffle: bool if `True`: shuffles the training dataset Returns ------- Nothing will be returned Example -------- ```python import anndata import scgen train_data = anndata.read("./data/train.h5ad" validation_data = anndata.read("./data/validation.h5ad" network = scgen.VAEArith(x_dimension= train_data.shape[1], model_path="./models/test") network.train(train_data=train_data, use_validation=True, valid_data=validation_data, shuffle=True, n_epochs=2) ``` """ if initial_run: log.info("----Training----") if shuffle: train_data = shuffle_adata(train_data) if sparse.issparse(train_data.X): train_data.X = train_data.X.A # def on_epoch_end(epoch, logs): # if epoch % checkpoint == 0: # path_to_save = os.path.join(kwargs.get("path_to_save"), f"epoch_{epoch}") + "/" # scgen.visualize_trained_network_results(self, vis_data, kwargs.get("cell_type"), # kwargs.get("conditions"), # kwargs.get("condition_key"), kwargs.get("cell_type_key"), # path_to_save, # plot_umap=False, # plot_reg=True) callbacks = [ # LambdaCallback(on_epoch_end=on_epoch_end), # EarlyStopping(patience=early_stop_limit, monitor='loss', min_delta=threshold), CSVLogger(filename="./csv_logger.log") ] if validation_data is not None: result = self.vae_model.fit(x=train_data.X, y=train_data.X, epochs=n_epochs, batch_size=batch_size, validation_data=(validation_data.X, validation_data.X), shuffle=shuffle, callbacks=callbacks, verbose=verbose) else: result = self.vae_model.fit(x=train_data.X, y=train_data.X, epochs=n_epochs, batch_size=batch_size, shuffle=shuffle, callbacks=callbacks, verbose=verbose) if save is True: os.makedirs(self.model_to_use, exist_ok=True) self.vae_model.save(os.path.join(self.model_to_use, "vae.h5"), overwrite=True) self.encoder_model.save(os.path.join(self.model_to_use, "encoder.h5"), overwrite=True) self.decoder_model.save(os.path.join(self.model_to_use, "decoder.h5"), overwrite=True) log.info(f"Models are saved in file: {self.model_to_use}. Training finished") return result