import tensorflow.keras.backend as kb from megnet.layers.graph.base import GraphNetworkLayer from megnet.activations import softplus2 import tensorflow as tf class InteractionLayer(GraphNetworkLayer): """ The Continuous filter InteractionLayer in Schnet Schütt et al. SchNet: A continuous-filter convolutional neural network for modeling quantum interactions Args: activation (str): Default: None. The activation function used for each sub-neural network. Examples include 'relu', 'softmax', 'tanh', 'sigmoid' and etc. use_bias (bool): Default: True. Whether to use the bias term in the neural network. kernel_initializer (str): Default: 'glorot_uniform'. Initialization function for the layer kernel weights, bias_initializer (str): Default: 'zeros' activity_regularizer (str): Default: None. The regularization function for the output kernel_constraint (str): Default: None. Keras constraint for kernel values bias_constraint (str): Default: None .Keras constraint for bias values Methods: call(inputs, mask=None): the logic of the layer, returns the final graph compute_output_shape(input_shape): compute static output shapes, returns list of tuple shapes build(input_shape): initialize the weights and biases for each function phi_e(inputs): update function for bonds and returns updated bond attribute e_p rho_e_v(e_p, inputs): aggregate updated bonds e_p to per atom attributes, b_e_p phi_v(b_e_p, inputs): update the atom attributes by the results from previous step b_e_p and all the inputs returns v_p. rho_e_u(e_p, inputs): aggregate bonds to global attribute rho_v_u(v_p, inputs): aggregate atom to global attributes get_config(): part of keras interface for serialization """ def __init__(self, activation=softplus2, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): super().__init__(activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint, **kwargs) def build(self, input_shapes): vdim = input_shapes[0][2] edim = input_shapes[1][2] with kb.name_scope(self.name): with kb.name_scope('phi_e'): e_shapes = [[edim, vdim]] + [[vdim, vdim]] * 2 self.phi_e_weights = [self.add_weight(shape=i, initializer=self.kernel_initializer, name='weight_v_%d' % j, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) for j, i in enumerate(e_shapes)] if self.use_bias: self.phi_e_biases = [self.add_weight(shape=(i[-1],), initializer=self.bias_initializer, name='bias_v_%d' % j, regularizer=self.bias_regularizer, constraint=self.bias_constraint) for j, i in enumerate(e_shapes)] else: self.phi_e_biases = None with kb.name_scope(self.name): with kb.name_scope('phi_v'): v_shapes = [[vdim, vdim]] + [[vdim, vdim]] * 2 self.phi_v_weights = [self.add_weight(shape=i, initializer=self.kernel_initializer, name='weight_v_%d' % j, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) for j, i in enumerate(v_shapes)] if self.use_bias: self.phi_v_biases = [self.add_weight(shape=(i[-1],), initializer=self.bias_initializer, name='bias_v_%d' % j, regularizer=self.bias_regularizer, constraint=self.bias_constraint) for j, i in enumerate(v_shapes)] else: self.phi_v_biases = None self.built = True def compute_output_shape(self, input_shape): return input_shape def phi_e(self, inputs): nodes, edges, u, index1, index2, gnode, gbond = inputs return edges def rho_e_v(self, e_p, inputs): """ Reduce edge attributes to node attribute, eqn 5 in the paper Args: e_p: updated bond inputs: the whole input list Returns: summed tensor """ nodes, edges, u, index1, index2, gnode, gbond = inputs atomwise1 = self._mlp(nodes, self.phi_v_weights[0], self.phi_v_biases[0]) cfconv1 = self.activation(self._mlp(edges, self.phi_e_weights[0], self.phi_e_biases[0])) cfconv2 = self.activation(self._mlp(cfconv1, self.phi_e_weights[1], self.phi_e_biases[1])) cfconv_out = self._mlp(cfconv2, self.phi_e_weights[2], self.phi_e_biases[2]) index1 = tf.reshape(index1, (-1,)) index2 = tf.reshape(index2, (-1,)) fr = tf.gather(atomwise1, index2, axis=1) after_cfconv = atomwise1 + \ tf.transpose(a=tf.math.segment_sum(tf.transpose(a=fr * cfconv_out, perm=[1, 0, 2]), index1), perm=[1, 0, 2]) atomwise2 = self.activation(self._mlp(after_cfconv, self.phi_v_weights[1], self.phi_v_biases[1])) atomwise3 = self._mlp(atomwise2, self.phi_v_weights[2], self.phi_v_biases[2]) return atomwise3 def phi_v(self, b_ei_p, inputs): nodes, edges, u, index1, index2, gnode, gbond = inputs return nodes + b_ei_p def rho_e_u(self, e_p, inputs): return 0 def rho_v_u(self, v_p, inputs): return 0 def phi_u(self, b_e_p, b_v_p, inputs): return inputs[2] def _mlp(self, input_, weights, bias): output = kb.dot(input_, weights) + bias return output def get_config(self): base_config = super().get_config() return base_config