import tensorflow as tf from tensorflow.keras import backend as K, initializers, regularizers, constraints from tensorflow.keras.layers import Layer, Dense from spektral.layers import ops class GlobalPooling(Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.supports_masking = True self.pooling_op = None self.batch_pooling_op = None def build(self, input_shape): if isinstance(input_shape, list) and len(input_shape) == 2: self.data_mode = 'disjoint' else: if len(input_shape) == 2: self.data_mode = 'single' else: self.data_mode = 'batch' super().build(input_shape) def call(self, inputs): if self.data_mode == 'disjoint': X = inputs[0] I = inputs[1] if K.ndim(I) == 2: I = I[:, 0] else: X = inputs if self.data_mode == 'disjoint': return self.pooling_op(X, I) else: return self.batch_pooling_op(X, axis=-2, keepdims=(self.data_mode == 'single')) def compute_output_shape(self, input_shape): if self.data_mode == 'single': return (1,) + input_shape[-1:] elif self.data_mode == 'batch': return input_shape[:-2] + input_shape[-1:] else: # Input shape is a list of shapes for X and I return input_shape[0] def get_config(self): return super().get_config() class GlobalSumPool(GlobalPooling): """ A global sum pooling layer. Pools a graph by computing the sum of its node features. **Mode**: single, disjoint, mixed, batch. **Input** - Node features of shape `([batch], N, F)`; - Graph IDs of shape `(N, )` (only in disjoint mode); **Output** - Pooled node features of shape `(batch, F)` (if single mode, shape will be `(1, F)`). **Arguments** None. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.pooling_op = tf.math.segment_sum self.batch_pooling_op = tf.reduce_sum class GlobalAvgPool(GlobalPooling): """ An average pooling layer. Pools a graph by computing the average of its node features. **Mode**: single, disjoint, mixed, batch. **Input** - Node features of shape `([batch], N, F)`; - Graph IDs of shape `(N, )` (only in disjoint mode); **Output** - Pooled node features of shape `(batch, F)` (if single mode, shape will be `(1, F)`). **Arguments** None. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.pooling_op = tf.math.segment_mean self.batch_pooling_op = tf.reduce_mean class GlobalMaxPool(GlobalPooling): """ A max pooling layer. Pools a graph by computing the maximum of its node features. **Mode**: single, disjoint, mixed, batch. **Input** - Node features of shape `([batch], N, F)`; - Graph IDs of shape `(N, )` (only in disjoint mode); **Output** - Pooled node features of shape `(batch, F)` (if single mode, shape will be `(1, F)`). **Arguments** None. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.pooling_op = tf.math.segment_max self.batch_pooling_op = tf.reduce_max class GlobalAttentionPool(GlobalPooling): r""" A gated attention global pooling layer as presented by [Li et al. (2017)](https://arxiv.org/abs/1511.05493). This layer computes: $$ \X' = \sum\limits_{i=1}^{N} (\sigma(\X \W_1 + \b_1) \odot (\X \W_2 + \b_2))_i $$ where \(\sigma\) is the sigmoid activation function. **Mode**: single, disjoint, mixed, batch. **Input** - Node features of shape `([batch], N, F)`; - Graph IDs of shape `(N, )` (only in disjoint mode); **Output** - Pooled node features of shape `(batch, channels)` (if single mode, shape will be `(1, channels)`). **Arguments** - `channels`: integer, number of output channels; - `bias_initializer`: initializer for the bias vectors; - `kernel_regularizer`: regularization applied to the kernel matrices; - `bias_regularizer`: regularization applied to the bias vectors; - `kernel_constraint`: constraint applied to the kernel matrices; - `bias_constraint`: constraint applied to the bias vectors. """ def __init__(self, channels, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): super().__init__(**kwargs) self.channels = channels self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) def build(self, input_shape): super().build(input_shape) layer_kwargs = dict( kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer, kernel_regularizer=self.kernel_regularizer, bias_regularizer=self.bias_regularizer, kernel_constraint=self.kernel_constraint, bias_constraint=self.bias_constraint ) self.features_layer = Dense(self.channels, name='features_layer', **layer_kwargs) self.attention_layer = Dense(self.channels, activation='sigmoid', name='attn_layer', **layer_kwargs) self.built = True def call(self, inputs): if self.data_mode == 'disjoint': X, I = inputs if K.ndim(I) == 2: I = I[:, 0] else: X = inputs inputs_linear = self.features_layer(X) attn = self.attention_layer(X) masked_inputs = inputs_linear * attn if self.data_mode in {'single', 'batch'}: output = K.sum(masked_inputs, axis=-2, keepdims=self.data_mode == 'single') else: output = tf.math.segment_sum(masked_inputs, I) return output def compute_output_shape(self, input_shape): if self.data_mode == 'single': return (1,) + (self.channels,) elif self.data_mode == 'batch': return input_shape[:-2] + (self.channels,) else: output_shape = input_shape[0] output_shape = output_shape[:-1] + (self.channels,) return output_shape def get_config(self): config = { 'channels': self.channels, 'kernel_initializer': self.kernel_initializer, 'bias_initializer': self.bias_initializer, 'kernel_regularizer': self.kernel_regularizer, 'bias_regularizer': self.bias_regularizer, 'kernel_constraint': self.kernel_constraint, 'bias_constraint': self.bias_constraint, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) class GlobalAttnSumPool(GlobalPooling): r""" A node-attention global pooling layer. Pools a graph by learning attention coefficients to sum node features. This layer computes: $$ \alpha = \textrm{softmax}( \X \a); \\ \X' = \sum\limits_{i=1}^{N} \alpha_i \cdot \X_i $$ where \(\a \in \mathbb{R}^F\) is a trainable vector. Note that the softmax is applied across nodes, and not across features. **Mode**: single, disjoint, mixed, batch. **Input** - Node features of shape `([batch], N, F)`; - Graph IDs of shape `(N, )` (only in disjoint mode); **Output** - Pooled node features of shape `(batch, F)` (if single mode, shape will be `(1, F)`). **Arguments** - `attn_kernel_initializer`: initializer for the attention weights; - `attn_kernel_regularizer`: regularization applied to the attention kernel matrix; - `attn_kernel_constraint`: constraint applied to the attention kernel matrix; """ def __init__(self, attn_kernel_initializer='glorot_uniform', attn_kernel_regularizer=None, attn_kernel_constraint=None, **kwargs): super().__init__(**kwargs) self.attn_kernel_initializer = initializers.get( attn_kernel_initializer) self.attn_kernel_regularizer = regularizers.get( attn_kernel_regularizer) self.attn_kernel_constraint = constraints.get(attn_kernel_constraint) def build(self, input_shape): assert len(input_shape) >= 2 if isinstance(input_shape, list) and len(input_shape) == 2: self.data_mode = 'disjoint' F = input_shape[0][-1] else: if len(input_shape) == 2: self.data_mode = 'single' else: self.data_mode = 'batch' F = input_shape[-1] # Attention kernels self.attn_kernel = self.add_weight(shape=(F, 1), initializer=self.attn_kernel_initializer, regularizer=self.attn_kernel_regularizer, constraint=self.attn_kernel_constraint, name='attn_kernel') self.built = True def call(self, inputs): if self.data_mode == 'disjoint': X, I = inputs if K.ndim(I) == 2: I = I[:, 0] else: X = inputs attn_coeff = K.dot(X, self.attn_kernel) attn_coeff = K.squeeze(attn_coeff, -1) attn_coeff = K.softmax(attn_coeff) if self.data_mode == 'single': output = K.dot(attn_coeff[None, ...], X) elif self.data_mode == 'batch': output = K.batch_dot(attn_coeff, X) else: output = attn_coeff[:, None] * X output = tf.math.segment_sum(output, I) return output def get_config(self): config = { 'attn_kernel_initializer': self.attn_kernel_initializer, 'attn_kernel_regularizer': self.attn_kernel_regularizer, 'attn_kernel_constraint': self.attn_kernel_constraint, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) class SortPool(Layer): r""" A SortPool layer as described by [Zhang et al](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf). This layers takes a graph signal \(\mathbf{X}\) and returns the topmost k rows according to the last column. If \(\mathbf{X}\) has less than k rows, the result is zero-padded to k. **Mode**: single, disjoint, batch. **Input** - Node features of shape `([batch], N, F)`; - Graph IDs of shape `(N, )` (only in disjoint mode); **Output** - Pooled node features of shape `(batch, k, F)` (if single mode, shape will be `(1, k, F)`). **Arguments** - `k`: integer, number of nodes to keep; """ def __init__(self, k): super(SortPool, self).__init__() k = int(k) if k <= 0: raise ValueError("K must be a positive integer") self.k = k def build(self, input_shape): if isinstance(input_shape, list) and len(input_shape) == 2: self.data_mode = 'disjoint' self.F = input_shape[0][-1] else: if len(input_shape) == 2: self.data_mode = 'single' else: self.data_mode = 'batch' self.F = input_shape[-1] def call(self, inputs): if self.data_mode == 'disjoint': X, I = inputs X = ops.disjoint_signal_to_batch(X, I) else: X = inputs if self.data_mode == 'single': X = tf.expand_dims(X, 0) N = tf.shape(X)[-2] sort_perm = tf.argsort(X[..., -1], direction='DESCENDING') X_sorted = tf.gather(X, sort_perm, axis=-2, batch_dims=1) def truncate(): _X_out = X_sorted[..., : self.k, :] return _X_out def pad(): padding = [[0, 0], [0, self.k - N], [0, 0]] _X_out = tf.pad(X_sorted, padding) return _X_out X_out = tf.cond(tf.less_equal(self.k, N), truncate, pad) if self.data_mode == 'single': X_out = tf.squeeze(X_out, [0]) X_out.set_shape((self.k, self.F)) elif self.data_mode == 'batch' or self.data_mode == 'disjoint': X_out.set_shape((None, self.k, self.F)) return X_out def get_config(self): config = { 'k': self.k, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): if self.data_mode == 'single': return self.k, self.F elif self.data_mode == 'batch' or self.data_mode == 'disjoint': return input_shape[0], self.k, self.F