from keras import backend as K from keras.layers import InputSpec import tensorflow as tf from keras.engine.topology import Layer # ------------------------------------------------------------------------------------- # Attention Layer from Self-Attention Generative Adversarial Networks # Paper: https://arxiv.org/abs/1805.08318 # Author of the layer: Hao Chen # Source: https://stackoverflow.com/questions/50819931/self-attention-gan-in-keras # ------------------------------------------------------------------------------------- class Attention(Layer): def __init__(self, ch, **kwargs): super(Attention, self).__init__(**kwargs) self.channels = ch self.filters_f_g = self.channels // 8 self.filters_h = self.channels def build(self, input_shape): kernel_shape_f_g = (1, 1) + (self.channels, self.filters_f_g) kernel_shape_h = (1, 1) + (self.channels, self.filters_h) # Create a trainable weight variable for this layer: self.gamma = self.add_weight(name='gamma', shape=[1], initializer='zeros', trainable=True) self.kernel_f = self.add_weight(shape=kernel_shape_f_g, initializer='glorot_uniform', name='kernel_f') self.kernel_g = self.add_weight(shape=kernel_shape_f_g, initializer='glorot_uniform', name='kernel_g') self.kernel_h = self.add_weight(shape=kernel_shape_h, initializer='glorot_uniform', name='kernel_h') self.bias_f = self.add_weight(shape=(self.filters_f_g,), initializer='zeros', name='bias_F') self.bias_g = self.add_weight(shape=(self.filters_f_g,), initializer='zeros', name='bias_g') self.bias_h = self.add_weight(shape=(self.filters_h,), initializer='zeros', name='bias_h') super(Attention, self).build(input_shape) # Set input spec. self.input_spec = InputSpec(ndim=4, axes={3: input_shape[-1]}) self.built = True def call(self, x): def hw_flatten(x): return K.reshape(x, shape=[K.shape(x)[0], K.shape(x)[1]*K.shape(x)[2], K.shape(x)[-1]]) f = K.conv2d(x, kernel=self.kernel_f, strides=(1, 1), padding='same') # [bs, h, w, c'] f = K.bias_add(f, self.bias_f) g = K.conv2d(x, kernel=self.kernel_g, strides=(1, 1), padding='same') # [bs, h, w, c'] g = K.bias_add(g, self.bias_g) h = K.conv2d(x, kernel=self.kernel_h, strides=(1, 1), padding='same') # [bs, h, w, c] h = K.bias_add(h, self.bias_h) s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] beta = K.softmax(s, axis=-1) # attention map o = K.batch_dot(beta, hw_flatten(h)) # [bs, N, C] o = K.reshape(o, shape=K.shape(x)) # [bs, h, w, C] x = self.gamma * o + x return x def compute_output_shape(self, input_shape): return input_shape