from keras.models import Input, Model from keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Activation def conv_batch_norm_relu(x, n_filters, f, padding='same', activation='relu'): x = Conv2D(n_filters, f, padding=padding)(x) x = BatchNormalization()(x) x = Activation(activation)(x) return x def model(): input = Input(shape=(224, 224, 3)) x = conv_batch_norm_relu(input, 32, (3, 3), padding='same', activation='relu') x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x) x = conv_batch_norm_relu(x, 64, (3, 3), padding='same', activation='relu') x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x) x = conv_batch_norm_relu(x, 128, (3, 3), padding='same', activation='relu') x = conv_batch_norm_relu(x, 64, (1, 1), padding='same', activation='relu') x = conv_batch_norm_relu(x, 128, (3, 3), padding='same', activation='relu') x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x) x = conv_batch_norm_relu(x, 256, (3, 3), padding='same', activation='relu') x = conv_batch_norm_relu(x, 128, (1, 1), padding='same', activation='relu') x = conv_batch_norm_relu(x, 256, (3, 3), padding='same', activation='relu') x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x) x = conv_batch_norm_relu(x, 512, (3, 3), padding='same', activation='relu') x = conv_batch_norm_relu(x, 256, (1, 1), padding='same', activation='relu') x = conv_batch_norm_relu(x, 512, (3, 3), padding='same', activation='relu') x = conv_batch_norm_relu(x, 256, (1, 1), padding='same', activation='relu') x = conv_batch_norm_relu(x, 512, (3, 3), padding='same', activation='relu') x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x) x = conv_batch_norm_relu(x, 1024, (3, 3), padding='same', activation='relu') x = conv_batch_norm_relu(x, 512, (1, 1), padding='same', activation='relu') x = conv_batch_norm_relu(x, 1024, (3, 3), padding='same', activation='relu') x = conv_batch_norm_relu(x, 512, (1, 1), padding='same', activation='relu') x = conv_batch_norm_relu(x, 1024, (3, 3), padding='same', activation='relu') x = Conv2D(5, (1, 1), padding='same')(x) x = BatchNormalization()(x) x = Activation('sigmoid', name='output')(x) return Model(inputs=input, outputs=x) if __name__ == '__main__': model = model() model.summary()