from typing import NamedTuple import tensorflow as tf from .types import * from .query import * from .messaging_cell_helpers import * from ..args import ACTIVATION_FNS from ..attention import * from ..input import get_table_with_embedding from ..const import EPSILON from ..util import * from ..layers import * from ..activations import * def messaging_cell(context:CellContext): node_table, node_table_width, node_table_len = get_table_with_embedding(context.args, context.features, context.vocab_embedding, "kb_node") node_table_width = context.args["embed_width"] node_table = node_table[:,:,0:node_table_width] in_signal = context.in_iter_id taps = {} def add_taps(val, prefix): ret,tps = val for k,v in tps.items(): taps[prefix+"_"+k] = v return ret in_write_signal = layer_dense(in_signal, context.args["mp_state_width"], "sigmoid") in_write_query = add_taps(generate_token_index_query(context, "mp_write_query"), "mp_write_query") read_queries = [] for i in range(context.args["mp_read_heads"]): read_queries.append(add_taps(generate_token_index_query(context, f"mp_read{i}_query"), f"mp_read{i}_query")) out_read_signals, node_state, taps2 = do_messaging_cell(context, node_table, node_table_width, node_table_len, in_write_query, in_write_signal, read_queries) return out_read_signals, node_state, {**taps, **taps2} def calc_normalized_adjacency(context, node_state): # Aggregate via adjacency matrix with normalisation (that does not include self-edges) adj = tf.cast(context.features["kb_adjacency"], tf.float32) degree = tf.reduce_sum(adj, -1, keepdims=True) inv_degree = tf.reciprocal(degree) node_mask = tf.expand_dims(tf.sequence_mask(context.features["kb_nodes_len"], context.args["kb_node_max_len"]), -1) inv_degree = tf.where(node_mask, inv_degree, tf.zeros(tf.shape(inv_degree))) inv_degree = tf.where(tf.greater(degree, 0), inv_degree, tf.zeros(tf.shape(inv_degree))) inv_degree = tf.check_numerics(inv_degree, "inv_degree") adj_norm = inv_degree * adj adj_norm = tf.cast(adj_norm, node_state.dtype) adj_norm = tf.check_numerics(adj_norm, "adj_norm") node_incoming = tf.einsum('bnw,bnm->bmw', node_state, adj_norm) return node_incoming def do_messaging_cell(context:CellContext, node_table, node_table_width, node_table_len, in_write_query, in_write_signal, in_read_queries): ''' Operate a message passing cell Each iteration it'll do one round of message passing Returns: read_signal, node_state for to_node in nodes: to_node.state = combine_incoming_signals([ message_pass(from_node, to_node) for from_node in to_node.neighbors ] + [node_self_update(to_node)]) ''' with tf.name_scope("messaging_cell"): taps = {} taps["mp_write_query"] = in_write_query taps["mp_write_signal"] = in_write_signal node_state_shape = tf.shape(context.in_node_state) node_state = context.in_node_state assert len(node_state.shape) == 3, f"Node state should have three dimensions, has {len(node_state.shape)}" padded_node_table = pad_to_table_len(node_table, node_state, name="padded_node_table") # -------------------------------------------------------------------------- # Write to graph # -------------------------------------------------------------------------- write_signal, _, a_taps = attention_write_by_key( keys=node_table, key_width=node_table_width, keys_len=node_table_len, query=in_write_query, value=in_write_signal, name="mp_write_signal" ) for k,v in a_taps.items(): taps["mp_write_"+k] = v write_signal = pad_to_table_len(write_signal, node_state, name="write_signal") node_state += write_signal node_state = dynamic_assert_shape(node_state, node_state_shape, "node_state") # -------------------------------------------------------------------------- # Calculate adjacency # -------------------------------------------------------------------------- node_state = calc_normalized_adjacency(context, node_state) # -------------------------------------------------------------------------- # Read from graph # -------------------------------------------------------------------------- out_read_signals = [] for idx, qry in enumerate(in_read_queries): out_read_signal, _, a_taps = attention_key_value( keys=padded_node_table, keys_len=node_table_len, key_width=node_table_width, query=qry, table=node_state, name=f"mp_read{idx}" ) out_read_signals.append(out_read_signal) for k,v in a_taps.items(): taps[f"mp_read{idx}_{k}"] = v taps[f"mp_read{idx}_signal"] = out_read_signal taps[f"mp_read{idx}_query"] = qry taps["mp_node_state"] = node_state node_state = dynamic_assert_shape(node_state, node_state_shape, "node_state") assert node_state.shape[-1] == context.in_node_state.shape[-1], "Node state should not lose dimension" return out_read_signals, node_state, taps