# coding=utf-8 # Copyright 2020 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Scheduled Sampling. This module implemented scheduled sampling as described in (Bengio et al, 2015). The entry points are two functions, `sequential_scheduled_sampling_for_t2tmodel()`: scheduled sampling adapted to instances of T2TModel. `sequential_scheduled_sampling()`: raw implementation of scheduled sampling. May be used independent of T2T. **WARNING** This code is VERY slow. Its runtime is at least O(n^2) for sequences of length n. For models with self-attention, its runtime is O(n^3). """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy from tensor2tensor.layers import common_layers import tensorflow.compat.v1 as tf from tensorflow.python.ops import inplace_ops # pylint: disable=g-direct-tensorflow-import def sequential_scheduled_sampling_for_t2tmodel(t2tmodel, features): """Schedule Sampling for T2TModels. Args: t2tmodel: T2TModel instance. features: {str: Tensor}. Input features. Returns: ss_logits: [batch_size, seq_len, 1, 1, vocab_size]. losses_dict: {str: scalar Tensor}. Losses to minimize. """ targets = features["targets"] targets_size = common_layers.shape_list(targets) batch_size = targets_size[0] seq_len = targets_size[1] targets = tf.reshape(targets, [batch_size, seq_len]) adapter = ScheduledSamplingAdapter(t2tmodel, features) ss_tokens, ss_logits, losses_dict = sequential_scheduled_sampling( infer_fn=adapter.infer_fn, mix_fn=adapter.mix_fn, loss_fn=adapter.loss_fn, targets=targets) _ = ss_tokens # unused. targets_vocab_size = t2tmodel.problem_hparams.vocab_size["targets"] ss_logits = tf.reshape(ss_logits, [batch_size, seq_len, 1, 1, targets_vocab_size]) return ss_logits, losses_dict def sequential_scheduled_sampling(infer_fn, mix_fn, loss_fn, targets): """Scheduled Sampling. Args: infer_fn: Function. Computes logits for all timesteps. mix_fn: Function. Mixes gold and sample tokens. loss_fn: Function. Computes loss between gold tokens and logits. targets: Tensor of shape [batch_size, seq_len]. Gold tokens. Returns: ss_tokens: Tensor of shape [batch_size, seq_len]. Scheduled sampling tokens. ss_logits: Tensor of shape [batch_size, seq_len, vocab_size]. Logits for next token when conditioning on ss_tokens. losses_dict: {str: scalar Tensor}. Losses to optimize. """ targets_shape = common_layers.shape_list(targets) batch_size = targets_shape[0] seq_len = targets_shape[1] if not targets.shape.is_fully_defined(): # TODO(duckworthd): When running on GPU, I get the following error. Solve # it to enable use on other devices. # # Cannot use 'Identity_186' as input to # 'transformer/parallel_0_7/transformer/transformer/symbol_modality_16282_512/shared/convert_gradient_to_tensor_HBc3xYw22Mw' # because 'Identity_186' is in a while loop. raise ValueError( "The following code only works on TPU. As targets.shape isn't fully " "defined, I am assuming you are using a different device.") def cond_fn(i, ss_tokens): """True if i < seq_len.""" _ = ss_tokens return i < seq_len def body_fn(i, ss_tokens): """Constructs conditioning tokens for scheduled sampling.""" # next_token_logits depends on timesteps 0...i-1. # # [batch_size, seq_len] -> [batch_size, seq_len, vocab_size] ss_tokens_logits = infer_fn(ss_tokens) # Same as 'next_token_logits = ss_tokens_logits[:, i, :]'. vocab_size = common_layers.shape_list(ss_tokens_logits)[2] next_token_logits = tf.slice( ss_tokens_logits, begin=[0, i, 0], size=[batch_size, 1, vocab_size]) next_token_logits = tf.squeeze(next_token_logits, axis=[1]) # [batch_size, vocab_size] -> [batch_size] sampled_next_tokens = _sample_next_tokens(next_token_logits) # Same as 'gold_next_tokens = targets[:, i]'. gold_next_tokens = tf.slice(targets, begin=[0, i], size=[batch_size, 1]) gold_next_tokens = tf.squeeze(gold_next_tokens, axis=[1]) next_tokens = mix_fn(gold_next_tokens, sampled_next_tokens) ss_tokens = _update_timestep(ss_tokens, timestep=i, values=next_tokens) return i+1, tf.stop_gradient(ss_tokens) # tf.while_loop() over all timesteps. Generate scheduled sampling tokens. i = 0 ss_tokens = tf.zeros([batch_size, seq_len], dtype=tf.int32) i, ss_tokens = tf.while_loop(cond_fn, body_fn, [i, ss_tokens]) ss_logits = infer_fn(ss_tokens) return ss_tokens, ss_logits, loss_fn(targets, ss_logits) def _mix_tokens(p_sample, gold_targets, sampled_targets): """Interleave sampled and gold tokens randomly. Args: p_sample: float in [0, 1]. Probability a token will come from 'sampled_targets'. 0 means all-gold, 1 means all-sampled. gold_targets: Tensor. Gold token IDs. sampled_targets: Tensor. Sampled token IDs. Same shape as 'gold_targets'. Returns: Tensor of same shape as 'gold_targets' containing a mix of tokens from 'gold_targets' and 'sampled_targets'. """ targets_shape = common_layers.shape_list(sampled_targets) return tf.where( tf.less(tf.random_uniform(targets_shape), p_sample), sampled_targets, gold_targets) def _sample_next_tokens(logits): """Sample tokens for next timestep.""" batch_size = common_layers.shape_list(logits)[0] next_tokens = tf.random.categorical(logits, 1) next_tokens = tf.cast(next_tokens, tf.int32) next_tokens = tf.reshape(next_tokens, [batch_size]) return next_tokens def _update_timestep(x, timestep, values): """Set x[:, timestep] = values. This operation is **NOT** differentiable. Args: x: Tensor of shape [batch_size, seq_len, ...] timestep: int or scalar Tensor. Index to update in x. values: Tensor of shape [batch_size, ...]. New values for x[:, i]. Returns: Copy of 'x' after setting x[:, timestep] = values. """ perm = range(x.shape.ndims) perm[0], perm[1] = perm[1], perm[0] x = tf.transpose(x, perm) x = inplace_ops.alias_inplace_update(x, timestep, values) x = tf.transpose(x, perm) return x def inverse_decay_mix_prob(warmup_schedule_name, p_max, num_warmup_steps): """Interpolate from 0.001 to 'p_max' over 'num_warmup_steps'.""" warmup_schedule_fn = { "exp": common_layers.inverse_exp_decay, "linear": common_layers.inverse_lin_decay, "sigmoid": common_layers.inverse_sigmoid_decay, }[warmup_schedule_name] return p_max * warmup_schedule_fn(num_warmup_steps, min_value=0.001) class ScheduledSamplingAdapter(object): """Adapts T2TModel for sequential_scheduled_sampling().""" def __init__(self, t2tmodel, features): self._t2tmodel = t2tmodel self._features = features hparams = self._t2tmodel.hparams assert hparams.mode == tf.estimator.ModeKeys.TRAIN, hparams.mode def infer_fn(self, partial_targets): """Computes logits for all timesteps. Args: partial_targets: [batch_size, seq_len]. Targets to condition on. Returns: next_token_logits: [batch_size, seq_len, vocab_size] """ batch_size, seq_len = common_layers.shape_list(partial_targets) partial_targets = tf.reshape(partial_targets, [batch_size, seq_len, 1, 1]) features = copy.copy(self._features) features["targets"] = partial_targets with tf.variable_scope(tf.get_variable_scope(), reuse=True): transformed_features = self._t2tmodel.bottom(features) with tf.variable_scope("body"): body_outputs, losses = self._t2tmodel._normalize_body_output( # pylint: disable=protected-access self._t2tmodel.body(transformed_features)) assert losses == {"extra": 0.0}, ( "Auxiliary losses are not propagated in this code. %s" % (losses,)) logits = self._t2tmodel.top(body_outputs, features) vocab_size = self._t2tmodel.problem_hparams.vocab_size["targets"] logits = tf.reshape(logits, [batch_size, seq_len, vocab_size]) return logits def mix_fn(self, gold_tokens, sampled_tokens): """Mixes gold and sampled tokens randomly.""" hparams = self._t2tmodel.hparams p_sample = inverse_decay_mix_prob( hparams.scheduled_sampling_warmup_schedule, hparams.scheduled_sampling_gold_mixin_prob, hparams.scheduled_sampling_warmup_steps) return _mix_tokens( p_sample=p_sample, gold_targets=gold_tokens, sampled_targets=sampled_tokens) def loss_fn(self, targets, logits): """Constructs loss dict. Args: targets: [batch_size, seq_len] logits: [batch_size, seq_len, vocab_size] Returns: {str: Tensor of shape []}. Losses. """ batch_size, seq_len, vocab_size = common_layers.shape_list(logits) targets = tf.reshape(targets, [batch_size, seq_len, 1, 1]) logits = tf.reshape(logits, [batch_size, seq_len, 1, 1, vocab_size]) features = copy.copy(self._features) features["targets"] = targets with tf.variable_scope(tf.get_variable_scope(), reuse=True): losses = { "training": self._t2tmodel.loss(logits, features), } return losses