from keras.engine import Model from keras.layers import Layer, Bidirectional, TimeDistributed, \ Dense, LSTM, Masking, Input, RepeatVector, Dropout, Convolution1D, \ BatchNormalization, Activation from keras.layers.merge import concatenate, add import keras.backend as K from keras.regularizers import l2 from .structure_processor import NUM_FEATURES def false_neg(y_true, y_pred): return K.squeeze(K.clip(y_true - K.round(y_pred), 0.0, 1.0), axis=-1) def false_pos(y_true, y_pred): return K.squeeze(K.clip(K.round(y_pred) - y_true, 0.0, 1.0), axis=-1) # Should probably triple-check that it works as expected class MaskingByLambda(Layer): def __init__(self, func, **kwargs): self.supports_masking = True self.mask_func = func super(MaskingByLambda, self).__init__(**kwargs) def compute_mask(self, input, input_mask=None): return self.mask_func(input, input_mask) def call(self, x, mask=None): exd_mask = K.expand_dims(self.mask_func(x, mask), axis=-1) return x * K.cast(exd_mask, K.floatx()) def mask_by_input(tensor): return lambda input, mask: tensor # 1D convolution that supports masking by retaining the mask of the input class MaskedConvolution1D(Convolution1D): def __init__(self, *args, **kwargs): self.supports_masking = True assert kwargs['padding'] == 'same' # Only makes sense for 'same' super(MaskedConvolution1D, self).__init__(*args, **kwargs) def compute_mask(self, input, input_mask=None): return input_mask def call(self, x, mask=None): assert mask is not None mask = K.expand_dims(mask, axis=-1) x = super(MaskedConvolution1D, self).call(x) return x * K.cast(mask, K.floatx()) def ab_ag_seq_model(max_ag_len, max_cdr_len): input_ag = Input(shape=(max_ag_len, NUM_FEATURES)) ag_seq = Masking()(input_ag) enc_ag = Bidirectional(LSTM(128, dropout=0.1, recurrent_dropout=0.1), merge_mode='concat')(ag_seq) input_ab = Input(shape=(max_cdr_len, NUM_FEATURES)) label_mask = Input(shape=(max_cdr_len,)) seq = Masking()(input_ab) loc_fts = MaskedConvolution1D(64, 5, padding='same', activation='elu')(seq) glb_fts = Bidirectional(LSTM(256, dropout=0.15, recurrent_dropout=0.2, return_sequences=True), merge_mode='concat')(loc_fts) enc_ag_rep = RepeatVector(max_cdr_len)(enc_ag) ab_ag_repr = concatenate([glb_fts, enc_ag_rep]) ab_ag_repr = MaskingByLambda(mask_by_input(label_mask))(ab_ag_repr) ab_ag_repr = Dropout(0.3)(ab_ag_repr) aa_probs = TimeDistributed(Dense(1, activation='sigmoid'))(ab_ag_repr) model = Model(inputs=[input_ag, input_ab, label_mask], outputs=aa_probs) model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['binary_accuracy', false_pos, false_neg], sample_weight_mode="temporal") return model def base_ab_seq_model(max_cdr_len): input_ab = Input(shape=(max_cdr_len, NUM_FEATURES)) label_mask = Input(shape=(max_cdr_len,)) seq = MaskingByLambda(mask_by_input(label_mask))(input_ab) loc_fts = MaskedConvolution1D(28, 3, padding='same', activation='elu', kernel_regularizer=l2(0.01))(seq) res_fts = add([seq, loc_fts]) glb_fts = Bidirectional(LSTM(256, dropout=0.15, recurrent_dropout=0.2, return_sequences=True), merge_mode='concat')(res_fts) fts = Dropout(0.3)(glb_fts) probs = TimeDistributed(Dense(1, activation='sigmoid', kernel_regularizer=l2(0.01)))(fts) return input_ab, label_mask, res_fts, probs def ab_seq_model(max_cdr_len): input_ab, label_mask, _, probs = base_ab_seq_model(max_cdr_len) model = Model(inputs=[input_ab, label_mask], outputs=probs) model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['binary_accuracy', false_pos, false_neg], sample_weight_mode="temporal") return model def conv_output_ab_seq_model(max_cdr_len): input_ab, label_mask, loc_fts, probs = base_ab_seq_model(max_cdr_len) model = Model(inputs=[input_ab, label_mask], outputs=[probs, loc_fts]) return model