import tensorflow as tf import collections from tensorflow.python.util import nest from tensorflow.contrib.rnn import RNNCell from dynamic_decode import transpose_batch_time from greedy_decoder_cell import DecoderOutput class BeamSearchDecoderCellState(collections.namedtuple( "BeamSearchDecoderCellState", ("cell_state", "log_probs"))): """State of the Beam Search decoding cell_state: shape = structure of [batch_size, beam_size, ?] cell state for all the hypotheses embedding: shape = [batch_size, beam_size, embedding_size] embeddings of the previous time step for each hypothesis log_probs: shape = [batch_size, beam_size] log_probs of the hypotheses finished: shape = [batch_size, beam_size] boolean to know if one beam hypothesis has reached token id_end """ pass class BeamSearchDecoderOutput(collections.namedtuple( "BeamSearchDecoderOutput", ("logits", "ids", "parents"))): """Stores the logic for the beam search decoding logits: shape = [batch_size, beam_size, vocab_size] scores before softmax of the beam search hypotheses ids: shape = [batch_size, beam_size] ids of the best words at this time step parents: shape = [batch_size, beam_size] ids of the beam index from previous time step """ pass class BeamSearchDecoderCell(object): def __init__(self, embeddings, cell, batch_size, start_token, end_token, beam_size=5, div_gamma=1, div_prob=0): """Initializes parameters for Beam Search Args: embeddings: (tf.Variable) shape = (vocab_size, embedding_size) cell: instance of Cell that defines a step function, etc. batch_size: tf.int extracted with tf.Shape or int start_token: id of start token end_token: int, id of the end token beam_size: int, size of the beam div_gamma: float, amount of penalty to add to beam hypo for diversity. Coefficient of penaly will be log(div_gamma). Use value between 0 and 1. (1 means no penalty) div_prob: only apply div penalty with probability div_prob. div_prob = 0. means never apply penalty """ self._embeddings = embeddings self._cell = cell self._dim_embeddings = embeddings.shape[-1].value self._batch_size = batch_size self._start_token = start_token self._beam_size = beam_size self._end_token = end_token self._vocab_size = embeddings.shape[0].value self._div_gamma = float(div_gamma) self._div_prob = float(div_prob) @property def output_dtype(self): """Needed for custom dynamic_decode for the TensorArray of results""" return BeamSearchDecoderOutput(logits=self._cell.output_dtype, ids=tf.int32, parents=tf.int32) @property def final_output_dtype(self): """For the finalize method""" return DecoderOutput(logits=self._cell.output_dtype, ids=tf.int32) @property def state_size(self): return BeamSearchDecoderOutput( logits=tf.TensorShape([self._beam_size, self._vocab_size]), ids=tf.TensorShape([self._beam_size]), parents=tf.TensorShape([self._beam_size])) @property def final_output_size(self): return DecoderOutput(logits=tf.TensorShape([self._beam_size, self._vocab_size]), ids=tf.TensorShape([self._beam_size])) def initial_state(self): """Returns initial state for the decoder""" # cell initial state cell_state = self._cell.initial_state() cell_state = nest.map_structure(lambda t: tile_beam(t, self._beam_size), cell_state) # prepare other initial states log_probs = tf.zeros([self._batch_size, self._beam_size], dtype=self._cell.output_dtype) return BeamSearchDecoderCellState(cell_state, log_probs) def initial_inputs(self): return tf.tile(tf.reshape(self._start_token, [1, 1, self._dim_embeddings]), multiples=[self._batch_size, self._beam_size, 1]) def initialize(self): initial_state = self.initial_state() initial_inputs = self.initial_inputs() initial_finished = tf.zeros(shape=[self._batch_size, self._beam_size], dtype=tf.bool) return initial_state, initial_inputs, initial_finished def step(self, time, state, embedding, finished): """ Args: time: tensorf or int embedding: shape [batch_size, beam_size, d] state: structure of shape [bach_size, beam_size, ...] finished: structure of shape [batch_size, beam_size, ...] """ # merge batch and beam dimension before callling step of cell cell_state = nest.map_structure(merge_batch_beam, state.cell_state) embedding = merge_batch_beam(embedding) # compute new logits logits, new_cell_state = self._cell.step(embedding, cell_state) # split batch and beam dimension before beam search logic new_logits = split_batch_beam(logits, self._beam_size) new_cell_state = nest.map_structure( lambda t: split_batch_beam(t, self._beam_size), new_cell_state) # compute log probs of the step # shape = [batch_size, beam_size, vocab_size] step_log_probs = tf.nn.log_softmax(new_logits) # shape = [batch_size, beam_size, vocab_size] step_log_probs = mask_probs(step_log_probs, self._end_token, finished) # shape = [batch_size, beam_size, vocab_size] log_probs = tf.expand_dims(state.log_probs, axis=-1) + step_log_probs log_probs = add_div_penalty(log_probs, self._div_gamma, self._div_prob, self._batch_size, self._beam_size, self._vocab_size) # compute the best beams # shape = (batch_size, beam_size * vocab_size) log_probs_flat = tf.reshape(log_probs, [self._batch_size, self._beam_size * self._vocab_size]) # if time = 0, consider only one beam, otherwise beams are equal log_probs_flat = tf.cond(time > 0, lambda: log_probs_flat, lambda: log_probs[:, 0]) new_probs, indices = tf.nn.top_k(log_probs_flat, self._beam_size) # of shape [batch_size, beam_size] new_ids = indices % self._vocab_size new_parents = indices // self._vocab_size # get ids of words predicted and get embedding new_embedding = tf.nn.embedding_lookup(self._embeddings, new_ids) # compute end of beam finished = gather_helper(finished, new_parents, self._batch_size, self._beam_size) new_finished = tf.logical_or(finished, tf.equal(new_ids, self._end_token)) new_cell_state = nest.map_structure( lambda t: gather_helper(t, new_parents, self._batch_size, self._beam_size), new_cell_state) # create new state of decoder new_state = BeamSearchDecoderCellState(cell_state=new_cell_state, log_probs=new_probs) new_output = BeamSearchDecoderOutput(logits=new_logits, ids=new_ids, parents=new_parents) return (new_output, new_state, new_embedding, new_finished) def finalize(self, final_outputs, final_state): """ Args: final_outputs: structure of tensors of shape [time dimension, batch_size, beam_size, d] final_state: instance of BeamSearchDecoderOutput Returns: [time, batch, beam, ...] structure of Tensor """ # reverse the time dimension maximum_iterations = tf.shape(final_outputs.ids)[0] final_outputs = nest.map_structure(lambda t: tf.reverse(t, axis=[0]), final_outputs) # initial states def create_ta(d): return tf.TensorArray(dtype=d, size=maximum_iterations) initial_time = tf.constant(0, dtype=tf.int32) initial_outputs_ta = nest.map_structure(create_ta, self.final_output_dtype) initial_parents = tf.tile( tf.expand_dims(tf.range(self._beam_size), axis=0), multiples=[self._batch_size, 1]) def condition(time, outputs_ta, parents): return tf.less(time, maximum_iterations) # beam search decoding cell def body(time, outputs_ta, parents): # get ids, logits and parents predicted at time step by decoder input_t = nest.map_structure(lambda t: t[time], final_outputs) # extract the entries corresponding to parents new_state = nest.map_structure( lambda t: gather_helper(t, parents, self._batch_size, self._beam_size), input_t) # create new output new_output = DecoderOutput(logits=new_state.logits, ids=new_state.ids) # write beam ids outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, new_output) return (time + 1), outputs_ta, parents res = tf.while_loop( condition, body, loop_vars=[initial_time, initial_outputs_ta, initial_parents], back_prop=False) # unfold and stack the structure from the nested tas final_outputs = nest.map_structure(lambda ta: ta.stack(), res[1]) # reverse time step final_outputs = nest.map_structure(lambda t: tf.reverse(t, axis=[0]), final_outputs) return DecoderOutput(logits=final_outputs.logits, ids=final_outputs.ids) def sample_bernoulli(p, s): """Samples a boolean tensor with shape = s according to bernouilli""" return tf.greater(p, tf.random_uniform(s)) def add_div_penalty(log_probs, div_gamma, div_prob, batch_size, beam_size, vocab_size): """Adds penalty to beam hypothesis following this paper by Li et al. 2016 "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" Args: log_probs: (tensor of floats) shape = (batch_size, beam_size, vocab_size) div_gamma: (float) diversity parameter div_prob: (float) adds penalty with proba div_prob """ if div_gamma is None or div_prob is None: return log_probs if div_gamma == 1. or div_prob == 0.: return log_probs # 1. get indices that would sort the array top_probs, top_inds = tf.nn.top_k(log_probs, k=vocab_size, sorted=True) # 2. inverse permutation to get rank of each entry top_inds = tf.reshape(top_inds, [-1, vocab_size]) index_rank = tf.map_fn(tf.invert_permutation, top_inds, back_prop=False) index_rank = tf.reshape(index_rank, shape=[batch_size, beam_size, vocab_size]) # 3. compute penalty penalties = tf.log(div_gamma) * tf.cast(index_rank, log_probs.dtype) # 4. only apply penalty with some probability apply_penalty = tf.cast( sample_bernoulli(div_prob, [batch_size, beam_size, vocab_size]), penalties.dtype) penalties *= apply_penalty return log_probs + penalties def merge_batch_beam(t): """ Args: t: tensor of shape [batch_size, beam_size, ...] whose dimensions after beam_size must be statically known Returns: t: tensorf of shape [batch_size * beam_size, ...] """ batch_size = tf.shape(t)[0] beam_size = t.shape[1].value if t.shape.ndims == 2: return tf.reshape(t, [batch_size*beam_size, 1]) elif t.shape.ndims == 3: return tf.reshape(t, [batch_size*beam_size, t.shape[-1].value]) elif t.shape.ndims == 4: return tf.reshape(t, [batch_size*beam_size, t.shape[-2].value, t.shape[-1].value]) else: raise NotImplementedError def split_batch_beam(t, beam_size): """ Args: t: tensorf of shape [batch_size*beam_size, ...] Returns: t: tensor of shape [batch_size, beam_size, ...] """ if t.shape.ndims == 1: return tf.reshape(t, [-1, beam_size]) elif t.shape.ndims == 2: return tf.reshape(t, [-1, beam_size, t.shape[-1].value]) elif t.shape.ndims == 3: return tf.reshape(t, [-1, beam_size, t.shape[-2].value, t.shape[-1].value]) else: raise NotImplementedError def tile_beam(t, beam_size): """ Args: t: tensor of shape [batch_size, ...] Returns: t: tensorf of shape [batch_size, beam_size, ...] """ # shape = [batch_size, 1 , x] t = tf.expand_dims(t, axis=1) if t.shape.ndims == 2: multiples = [1, beam_size] elif t.shape.ndims == 3: multiples = [1, beam_size, 1] elif t.shape.ndims == 4: multiples = [1, beam_size, 1, 1] return tf.tile(t, multiples) def mask_probs(probs, end_token, finished): """ Args: probs: tensor of shape [batch_size, beam_size, vocab_size] end_token: (int) finished: tensor of shape [batch_size, beam_size], dtype = tf.bool """ # one hot of shape [vocab_size] vocab_size = probs.shape[-1].value one_hot = tf.one_hot(end_token, vocab_size, on_value=0., off_value=probs.dtype.min, dtype=probs.dtype) # expand dims of shape [batch_size, beam_size, 1] finished = tf.expand_dims(tf.cast(finished, probs.dtype), axis=-1) return (1. - finished) * probs + finished * one_hot def gather_helper(t, indices, batch_size, beam_size): """ Args: t: tensor of shape = [batch_size, beam_size, d] indices: tensor of shape = [batch_size, beam_size] Returns: new_t: tensor w shape as t but new_t[:, i] = t[:, new_parents[:, i]] """ range_ = tf.expand_dims(tf.range(batch_size) * beam_size, axis=1) indices = tf.reshape(indices + range_, [-1]) output = tf.gather( tf.reshape(t, [batch_size*beam_size, -1]), indices) if t.shape.ndims == 2: return tf.reshape(output, [batch_size, beam_size]) elif t.shape.ndims == 3: d = t.shape[-1].value return tf.reshape(output, [batch_size, beam_size, d])