from __future__ import absolute_import from keras.engine import InputSpec from keras.layers import Wrapper, Merge class Residual(Wrapper): """This wrapper automatically applies a residual to a model. For an input `x` and a model `F(x)`, the residual wrapper gives the output `y = x + F(x)`. In this configuration, the output of F(x) must have the same shape as x. Other merge modes are supported besides summation. input = Input(shape=(5,)) # Apply the residual normally output1 = Residual(Dense(5), merge_mode='sum')(input) # Throws an exception due to mismatching shapes output2 = Residual(Dense(3), merge_mode='sum')(input) # Product: `y = x * F(x)` output3 = Residual(Dense(5), merge_mode='mul')(input) For more modes, see: https://keras.io/layers/core/#merge Alternatively, a function which takes the input and the layer output can be passed to define the merge: from keras.layers import Merge def diff_merge(): # x_fx = [x, fx] diff = lambda x: x[1] - x[0] return Merge(mode=diff, output_shape=lambda x: x[0]) # Difference: `y = F(x) - x` output4 = Residual(Dense(5), merge_mode=diff_merge())(input) Args: layer: The layer to wrap merge_mode: The merge operation """ def __init__(self, layer, merge_mode='sum', **kwargs): self.merge_mode = merge_mode self.supports_masking = True super(Residual, self).__init__(layer, **kwargs) def build(self, input_shape): output_shape = self.layer.get_output_shape_for(input_shape) if output_shape != input_shape: raise Exception('Cannot apply residual to layer "{}": ' 'mismatching input and output shapes' 'input="{}" and output="{}"' .format(self.layer.name, input_shape, output_shape)) if not self.layer.built: self.layer.build(input_shape) self.layer.built = True self.input_spec = [InputSpec(shape=input_shape)] super(Residual, self).build() def get_output_shape_for(self, input_shape): return input_shape def call(self, x, mask=None): layer_output = self.layer.call(x, mask) if isinstance(self.merge_mode, str): self.merge_mode = Merge(mode=self.merge_mode) output = self.merge_mode([x, layer_output]) return output @classmethod def from_config(cls, config): from keras.utils.layer_utils import layer_from_config merge_mode = layer_from_config(config.pop('merge_mode')) residual = super(Residual, cls).from_config(config) residual.merge_mode = merge_mode return residual def get_config(self): config = {'merge_mode': {'class_name': 'Merge', 'config': self.merge_mode.get_config()}} base_config = super(Residual, self).get_config() return dict(list(base_config.items()) + list(config.items()))