# -*- coding: utf-8 -*- import os from keras.layers import Flatten, Input from keras.layers.convolutional import Conv2D, Conv2DTranspose, ZeroPadding2D from keras.layers.core import Activation, Dense, Dropout, Permute, Reshape from keras.layers.merge import concatenate from keras.layers.normalization import BatchNormalization from keras.layers.pooling import AveragePooling2D, GlobalAveragePooling2D from keras.layers.wrappers import TimeDistributed from keras.models import Model from keras.regularizers import l2 from keras.utils import plot_model os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' def conv_block(input, growth_rate, dropout_rate=None, weight_decay=1e-4): x = BatchNormalization(axis=-1, epsilon=1.1e-5)(input) x = Activation('relu')(x) x = Conv2D(growth_rate, (3, 3),kernel_initializer='he_normal', padding='same')(x) if(dropout_rate): x = Dropout(dropout_rate)(x) return x def dense_block(x,nb_layers,nb_filter,growth_rate,droput_rate=0.2,weight_decay=1e-4): for i in range(nb_layers): cb = conv_block(x,growth_rate,droput_rate,weight_decay) x = concatenate([x,cb],axis=-1) nb_filter +=growth_rate return x ,nb_filter def transition_block(input,nb_filter,dropout_rate=None,pooltype=1,weight_decay=1e-4): x = BatchNormalization(axis=-1,epsilon=1.1e-5)(input) x = Activation('relu')(x) x = Conv2D(nb_filter,(1,1),kernel_initializer='he_normal', padding='same', use_bias=False, kernel_regularizer=l2(weight_decay))(x) if(dropout_rate): x = Dropout(dropout_rate)(x) if(pooltype==2): x = AveragePooling2D((2,2),strides=(2,2))(x) elif(pooltype==1): x = ZeroPadding2D(padding=(0,1))(x) x = AveragePooling2D((2,2),strides=(2,1))(x) elif(pooltype==3): x = AveragePooling2D((2,2),strides=(2,1))(x) return x,nb_filter def dense_cnn(input, nclass): _dropout_rate = 0.2 _weight_decay = 1e-4 _nb_filter = 64 # conv 64 5*5 s=2 x = Conv2D(_nb_filter, (5, 5), strides=(2, 2), kernel_initializer='he_normal', padding='same', use_bias=False, kernel_regularizer=l2(_weight_decay))(input) # 64 + 8 * 8 = 128 x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay) #128 x, _nb_filter = transition_block(x, 128, _dropout_rate, 2, _weight_decay) #128 + 8 * 8 = 192 x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay) #192->128 x, _nb_filter = transition_block(x, 128, _dropout_rate, 2, _weight_decay) #128 + 8 * 8 = 192 x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay) x = BatchNormalization(axis=-1, epsilon=1.1e-5)(x) x = Activation('relu')(x) x = Permute((2, 1, 3), name='permute')(x) x = TimeDistributed(Flatten(), name='flatten')(x) y_pred = Dense(nclass, name='out', activation='softmax')(x) basemodel = Model(inputs=input,outputs=y_pred) basemodel.summary() return basemodel if __name__ == '__main__': input = Input(shape=(32, 153, 1), name='the_input') cnn = dense_cnn(input, 1000) cnn.compile(optimizer='adam',loss='categorical_crossentropy', metrics=['accuracy']) plot_model(cnn, to_file='./models/densenet.png') cnn.summary()