import collections import functools import math import numpy as np import tensorflow as tf from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as layers_base from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest from tensorflow.contrib.rnn import LSTMStateTuple from tensorflow.contrib.seq2seq import AttentionMechanism _zero_state_tensors = rnn_cell_impl._zero_state_tensors class PointerWrapperState( collections.namedtuple("PointerWrapperState", ("cell_state", "attention", "time", "alignments", "alignment_history", "p_gen_history", "vocab_dist_history", "copy_dist_history", "final_dist_history"))): def clone(self, **kwargs): return super(PointerWrapperState, self)._replace(**kwargs) def _compute_pgen(cell_output, cell_state, input, context, out_choices): ptr_inputs = tf.concat((context, cell_state.c, cell_state.h, input), -1, name='ptr_inputs') p_gen = tf.layers.dense( ptr_inputs, out_choices, activation=tf.sigmoid, use_bias=True, name='pointer_generator') return tf.nn.softmax(p_gen) def _compute_attention(attention_mechanism, cell_output, previous_alignments, attention_layer): """Computes the attention and alignments for a given attention_mechanism.""" alignments = attention_mechanism( cell_output, previous_alignments=previous_alignments) expanded_alignments = array_ops.expand_dims(alignments, 1) context = math_ops.matmul(expanded_alignments, attention_mechanism.values) context = array_ops.squeeze(context, [1]) if attention_layer is not None: attention = attention_layer(array_ops.concat([cell_output, context], 1)) else: attention = context return attention, alignments def _calc_copy_dist(attn_dist, batch_size, vocab_size, num_source_OOVs, enc_batch_extended_vocab): # assign probabilities from copy distribution # into correspodning positions in extended_vocab_dist # to do this, we need to use scatter_nd # scatter_nd (in this case) requires two numbers # one is the index in batch-dimension # the other is the index in vocab-dimension # So first, we create a batch-matrix like: # [[1, 1, 1, 1, 1, ...], # [2, 2, 2, 2, 2, ...], # [...] # [N, N, N, N, N, ...]] # [1, 2, ..., N] # to [[1], [2], ..., [N]] # and finally to the final shape enc_seq_len = tf.shape(enc_batch_extended_vocab)[1] batch_nums = tf.range(0, limit=batch_size) batch_nums = tf.expand_dims(batch_nums, 1) batch_nums = tf.tile(batch_nums, [1, enc_seq_len]) # stick together batch-dim and index-dim indices = tf.stack((batch_nums, enc_batch_extended_vocab), axis=2) extended_vsize = vocab_size + num_source_OOVs scatter_shape = [batch_size, extended_vsize] # scatter the attention distributions # into the word-indices P_copy_projected = tf.scatter_nd( indices, attn_dist, scatter_shape) return P_copy_projected def _calc_final_dist(vocab_dist, copy_dists, p_gen, batch_size, vocab_size, num_source_OOVs): ''' calculate the final distribution w/ ptr net (one step) vocab_dist: predicted vocab distribution, tensor, shape b x v_size attn_dist: predicted attn distribution, tensor, shape b x v_size_ext p_gen: generation probability, tensor, shape b batch_size: int, batch size vocab_size: int, v_size num_source_OOVs: int, # of oovs enc_batch_extended_vocab: encoded context w/ extra vocabulary, e.g. replace all UNK with actual oov indices ''' #print(p_gen.get_shape().as_list()) p_generate = tf.expand_dims(p_gen[:, 0], 1) p_copies = tf.expand_dims(p_gen[:, 1:], 2) #print(p_generate.get_shape().as_list()) #print(p_copies.get_shape().as_list()) copy_dist = tf.stack(copy_dists, 1) #print(copy_dist.get_shape().as_list()) #raw_input() # P(gen) x P(vocab) weighted_P_vocab = p_generate * vocab_dist # (1 - P(gen)) x P(attention) weighted_P_copy = p_copies * copy_dist #print(weighted_P_copy.get_shape().as_list()) weighted_P_copy = tf.reduce_sum(weighted_P_copy, 1) #print(weighted_P_copy.get_shape().as_list()) #raw_input() # get the word-idx for all words extended_vsize = vocab_size + num_source_OOVs # placeholders to OOV words extra_zeros = tf.zeros((batch_size, num_source_OOVs)) # this distribution span the entire words weighted_P_vocab_extended = tf.concat( axis=1, values=[weighted_P_vocab, extra_zeros]) # Add the vocab distributions and the copy distributions together # to get the final distributions, final_dists is a list length # max_dec_steps; each entry is (batch_size, extended_vsize) # giving the final distribution for that decoder timestep # Note that for decoder timesteps and examples corresponding to # a [PAD] token, this is junk - ignore. final_dists = weighted_P_vocab_extended + weighted_P_copy return final_dists def _convert_to_output_dist(full_vocab_dist, vocab_size, unk_id): ''' convert a final distribution over the full vocab into an output distribution that maps oov probs to unk token full_vocab_dist: complete vocab distribution, shape b x (v_size + oov_size) vocab_size: int, vocab size unk_id: int, unk token id (e.g. things to map oov to) ''' extra_unk_probs = full_vocab_dist[:, vocab_size:] # [b x oov_size] extra_unk_probs = tf.reduce_sum(extra_unk_probs, axis=1) # [b] batch_size = tf.shape(full_vocab_dist)[0] batch_idx = tf.range(0, limit=batch_size) unk_idx = tf.fill([batch_size], unk_id) scatter_idx = tf.stack((batch_idx, unk_idx), axis=1) # [b x 2] scatter_shape = [batch_size, vocab_size] unk_dist = tf.scatter_nd( scatter_idx, extra_unk_probs, scatter_shape) known_vocab_dist = full_vocab_dist[:, :vocab_size] # [b x vocab_size] return known_vocab_dist + unk_dist class AttnPointerWrapper(rnn_cell_impl.RNNCell): """Wraps an cell with attention and pointer net """ def __init__(self, cell, attention_mechanism, output_layer, max_oovs, batch_size, memory_full_vocab, attention_layer_size=None, alignment_history=False, cell_input_fn=None, output_attention=False, output_generation_distribution=False, output_copy_distribution=False, output_combined_distribution=True, initial_cell_state=None, unk_id=None, name=None): super(AttnPointerWrapper, self).__init__(name=name) if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError( "cell must be an RNNCell, saw type: %s" % type(cell).__name__) if isinstance(attention_mechanism, (list, tuple)): self._is_multi = True attention_mechanisms = attention_mechanism for attention_mechanism in attention_mechanisms: if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must contain only instances of " "AttentionMechanism, saw type: %s" % type(attention_mechanism).__name__) else: self._is_multi = False if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must be an AttentionMechanism or list of " "multiple AttentionMechanism instances, saw type: %s" % type(attention_mechanism).__name__) attention_mechanisms = (attention_mechanism,) if cell_input_fn is None: cell_input_fn = ( lambda inputs, attention: array_ops.concat([inputs, attention], -1)) else: if not callable(cell_input_fn): raise TypeError( "cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__) if attention_layer_size is not None: attention_layer_sizes = tuple( attention_layer_size if isinstance(attention_layer_size, (list, tuple)) else (attention_layer_size,)) if len(attention_layer_sizes) != len(attention_mechanisms): raise ValueError( "If provided, attention_layer_size must contain exactly one " "integer per attention_mechanism, saw: %d vs %d" % (len(attention_layer_sizes), len(attention_mechanisms))) self._attention_layers = tuple( layers_core.Dense( attention_layer_size, name="attention_layer", use_bias=False) for attention_layer_size in attention_layer_sizes) self._attention_layer_size = sum(attention_layer_sizes) else: self._attention_layers = None self._attention_layer_size = sum( attention_mechanism.values.get_shape()[-1].value for attention_mechanism in attention_mechanisms) self._cell = cell self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn self._output_attention = output_attention self._output_generation_distribution = output_generation_distribution self._output_copy_distribution = output_copy_distribution self._output_combined_distribution = output_combined_distribution self._unk_id = unk_id self._alignment_history = alignment_history self._output_layer = output_layer self._max_oovs = max_oovs self._batch_size = batch_size if memory_full_vocab is not None: self._memory_full_vocab = tuple( memory_full_vocab if isinstance(memory_full_vocab, (list, tuple)) else (memory_full_vocab, )) if len(self._memory_full_vocab) != len(attention_mechanisms): raise ValueError("memory full vocab must be same size as" "attention mechanisms, saw %d vs %d" % (len(memory_full_vocab), len(attention_mechanisms))) if self._output_combined_distribution or \ self._output_generation_distribution or \ self._output_copy_distribution or \ self._output_attention: assert self._output_combined_distribution ^\ self._output_generation_distribution ^\ self._output_copy_distribution ^\ self._output_attention, "Can only output one type!" if self._output_combined_distribution or self._output_copy_distribution: assert self._unk_id is not None with ops.name_scope(name, "AttnPointerWrapperInit"): if initial_cell_state is None: self._initial_cell_state = None else: final_state_tensor = nest.flatten(initial_cell_state)[-1] state_batch_size = ( final_state_tensor.shape[0].value or array_ops.shape(final_state_tensor)[0]) error_message = ( "When constructing AttnPointerWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and initial_cell_state. Are you using " "the BeamSearchDecoder? You may need to tile your initial state " "via the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( self._batch_size_checks(state_batch_size, error_message)): self._initial_cell_state = nest.map_structure( lambda s: array_ops.identity(s, name="check_initial_cell_state"), initial_cell_state) def _batch_size_checks(self, batch_size, error_message): return [check_ops.assert_equal(batch_size, attention_mechanism.batch_size, message=error_message) for attention_mechanism in self._attention_mechanisms] def _item_or_tuple(self, seq): """Returns `seq` as tuple or the singular element. Which is returned is determined by how the AttentionMechanism(s) were passed to the constructor. Args: seq: A non-empty sequence of items or generator. Returns: Either the values in the sequence as a tuple if AttentionMechanism(s) were passed to the constructor as a sequence or the singular element. """ t = tuple(seq) if self._is_multi: return t else: return t[0] @property def output_size(self): if self._output_combined_distribution or \ self._output_copy_distribution or \ self._output_generation_distribution: return self._output_layer.units if self._output_attention: return self._attention_layer_size else: return self._cell.output_size @property def state_size(self): return PointerWrapperState( cell_state=self._cell.state_size, time=tensor_shape.TensorShape([]), attention=self._attention_layer_size, alignments=self._item_or_tuple( a.alignments_size for a in self._attention_mechanisms), alignment_history=self._item_or_tuple( () for _ in self._attention_mechanisms), p_gen_history=tensor_shape.TensorShape([]), vocab_dist_history=self._output_layer.units, copy_dist_history=self._item_or_tuple( self._output_layer.units + self._max_oovs for _ in self._attention_mechanisms), final_dist_history=self._output_layer.units + self._max_oovs) # sometimes a TensorArray def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: cell_state = self._cell.zero_state(batch_size, dtype) error_message = ( "When calling zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and the requested batch size. Are you using " "the BeamSearchDecoder? If so, make sure your encoder output has " "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and " "the batch_size= argument passed to zero_state is " "batch_size * beam_width.") with ops.control_dependencies( self._batch_size_checks(batch_size, error_message)): cell_state = nest.map_structure( lambda s: array_ops.identity(s, name="checked_cell_state"), cell_state) state = PointerWrapperState( cell_state=cell_state, time=array_ops.zeros([], dtype=dtypes.int32), attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype), alignments=self._item_or_tuple( attention_mechanism.initial_alignments(batch_size, dtype) for attention_mechanism in self._attention_mechanisms), alignment_history=self._item_or_tuple( tensor_array_ops.TensorArray(dtype=dtype, size=0, dynamic_size=True) if self._alignment_history else () for _ in self._attention_mechanisms), p_gen_history=tensor_array_ops.TensorArray(dtype=dtype, size=0, dynamic_size=True), vocab_dist_history=tensor_array_ops.TensorArray( dtype=tf.float32, size=0, dynamic_size=True), copy_dist_history=self._item_or_tuple( tensor_array_ops.TensorArray(dtype=dtype, size=0, dynamic_size=True) for _ in self._attention_mechanisms), final_dist_history=tensor_array_ops.TensorArray( dtype=tf.float32, size=0, dynamic_size=True)) return state def call(self, inputs, state): if not isinstance(state, PointerWrapperState): raise TypeError("Expected state to be instance of PointerWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) if isinstance(cell_state, LSTMStateTuple): last_out_state = cell_state else: last_out_state = cell_state[-1] cell_batch_size = ( cell_output.shape[0].value or array_ops.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = array_ops.identity( cell_output, name="checked_cell_output") if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history previous_copy_dist_history = state.copy_dist_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] previous_copy_dist_history = [state.copy_dist_history] previous_vocab_dist_history = state.vocab_dist_history print(previous_vocab_dist_history) print(cell_output) print(self._output_layer) vocab_dist = tf.nn.softmax(self._output_layer(cell_output)) print(vocab_dist) vocab_dist_history = previous_vocab_dist_history.write( state.time, vocab_dist) print("Vocab dist history") print(vocab_dist_history) vocab_size = self._output_layer.units all_alignments = [] all_attentions = [] all_histories = [] all_copy_dists = [] all_copy_dist_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): attention, alignments = _compute_attention( attention_mechanism, cell_output, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None) copy_dist = _calc_copy_dist(alignments, self._batch_size, vocab_size, self._max_oovs, self._memory_full_vocab[i]) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () copy_dist_history = previous_copy_dist_history[i].write( state.time, copy_dist) all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) all_copy_dists.append(copy_dist) all_copy_dist_histories.append(copy_dist_history) attention_vect = array_ops.concat(all_attentions, 1) p_gen = _compute_pgen(cell_output, last_out_state, inputs, attention_vect, len(self._attention_mechanisms) + 1) # for gen previous_final_dist_history = state.final_dist_history previous_p_gen_history = state.p_gen_history final_dist = _calc_final_dist(vocab_dist, all_copy_dists, p_gen, self._batch_size, vocab_size, self._max_oovs) final_dist_history = previous_final_dist_history.write( state.time, final_dist) p_gen_history = previous_p_gen_history.write(state.time, p_gen) print("Final_dist_history") print(final_dist_history) attention = array_ops.concat(all_attentions, 1) next_state = PointerWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories), p_gen_history=p_gen_history, vocab_dist_history=vocab_dist_history, copy_dist_history=self._item_or_tuple(all_copy_dist_histories), final_dist_history=final_dist_history) if self._output_generation_distribution: return vocab_dist, next_state elif self._output_copy_distribution: return (_convert_to_output_dist(copy_dist, vocab_size, self._unk_id), next_state) elif self._output_combined_distribution: return (_convert_to_output_dist(final_dist, vocab_size, self._unk_id), next_state) elif self._output_attention: return attention, next_state else: return cell_output, next_state