from keras.models import Model from keras.layers import Input, Add, Activation, Dropout, GlobalAveragePooling2D, Dense from keras.layers.convolutional import Convolution2D from keras.layers.normalization import BatchNormalization from keras import backend as K def initial_conv(input): x = Convolution2D(16, (3, 3), padding='same', kernel_initializer='he_normal', use_bias=False)(input) channel_axis = 1 if K.image_data_format() == "channels_first" else -1 x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x) x = Activation('relu')(x) return x def expand_conv(init, base, k, strides=(1, 1)): x = Convolution2D(base * k, (3, 3), padding='same', strides=strides, kernel_initializer='he_normal', use_bias=False)(init) channel_axis = 1 if K.image_data_format() == "channels_first" else -1 x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x) x = Activation('relu')(x) x = Convolution2D(base * k, (3, 3), padding='same', kernel_initializer='he_normal', use_bias=False)(x) skip = Convolution2D(base * k, (1, 1), padding='same', strides=strides, kernel_initializer='he_normal', use_bias=False)(init) m = Add()([x, skip]) return m def conv_block(input, base, k=1, dropout=0.0): init = input channel_axis = 1 if K.image_data_format() == "channels_first" else -1 x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(input) x = Activation('relu')(x) x = Convolution2D(base * k, (3, 3), padding='same', kernel_initializer='he_normal', use_bias=False)(x) if dropout > 0.0: x = Dropout(dropout)(x) x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x) x = Activation('relu')(x) x = Convolution2D(base * k, (3, 3), padding='same', kernel_initializer='he_normal', use_bias=False)(x) m = Add()([init, x]) return m def create_wide_residual_network(input_dim, nb_classes=100, N=2, k=1, dropout=0.0, final_activation='softmax', verbose=1, name=None): """ Creates a Wide Residual Network with specified parameters :param input: Input Keras object :param nb_classes: Number of output classes :param N: Depth of the network. Compute N = (n - 4) / 6. Example : For a depth of 16, n = 16, N = (16 - 4) / 6 = 2 Example2: For a depth of 28, n = 28, N = (28 - 4) / 6 = 4 Example3: For a depth of 40, n = 40, N = (40 - 4) / 6 = 6 :param k: Width of the network. :param dropout: Adds dropout if value is greater than 0.0 :param final_activation: Activation function of last layer :param verbose: Debug info to describe created WRN :return: """ channel_axis = 1 if K.image_data_format() == "channels_first" else -1 ip = Input(shape=input_dim) x = initial_conv(ip) nb_conv = 4 for block_index, base in enumerate([16, 32, 64]): x = expand_conv(x, base, k, strides=(2, 2) if block_index > 0 else (1, 1)) for i in range(N - 1): x = conv_block(x, base, k, dropout) nb_conv += 2 x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x) x = Activation('relu')(x) x = GlobalAveragePooling2D()(x) x = Dense(nb_classes, activation=final_activation, name = 'prob' if final_activation == 'softmax' else 'embedding')(x) model = Model(ip, x, name=name) if verbose: print("Wide Residual Network-%d-%d created." % (nb_conv, k)) return model if __name__ == "__main__": from keras.utils import plot_model from keras.layers import Input from keras.models import Model init = (32, 32, 3) wrn_28_10 = create_wide_residual_network(init, nb_classes=10, N=2, k=2, dropout=0.0) wrn_28_10.summary() plot_model(wrn_28_10, "WRN-16-2.png", show_shapes=True, show_layer_names=True)