from keras.layers import Input from keras import layers from keras.layers import Activation from keras.layers import Conv2D from keras.layers import MaxPooling2D from keras.layers import ZeroPadding2D from keras.layers import BatchNormalization from keras.models import Model import keras.backend as K from networks.classifiers import build_pyramid_pooling_module def BN(axis, name=""): return BatchNormalization(axis=axis, momentum=0.1, name=name, epsilon=1e-5) def identity_block(input_tensor, kernel_size, filters, stage, block, dilation=1): """The identity block is the block that has no conv layer at shortcut. # Arguments input_tensor: input tensor kernel_size: defualt 3, the kernel size of middle conv layer at main path filters: list of integers, the filterss of 3 conv layer at main path stage: integer, current stage label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names dilation: dilation of the intermediate convolution # Returns Output tensor for the block. """ filters1, filters2, filters3 = filters if K.image_data_format() == 'channels_last': bn_axis = 3 else: bn_axis = 1 conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a', use_bias=False)(input_tensor) x = BN(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b', use_bias=False, dilation_rate=dilation)(x) x = BN(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c', use_bias=False)(x) x = BN(axis=bn_axis, name=bn_name_base + '2c')(x) x = layers.add([x, input_tensor]) x = Activation('relu')(x) return x def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(1, 1), dilation=1): """conv_block is the block that has a conv layer at shortcut # Arguments input_tensor: input tensor kernel_size: defualt 3, the kernel size of middle conv layer at main path filters: list of integers, the filterss of 3 conv layer at main path stage: integer, current stage label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names # Returns Output tensor for the block. Note that from stage 3, the first conv layer at main path is with strides=(2,2) And the shortcut should have strides=(2,2) as well """ filters1, filters2, filters3 = filters if K.image_data_format() == 'channels_last': bn_axis = 3 else: bn_axis = 1 conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' x = Conv2D(filters1, (1, 1), strides=strides, name=conv_name_base + '2a', use_bias=False)(input_tensor) x = BN(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b', use_bias=False, dilation_rate=dilation)(x) x = BN(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c', use_bias=False)(x) x = BN(axis=bn_axis, name=bn_name_base + '2c')(x) shortcut = Conv2D(filters3, (1, 1), strides=strides, name=conv_name_base + '1', use_bias=False)(input_tensor) shortcut = BN(axis=bn_axis, name=bn_name_base + '1')(shortcut) x = layers.add([x, shortcut]) x = Activation('relu')(x) return x def ResNet101(input_tensor=None): img_input = input_tensor if K.image_data_format() == 'channels_last': bn_axis = 3 else: bn_axis = 1 x = ZeroPadding2D((3, 3))(img_input) x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=False)(x) x = BN(axis=bn_axis, name='bn_conv1')(x) x = Activation('relu')(x) x = ZeroPadding2D((1, 1))(x) x = MaxPooling2D((3, 3), strides=(2, 2), padding='valid')(x) x = conv_block(x, 3, [64, 64, 256], stage=2, block='a') x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', strides=(2, 2)) x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='g', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='h', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='i', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='j', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='k', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='l', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='m', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='n', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='o', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='p', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='q', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='r', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='s', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='t', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='u', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='v', dilation=2) x = identity_block(x, 3, [256, 256, 1024], stage=4, block='w', dilation=2) x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', dilation=4) x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', dilation=4) x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', dilation=4) return x def build_network(nb_classes, input_shape, resnet_layers=101, classifier='psp', sigmoid=False, output_size=None, num_input_channels=4): """Build Network""" inp = Input((input_shape[0], input_shape[1], num_input_channels)) if resnet_layers == 101: res = ResNet101(inp) else: ValueError('Resnet {} does not exist'.format(resnet_layers)) if classifier == 'psp': print("Building network based on ResNet %i and PSP module expecting inputs of shape %s predicting %i classes" % ( resnet_layers, input_shape, nb_classes)) x = build_pyramid_pooling_module(res, input_shape, nb_classes, sigmoid=sigmoid, output_size=output_size) else: raise ValueError('Classifier not implemented.') model = Model(inputs=inp, outputs=x) return model