from keras.layers import ReLU, BatchNormalization from keras import backend as K from octave_conv import initial_octconv, final_octconv, octconv_block def initial_oct_conv_bn_relu(ip, filters, kernel_size=(3, 3), strides=(1, 1), alpha=0.5, padding='same', dilation=None, bias=False, activation=True): channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 x_high, x_low = initial_octconv(ip, filters, kernel_size, strides, alpha, padding, dilation, bias) relu = ReLU() x_high = BatchNormalization(axis=channel_axis)(x_high) if activation: x_high = relu(x_high) x_low = BatchNormalization(axis=channel_axis)(x_low) if activation: x_low = relu(x_low) return x_high, x_low def final_oct_conv_bn_relu(ip_high, ip_low, filters, kernel_size=(3, 3), strides=(1, 1), padding='same', dilation=None, bias=False, activation=True): channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 x = final_octconv(ip_high, ip_low, filters, kernel_size, strides, padding, dilation, bias) x = BatchNormalization(axis=channel_axis)(x) if activation: x = ReLU()(x) return x def oct_conv_bn_relu(ip_high, ip_low, filters, kernel_size=(3, 3), strides=(1, 1), alpha=0.5, padding='same', dilation=None, bias=False, activation=True): channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 x_high, x_low = octconv_block(ip_high, ip_low, filters, kernel_size, strides, alpha, padding, dilation, bias) relu = ReLU() x_high = BatchNormalization(axis=channel_axis)(x_high) if activation: x_high = relu(x_high) x_low = BatchNormalization(axis=channel_axis)(x_low) if activation: x_low = relu(x_low) return x_high, x_low