"""Implementation of RETAIN Keras from Edward Choi""" import os import argparse import numpy as np import pandas as pd import tensorflow as tf import keras.layers as L from keras import backend as K from keras.models import Model from keras.callbacks import ModelCheckpoint, Callback from keras.preprocessing import sequence from keras.utils.data_utils import Sequence from keras.regularizers import l2 from keras.constraints import non_neg, Constraint from keras_exp.multigpu import get_available_gpus, make_parallel from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve class SequenceBuilder(Sequence): """Generate Batches of data""" def __init__(self, data, target, batch_size, ARGS, target_out=True): #Receive all appropriate data self.codes = data[0] index = 1 if ARGS.numeric_size: self.numeric = data[index] index += 1 if ARGS.use_time: self.time = data[index] self.num_codes = ARGS.num_codes self.target = target self.batch_size = batch_size self.target_out = target_out self.numeric_size = ARGS.numeric_size self.use_time = ARGS.use_time self.n_steps = ARGS.n_steps #self.balance = (1-(float(sum(target))/len(target)))/(float(sum(target))/len(target)) def __len__(self): """Compute number of batches. Add extra batch if the data doesn't exactly divide into batches """ if len(self.codes)%self.batch_size == 0: return len(self.codes) // self.batch_size return len(self.codes) // self.batch_size+1 def __getitem__(self, idx): """Get batch of specific index""" def pad_data(data, length_visits, length_codes, pad_value=0): """Pad data to desired number of visiits and codes inside each visit""" zeros = np.full((len(data), length_visits, length_codes), pad_value) for steps, mat in zip(data, zeros): if steps != [[-1]]: for step, mhot in zip(steps, mat[-len(steps):]): #Populate the data into the appropriate visit mhot[:len(step)] = step return zeros #Compute reusable batch slice batch_slice = slice(idx*self.batch_size, (idx+1)*self.batch_size) x_codes = self.codes[batch_slice] #Max number of visits and codes inside the visit for this batch pad_length_visits = min(max(map(len, x_codes)), self.n_steps) pad_length_codes = max(map(lambda x: max(map(len, x)), x_codes)) #Number of elements in a batch (useful in case of partial batches) length_batch = len(x_codes) #Pad data x_codes = pad_data(x_codes, pad_length_visits, pad_length_codes, self.num_codes) outputs = [x_codes] #Add numeric data if necessary if self.numeric_size: x_numeric = self.numeric[batch_slice] x_numeric = pad_data(x_numeric, pad_length_visits, self.numeric_size, -99.0) outputs.append(x_numeric) #Add time data if necessary if self.use_time: x_time = sequence.pad_sequences(self.time[batch_slice], dtype=np.float32, maxlen=pad_length_visits, value=+99).reshape(length_batch, pad_length_visits, 1) outputs.append(x_time) #Add target if necessary (training vs validation) if self.target_out: target = self.target[batch_slice].reshape(length_batch, 1, 1) #sample_weights = (target*(self.balance-1)+1).reshape(length_batch, 1) #In our experiments sample weights provided worse results return (outputs, target) return outputs class FreezePadding_Non_Negative(Constraint): """Freezes the last weight to be near 0 and prevents non-negative embeddings""" def __call__(self, w): other_weights = K.cast(K.greater_equal(w, 0)[:-1], K.floatx()) last_weight = K.cast(K.equal(K.reshape(w[-1, :], (1, K.shape(w)[1])), 0.), K.floatx()) appended = K.concatenate([other_weights, last_weight], axis=0) w *= appended return w class FreezePadding(Constraint): """Freezes the last weight to be near 0.""" def __call__(self, w): other_weights = K.cast(K.ones(K.shape(w))[:-1], K.floatx()) last_weight = K.cast(K.equal(K.reshape(w[-1, :], (1, K.shape(w)[1])), 0.), K.floatx()) appended = K.concatenate([other_weights, last_weight], axis=0) w *= appended return w def read_data(ARGS): """Read the data from provided paths and assign it into lists""" data_train_df = pd.read_pickle(ARGS.path_data_train) data_test_df = pd.read_pickle(ARGS.path_data_test) y_train = pd.read_pickle(ARGS.path_target_train)['target'].values y_test = pd.read_pickle(ARGS.path_target_test)['target'].values data_output_train = [data_train_df['codes'].values] data_output_test = [data_test_df['codes'].values] if ARGS.numeric_size: data_output_train.append(data_train_df['numerics'].values) data_output_test.append(data_test_df['numerics'].values) if ARGS.use_time: data_output_train.append(data_train_df['to_event'].values) data_output_test.append(data_test_df['to_event'].values) return (data_output_train, y_train, data_output_test, y_test) def model_create(ARGS): """Create and Compile model and assign it to provided devices""" def retain(ARGS): """Create the model""" #Define the constant for model saving reshape_size = ARGS.emb_size+ARGS.numeric_size if ARGS.allow_negative: embeddings_constraint = FreezePadding() beta_activation = 'tanh' output_constraint = None else: embeddings_constraint = FreezePadding_Non_Negative() beta_activation = 'sigmoid' output_constraint = non_neg() #Get available gpus , returns empty list if none glist = get_available_gpus() def reshape(data): """Reshape the context vectors to 3D vector""" return K.reshape(x=data, shape=(K.shape(data)[0], 1, reshape_size)) #Code Input codes = L.Input((None, None), name='codes_input') inputs_list = [codes] #Calculate embedding for each code and sum them to a visit level codes_embs_total = L.Embedding(ARGS.num_codes+1, ARGS.emb_size, name='embedding', embeddings_constraint=embeddings_constraint)(codes) codes_embs = L.Lambda(lambda x: K.sum(x, axis=2))(codes_embs_total) #Numeric input if needed if ARGS.numeric_size: numerics = L.Input((None, ARGS.numeric_size), name='numeric_input') inputs_list.append(numerics) full_embs = L.concatenate([codes_embs, numerics], name='catInp') else: full_embs = codes_embs #Apply dropout on inputs full_embs = L.Dropout(ARGS.dropout_input)(full_embs) #Time input if needed if ARGS.use_time: time = L.Input((None, 1), name='time_input') inputs_list.append(time) time_embs = L.concatenate([full_embs, time], name='catInp2') else: time_embs = full_embs #Setup Layers #This implementation uses Bidirectional LSTM instead of reverse order # (see https://github.com/mp2893/retain/issues/3 for more details) #If training on GPU and Tensorflow use CuDNNLSTM for much faster training if glist: alpha = L.Bidirectional(L.CuDNNLSTM(ARGS.recurrent_size, return_sequences=True), name='alpha') beta = L.Bidirectional(L.CuDNNLSTM(ARGS.recurrent_size, return_sequences=True), name='beta') else: alpha = L.Bidirectional(L.LSTM(ARGS.recurrent_size, return_sequences=True, implementation=2), name='alpha') beta = L.Bidirectional(L.LSTM(ARGS.recurrent_size, return_sequences=True, implementation=2), name='beta') alpha_dense = L.Dense(1, kernel_regularizer=l2(ARGS.l2)) beta_dense = L.Dense(ARGS.emb_size+ARGS.numeric_size, activation=beta_activation, kernel_regularizer=l2(ARGS.l2)) #Compute alpha, visit attention alpha_out = alpha(time_embs) alpha_out = L.TimeDistributed(alpha_dense, name='alpha_dense_0')(alpha_out) alpha_out = L.Softmax(axis=1)(alpha_out) #Compute beta, codes attention beta_out = beta(time_embs) beta_out = L.TimeDistributed(beta_dense, name='beta_dense_0')(beta_out) #Compute context vector based on attentions and embeddings c_t = L.Multiply()([alpha_out, beta_out, full_embs]) c_t = L.Lambda(lambda x: K.sum(x, axis=1))(c_t) #Reshape to 3d vector for consistency between Many to Many and Many to One implementations contexts = L.Lambda(reshape)(c_t) #Make a prediction contexts = L.Dropout(ARGS.dropout_context)(contexts) output_layer = L.Dense(1, activation='sigmoid', name='dOut', kernel_regularizer=l2(ARGS.l2), kernel_constraint=output_constraint) #TimeDistributed is used for consistency # between Many to Many and Many to One implementations output = L.TimeDistributed(output_layer, name='time_distributed_out')(contexts) #Define the model with appropriate inputs model = Model(inputs=inputs_list, outputs=[output]) return model #Set Tensorflow to grow GPU memory consumption instead of grabbing all of it at once K.clear_session() config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) config.gpu_options.allow_growth = True tfsess = tf.Session(config=config) K.set_session(tfsess) #If there are multiple GPUs set up a multi-gpu model glist = get_available_gpus() if len(glist) > 1: with tf.device('/cpu:0'): model = retain(ARGS) model_final = make_parallel(model, glist) else: model_final = retain(ARGS) #Compile the model - adamax has produced best results in our experiments model_final.compile(optimizer='adamax', loss='binary_crossentropy', metrics=['accuracy'], sample_weight_mode="temporal") return model_final def create_callbacks(model, data, ARGS): """Create the checkpoint and logging callbacks""" class LogEval(Callback): """Logging Callback""" def __init__(self, filepath, model, data, ARGS, interval=1): super(Callback, self).__init__() self.filepath = filepath self.interval = interval self.data_test, self.y_test = data self.generator = SequenceBuilder(data=self.data_test, target=self.y_test, batch_size=ARGS.batch_size, ARGS=ARGS, target_out=False) self.model = model def on_epoch_end(self, epoch, logs={}): #Compute ROC-AUC and average precision the validation data every interval epochs if epoch % self.interval == 0: #Compute predictions of the model y_pred = [x[-1] for x in self.model.predict_generator(self.generator, verbose=0, use_multiprocessing=True, workers=5, max_queue_size=5)] score_roc = roc_auc_score(self.y_test, y_pred) score_pr = average_precision_score(self.y_test, y_pred) #Create log files if it doesn't exist, otherwise write to it if os.path.exists(self.filepath): append_write = 'a' else: append_write = 'w' with open(self.filepath, append_write) as file_output: file_output.write("\nEpoch: {:d}- ROC-AUC: {:.6f} ; PR-AUC: {:.6f}"\ .format(epoch, score_roc, score_pr)) print("\nEpoch: {:d} - ROC-AUC: {:.6f} PR-AUC: {:.6f}"\ .format(epoch, score_roc, score_pr)) #Create callbacks if not os.path.exists(ARGS.directory): os.makedirs(ARGS.directory) checkpoint = ModelCheckpoint(filepath=ARGS.directory+'/weights.{epoch:02d}.hdf5') log = LogEval(ARGS.directory+'/log.txt', model, data, ARGS) return(checkpoint, log) def train_model(model, data_train, y_train, data_test, y_test, ARGS): """Train the Model with appropriate callbacks and generator""" checkpoint, log = create_callbacks(model, (data_test, y_test), ARGS) train_generator = SequenceBuilder(data=data_train, target=y_train, batch_size=ARGS.batch_size, ARGS=ARGS) model.fit_generator(generator=train_generator, epochs=ARGS.epochs, max_queue_size=15, use_multiprocessing=True, callbacks=[checkpoint, log], verbose=1, workers=3, initial_epoch=0) def main(ARGS): """Main function""" print('Reading Data') data_train, y_train, data_test, y_test = read_data(ARGS) print('Creating Model') model = model_create(ARGS) print('Training Model') train_model(model=model, data_train=data_train, y_train=y_train, data_test=data_test, y_test=y_test, ARGS=ARGS) def parse_arguments(parser): """Read user arguments""" parser.add_argument('--num_codes', type=int, required=True, help='Number of medical codes') parser.add_argument('--numeric_size', type=int, default=0, help='Size of numeric inputs, 0 if none') parser.add_argument('--use_time', action='store_true', help='If argument is present the time input will be used') parser.add_argument('--emb_size', type=int, default=200, help='Size of the embedding layer') parser.add_argument('--epochs', type=int, default=1, help='Number of epochs') parser.add_argument('--n_steps', type=int, default=300, help='Maximum number of visits after which the data is truncated') parser.add_argument('--recurrent_size', type=int, default=200, help='Size of the recurrent layers') parser.add_argument('--path_data_train', type=str, default='data/data_train.pkl', help='Path to train data') parser.add_argument('--path_data_test', type=str, default='data/data_test.pkl', help='Path to test data') parser.add_argument('--path_target_train', type=str, default='data/target_train.pkl', help='Path to train target') parser.add_argument('--path_target_test', type=str, default='data/target_test.pkl', help='Path to test target') parser.add_argument('--batch_size', type=int, default=32, help='Batch Size') parser.add_argument('--dropout_input', type=float, default=0.0, help='Dropout rate for embedding') parser.add_argument('--dropout_context', type=float, default=0.0, help='Dropout rate for context vector') parser.add_argument('--l2', type=float, default=0.0, help='L2 regularitzation value') parser.add_argument('--directory', type=str, default='Model', help='Directory to save the model and the log file to') parser.add_argument('--allow_negative', action='store_true', help='If argument is present the negative weights for embeddings/attentions\ will be allowed (original RETAIN implementaiton)') args = parser.parse_args() return args if __name__ == '__main__': PARSER = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) ARGS = parse_arguments(PARSER) main(ARGS)