''' @author: Yannick Wilhelm @email: yannick.wilhelm@gmx.de @date: Febraury 2017 Deep residual learning blocks (He 2015, Deep Residual Learning) ''' from keras.layers import Conv2D from keras.layers import Conv3D from keras.layers import Conv3DTranspose from keras.layers import BatchNormalization from keras.layers import advanced_activations from keras.layers import Activation from keras.layers import MaxPooling2D from keras.layers import Add from keras.layers import LeakyReLU from keras.layers import Lambda import keras.backend as K from networks.multiclass.SENets.squeeze_excitation_block import * def identity_bottleneck_block(input_tensor, filters, stage, block, se_enabled=False, se_ratio=16): numFilters1, numFilters2, numFilters3 = filters if K.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 conv_name_base = 'res' + str(stage) + '_' + str(block) + '_branch' bn_name_base = 'bn' + str(stage) + '_' + str(block) + '_branch' x = Conv2D(numFilters1, (1, 1), kernel_initializer='he_normal', name=conv_name_base + '2a')(input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) x = Conv2D(numFilters2, (3, 3), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2b')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) x = Conv2D(numFilters3, (1, 1), kernel_initializer='he_normal', name=conv_name_base + '2c')(x) # squeeze and excitation block if se_enabled: x = squeeze_excitation_block(x, ratio=se_ratio) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) x = Add()([x, input_tensor]) x = Activation('relu')(x) return x def identity_block(input_tensor, filters, stage, block, se_enabled=False, se_ratio=16): numFilters1, numFilters2 = filters if K.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 conv_name_base = 'res' + str(stage) + '_' + str(block) + '_branch' bn_name_base = 'bn' + str(stage) + '_' + str(block) + '_branch' x = Conv2D(numFilters1, (3, 3), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2a')( input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) x = Conv2D(numFilters2, (3, 3), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2b')(x) # squeeze and excitation block if se_enabled: x = squeeze_excitation_block(x, ratio=se_ratio) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Add()([x, input_tensor]) x = Activation('relu')(x) return x def projection_block(input_tensor, filters, stage, block, se_enabled=False, se_ratio=16): numFilters1, numFilters2 = filters if K.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 conv_name_base = 'res' + str(stage) + '_' + str(block) + '_branch' bn_name_base = 'bn' + str(stage) + '_' + str(block) + '_branch' # downsampling directly by convolution with stride 2 x = Conv2D(numFilters1, (3, 3), padding='same', strides=(2, 2), kernel_initializer='he_normal', name=conv_name_base + '2a')(input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) x = Conv2D(numFilters2, (3, 3), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2b')(x) # squeeze and excitation block if se_enabled: x = squeeze_excitation_block(x, ratio=se_ratio) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) # projection shortcut convolution x_shortcut = Conv2D(numFilters2, (1, 1), strides=(2, 2), kernel_initializer='he_normal', name=conv_name_base + '1')( input_tensor) x_shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(x_shortcut) # addition of shortcut x = Add()([x, x_shortcut]) x = Activation('relu')(x) return x def projection_bottleneck_block(input_tensor, filters, stage, block, se_enabled=False, se_ratio=16): numFilters1, numFilters2, numFilters3 = filters if K.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 conv_name_base = 'res' + str(stage) + '_' + str(block) + '_branch' bn_name_base = 'bn' + str(stage) + '_' + str(block) + '_branch' x = Conv2D(numFilters1, (1, 1), strides=(2, 2), kernel_initializer='he_normal', name=conv_name_base + '2a')( input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) x = Conv2D(numFilters2, (3, 3), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2b')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) x = Conv2D(numFilters3, (1, 1), kernel_initializer='he_normal', name=conv_name_base + '2c')(x) # squeeze and excitation block if se_enabled: x = squeeze_excitation_block(x, ratio=se_ratio) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) # projection shortcut x_shortcut = Conv2D(numFilters3, (1, 1), strides=(2, 2), kernel_initializer='he_normal', name=conv_name_base + '1')( input_tensor) x_shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(x_shortcut) x = Add()([x, x_shortcut]) x = Activation('relu')(x) return x def zero_padding_block(input_tensor, filters, stage, block, se_enabled=False, se_ratio=16): numFilters1, numFilters2 = filters if K.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 conv_name_base = 'res' + str(stage) + '_' + str(block) + '_branch' bn_name_base = 'bn' + str(stage) + '_' + str(block) + '_branch' # downsampling directly by convolution with stride 2 x = Conv2D(numFilters1, (3, 3), strides=(2, 2), kernel_initializer='he_normal', name=conv_name_base + '2a')( input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) x = Conv2D(numFilters2, (3, 3), kernel_initializer='he_normal', name=conv_name_base + '2b')(x) # squeeze and excitation block if se_enabled: x = squeeze_excitation_block(x, ratio=se_ratio) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) # zero padding and downsampling with 1x1 conv shortcut connection x_shortcut = Conv2D(1, (1, 1), strides=(2, 2), kernel_initializer='he_normal', name=conv_name_base + '1')( input_tensor) x_shortcut2 = MaxPooling2D(pool_size=(1, 1), strides=(2, 2), border_mode='same')(input_tensor) x_shortcut = Lambda(zeropad, output_shape=zeropad_output_shape)(x_shortcut) x_shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(x_shortcut) # addition of shortcut x = Add()([x, x_shortcut]) x = Activation('relu')(x) return x def zeropad(x): y = K.zeros_like(x) return K.concatenate([x, y], axis=1) def zeropad_output_shape(input_shape): shape = list(input_shape) assert len(shape) == 4 shape[1] *= 2 return tuple(shape) ######################################################################################################################## ### 3D Residual Blocks ################################################################################################# ######################################################################################################################## def identity_block_3D(input_tensor, filters, kernel_size=(3, 3, 3), stage=0, block=0, se_enabled=False, se_ratio=16): numFilters1, numFilters2 = filters if K.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 conv_name_base = 'res' + str(stage) + '_' + str(block) + '_branch' bn_name_base = 'bn' + str(stage) + '_' + str(block) + '_branch' x = Conv3D(filters=numFilters1, kernel_size=kernel_size, strides=(1, 1, 1), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2a')(input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = LeakyReLU(alpha=0.01)(x) x = Conv3D(filters=numFilters2, kernel_size=kernel_size, strides=(1, 1, 1), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2b')(x) # squeeze and excitation block if se_enabled: x = squeeze_excitation_block_3D(x, ratio=se_ratio) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Add()([x, input_tensor]) x = LeakyReLU(alpha=0.01)(x) return x def projection_block_3D(input_tensor, filters, kernel_size=(3, 3, 3), stage=0, block=0, se_enabled=False, se_ratio=16): numFilters1, numFilters2 = filters if K.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 conv_name_base = 'res' + str(stage) + '_' + str(block) + '_branch' bn_name_base = 'bn' + str(stage) + '_' + str(block) + '_branch' # downsampling directly by convolution with stride 2 x = Conv3D(filters=numFilters1, kernel_size=kernel_size, strides=(2, 2, 2), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2a')(input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = LeakyReLU(alpha=0.01)(x) x = Conv3D(filters=numFilters2, kernel_size=kernel_size, strides=(1, 1, 1), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2b')(x) # squeeze and excitation block if se_enabled: x = squeeze_excitation_block_3D(x, ratio=se_ratio) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) # projection shortcut convolution x_shortcut = Conv3D(filters=numFilters2, kernel_size=(2, 2, 2), strides=(2, 2, 2), padding='same', kernel_initializer='he_normal', name=conv_name_base + '1')(input_tensor) x_shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(x_shortcut) # addition of shortcut x = Add()([x, x_shortcut]) x = LeakyReLU(alpha=0.01)(x) return x ######################################################################################################################## ######################################################################################################################## ### Transposed 3D projection block ######################################################################################################################## def transposed_projection_block_3D(input_tensor, filters, kernel_size=(3, 3, 3), stage=0, block=0, se_enabled=False, se_ratio=16): numFilters1, numFilters2 = filters if K.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 transposed_conv_name_base = 'transposed_res' + str(stage) + '_' + str(block) + '_branch' conv_name_base = 'res' + str(stage) + '_' + str(block) + '_branch' bn_name_base = 'bn' + str(stage) + '_' + str(block) + '_branch' # upsampling directly by transposed convolution with stride 2 x = Conv3DTranspose(filters=numFilters1, kernel_size=kernel_size, strides=(2, 2, 2), padding='same', data_format=K.image_data_format(), kernel_initializer='he_normal', name=transposed_conv_name_base + '2a')(input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = LeakyReLU(alpha=0.01)(x) x = Conv3D(filters=numFilters2, kernel_size=kernel_size, strides=(1, 1, 1), padding='same', kernel_initializer='he_normal', name=conv_name_base + '2b')(x) # squeeze and excitation block if se_enabled: x = squeeze_excitation_block_3D(x, ratio=se_ratio) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) # projection shortcut transposed convolution x_shortcut = Conv3DTranspose(filters=numFilters2, kernel_size=(2, 2, 2), strides=(2, 2, 2), padding='same', data_format=K.image_data_format(), kernel_initializer='he_normal', name=transposed_conv_name_base + '2b')(input_tensor) x_shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(x_shortcut) # addition of shortcut x = Add()([x, x_shortcut]) x = LeakyReLU(alpha=0.01)(x) return x ########################################################################################################################