import tensorflow as tf from storybro.generation.gpt2 import model def penalize_used(logits, output): # I want to change the indices of logits wherever the index is found in output change_tensor = tf.zeros_like(logits, dtype=logits.dtype) unique = tf.unique(output[0])[0] ones = tf.ones_like(unique, dtype=unique.dtype) indices = tf.expand_dims(unique, 1) updates = tf.scatter_nd(indices, ones, [logits.shape[1]]) bool_tensor = tf.expand_dims(tf.cast(updates, tf.bool), 0) return tf.compat.v1.where(bool_tensor, logits * 0.85, logits) def top_k_logits(logits, k): if k == 0: # no truncation return logits def _top_k(): values, _ = tf.nn.top_k(logits, k=k) min_values = values[:, -1, tf.newaxis] return tf.where( logits < min_values, tf.ones_like(logits, dtype=logits.dtype) * -1e10, logits, ) return tf.cond(tf.equal(k, 0), lambda: logits, lambda: _top_k(),) def top_p_logits(logits, p): """Nucleus sampling""" batch, _ = logits.shape.as_list() sorted_logits = tf.sort(logits, direction="DESCENDING", axis=-1) cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) indices = tf.stack( [ tf.range(0, batch), # number of indices to include tf.maximum( tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0 ), ], axis=-1, ) min_values = tf.gather_nd(sorted_logits, indices) return tf.where(logits < min_values, tf.ones_like(logits) * -1e10, logits,) def sample_sequence( *, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1 ): if start_token is None: assert context is not None, "Specify exactly one of start_token and context!" else: assert context is None, "Specify exactly one of start_token and context!" context = tf.fill([batch_size, 1], start_token) def step(hparams, tokens, past=None): lm_output = model.model( hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE ) logits = lm_output["logits"][:, :, : hparams.n_vocab] presents = lm_output["present"] presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size)) return { "logits": logits, "presents": presents, } with tf.name_scope("sample_sequence"): def body(past, prev, output): next_outputs = step(hparams, prev, past=past) logits = next_outputs["logits"][:, -1, :] / tf.to_float(temperature) logits = penalize_used(logits, output) logits = top_k_logits(logits, k=top_k) logits = top_p_logits(logits, p=top_p) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) return [ next_outputs["presents"] if past is None else tf.concat([past, next_outputs["presents"]], axis=-2), samples, tf.concat([output, samples], axis=1), ] past, prev, output = body(None, context, context) def cond(*args): return True _, _, tokens = tf.while_loop( cond=cond, body=body, maximum_iterations=length - 1, loop_vars=[past, prev, output], shape_invariants=[ tf.TensorShape( model.past_shape(hparams=hparams, batch_size=batch_size) ), tf.TensorShape([batch_size, None]), tf.TensorShape([batch_size, None]), ], back_prop=False, ) return tokens