from keras.models import Model from keras import layers from keras.layers import Input from keras.layers import Lambda from keras.layers import Activation from keras.layers import Concatenate from keras.layers import Add from keras.layers import Dropout from keras.layers import BatchNormalization from keras.layers import Conv2D from keras.layers import DepthwiseConv2D from keras.layers import ZeroPadding2D from keras.layers import GlobalAveragePooling2D def _conv2d_same(x, filters, prefix, stride=1, kernel_size=3, rate=1): # 计算padding的数量,hw是否需要收缩 if stride == 1: return Conv2D(filters, (kernel_size, kernel_size), strides=(stride, stride), padding='same', use_bias=False, dilation_rate=(rate, rate), name=prefix)(x) else: kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) pad_total = kernel_size_effective - 1 pad_beg = pad_total // 2 pad_end = pad_total - pad_beg x = ZeroPadding2D((pad_beg, pad_end))(x) return Conv2D(filters, (kernel_size, kernel_size), strides=(stride, stride), padding='valid', use_bias=False, dilation_rate=(rate, rate), name=prefix)(x) def SepConv_BN(x, filters, prefix, stride=1, kernel_size=3, rate=1, depth_activation=False, epsilon=1e-3): # 计算padding的数量,hw是否需要收缩 if stride == 1: depth_padding = 'same' else: kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) pad_total = kernel_size_effective - 1 pad_beg = pad_total // 2 pad_end = pad_total - pad_beg x = ZeroPadding2D((pad_beg, pad_end))(x) depth_padding = 'valid' # 如果需要激活函数 if not depth_activation: x = Activation('relu')(x) # 分离卷积,首先3x3分离卷积,再1x1卷积 # 3x3采用膨胀卷积 x = DepthwiseConv2D((kernel_size, kernel_size), strides=(stride, stride), dilation_rate=(rate, rate), padding=depth_padding, use_bias=False, name=prefix + '_depthwise')(x) x = BatchNormalization(name=prefix + '_depthwise_BN', epsilon=epsilon)(x) if depth_activation: x = Activation('relu')(x) # 1x1卷积,进行压缩 x = Conv2D(filters, (1, 1), padding='same', use_bias=False, name=prefix + '_pointwise')(x) x = BatchNormalization(name=prefix + '_pointwise_BN', epsilon=epsilon)(x) if depth_activation: x = Activation('relu')(x) return x def _xception_block(inputs, depth_list, prefix, skip_connection_type, stride, rate=1, depth_activation=False, return_skip=False): residual = inputs for i in range(3): residual = SepConv_BN(residual, depth_list[i], prefix + '_separable_conv{}'.format(i + 1), stride=stride if i == 2 else 1, rate=rate, depth_activation=depth_activation) if i == 1: skip = residual if skip_connection_type == 'conv': shortcut = _conv2d_same(inputs, depth_list[-1], prefix + '_shortcut', kernel_size=1, stride=stride) shortcut = BatchNormalization(name=prefix + '_shortcut_BN')(shortcut) outputs = layers.add([residual, shortcut]) elif skip_connection_type == 'sum': outputs = layers.add([residual, inputs]) elif skip_connection_type == 'none': outputs = residual if return_skip: return outputs, skip else: return outputs def Xception(inputs,alpha=1,OS=16): if OS == 8: entry_block3_stride = 1 middle_block_rate = 2 # ! Not mentioned in paper, but required exit_block_rates = (2, 4) atrous_rates = (12, 24, 36) else: entry_block3_stride = 2 middle_block_rate = 1 exit_block_rates = (1, 2) atrous_rates = (6, 12, 18) # 256,256,32 x = Conv2D(32, (3, 3), strides=(2, 2), name='entry_flow_conv1_1', use_bias=False, padding='same')(inputs) x = BatchNormalization(name='entry_flow_conv1_1_BN')(x) x = Activation('relu')(x) # 256,256,64 x = _conv2d_same(x, 64, 'entry_flow_conv1_2', kernel_size=3, stride=1) x = BatchNormalization(name='entry_flow_conv1_2_BN')(x) x = Activation('relu')(x) # 256,256,128 -> 256,256,128 -> 128,128,128 x = _xception_block(x, [128, 128, 128], 'entry_flow_block1', skip_connection_type='conv', stride=2, depth_activation=False) # 128,128,256 -> 128,128,256 -> 64,64,256 # skip = 128,128,256 x, skip1 = _xception_block(x, [256, 256, 256], 'entry_flow_block2', skip_connection_type='conv', stride=2, depth_activation=False, return_skip=True) x = _xception_block(x, [728, 728, 728], 'entry_flow_block3', skip_connection_type='conv', stride=entry_block3_stride, depth_activation=False) for i in range(16): x = _xception_block(x, [728, 728, 728], 'middle_flow_unit_{}'.format(i + 1), skip_connection_type='sum', stride=1, rate=middle_block_rate, depth_activation=False) x = _xception_block(x, [728, 1024, 1024], 'exit_flow_block1', skip_connection_type='conv', stride=1, rate=exit_block_rates[0], depth_activation=False) x = _xception_block(x, [1536, 1536, 2048], 'exit_flow_block2', skip_connection_type='none', stride=1, rate=exit_block_rates[1], depth_activation=True) return x,atrous_rates,skip1