import tensorflow as tf from graphsaint.tensorflow_version.inits import glorot,zeros,trained,ones,xavier,uniform from graphsaint.globals import * F_ACT = {'I': lambda x:x, 'relu': tf.nn.relu, 'leaky_relu': tf.nn.leaky_relu} # global unique layer ID dictionary for layer name assignment _LAYER_UIDS = {} def get_layer_uid(layer_name=''): """Helper function, assigns unique layer IDs.""" if layer_name not in _LAYER_UIDS: _LAYER_UIDS[layer_name] = 1 return 1 else: _LAYER_UIDS[layer_name] += 1 return _LAYER_UIDS[layer_name] class Layer: """Base layer class. Defines basic API for all layer objects. Implementation inspired by keras (http://keras.io). # Properties name: String, defines the variable scope of the layer. logging: Boolean, switches Tensorflow histogram logging on/off # Methods _call(inputs): Defines computation graph of layer (i.e. takes input, returns output) __call__(inputs): Wrapper for _call() _log_vars(): Log all variables """ def __init__(self, **kwargs): allowed_kwargs = {'name', 'logging', 'mulhead'} for kwarg in kwargs.keys(): assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg name = kwargs.get('name') if not name: layer = self.__class__.__name__.lower() name = layer + '_' + str(get_layer_uid(layer)) self.name = name self.vars = {} logging = kwargs.get('logging', False) self.logging = logging def _call(self, inputs): return inputs def __call__(self, inputs): with tf.name_scope(self.name): if self.logging: if type(inputs)==type([]) or type(inputs)==type((1,2)): _ip = inputs[0] else: _ip = inputs tf.summary.histogram(self.name + '/inputs', _ip) outputs = self._call(inputs) if self.logging: tf.summary.histogram(self.name + '/outputs', outputs) return outputs def _log_vars(self): for var in self.vars: tf.summary.histogram(self.name + '/vars/' + var, self.vars[var]) class JumpingKnowledge(Layer): def __init__(self, arch_gcn, dim_input_jk, mode=None, **kwargs): """ """ super(JumpingKnowledge,self).__init__(**kwargs) self.mode = mode if not mode: return self.act = F_ACT[arch_gcn['act']] self.bias = arch_gcn['bias'] self.dim_in = dim_input_jk self.dim_out = arch_gcn['dim'] with tf.variable_scope(self.name + '_vars'): self.vars['weights'] = glorot([self.dim_in,self.dim_out],name='weights') self.vars['bias'] = zeros([self.dim_out],name='bias') if self.bias == 'norm': self.vars['offset'] = zeros([1,self.dim_out],name='offset') self.vars['scale'] = ones([1,self.dim_out],name='scale') def _call(self, inputs): feats_l,idx_conv = inputs if not self.mode: return feats_l[-1] elif self.mode == 'concat': feats_sel = [f for i,f in enumerate(feats_l) if i in idx_conv] feats_aggr = tf.concat(feats_sel, axis=1) elif self.mode == 'max_pool': feats_sel = [f for i,f in enumerate(feats_l) if i in idx_conv] feats_stack = tf.stack(feats_sel) feats_aggr = tf.reduce_max(feats_stack,axis=0) else: raise NotImplementedError vw = tf.matmul(feats_aggr,self.vars['weights']) vw += self.vars['bias'] vw = self.act(vw) if self.bias == 'norm': mean,variance = tf.nn.moments(vw,axes=[1],keep_dims=True) vw = tf.nn.batch_normalization(vw,mean,variance,self.vars['offset'],self.vars['scale'],1e-9) return vw class HighOrderAggregator(Layer): """ If order == 1, then this layer is the normal GCN layer. If order == 0, this layer is equivalent to a dense layer (only self-to-self propagation). If order > 1, then this layer is a high-order layer propagating multi-hop information. """ def __init__(self, dim_in, dim_out, dropout=0., act='relu', \ order=1, aggr='mean', is_train=True, bias='norm', **kwargs): super(HighOrderAggregator,self).__init__(**kwargs) self.dropout = dropout self.bias = bias self.act = F_ACT[act] self.order = order self.aggr = aggr self.is_train = is_train if dim_out > 0: with tf.variable_scope(self.name + '_vars'): for o in range(self.order+1): _k = 'order{}_weights'.format(o) self.vars[_k] = glorot([dim_in,dim_out],name=_k) for o in range(self.order+1): _k = 'order{}_bias'.format(o) self.vars[_k] = zeros([dim_out],name=_k) if self.bias == 'norm': for o in range(self.order+1): _k1 = 'order{}_offset'.format(o) _k2 = 'order{}_scale'.format(o) self.vars[_k1] = zeros([1,dim_out],name=_k1) self.vars[_k2] = ones([1,dim_out],name=_k2) print('>> layer {}, dim: [{},{}]'.format(self.name, dim_in, dim_out)) if self.logging: self._log_vars() self.dim_in = dim_in self.dim_out = dim_out def _F_nonlinear(self,vecs,order): vw = tf.matmul(vecs,self.vars['order{}_weights'.format(order)]) vw += self.vars['order{}_bias'.format(order)] vw = self.act(vw) if self.bias == 'norm': # batch norm realized by tf.nn.batch_norm (consistent with SGCN implementation) mean,variance = tf.nn.moments(vw,axes=[1],keep_dims=True) _off = 'order{}_offset'.format(order) _sca = 'order{}_scale'.format(order) vw = tf.nn.batch_normalization(vw,mean,variance,self.vars[_off],self.vars[_sca],1e-9) return vw def _call(self, inputs): # vecs: input feature of the current layer. # adj_partition_list: the row partitions of the full graph adj # (only used in full-batch evaluation on the val/test sets) vecs, adj_norm, len_feat, adj_partition_list, _ = inputs vecs = tf.nn.dropout(vecs, 1-self.dropout) vecs_hop = [tf.identity(vecs) for o in range(self.order+1)] for o in range(self.order): for a in range(o+1): ans1 = tf.sparse_tensor_dense_matmul(adj_norm,vecs_hop[o+1]) ans_partition = [tf.sparse_tensor_dense_matmul(adj,vecs_hop[o+1]) for adj in adj_partition_list] ans2 = tf.concat(ans_partition,0) vecs_hop[o+1]=tf.cond(self.is_train,lambda: tf.identity(ans1),lambda: tf.identity(ans2)) vecs_hop = [self._F_nonlinear(v,o) for o,v in enumerate(vecs_hop)] if self.aggr == 'mean': ret = vecs_hop[0] for o in range(len(vecs_hop)-1): ret += vecs_hop[o+1] elif self.aggr == 'concat': ret = tf.concat(vecs_hop,axis=1) else: raise NotImplementedError return ret class AttentionAggregator(Layer): """ Attention mechanism by GAT. We remove the softmax step since during minibatch training, we cannot see all neighbors of a node. """ def __init__(self, dim_in, dim_out, dropout=0., act='relu', order=1, aggr='mean', is_train=True, bias='norm', **kwargs): assert order <= 1, "now only support attention for order 0/1 layers" super(AttentionAggregator,self).__init__(**kwargs) self.dropout = dropout self.bias = bias self.act = F_ACT[act] self.order = order self.aggr = aggr self.is_train = is_train if 'mulhead' in kwargs.keys(): self.mulhead = int(kwargs['mulhead']) else: self.mulhead = 1 with tf.variable_scope(self.name + '_vars'): self.vars['order0_weights'] = glorot([dim_in,dim_out],name='order0_weights') for k in range(self.mulhead): self.vars['order1_weights_h{}'.format(k)] = glorot([dim_in,int(dim_out/self.mulhead)],name='order1_weights_h{}'.format(k)) self.vars['order0_bias'] = zeros([dim_out],name='order0_bias') for k in range(self.mulhead): self.vars['order1_bias_h{}'.format(k)] = zeros([int(dim_out/self.mulhead)],name='order1_bias_h{}'.format(k)) if self.bias == 'norm': for o in range(self.order+1): _k1 = 'order{}_offset'.format(o) _k2 = 'order{}_scale'.format(o) self.vars[_k1] = zeros([1,dim_out],name=_k1) self.vars[_k2] = ones([1,dim_out],name=_k2) for k in range(self.mulhead): self.vars['attention_0_h{}'.format(k)] = glorot([1,int(dim_out/self.mulhead)],name='attention_0_h{}'.format(k)) self.vars['attention_1_h{}'.format(k)] = glorot([1,int(dim_out/self.mulhead)],name='attention_1_h{}'.format(k)) self.vars['att_bias_0_h{}'.format(k)] = zeros([1],name='att_bias_0_h{}'.format(k)) self.vars['att_bias_1_h{}'.format(k)] = zeros([1],name='att_bias_1_h{}'.format(k)) print('>> layer {}, dim: [{},{}]'.format(self.name, dim_in, dim_out)) if self.logging: self._log_vars() self.dim_in = dim_in self.dim_out = dim_out def _F_edge_weight(self,adj_part,vecs_neigh,vecs_self,offset=0): adj_mask = tf.dtypes.cast(tf.dtypes.cast(adj_part, tf.bool), tf.float32) a1 = tf.SparseTensor(adj_mask.indices,tf.nn.embedding_lookup(vecs_neigh,adj_mask.indices[:,1]),adj_mask.dense_shape) a2 = tf.SparseTensor(adj_mask.indices,tf.nn.embedding_lookup(vecs_self,adj_mask.indices[:,0]+offset),adj_mask.dense_shape) alpha = tf.SparseTensor(adj_mask.indices,tf.nn.relu(a1.values+a2.values),adj_mask.dense_shape) adj_weighted = tf.SparseTensor(adj_mask.indices,adj_part.values*alpha.values,adj_mask.dense_shape) return adj_weighted def _call(self, inputs): vecs, adj_norm, len_feat, adj_partition_list, dim0_adj_sub = inputs adj_norm = tf.cond(self.is_train,lambda: adj_norm,lambda: tf.sparse.concat(0,adj_partition_list)) vecs_do1 = tf.nn.dropout(vecs, 1-self.dropout) vecs_do2 = tf.nn.dropout(vecs, 1-self.dropout) vw_self = tf.matmul(vecs_do2,self.vars['order0_weights']) ret_self = self.act(vw_self + self.vars['order0_bias']) if self.bias == 'norm': mean,variance = tf.nn.moments(ret_self,axes=[1],keep_dims=True) ret_self = tf.nn.batch_normalization(ret_self,mean,variance,self.vars['order0_offset'],self.vars['order0_scale'],1e-9) if self.order == 0: return ret_self # the aggr below only applies to order 1 layers ret_neigh_l_subg = list() ret_neigh_l_fullg = list() offset = 0 vw_neigh = list() vw_neigh_att = list() vw_self_att = list() for i in range(self.mulhead): vw_neigh.append(tf.matmul(vecs_do1,self.vars['order1_weights_h{}'.format(i)])) vw_neigh_att.append(tf.reduce_sum(vw_neigh[i]*self.vars['attention_1_h{}'.format(i)],axis=-1)\ + self.vars['att_bias_1_h{}'.format(i)]) vw_self_att.append(tf.reduce_sum(vw_neigh[i]*self.vars['attention_0_h{}'.format(i)],axis=-1)\ + self.vars['att_bias_0_h{}'.format(i)]) for i in range(self.mulhead): adj_weighted = self._F_edge_weight(adj_norm,vw_neigh_att[i],vw_self_att[i],offset=0) ret_neigh_i = self.act(tf.sparse_tensor_dense_matmul(adj_weighted,vw_neigh[i])) \ + self.vars['order1_bias_h{}'.format(i)] ret_neigh_l_subg.append(ret_neigh_i) for _adj in adj_partition_list: ret_neigh_la = list() for i in range(self.mulhead): adj_weighted = self._F_edge_weight(_adj,vw_neigh_att[i],vw_self_att[i],offset=offset) ret_neigh_i = self.act(tf.sparse_tensor_dense_matmul(adj_weighted,vw_neigh[i]) \ + self.vars['order1_bias_h{}'.format(i)]) ret_neigh_la.append(ret_neigh_i) ret_neigh_l_fullg.append(tf.concat(ret_neigh_la,axis=1)) offset += dim0_adj_sub ret_neigh = tf.cond(self.is_train, lambda: tf.concat(ret_neigh_l_subg,axis=1), lambda: tf.concat(ret_neigh_l_fullg,axis=0)) if self.bias == 'norm': mean,variance = tf.nn.moments(ret_neigh,axes=[1],keep_dims=True) ret_neigh = tf.nn.batch_normalization(ret_neigh,mean,variance,self.vars['order1_offset'],self.vars['order1_scale'],1e-9) if self.aggr == 'mean': ret = ret_neigh + ret_self elif self.aggr == 'concat': ret = tf.concat([ret_self,ret_neigh],axis=1) else: raise NotImplementedError return ret