import tensorflow as tf import tensorflow.contrib.seq2seq as seq2seq from tensorflow.contrib.seq2seq import AttentionWrapper from tensorflow.contrib.seq2seq import BasicDecoder from tensorflow.contrib.seq2seq import TrainingHelper from tensorflow.python.layers.core import Dense from icecaps.estimators.abstract_recurrent_estimator import AbstractRecurrentEstimator from icecaps.decoding.basic_decoder_custom import BasicDecoder as MMIDecoder from icecaps.decoding.beam_search_decoder_custom import BeamSearchDecoder from icecaps.decoding.dynamic_decoder_custom import dynamic_decode from icecaps.util.vocabulary import Vocabulary class Seq2SeqDecoderEstimator(AbstractRecurrentEstimator): def __init__(self, model_dir="/tmp", params=dict(), config=None, scope="", is_mmi_model=False): self.is_mmi_model = is_mmi_model super().__init__(model_dir, params, config, scope) def _model_fn(self, features, mode, params): with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): self.extract_args(features, mode, params) self.batch_size = tf.shape(self.features["inputs"])[0] self.build_rnn() if mode == tf.estimator.ModeKeys.PREDICT: if self.is_mmi_model: self.build_inputs() self.build_mmi_decoder() self.predictions = { "inputs": tf.convert_to_tensor(self.features["original_inputs"]), "targets": tf.convert_to_tensor(self.features["targets"]), "scores": self.mmi_scores, } export_outputs = None else: self.build_rt_decoder() self.predictions = { "inputs": self.inputs_pred, "outputs": self.rt_hypotheses, "scores": self.scores } export_outputs = { 'predict_output': tf.estimator.export.PredictOutput(self.predictions)} return tf.estimator.EstimatorSpec(mode, predictions=self.predictions, export_outputs=export_outputs) self.build_inputs() self.build_train_decoder() self.build_loss() if mode == tf.estimator.ModeKeys.TRAIN: self.build_optimizer() return tf.estimator.EstimatorSpec(mode, loss=self.reported_loss, train_op=self.train_op) if mode == tf.estimator.ModeKeys.EVAL: print("Number of parameters: " + str(self.get_num_model_params())) self.eval_metric_ops = { "gs_token_accuracy": tf.metrics.accuracy( labels=self.targets_sparse, predictions=self.gs_hypotheses, weights=self.target_mask), "program_accuracy": tf.metrics.accuracy( labels=tf.zeros([self.batch_size], dtype=tf.int32), predictions=tf.reduce_sum(tf.cast(self.target_mask, tf.int32) * tf.squared_difference(self.targets_sparse, tf.cast(self.gs_hypotheses, tf.int32)), -1)), } return tf.estimator.EstimatorSpec(mode, loss=self.reported_loss, eval_metric_ops=self.eval_metric_ops) @classmethod def construct_expected_params(cls): expected_params = super().construct_expected_params() expected_params["use_attention"] = cls.make_param(False) expected_params["attention_input_feeding"] = cls.make_param(False) expected_params["attention_type"] = cls.make_param("luong") expected_params["shrink_vocab"] = cls.make_param(0) expected_params["repetition_allowance"] = cls.make_param(0.01) expected_params["repetition_penalty"] = cls.make_param(1.0) expected_params["post_repetition_penalty"] = cls.make_param(5.0) return expected_params def extract_args(self, features, mode, params): super().extract_args(features, mode, params) self.beam_search_decoding = tf.constant( self.mode == tf.estimator.ModeKeys.PREDICT and self.hparams.beam_width > 1 and not self.is_mmi_model) def build_inputs(self): super().build_inputs("targets") start_tokens_sparse = tf.ones(shape=[self.batch_size, 1], dtype=tf.int32) * self.vocab.start_token_id start_tokens_dense = self.embed_sparse_to_dense(start_tokens_sparse) self.inputs_dense = tf.concat([start_tokens_dense, self.inputs_dense], axis=1) self.inputs_length += 1 # [batch_size] self.inputs_max_length += 1 # [batch_size, max_time_steps + 1] end_tokens_sparse = tf.ones(shape=[self.batch_size, 1], dtype=tf.int32) * self.vocab.end_token_id self.targets_sparse = tf.concat([self.inputs_sparse, end_tokens_sparse], axis=1) self.target_mask = tf.sequence_mask( lengths=self.inputs_length, maxlen=self.inputs_max_length, dtype=tf.float32) def build_attention_mechanism(self): if self.hparams.attention_type == 'luong': attention_mechanism = seq2seq.LuongAttention( self.hparams.hidden_units, self.feedforward_inputs, self.feedforward_inputs_length) elif self.hparams.attention_type == 'bahdanau': attention_mechanism = seq2seq.BahdanauAttention( self.hparams.hidden_units, self.feedforward_inputs, self.feedforward_inputs_length,) else: raise ValueError( "Currently, the only supported attention types are 'luong' and 'bahdanau'.") def _attention_input_feeding(self, input_feed): if self.hparams.attention_input_feeding: self.attention_input_layer = Dense(self.hparams.hidden_units, name='attention_input_layer') return self.attention_input_layer(tf.concat([input_feed, attention], -1)) else: return input_feed def build_attention_wrapper(self, final_cell): self.feedforward_inputs = tf.cond( self.beam_search_decoding, lambda: seq2seq.tile_batch( self.features["inputs"], multiplier=self.hparams.beam_width), lambda: self.features["inputs"]) self.feedforward_inputs_length = tf.cond( self.beam_search_decoding, lambda: seq2seq.tile_batch( self.features["length"], multiplier=self.hparams.beam_width), lambda: self.features["length"]) attention_mechanism = self.build_attention_mechanism() return AttentionWrapper( cell=final_cell, attention_mechanism=attention_mechanism, attention_layer_size=self.hparams.hidden_units, cell_input_fn=self._attention_input_feeding, initial_cell_state=self.initial_state[-1] if self.hparams.depth > 1 else self.initial_state) def build_rnn(self): self.initial_state = tf.cond( self.beam_search_decoding, lambda: seq2seq.tile_batch( self.features["state"], self.hparams.beam_width), lambda: self.features["state"], name="initial_state") self.build_embeddings() cell_list = self.build_deep_cell(return_raw_list=True) if self.hparams.use_attention: cell_list[-1] = self.build_attention(cell_list[-1]) if self.hparams.depth > 1: self.initial_state[-1] = final_cell.zero_state(batch_size=self.batch_size) else: self.initial_state = final_cell.zero_state(batch_size=self.batch_size) with tf.name_scope('rnn_cell'): self.cell = self.build_deep_cell(cell_list) self.output_layer = Dense(self.vocab.size(), name='output_layer') def build_train_decoder(self): with tf.name_scope('train_decoder'): training_helper = TrainingHelper(inputs=self.inputs_dense, sequence_length=self.inputs_length, time_major=False, name='training_helper') with tf.name_scope('basic_decoder'): training_decoder = BasicDecoder(cell=self.cell, helper=training_helper, initial_state=self.initial_state, output_layer=self.output_layer) with tf.name_scope('dynamic_decode'): (outputs, self.last_state, self.outputs_length) = (seq2seq.dynamic_decode( decoder=training_decoder, output_time_major=False, impute_finished=True, maximum_iterations=self.inputs_max_length)) self.logits = tf.identity(outputs.rnn_output) self.log_probs = tf.nn.log_softmax(self.logits) self.gs_hypotheses = tf.argmax(self.log_probs, -1) def build_rt_decoder(self): self.build_embeddings() start_tokens_sparse = tf.ones(shape=[self.batch_size], dtype=tf.int32) * self.vocab.start_token_id with tf.name_scope('beamsearch_decoder'): rt_decoder = BeamSearchDecoder(cell=self.cell, embedding=self.embed_sparse_to_dense, start_tokens=start_tokens_sparse, end_token=self.vocab.end_token_id, initial_state=self.initial_state, beam_width=self.hparams.beam_width, output_layer=self.output_layer, skip_tokens_decoding=self.vocab.skip_tokens, shrink_vocab=self.hparams.shrink_vocab) (hypotheses, input_query_ids, scores) = dynamic_decode( decoder=rt_decoder, output_time_major=False, maximum_iterations=self.hparams.max_length, repetition=self.hparams.repetition_penalty) sort_ids = tf.argsort( scores, direction="DESCENDING", stable=True, axis=0) scores = tf.gather_nd(scores, sort_ids) hypotheses = tf.gather_nd(hypotheses, sort_ids) input_query_ids = tf.gather_nd(input_query_ids, sort_ids) sort_ids = tf.argsort( input_query_ids, direction="ASCENDING", stable=True, axis=0) scores = tf.gather_nd(scores, sort_ids) hypotheses = tf.gather_nd(hypotheses, sort_ids) input_query_ids = tf.gather_nd(input_query_ids, sort_ids) input_queries = tf.gather_nd(tf.convert_to_tensor( self.features["original_inputs"]), input_query_ids) self.rt_hypotheses = tf.identity(hypotheses) self.inputs_pred = tf.identity(input_queries) self.scores = tf.identity(scores) def build_mmi_decoder(self): with tf.name_scope('mmi_scorer'): training_helper = TrainingHelper(inputs=self.inputs_dense, sequence_length=self.inputs_length, time_major=False, name='mmi_training_helper') with tf.name_scope('mmi_basic_decoder'): training_decoder = MMIDecoder(cell=self.cell, helper=training_helper, initial_state=self.initial_state, output_layer=self.output_layer) with tf.name_scope('mmi_dynamic_decoder'): (outputs, self.last_state, self.outputs_length) = seq2seq.dynamic_decode( decoder=training_decoder, output_time_major=False, impute_finished=True, maximum_iterations=self.inputs_max_length) self.scores_raw = tf.identity( tf.transpose(outputs.scores, [1, 2, 0])) targets = self.features["targets"] targets = tf.cast(targets, dtype=tf.int32) target_len = tf.cast(tf.count_nonzero( targets - self.vocab.end_token_id, -1), dtype=tf.int32) max_target_len = tf.reduce_max(target_len) pruned_targets = tf.slice(targets, [0, 0], [-1, max_target_len]) index = (tf.range(0, max_target_len, 1)) * \ tf.ones(shape=[self.batch_size, 1], dtype=tf.int32) row_no = tf.transpose(tf.range( 0, self.batch_size, 1) * tf.ones(shape=(max_target_len, 1), dtype=tf.int32)) indices = tf.stack([index, pruned_targets, row_no], axis=2) # Retrieve scores corresponding to indices batch_scores = tf.gather_nd(self.scores_raw, indices) self.mmi_scores = tf.reduce_sum(batch_scores, axis=1)