from keras import backend as K, initializers, regularizers, constraints from keras.backend import image_data_format from keras.backend.tensorflow_backend import _preprocess_conv2d_input, _preprocess_padding from keras.engine.topology import InputSpec import tensorflow as tf from keras.layers import Conv2D from keras.legacy.interfaces import conv2d_args_preprocessor, generate_legacy_interface from keras.utils import conv_utils # This code mostly is taken form Keras: Separable Convolution Layer source code and changed according to needs. def depthwise_conv2d_args_preprocessor(args, kwargs): converted = [] if 'init' in kwargs: init = kwargs.pop('init') kwargs['depthwise_initializer'] = init converted.append(('init', 'depthwise_initializer')) args, kwargs, _converted = conv2d_args_preprocessor(args, kwargs) return args, kwargs, converted + _converted legacy_depthwise_conv2d_support = generate_legacy_interface( allowed_positional_args=['filters', 'kernel_size'], conversions=[('nb_filter', 'filters'), ('subsample', 'strides'), ('border_mode', 'padding'), ('dim_ordering', 'data_format'), ('b_regularizer', 'bias_regularizer'), ('b_constraint', 'bias_constraint'), ('bias', 'use_bias')], value_conversions={'dim_ordering': {'tf': 'channels_last', 'th': 'channels_first', 'default': None}}, preprocessor=depthwise_conv2d_args_preprocessor) class DepthwiseConv2D(Conv2D): @legacy_depthwise_conv2d_support def __init__(self, filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, depth_multiplier=1, activation=None, use_bias=True, depthwise_initializer='glorot_uniform', bias_initializer='zeros', depthwise_regularizer=None, bias_regularizer=None, activity_regularizer=None, depthwise_constraint=None, bias_constraint=None, **kwargs): super(DepthwiseConv2D, self).__init__( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, activation=activation, use_bias=use_bias, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, bias_constraint=bias_constraint, **kwargs) self.depth_multiplier = depth_multiplier self.depthwise_initializer = initializers.get(depthwise_initializer) self.depthwise_regularizer = regularizers.get(depthwise_regularizer) self.depthwise_constraint = constraints.get(depthwise_constraint) def build(self, input_shape): if len(input_shape) < 4: raise ValueError('Inputs to `SeparableConv2D` should have rank 4. ' 'Received input shape:', str(input_shape)) if self.data_format == 'channels_first': channel_axis = 1 else: channel_axis = 3 if input_shape[channel_axis] is None: raise ValueError('The channel dimension of the inputs to ' '`SeparableConv2D` ' 'should be defined. Found `None`.') input_dim = int(input_shape[channel_axis]) depthwise_kernel_shape = (self.kernel_size[0], self.kernel_size[1], input_dim, self.depth_multiplier) self.depthwise_kernel = self.add_weight( shape=depthwise_kernel_shape, initializer=self.depthwise_initializer, name='depthwise_kernel', regularizer=self.depthwise_regularizer, constraint=self.depthwise_constraint) if self.use_bias: self.bias = self.add_weight(shape=(self.filters,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.bias = None # Set input spec. self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim}) self.built = True def call(self, inputs): if self.data_format is None: data_format = image_data_format() if self.data_format not in {'channels_first', 'channels_last'}: raise ValueError('Unknown data_format ' + str(data_format)) x = _preprocess_conv2d_input(inputs, self.data_format) padding = _preprocess_padding(self.padding) strides = (1,) + self.strides + (1,) outputs = tf.nn.depthwise_conv2d(inputs, self.depthwise_kernel, strides=strides, padding=padding, rate=self.dilation_rate) if self.bias: outputs = K.bias_add( outputs, self.bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs def compute_output_shape(self, input_shape): if self.data_format == 'channels_first': rows = input_shape[2] cols = input_shape[3] elif self.data_format == 'channels_last': rows = input_shape[1] cols = input_shape[2] rows = conv_utils.conv_output_length(rows, self.kernel_size[0], self.padding, self.strides[0]) cols = conv_utils.conv_output_length(cols, self.kernel_size[1], self.padding, self.strides[1]) if self.data_format == 'channels_first': return (input_shape[0], self.filters, rows, cols) elif self.data_format == 'channels_last': return (input_shape[0], rows, cols, self.filters) def get_config(self): config = super(DepthwiseConv2D, self).get_config() config.pop('kernel_initializer') config.pop('kernel_regularizer') config.pop('kernel_constraint') config['depth_multiplier'] = self.depth_multiplier config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer) config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer) config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint) return config DepthwiseConvolution2D = DepthwiseConv2D