from tensorflow.python.ops import array_ops from tensorflow.contrib import rnn from tensorflow.contrib import seq2seq from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import _compute_attention class SyncAttentionWrapper(seq2seq.AttentionWrapper): def __init__(self, cell, attention_mechanism, attention_layer_size=None, alignment_history=False, cell_input_fn=None, output_attention=True, initial_cell_state=None, name=None): if not isinstance(cell, (rnn.LSTMCell, rnn.GRUCell)): raise ValueError('SyncAttentionWrapper only supports LSTMCell and GRUCell, ' 'Got: {}'.format(cell)) super(SyncAttentionWrapper, self).__init__( cell, attention_mechanism, attention_layer_size=attention_layer_size, alignment_history=alignment_history, cell_input_fn=cell_input_fn, output_attention=output_attention, initial_cell_state=initial_cell_state, name=name ) def call(self, inputs, state): if not isinstance(state, seq2seq.AttentionWrapperState): raise TypeError("Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_attention_states = [] all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): if isinstance(self._cell, rnn.LSTMCell): rnn_cell_state = state.cell_state.h else: rnn_cell_state = state.cell_state attention, alignments, next_attention_state = _compute_attention( attention_mechanism, rnn_cell_state, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_attention_states.append(next_attention_state) all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) attention = array_ops.concat(all_attentions, 1) cell_inputs = self._cell_input_fn(inputs, attention) cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state) next_state = seq2seq.AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, attention_state=self._item_or_tuple(all_attention_states), alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories)) if self._output_attention: return attention, next_state else: return cell_output, next_state