"""Class for generating sequences Adapted from https://github.com/tensorflow/models/blob/master/im2txt/im2txt/inference_utils/sequence_generator.py""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import heapq from .config import EOS class Sequence(object): """Represents a complete or partial sequence.""" def __init__(self, output, state, logprob, score, attention=None): """Initializes the Sequence. Args: output: List of word ids in the sequence. state: Model state after generating the previous word. logprob: Log-probability of the sequence. score: Score of the sequence. """ self.output = output self.state = state self.logprob = logprob self.score = score self.attention = attention def __cmp__(self, other): """Compares Sequences by score.""" assert isinstance(other, Sequence) if self.score == other.score: return 0 elif self.score < other.score: return -1 else: return 1 # For Python 3 compatibility (__cmp__ is deprecated). def __lt__(self, other): assert isinstance(other, Sequence) return self.score < other.score # Also for Python 3 compatibility. def __eq__(self, other): assert isinstance(other, Sequence) return self.score == other.score class TopN(object): """Maintains the top n elements of an incrementally provided set.""" def __init__(self, n): self._n = n self._data = [] def size(self): assert self._data is not None return len(self._data) def push(self, x): """Pushes a new element.""" assert self._data is not None if len(self._data) < self._n: heapq.heappush(self._data, x) else: heapq.heappushpop(self._data, x) def extract(self, sort=False): """Extracts all elements from the TopN. This is a destructive operation. The only method that can be called immediately after extract() is reset(). Args: sort: Whether to return the elements in descending sorted order. Returns: A list of data; the top n elements provided to the set. """ assert self._data is not None data = self._data self._data = None if sort: data.sort(reverse=True) return data def reset(self): """Returns the TopN to an empty state.""" self._data = [] class SequenceGenerator(object): """Class to generate sequences from an image-to-text model.""" def __init__(self, decode_step, eos_id=EOS, beam_size=3, max_sequence_length=50, get_attention=False, length_normalization_factor=0.0, length_normalization_const=5., device_ids=None): """Initializes the generator. Args: deocde_step: function, with inputs: (input, state) and outputs len(vocab) values eos_id: the token number symobling the end of sequence beam_size: Beam size to use when generating sequences. max_sequence_length: The maximum sequence length before stopping the search. length_normalization_factor: If != 0, a number x such that sequences are scored by logprob/length^x, rather than logprob. This changes the relative scores of sequences depending on their lengths. For example, if x > 0 then longer sequences will be favored. alpha in: https://arxiv.org/abs/1609.08144 length_normalization_const: 5 in https://arxiv.org/abs/1609.08144 """ self.decode_step = decode_step self.eos_id = eos_id self.beam_size = beam_size self.max_sequence_length = max_sequence_length self.length_normalization_factor = length_normalization_factor self.length_normalization_const = length_normalization_const self.get_attention = get_attention self.device_ids = device_ids def beam_search(self, initial_input, initial_state=None): """Runs beam search sequence generation on a single image. Args: initial_input: An initial input for the model - list of batch size holding the first input for every entry. initial_state (optional): An initial state for the model - list of batch size holding the current state for every entry. Returns: A list of batch size, each the most likely sequence from the possible beam_size candidates. """ batch_size = len(initial_input) partial_sequences = [TopN(self.beam_size) for _ in range(batch_size)] complete_sequences = [TopN(self.beam_size) for _ in range(batch_size)] words, logprobs, new_state = self.decode_step( initial_input, initial_state, k=self.beam_size, feed_all_timesteps=True, get_attention=self.get_attention) for b in range(batch_size): # Create first beam_size candidate hypotheses for each entry in # batch for k in range(self.beam_size): seq = Sequence( output=initial_input[b] + [words[b][k]], state=new_state[b], logprob=logprobs[b][k], score=logprobs[b][k], attention=None if not self.get_attention else [new_state[b].attention_score]) partial_sequences[b].push(seq) # Run beam search. for _ in range(self.max_sequence_length - 1): partial_sequences_list = [p.extract() for p in partial_sequences] for p in partial_sequences: p.reset() # Keep a flattened list of parial hypotheses, to easily feed # through a model as whole batch flattened_partial = [ s for sub_partial in partial_sequences_list for s in sub_partial] input_feed = [c.output for c in flattened_partial] state_feed = [c.state for c in flattened_partial] if len(input_feed) == 0: # We have run out of partial candidates; happens when # beam_size=1 break # Feed current hypotheses through the model, and recieve new outputs and states # logprobs are needed to rank hypotheses words, logprobs, new_states \ = self.decode_step( input_feed, state_feed, k=self.beam_size + 1, get_attention=self.get_attention, device_ids=self.device_ids) idx = 0 for b in range(batch_size): # For every entry in batch, find and trim to the most likely # beam_size hypotheses for partial in partial_sequences_list[b]: state = new_states[idx] if self.get_attention: attention = partial.attention + \ [new_states[idx].attention_score] else: attention = None k = 0 num_hyp = 0 while num_hyp < self.beam_size: w = words[idx][k] output = partial.output + [w] logprob = partial.logprob + logprobs[idx][k] score = logprob k += 1 num_hyp += 1 if w.item() == self.eos_id: if self.length_normalization_factor > 0: L = self.length_normalization_const length_penalty = (L + len(output)) / (L + 1) score /= length_penalty ** self.length_normalization_factor beam = Sequence(output, state, logprob, score, attention) complete_sequences[b].push(beam) num_hyp -= 1 # we can fit another hypotheses as this one is over else: beam = Sequence(output, state, logprob, score, attention) partial_sequences[b].push(beam) idx += 1 # If we have no complete sequences then fall back to the partial sequences. # But never output a mixture of complete and partial sequences because a # partial sequence could have a higher score than all the complete # sequences. for b in range(batch_size): if not complete_sequences[b].size(): complete_sequences[b] = partial_sequences[b] seqs = [complete.extract(sort=True)[0] for complete in complete_sequences] return seqs