import tensorflow as tf def collect_vars(fn): """Collect all new variables created within `fn`. Args: fn: a function that takes no arguments and creates trainable tf.Variable objects. Returns: outputs: the outputs of `fn()`. new_vars: a list of the newly created variables. """ previous_vars = set(tf.trainable_variables()) outputs = fn() current_vars = set(tf.trainable_variables()) new_vars = current_vars.difference(previous_vars) return outputs, list(new_vars) def get_embedding_var(params, name="input_emb", reuse=False): if reuse: return tf.contrib.framework.get_unique_variable("cycle_gan/" + name) else: return tf.Variable( tf.random_normal( [params.vocab_size, params.hidden_size], mean=0.0, stddev=1.0), name=name) def embed_inputs(inputs, params, name="input_emb", reuse=False): assert inputs.dtype == tf.int64 or inputs.dtype == tf.int32, "Embedding lookup indices must be of integer type." w = get_embedding_var(params, name, reuse) return tf.gather(w, inputs) def softmax_to_embedding(x, params): o_shape = x.shape.as_list() if o_shape[0] is None: o_shape[0] = tf.shape(x)[0] if o_shape[1] is None: o_shape[1] = tf.shape(x)[1] output_dist = tf.reshape(x, [o_shape[0] * o_shape[1], params.vocab_size]) w_emb = get_embedding_var(params, reuse=True) output = tf.matmul(output_dist, w_emb) output = tf.reshape(output, [o_shape[0], o_shape[1], params.hidden_size]) return output def construct_vocab_lookup_table(vocab): mapping_string = tf.constant(vocab) return tf.contrib.lookup.index_to_string_table_from_tensor( mapping_string, default_value="<UNK>") def log_text(F, G, params): lookup_table = construct_vocab_lookup_table(params.vocab) X_vocab = tf.expand_dims(tf.range(params.vocab_size), axis=0) if params.use_embeddings: X = embed_inputs(X_vocab, params, reuse=True) else: X = tf.one_hot(X_vocab, depth=params.vocab_size) X_map_distribution = F(X, params.F, params) X_map_indices = tf.argmax(X_map_distribution, axis=-1) # X_vocab = tf.Print(X_vocab, [X_vocab], message="X_vocab", summarize=10) # X_map_indices = tf.Print( # X_map_indices, [X_map_indices], message="X_map_indices", summarize=10) X_map_text = lookup_table.lookup(tf.to_int64(X_map_indices)) X_vocab_text = lookup_table.lookup(tf.to_int64(X_vocab)) X_text = tf.string_join([X_vocab_text, "->", X_map_text]) tf.summary.text("F_map", X_text) Y_vocab = tf.expand_dims(tf.range(params.vocab_size), axis=0) if params.use_embeddings: Y = embed_inputs(Y_vocab, params, reuse=True) else: Y = tf.one_hot(Y_vocab, depth=params.vocab_size) Y_map_distribution = G(Y, params.G, params) Y_map_indices = tf.argmax(Y_map_distribution, axis=-1) # Y_vocab = tf.Print(Y_vocab, [Y_vocab], message="Y_vocab", summarize=10) # Y_map_indices = tf.Print( # Y_map_indices, [Y_map_indices], message="Y_map_indices", summarize=10) Y_map_text = lookup_table.lookup(tf.to_int64(Y_map_indices)) Y_vocab_text = lookup_table.lookup(tf.to_int64(Y_vocab)) Y_text = tf.string_join([Y_vocab_text, "->", Y_map_text]) tf.summary.text("G_map", Y_text) def groundtruth_accuracy(A, A_groundtruth, mask): groundtruth_mask = tf.to_float(mask) groundtruth_equalities = tf.to_float(tf.equal(A, A_groundtruth)) groundtruth_accs = tf.reduce_sum( groundtruth_equalities * groundtruth_mask, axis=1) / tf.reduce_sum( groundtruth_mask, axis=1) return tf.reduce_mean(groundtruth_accs) def sample_along_line(A_true, A_fake, params): A_unif = tf.tile( tf.random_uniform([params.batch_size, 1, 1]), [1, tf.shape(A_fake)[1], tf.shape(A_fake)[2]]) return A_unif * A_fake + (1 - A_unif) * A_true def wasserstein_penalty(discriminator, A_true, A_fake, params, discriminator_params): A_interp = sample_along_line(A_true, A_fake, params) if params.use_embeddings: A_interp = softmax_to_embedding(A_interp, params) discrim_A_interp = discriminator(A_interp, discriminator_params, params) discrim_A_grads = tf.gradients(discrim_A_interp, [A_interp]) discrim_A_grads = tf.squeeze(discrim_A_grads) if params.original_l2: l2_loss = tf.sqrt( tf.reduce_sum( tf.convert_to_tensor(discrim_A_grads)**2, axis=[1, 2])) if params.true_lipschitz: loss = params.wasserstein_loss * tf.reduce_mean( tf.nn.relu(l2_loss - 1)**2) else: loss = params.wasserstein_loss * tf.reduce_mean((l2_loss - 1)**2) else: loss = params.wasserstein_loss * (tf.nn.l2_loss(discrim_A_grads) - 1)**2 return loss