from typing import NamedTuple import tensorflow as tf from .types import * from .query import * from ..args import ACTIVATION_FNS from ..attention import * from ..input import get_table_with_embedding from ..const import EPSILON from ..util import * from ..layers import * from ..activations import * MP_State = tf.Tensor class MP_Node(NamedTuple): id: str properties: tf.Tensor state: MP_State use_message_passing_fn = False use_self_reference = False def layer_normalize(tensor): '''Apologies if I've abused this term''' in_shape = tf.shape(tensor) axes = list(range(1, len(tensor.shape))) # Keep batch axis t = tf.reduce_sum(tensor, axis=axes ) t += EPSILON t = tf.reciprocal(t) t = tf.check_numerics(t, "1/sum") tensor = tf.einsum('brc,b->brc', tensor, t) tensor = dynamic_assert_shape(tensor, in_shape, "layer_normalize_tensor") return tensor def calc_normalized_adjacency(context, node_state): # Aggregate via adjacency matrix with normalisation (that does not include self-edges) adj = tf.cast(context.features["kb_adjacency"], tf.float32) degree = tf.reduce_sum(adj, -1, keepdims=True) inv_degree = tf.reciprocal(degree) node_mask = tf.expand_dims(tf.sequence_mask(context.features["kb_nodes_len"], context.args["kb_node_max_len"]), -1) inv_degree = tf.where(node_mask, inv_degree, tf.zeros(tf.shape(inv_degree))) inv_degree = tf.where(tf.greater(degree, 0), inv_degree, tf.zeros(tf.shape(inv_degree))) inv_degree = tf.check_numerics(inv_degree, "inv_degree") adj_norm = inv_degree * adj adj_norm = tf.cast(adj_norm, node_state.dtype) adj_norm = tf.check_numerics(adj_norm, "adj_norm") node_incoming = tf.einsum('bnw,bnm->bmw', node_state, adj_norm) return node_incoming def mp_matmul(state, mat, name): return tf.nn.conv1d(state, mat, 1, 'VALID', name=name) def calc_right_shift(node_incoming): shape = tf.shape(node_incoming) node_incoming = tf.concat([node_incoming[:,:,1:],node_incoming[:,:,0:1]], axis=-1) node_incoming = dynamic_assert_shape(node_incoming, shape, "node_incoming") return node_incoming def node_dense(nodes, units, name, activation="linear"): with tf.variable_scope(name): assert nodes.shape[-1].value is not None, "Nodes must have fixed last dimension" w = tf.get_variable("w", [1, nodes.shape[-1], units], initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0)) b = tf.get_variable("b", [1, units], initializer=tf.initializers.random_uniform) r = mp_matmul(nodes, w, 'matmul') + b r = ACTIVATION_FNS[activation](r) return r def node_gru(context, node_state, node_incoming, padded_node_table): all_inputs = [node_state, node_incoming] if context.args["use_mp_node_id"]: all_inputs.append(padded_node_table[:,:,:context.args["embed_width"]]) old_and_new = tf.concat(all_inputs, axis=-1) input_width = old_and_new.shape[-1] forget_w = tf.get_variable("mp_forget_w", [1, input_width, context.args["mp_state_width"]]) forget_b = tf.get_variable("mp_forget_b", [1, context.args["mp_state_width"]]) reuse_w = tf.get_variable("mp_reuse_w", [1, input_width, context.args["mp_state_width"]]) transform_w = tf.get_variable("mp_transform_w", [1, 2 * context.args["mp_state_width"], context.args["mp_state_width"]]) # Initially likely to be zero forget_signal = tf.nn.sigmoid(mp_matmul(old_and_new , forget_w, 'forget_signal') + forget_b) reuse_signal = tf.nn.sigmoid(mp_matmul(old_and_new , reuse_w, 'reuse_signal')) reuse_and_new = tf.concat([reuse_signal * node_state, node_incoming], axis=-1) proposed_new_state = ACTIVATION_FNS[context.args["mp_activation"]](mp_matmul(reuse_and_new, transform_w, 'proposed_new_state')) node_state = (1-forget_signal) * node_state + (forget_signal) * proposed_new_state return node_state