# Copyright 2015 Google Inc. All Rights Reserved. # Modifications copyright 2017 Jan Buys. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorflow.python.util import nest # We disable pylint because we need python3 compatibility. from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin import tensorflow as tf import data_utils linear = tf.nn.rnn_cell._linear # pylint: disable=protected-access #TODO rename: remove _ (not local method) def _extract_embed(embedding, update_embedding=True): """Get a loop_function that embeds symbols. Args: embedding: list of embedding tensors for symbols. update_embedding: Boolean; if False, the gradients will not propagate through the embeddings. Returns: A loop function. """ def embed_function(symbol): emb = tf.nn.embedding_lookup(embedding, symbol) # Note that gradients will not propagate through the second parameter of # embedding_lookup. if not update_embedding: emb = tf.stop_gradient(emb) return emb return embed_function def _extract_argmax_and_embed(embedding, output_projection=None, update_embedding=True): """Get a loop_function that extracts the previous symbol and embeds it. Args: embedding: embedding tensor for symbols. output_projection: None or a pair (W, B). If provided, each fed previous output will first be multiplied by W and added B. update_embedding: Boolean; if False, the gradients will not propagate through the embeddings. Returns: A loop function. """ def loop_function(prev, _): if output_projection is not None: prev = tf.matmul(prev, output_projection[0]) + output_projection[1] prev_symbol = tf.argmax(prev, 1) # Note that gradients will not propagate through the second parameter of # embedding_lookup. emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol) if not update_embedding: emb_prev = tf.stop_gradient(emb_prev) return emb_prev return loop_function def tile_embedding_attention(emb_inp, symbol_inp, initial_state, attention_states, beam_size, embedding_size): """Make beam_size copies of the attention states.""" tile_emb_inp = [] for inp in emb_inp: tile_emb = tf.tile(tf.reshape(inp, [1, -1]), tf.pack([beam_size, 1])) tile_emb = tf.reshape(tile_emb, [-1, embedding_size]) tile_emb_inp.append(tile_emb) tile_symbol_inp = [] for inp in symbol_inp: tile_sym = tf.tile(tf.reshape(inp, [1, 1]), tf.pack([beam_size, 1])) tile_emb = tf.reshape(tile_emb, [-1]) tile_symbol_inp.append(tile_emb) tile_initial_state = tf.tile(tf.reshape(initial_state, [1, -1]), tf.pack([beam_size, 1])) attn_length = attention_states.get_shape()[1].value attn_size = attention_states.get_shape()[2].value tile_attention_states = tf.tile(attention_states, tf.pack([beam_size, 1, 1])) tile_attention_states = tf.reshape(tile_attention_states, [-1, attn_length, attn_size]) return tile_emb_inp, tile_symbol_inp, tile_initial_state, tile_attention_states def attention(query, num_heads, y_w, v, hidden, hidden_features, attention_vec_size, attn_length, use_global_attention=False): """Puts attention masks on hidden using hidden_features and query. Args: query: vector, usually the current decoder state. 2D Tensor [batch_size x state_size]. num_heads: int. Currently always 1. v: attention model parameters. hidden: attention_states. hidden_features: same linear layer applied to all attention_states. attention_vec_size: attention embedding size. attn_length: number of inputs over which the attention spans. use_impatient_reader: make attention function dependent on previous attention vector. prev_ds: previous weighted averaged attention vector. Returns: atts: softmax over attention inputs. ds: attention-weighted averaged attention vector. """ at_logits = [] # result of attention logits at_probs = [] # result of attention probabilities ds = [] # results of attention reads will be stored here. if nest.is_sequence(query): # if the query is a tuple, flatten it. query_list = nest.flatten(query) for q in query_list: # check that ndims == 2 if specified. ndims = q.get_shape().ndims if ndims: assert ndims == 2 query = tf.concat(1, query_list) for a in xrange(num_heads): with tf.variable_scope("Attention_%d" % a): y = tf.matmul(query, y_w[a][0]) + y_w[a][1] y = tf.reshape(y, [-1, 1, 1, attention_vec_size]) # Attention mask is a softmax of v^T * tanh(...). if use_global_attention: s = tf.reduce_sum(hidden_features[a] * y, [2, 3]) else: # Broadcast to add y (query vector) to all hidden_features. s = tf.reduce_sum( v[a] * tf.tanh(hidden_features[a] + y), [2, 3]) at_logits.append(s) att = tf.nn.softmax(s) at_probs.append(att) # Now calculate the attention-weighted vector d. d = tf.reduce_sum( tf.reshape(att, [-1, attn_length, 1, 1]) * hidden, [1, 2]) ds.append(tf.reshape(d, [-1, attention_vec_size])) return at_logits, at_probs, ds def extend_outputs_to_labels(outputs, label_inputs, label_logits, label_vectors, feed_previous): """Include (predicted) input labels in encoder attention vectors.""" new_outputs = [] for i, cell_output in enumerate(outputs): input_label = label_inputs[i] if feed_previous: input_label = tf.argmax(label_logits[i], 1) label_emb = tf.nn.embedding_lookup(label_vectors, input_label) concat_emb = tf.concat(1, [cell_output, label_emb]) new_outputs.append(concat_emb) return new_outputs def gumbel_noise(batch_size, logit_size): """Computes Gumbel noise. When the output is added to a logit, taking the argmax will be approximately equivalent to sampling from the logit. """ size = tf.pack([batch_size, logit_size]) uniform_sample = tf.random_uniform(size, 0, 1, dtype=dtype, seed=None, name=None) noise = -tf.log(-tf.log(uniform_sample)) return noise def init_thin_stack(batch_size, max_num_concepts): """Initializes the thin stack. Returns: thin_stack: Tensor with the stack content. thin_stack_head_next: Index pointers to element after stack head. """ # Stack initialized to -1, points to initial state. thin_stack = -tf.ones(tf.pack([batch_size, max_num_concepts]), dtype=tf.int32) # Reshape to ensure dimension 1 is known. thin_stack = tf.reshape(thin_stack, [-1, max_num_concepts]) # Set to 0 at position 0. inds = tf.transpose(tf.to_int64(tf.pack( [tf.range(batch_size), tf.zeros(tf.pack([batch_size]), dtype=tf.int32)]))) delta = tf.SparseTensor(inds, tf.ones(tf.pack([batch_size]), dtype=tf.int32), tf.pack([tf.to_int64(batch_size), max_num_concepts])) new_thin_stack = thin_stack + tf.sparse_tensor_to_dense(delta) # Position 0 is for empty stack; position after head always >= 1. thin_stack_head_next = tf.ones(tf.pack([batch_size]), dtype=tf.int32) return new_thin_stack, thin_stack_head_next def write_thin_stack(thin_stack, stack_pointers, decoder_position, batch_size, max_num_concepts): """Writes to the thin stack at the given pointers the current decoder position.""" new_vals = tf.fill(tf.pack([batch_size]), decoder_position) return write_thin_stack_vals(thin_stack, stack_pointers, new_vals, batch_size, max_num_concepts) def write_thin_stack_vals(thin_stack, stack_pointers, new_vals, batch_size, max_num_concepts): """Writes to the thin stack at the given pointers the current decoder position.""" # SparseTensor requires type int64. stack_inds = tf.transpose(tf.to_int64(tf.pack( [tf.range(batch_size), stack_pointers]))) # nn_stack_pointers current_vals = tf.gather_nd(thin_stack, stack_inds) delta = tf.SparseTensor(stack_inds, new_vals - current_vals, tf.pack([tf.to_int64(batch_size), max_num_concepts])) new_thin_stack = thin_stack + tf.sparse_tensor_to_dense(delta) return new_thin_stack def pure_reduce_thin_stack(thin_stack_head_next, transition_state): """Applies reduce to the thin stack and its head if in reduce state.""" # Pop if current transition state is reduce. stack_head_updates = tf.sparse_to_dense(data_utils.RE_STATE, tf.pack([data_utils.NUM_TR_STATES]), -1) new_thin_stack_head_next = tf.add(thin_stack_head_next, tf.gather(stack_head_updates, transition_state)) return new_thin_stack_head_next def reduce_thin_stack(thin_stack, thin_stack_head_next, batch_size, max_num_concepts, decoder_position, transition_state): """Applies reduce to the thin stack and its head if in reduce state.""" # Pop if current transition state is reduce. stack_head_updates = tf.sparse_to_dense(data_utils.RE_STATE, tf.pack([data_utils.NUM_TR_STATES]), -1) new_thin_stack_head_next = tf.add(thin_stack_head_next, tf.gather(stack_head_updates, transition_state)) return new_thin_stack_head_next def update_buffer_head(buffer_head, predicted_attns, transition_state): updates = tf.sparse_to_dense(tf.pack([data_utils.GEN_STATE]), tf.pack([data_utils.NUM_TR_STATES]), True, default_value=False) is_gen_state = tf.gather(updates, transition_state) new_buffer_head = tf.select(is_gen_state, predicted_attns, buffer_head) return new_buffer_head def pure_shift_thin_stack(thin_stack_head_next, transition_state): """Applies shift to the thin stack and its head if in shift state.""" # Push if previous transition state is shift (or pointer shift). stack_head_updates = tf.sparse_to_dense(tf.pack( [data_utils.GEN_STATE]), tf.pack([data_utils.NUM_TR_STATES]), 1) new_thin_stack_head_next = tf.add(thin_stack_head_next, tf.gather(stack_head_updates, transition_state)) return new_thin_stack_head_next def shift_thin_stack(thin_stack, thin_stack_head_next, batch_size, max_num_concepts, decoder_position, prev_transition_state): """Applies shift to the thin stack and its head if in shift state.""" # Head points to item after stack top, so always update the stack entry. new_thin_stack = write_thin_stack(thin_stack, thin_stack_head_next, decoder_position, batch_size, max_num_concepts) # Push if previous transition state is shift (or pointer shift). stack_head_updates = tf.sparse_to_dense(tf.pack( [data_utils.GEN_STATE]), tf.pack([data_utils.NUM_TR_STATES]), 1) new_thin_stack_head_next = tf.add(thin_stack_head_next, tf.gather(stack_head_updates, prev_transition_state)) return new_thin_stack, new_thin_stack_head_next def update_reduce_thin_stack(thin_stack, thin_stack_head_next, batch_size, max_num_concepts, decoder_position, transition_state): """If in reduce state, replaces the stack top with current decoder_position.""" # Aim at head for reduce (update), head_next otherwise (no update). re_index_updates = tf.sparse_to_dense(data_utils.RE_STATE, tf.pack([data_utils.NUM_TR_STATES]), -1) re_stack_head = tf.add(thin_stack_head_next, tf.gather(re_index_updates, transition_state)) # Update the stack. new_thin_stack = write_thin_stack(thin_stack, re_stack_head, decoder_position, batch_size, max_num_concepts) return new_thin_stack def extract_stack_head_entries(thin_stack, thin_stack_head_next, batch_size): """Finds entries (indices) at stack head for every instance in batch.""" stack_head_inds = tf.sub(thin_stack_head_next, tf.ones(tf.pack([batch_size]), dtype=tf.int32)) # For every batch entry, get the thin stack head entry. stack_inds = tf.transpose(tf.pack( [tf.range(batch_size), stack_head_inds])) stack_heads = tf.gather_nd(thin_stack, stack_inds) return stack_heads def mask_decoder_restrictions(logit, logit_size, decoder_restrictions, transition_state): """Enforces decoder restrictions determined by the transition state.""" restrict_mask_list = [] with tf.device("/cpu:0"): # sparse-to-dense must be on CPU for now for restr in decoder_restrictions: restrict_mask_list.append(tf.sparse_to_dense(restr, tf.pack([logit_size]), np.inf, default_value=-np.inf)) mask = tf.gather(tf.pack(restrict_mask_list), transition_state) new_logit = tf.minimum(logit, mask) return new_logit def mask_decoder_reduce(logit, thin_stack_head_next, logit_size, batch_size): """Ensures that we can only reduce when the stack has at least 1 item. For each batch entry k: If thin_stack_head_next == 0, #alternatively, or 1. let logit[k][reduce_index] = -np.inf, else don't change. """ # Allow reduce only if at least 1 item on stack, i.e., pointer >= 2. update_vals = tf.pack([-np.inf, -np.inf, 0.0]) update_val = tf.gather(update_vals, tf.minimum(thin_stack_head_next, 2*tf.ones(tf.pack([batch_size]), dtype=tf.int32))) re_filled = tf.fill(tf.pack([batch_size]), tf.to_int64(data_utils.REDUCE_ID)) re_inds = tf.transpose(tf.pack( [tf.to_int64(tf.range(batch_size)), re_filled])) re_delta = tf.SparseTensor(re_inds, update_val, tf.to_int64( tf.pack([batch_size, logit_size]))) new_logit = logit + tf.sparse_tensor_to_dense(re_delta) return new_logit def mask_decoder_only_shift(logit, thin_stack_head_next, transition_state_map, logit_size, batch_size): """Ensures that if the stack is empty, has to GEN_STATE (shift transition) For each batch entry k: If thin_stack_head_next == 0, #alternatively, or 1. let logit[k][reduce_index] = -np.inf, else don't change. """ stack_is_empty_bool = tf.less_equal(thin_stack_head_next, 1) stack_is_empty = tf.select(stack_is_empty_bool, tf.ones(tf.pack([batch_size]), dtype=tf.int32), tf.zeros(tf.pack([batch_size]), dtype=tf.int32)) stack_is_empty = tf.reshape(stack_is_empty, [-1, 1]) # Sh and Re states are disallowed (but not root). state_is_disallowed_updates = tf.sparse_to_dense( tf.pack([data_utils.RE_STATE, data_utils.ARC_STATE]), tf.pack([data_utils.NUM_TR_STATES]), 1) logit_states = tf.gather(transition_state_map, tf.range(logit_size)) state_is_disallowed = tf.gather(state_is_disallowed_updates, logit_states) state_is_disallowed = tf.reshape(state_is_disallowed, [1, -1]) index_delta = tf.matmul(stack_is_empty, state_is_disallowed) # 1 if disallowed values = tf.pack([0, -np.inf]) delta = tf.gather(values, index_delta) new_logit = logit + delta return new_logit def mask_decoder_only_reduce(logit, thin_stack_head_next, transition_state_map, max_stack_size, logit_size, batch_size): """Ensures that if the stack is empty, has to GEN_STATE (shift transition) For each batch entry k: If thin_stack_head_next == 0, #alternatively, or 1. let logit[k][reduce_index] = -np.inf, else don't change. """ # Allow reduce only if at least 1 item on stack, i.e., pointer >= 2. #stack_is_empty_updates = tf.pack([-np.inf, -np.inf, 0]) stack_is_full_bool = tf.greater_equal(thin_stack_head_next, max_stack_size - 1) stack_is_full = tf.select(stack_is_full_bool, tf.ones(tf.pack([batch_size]), dtype=tf.int32), tf.zeros(tf.pack([batch_size]), dtype=tf.int32)) stack_is_full = tf.reshape(stack_is_full, [-1, 1]) # Sh and Re states are allowed. state_is_disallowed_updates = tf.sparse_to_dense( tf.pack([data_utils.RE_STATE, data_utils.ARC_STATE, data_utils.ROOT_STATE]), tf.pack([data_utils.NUM_TR_STATES]), 0, 1) logit_states = tf.gather(transition_state_map, tf.range(logit_size)) state_is_disallowed = tf.gather(state_is_disallowed_updates, logit_states) state_is_disallowed = tf.reshape(state_is_disallowed, [1, -1]) index_delta = tf.matmul(stack_is_full, state_is_disallowed) # 1 if disallowed values = tf.pack([0, -np.inf]) delta = tf.gather(values, index_delta) new_logit = logit + delta return new_logit def gather_nd_lstm_states(states_c, states_h, inds, batch_size, input_size, state_size): concat_states_c = tf.concat(1, states_c) concat_states_h = tf.concat(1, states_h) new_prev_state_c = gather_nd_states(concat_states_c, inds, batch_size, input_size, state_size) new_prev_state_h = gather_nd_states(concat_states_h, inds, batch_size, input_size, state_size) return tf.nn.rnn_cell.LSTMStateTuple(new_prev_state_c, new_prev_state_h) def gather_nd_states(inputs, inds, batch_size, input_size, state_size): """Gathers an embedding for each batch entry with index inds from inputs. Args: inputs: Tensor [batch_size, input_size, state_size]. inds: Tensor [batch_size] Returns: output: Tensor [batch_size, embedding_size] """ sparse_inds = tf.transpose(tf.pack( [tf.range(batch_size), inds])) dense_inds = tf.sparse_to_dense(sparse_inds, tf.pack([batch_size, input_size]), tf.ones(tf.pack([batch_size]))) output_sum = tf.reduce_sum(tf.reshape(dense_inds, [-1, input_size, 1, 1]) * tf.reshape(inputs, [-1, input_size, 1, state_size]), [1, 2]) output = tf.reshape(output_sum, [-1, state_size]) return output def binary_select_state(state, updates, transition_state, batch_size): """Gathers state or zero for each batch entry.""" update_inds = tf.gather(updates, transition_state) sparse_diag = tf.transpose(tf.pack( [tf.range(batch_size), tf.range(batch_size)])) dense_inds = tf.sparse_to_dense(sparse_diag, tf.pack([batch_size, batch_size]), tf.to_float(update_inds)) new_state = tf.matmul(dense_inds, state) return new_state def hard_state_selection(attn_inds, hidden, batch_size, attn_length): batch_inds = tf.transpose(tf.pack( [tf.to_int64(tf.range(batch_size)), tf.to_int64(attn_inds)])) align_index = tf.to_float(tf.sparse_to_dense(batch_inds, tf.to_int64(tf.pack([batch_size, attn_length])), 1)) attns = tf.reduce_sum(hidden * tf.reshape(align_index, [-1, attn_length, 1, 1]), [1, 2]) return attns def gather_forced_att_logits(encoder_input_symbols, encoder_decoder_vocab_map, att_logit, batch_size, attn_length, target_vocab_size): """Gathers attention weights as logits for forced attention.""" flat_input_symbols = tf.reshape(encoder_input_symbols, [-1]) flat_label_symbols = tf.gather(encoder_decoder_vocab_map, flat_input_symbols) flat_att_logits = tf.reshape(att_logit, [-1]) flat_range = tf.to_int64(tf.range(tf.shape(flat_label_symbols)[0])) batch_inds = tf.floordiv(flat_range, attn_length) position_inds = tf.mod(flat_range, attn_length) attn_vocab_inds = tf.transpose(tf.pack( [batch_inds, position_inds, tf.to_int64(flat_label_symbols)])) # Exclude indexes of entries with flat_label_symbols[i] = -1. included_flat_indexes = tf.reshape(tf.where(tf.not_equal( flat_label_symbols, -1)), [-1]) included_attn_vocab_inds = tf.gather(attn_vocab_inds, included_flat_indexes) included_flat_att_logits = tf.gather(flat_att_logits, included_flat_indexes) sparse_shape = tf.to_int64(tf.pack( [batch_size, attn_length, target_vocab_size])) sparse_label_logits = tf.SparseTensor(included_attn_vocab_inds, included_flat_att_logits, sparse_shape) forced_att_logit_sum = tf.sparse_reduce_sum(sparse_label_logits, [1]) forced_att_logit = tf.reshape(forced_att_logit_sum, [-1, target_vocab_size]) return forced_att_logit def gather_prev_stack_state_index(pointer_vals, prev_index, transition_state, batch_size): """Gathers new previous state index.""" new_pointer_vals = tf.reshape(pointer_vals, [-1, 1]) # Helper tensors. prev_vals = tf.reshape(tf.fill( tf.pack([batch_size]), prev_index), [-1, 1]) trans_inds = tf.transpose(tf.pack( [tf.range(batch_size), transition_state])) # Gather new prev state for main tf.nn. Pointer vals if reduce, else prev. # State inds dimension [batch_size, NUM_TR_STATES] state_inds = tf.concat(1, [prev_vals]*6 + [new_pointer_vals, prev_vals]) prev_state_index = tf.gather_nd(state_inds, trans_inds) return prev_state_index def gather_prev_stack_aux_state_index(pointer_vals, prev_index, transition_state, batch_size): """Gather new prev state index for aux rnn: as for main, but zero if shift.""" new_pointer_vals = tf.reshape(pointer_vals, [-1, 1]) # Helper tensors. prev_vals = tf.reshape(tf.fill( tf.pack([batch_size]), prev_index), [-1, 1]) trans_inds = tf.transpose(tf.pack( [tf.range(batch_size), transition_state])) batch_zeros = tf.reshape(tf.zeros( tf.pack([batch_size]), dtype=tf.int32), [-1, 1]) # Gather new prev state for aux tf.nn. # State inds dimension [batch_size, NUM_TR_STATES] state_inds = tf.concat(1, [prev_vals, batch_zeros] + [prev_vals]*4 + [new_pointer_vals, prev_vals]) prev_state_index = tf.gather_nd(state_inds, trans_inds) return prev_state_index