# -*- coding: utf-8 -*- import tensorflow as tf from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, LSTMCell, RNNCell from tensorflow.contrib.rnn import DropoutWrapper def define_rnn_cell(cell_class, num_units, num_layers=1, keep_prob=1.0, input_keep_prob=None, output_keep_prob=None): if input_keep_prob is None: input_keep_prob = keep_prob if output_keep_prob is None: output_keep_prob = keep_prob cells = [] for _ in range(num_layers): if cell_class == 'GRU': cell = GRUCell(num_units=num_units) elif cell_class == 'LSTM': cell = LSTMCell(num_units=num_units) else: cell = RNNCell(num_units=num_units) if keep_prob < 1.0: cell = DropoutWrapper(cell=cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob) cells.append(cell) if len(cells) > 1: final_cell = MultiRNNCell(cells) else: final_cell = cells[0] return final_cell def sequence_loss(num_symbols, output_logits, targets, masks): """Sequence loss""" logits = tf.reshape(output_logits, [-1, num_symbols]) local_labels = tf.reshape(targets, [-1]) local_masks = tf.reshape(masks, [-1]) local_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=local_labels, logits=logits) local_loss = local_loss * local_masks loss = tf.reduce_sum(local_loss) total_size = tf.reduce_sum(local_masks) total_size += 1e-12 # to avoid division by 0 for all-0 weights loss = loss / total_size return loss def ppx_loss(num_symbols, output_logits, targets, masks): local_masks = tf.reshape(masks, [-1]) one_hot_targets = tf.one_hot(targets, num_symbols) ppx_prob = tf.reduce_sum(tf.nn.softmax(output_logits) * one_hot_targets, axis=2) ppx_loss = tf.reduce_sum(tf.reshape(-tf.log(1e-12 + ppx_prob), [-1]) * local_masks) total_size = tf.reduce_sum(local_masks) total_size += 1e-12 # to avoid division by 0 for all-0 weights ppx_loss = ppx_loss / total_size return ppx_loss def sentence_ppx(num_symbols, output_logits, targets, masks): batch_size = tf.shape(output_logits)[0] local_masks = tf.reshape(masks, [-1]) one_hot_targets = tf.one_hot(targets, num_symbols) ppx_prob = tf.reduce_sum(tf.nn.log_softmax(output_logits) * one_hot_targets, axis=2) sent_ppx = tf.reduce_sum( tf.reshape(tf.reshape(-ppx_prob, [-1]) * local_masks, [batch_size, -1]), axis=1) sent_ppx = sent_ppx / tf.reduce_sum(masks, axis=1) return sent_ppx