# coding=utf-8 # Copyright 2020 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utils for latent variable models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from six.moves import range # pylint: disable=redefined-builtin from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_image_attention as cia from tensor2tensor.layers import common_layers from tensor2tensor.layers import transformer_layers from tensor2tensor.utils import beam_search import tensorflow.compat.v1 as tf import tensorflow_probability as tfp DO_SUMMARIES = True def compress_self_attention_layer(x, hparams, name=None): """Attend function.""" with tf.variable_scope(name, default_name="compress_self_attention"): x, xshape, _ = cia.maybe_reshape_4d_to_3d(x) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.reshape(res, xshape) def compute_nats_and_bits_per_dim(data_dim, latent_dim, average_reconstruction, average_prior): """Computes negative ELBO, which is an upper bound on the negative likelihood. Args: data_dim: int-like indicating data dimensionality. latent_dim: int-like indicating latent dimensionality. average_reconstruction: Scalar Tensor indicating the reconstruction cost averaged over all data dimensions and any data batches. average_prior: Scalar Tensor indicating the negative log-prior probability averaged over all latent dimensions and any data batches. Returns: Tuple of scalar Tensors, representing the nats and bits per data dimension (e.g., subpixels) respectively. """ with tf.name_scope(None, default_name="compute_nats_per_dim"): data_dim = tf.cast(data_dim, average_reconstruction.dtype) latent_dim = tf.cast(latent_dim, average_prior.dtype) negative_log_likelihood = data_dim * average_reconstruction negative_log_prior = latent_dim * average_prior negative_elbo = negative_log_likelihood + negative_log_prior nats_per_dim = tf.divide(negative_elbo, data_dim, name="nats_per_dim") bits_per_dim = tf.divide(nats_per_dim, tf.log(2.), name="bits_per_dim") return nats_per_dim, bits_per_dim def multinomial_sample(x, vocab_size=None, sampling_method="random", temperature=1.0): """Multinomial sampling from a n-dimensional tensor. Args: x: Tensor of shape [..., vocab_size]. Parameterizes logits of multinomial. vocab_size: Number of classes in multinomial distribution. sampling_method: String, "random" or otherwise deterministic. temperature: Positive float. Returns: Tensor of shape [...]. """ vocab_size = vocab_size or common_layers.shape_list(x)[-1] if sampling_method == "random" and temperature > 0.0: samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]) / temperature, 1) else: samples = tf.argmax(x, axis=-1) reshaped_samples = tf.reshape(samples, common_layers.shape_list(x)[:-1]) return reshaped_samples def ae_latent_softmax(latents_pred, latents_discrete_hot, vocab_size, hparams): """Latent prediction and loss. Args: latents_pred: Tensor of shape [..., depth]. latents_discrete_hot: Tensor of shape [..., vocab_size]. vocab_size: an int representing the vocab size. hparams: HParams. Returns: sample: Tensor of shape [...], a sample from a multinomial distribution. loss: Tensor of shape [...], the softmax cross-entropy. """ with tf.variable_scope("latent_logits"): latents_logits = tf.layers.dense(latents_pred, vocab_size, name="logits_dense") if hparams.logit_normalization: latents_logits *= tf.rsqrt(1e-8 + tf.reduce_mean(tf.square(latents_logits))) loss = tf.nn.softmax_cross_entropy_with_logits_v2( labels=latents_discrete_hot, logits=latents_logits) # TODO(trandustin): tease this out from ae_latent_softmax. # we use just the loss portion to anchor prior / encoder on text. sample = multinomial_sample(latents_logits, vocab_size, hparams.sampling_method, hparams.sampling_temp) return sample, loss def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams): """Samples from the latent space in the autoencoder. Args: latents_dense_in: Tensor of shape [batch, length_q, ...]. Only the shape of its first two dimensions are used. length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). inputs: Tensor of shape [batch, length_kv, hparams.hidden_size]. Encodings to attend to in decoder. ed: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. embed: Callable which embeds discrete latent hot-vectors and a hidden size and returns dense vectors. hparams: HParams. Returns: Tensor of shape [batch, length]. """ def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(ids, axis=2) # Ids start with added all-zeros. latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]]) with tf.variable_scope(tf.get_variable_scope(), reuse=False): latents_dense = embed( tf.one_hot(latents_discrete, depth=2**hparams.bottleneck_bits), hparams.hidden_size) latents_pred = transformer_latent_decoder( latents_dense, inputs, ed, hparams, name="latent_prediction") logits = tf.layers.dense( latents_pred, 2**hparams.bottleneck_bits, name="logits_dense") current_output_position = common_layers.shape_list(ids)[1] - 1 logits = logits[:, current_output_position, :] return logits initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32) length = tf.shape(latents_dense_in)[1] ids, _, _ = beam_search.beam_search( symbols_to_logits_fn, initial_ids, 1, length, 2**hparams.bottleneck_bits, alpha=0.0, eos_id=-1, stop_early=False) res = tf.expand_dims(ids[:, 0, :], axis=2) # Pick first beam. return res[:, 1:] # Remove the added all-zeros from ids. def residual_block_layer(inputs, hparams): """Residual block over inputs. Runs a residual block consisting of conv: kernel_size x kernel_size conv: 1x1 dropout, add and normalize according to hparams.layer_postprocess_sequence. Args: inputs: Tensor of shape [batch, height, width, hparams.hidden_size]. hparams: HParams. Returns: Tensor of shape [batch, height, width, hparams.hidden_size]. """ kernel = (hparams.res_kernel_size, hparams.res_kernel_size) x = inputs for i in range(hparams.num_res_layers): with tf.variable_scope("res_conv_%d" % i): # kernel_size x kernel_size conv block y = common_layers.conv_block( common_layers.layer_norm(x, hparams.hidden_size, name="lnorm"), hparams.hidden_size, [((1, 1), kernel)], strides=(1, 1), padding="SAME", name="residual_conv") # 1x1 conv block y = common_layers.conv_block( y, hparams.hidden_size, [((1, 1), (1, 1))], strides=(1, 1), padding="SAME", name="residual_dense") x = common_layers.layer_postprocess(x, y, hparams) return x def compress_encoder(inputs, hparams, strides=(2, 2), kernel_size=(3, 3), name=None): """Encoder that compresses 2-D inputs by 2**num_compress_steps. Args: inputs: Tensor of shape [batch, height, width, channels]. hparams: HParams. strides: Tuple, strides for conv block. kernel_size: Tuple, kernel window size for conv block. name: string, variable scope. Returns: Tensor of shape [batch, latent_length, hparams.hidden_size], where latent_length is hparams.num_latents * (height*width) / 2**(hparams.num_compress_steps). """ with tf.variable_scope(name, default_name="compress"): x = inputs for i in range(hparams.num_compress_steps // 2): with tf.variable_scope("compress_conv_%d" % i): y = common_layers.conv_block( common_layers.layer_norm( x, hparams.hidden_size, name="lnorm"), hparams.hidden_size, dilation_rates_and_kernel_sizes=[((1, 1), kernel_size)], strides=strides, padding="SAME", name="compress_conv_%d" % i) y = tf.nn.dropout(y, 1.0 - hparams.dropout) if hparams.do_compress_attend: y = compress_self_attention_layer( x, hparams, name="compress_selfatt_%d" % i) y += x x = y x = residual_block_layer(x, hparams) # If using multiple copies of latents, blow up the hidden size and then # reshape to increase by num_latents. shape_x = common_layers.shape_list(x) x = tf.layers.dense(x, hparams.num_latents * hparams.hidden_size, name=name + "_dense") return tf.reshape(x, [shape_x[0], shape_x[1] * shape_x[2] * hparams.num_latents, hparams.hidden_size]) def compress_encoder_2d(x, hparams, name=None): """Encoder that compresses 2-D inputs by 2**num_compress_steps. Args: x: Tensor of shape [batch, height, width, channels]. hparams: HParams. name: string, variable scope. Returns: Tensor of shape [batch, latent_length, hparams.hidden_size], where latent_length is hparams.num_latents * (height*width) / 2**(hparams.num_compress_steps). """ return compress_encoder( x, hparams, strides=(2, 2), kernel_size=(hparams.kernel_size, hparams.kernel_size), name=name) def compress_encoder_1d(x, hparams, name=None): """Encoder that compresses 1-D inputs by 2**num_compress_steps. Args: x: Tensor of shape [batch, length, channels]. hparams: HParams. name: string, variable scope. Returns: Tensor of shape [batch, latent_length, hparams.hidden_size], where latent_length is hparams.num_latents * length / 2**hparams.num_compress_steps. """ x = tf.expand_dims(x, axis=2) return compress_encoder(x, hparams, strides=(2, 1), kernel_size=(hparams.kernel_size, 1), name=name) def decompress_decoder(inputs, hparams, strides=(2, 2), kernel=(3, 3), name=None): """Decoder that decompresses 2-D inputs by 2**num_compress_steps. Args: inputs: Tensor of shape [batch, compress_height, compress_width, channels]. hparams: HParams. strides: Tuple, strides for conv block. kernel: Tuple, kernel window size for conv block. name: string, variable scope. Returns: Tensor of shape [batch, height, width, hparams.hidden_size]. """ with tf.variable_scope(name, default_name="decompress"): x = inputs x = tf.layers.dense(x, hparams.hidden_size, name=name + "_dense") x = residual_block_layer(x, hparams) for i in range(hparams.num_compress_steps // 2): j = hparams.num_compress_steps // 2 - i - 1 with tf.variable_scope(name + "_%d" % j): if hparams.do_decompress_attend: y = compress_self_attention_layer( x, hparams, name="decompress_selfatt") x += y y = tf.layers.conv2d_transpose( x, hparams.hidden_size, kernel, strides=strides, padding="SAME", activation=tf.nn.relu if i > 0 else None, name="decompress_conv") x = y return x def decompress_decoder_2d(x, hparams, name=None): """Decoder that decompresses 2-D inputs by 2**num_compress_steps. Args: x: Tensor of shape [batch, compress_height, compress_width, channels]. hparams: HParams. name: string, variable scope. Returns: Tensor of shape [batch, height, width, hparams.hidden_size]. """ return decompress_decoder(x, hparams, strides=(2, 2), kernel=(hparams.kernel_size, hparams.kernel_size), name=name) def decompress_decoder_1d(x, hparams, name=None): """Decoder that decompresses 1-D inputs by 2**num_compress_steps. Args: x: Tensor of shape [batch, compress_length, channels]. hparams: HParams. name: string, variable scope. Returns: Tensor of shape [batch, length, hparams.hidden_size]. """ x = tf.expand_dims(x, axis=2) output = decompress_decoder(x, hparams, strides=(2, 1), kernel=(hparams.kernel_size, 1), name=name) return tf.squeeze(output, axis=2) def transformer_text_encoder(inputs, target_space, hparams, name=None): """Transformer text encoder over inputs with unmasked full attention. Args: inputs: Tensor of shape [batch, length, 1, hparams.hidden_size]. target_space: int. Used for encoding inputs under a target space id. hparams: HParams. name: string, variable scope. Returns: encoder_output: Tensor of shape [batch, length, hparams.hidden_size]. ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias for any padded tokens. """ with tf.variable_scope(name, default_name="transformer_text_encoder"): inputs = common_layers.flatten4d3d(inputs) [ encoder_input, encoder_self_attention_bias, ed, ] = transformer_layers.transformer_prepare_encoder( inputs, target_space=target_space, hparams=hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) encoder_output = transformer_layers.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams) return encoder_output, ed def transformer_image_decoder(targets, encoder_output, ed_attention_bias, hparams, name=None): """Transformer image decoder over targets with local attention. Args: targets: Tensor of shape [batch, ...], and whose size is batch * height * width * hparams.num_channels * hparams.hidden_size. encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: HParams. name: string, variable scope. Returns: Tensor of shape [batch, height, width * hparams.num_channels, hparams.hidden_size]. """ with tf.variable_scope(name, default_name="transformer_dec"): batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size]) decoder_input, _, _ = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output = tf.reshape(decoder_output, [batch_size, hparams.img_len, hparams.img_len * hparams.num_channels, hparams.hidden_size]) return decoder_output def transformer_latent_decoder(x, encoder_output, ed_attention_bias, hparams, name=None): """Transformer decoder over latents using latent_attention_type. Args: x: Tensor of shape [batch, length_q, hparams.hidden_size]. length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: HParams. name: string, variable scope. Returns: Tensor of shape [batch, length_q, hparams.hidden_size]. """ with tf.variable_scope(name, default_name="transformer_latent_dec"): batch_size = common_layers.shape_list(x)[0] compressed_img_len = (hparams.img_len // 2**(hparams.num_compress_steps // 2)) x = tf.reshape(x, [batch_size, compressed_img_len, compressed_img_len * hparams.num_latents, hparams.hidden_size]) decoder_input, _, _ = cia.prepare_decoder(x, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output = tf.reshape(decoder_output, [batch_size, compressed_img_len**2 * hparams.num_latents, hparams.hidden_size]) return decoder_output def bottleneck_layer(inputs, hparams, name="discrete_bottleneck"): """Computes latents given inputs (typically, compressed targets).""" [ latents_dense, latents_discrete, extra_loss, embed_fn, _, ] = hparams.bottleneck(inputs=inputs, filter_size=hparams.compress_filter_size, name=name, mode=hparams.mode) if DO_SUMMARIES: tf.summary.histogram("discrete_latents", tf.reshape(latents_discrete, [-1])) return latents_dense, latents_discrete, extra_loss, embed_fn def latent_prediction_model(inputs, ed_attention_bias, latents_discrete, latents_dense, hparams, vocab_size=None, name=None): """Transformer-based latent prediction model. It is an autoregressive decoder over latents_discrete given inputs. Args: inputs: Tensor of shape [batch, length_kv, hparams.hidden_size]. Inputs to attend to for the decoder on latents. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. latents_discrete: Tensor of shape [batch, length_q, vocab_size]. One-hot latents to compute log-probability of given inputs. latents_dense: Tensor of shape [batch, length_q, hparams.hidden_size]. length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). hparams: HParams. vocab_size: int or None. If None, it is 2**hparams.bottleneck_bits. name: string, variable scope. Returns: latents_pred: Tensor of shape [batch, length_q, hparams.hidden_size]. latents_pred_loss: Tensor of shape [batch, length_q]. """ with tf.variable_scope(name, default_name="latent_prediction"): if hparams.mode != tf.estimator.ModeKeys.PREDICT: latents_pred = transformer_latent_decoder(tf.stop_gradient(latents_dense), inputs, ed_attention_bias, hparams, name) if vocab_size is None: vocab_size = 2**hparams.bottleneck_bits if not hparams.soft_em: # TODO(trandustin): latents_discrete is not one-hot from # discrete_bottleneck unless hparams.soft_em is True. Refactor. latents_discrete = tf.one_hot(latents_discrete, depth=vocab_size) _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), vocab_size, hparams) return latents_pred, latent_pred_loss def transformer_autoencoder(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """Auto-encoder using a Transformer decoder and a prior over latent sequences. Args: inputs: Tensor of shape [batch, length, 1, hparams.hidden_size] or None. targets: Tensor of shape [batch, ..., channels]. Ellipses may be 1 or 2 dimensions denoting sequence length. target_space: int. Used for encoding inputs under a target space id. hparams: HParams. cache: Tensor of shape [batch, length] or None. predict_mask: Tensor masking whether to use gold targets or predictions. Returns: decoder_output: Tensor of shape [batch, ..., hparams.hidden_size] presenting pre-logit activations. After a transformation (`top` in `T2TModel`), it is used with targets to compute the "training" (reconstruction) loss. losses: dict of str to Tensors. There are three loss terms: "extra", "extra_loss", and "latent_pred". The first is hard-coded to 0. The latter two are Tensors of shape [batch]. cache: Tensor of shape [batch, length], either the same as cache, or newly computed if the cache input is None. """ original_targets_shape = common_layers.shape_list(targets) batch_size = original_targets_shape[0] if len(original_targets_shape) == 4: compress_fn = compress_encoder_2d decompress_fn = decompress_decoder_2d else: compress_fn = compress_encoder_1d decompress_fn = decompress_decoder_1d ed_attention_bias = None if inputs is not None: inputs, ed_attention_bias = transformer_text_encoder( inputs, target_space, hparams, name="input_encoder") losses = {"extra": 0., "extra_loss": 0., "latent_pred": 0.} if hparams.mode != tf.estimator.ModeKeys.PREDICT: targets_compressed = compress_fn(targets, hparams, name="compress") if hparams.mode == tf.estimator.ModeKeys.TRAIN: scale = common_layers.inverse_exp_decay(hparams.startup_steps) else: scale = 1.0 scale = tf.to_float(tf.less(tf.random_uniform([batch_size]), scale)) latents_dense, latents_discrete, extra_loss, _ = bottleneck_layer( targets_compressed, hparams) extra_loss = scale * tf.reduce_mean(extra_loss) _, latents_pred_loss = latent_prediction_model( inputs, ed_attention_bias, latents_discrete, latents_dense, hparams, name="latent_pred") latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) latents_pred_loss = scale * tf.reduce_mean(latents_pred_loss) latents_pred_loss *= tf.to_float(latent_time) # Apply dropout noise for each data point and time step. latents_dense_shape = common_layers.shape_list(latents_dense) latents_dense = tf.nn.dropout( latents_dense, keep_prob=1 - hparams.latent_dropout, noise_shape=[latents_dense_shape[0], latents_dense_shape[1], 1]) # TODO(trandustin): Can we combine extra and extra_loss? losses = {"extra": 0., "extra_loss": extra_loss, "latent_pred": latents_pred_loss} else: # Set the latent length, which is num_latents times the number of latent # pixels. The number of latent pixels is determined by a compression factor # on the number of image pixels. latent_len = ((hparams.img_len * hparams.img_len * hparams.num_latents) / (2**hparams.num_compress_steps)) _, _, _, embed_fn = bottleneck_layer(targets_compressed, hparams) latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias, embed_fn, hparams) cache_one_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits) latents_dense = embed_fn(cache_one_hot, hparams.hidden_size) if len(original_targets_shape) == 4: compressed_img_len = (hparams.img_len // 2**(hparams.num_compress_steps // 2)) latents_dense = tf.reshape(latents_dense, [batch_size, compressed_img_len, compressed_img_len, hparams.num_latents * hparams.hidden_size]) latents_dense = decompress_fn(latents_dense, hparams, name="decompress") latents_dense = tf.reshape( latents_dense, [-1, hparams.img_len, hparams.img_len, hparams.hidden_size]) if hparams.use_gold_targets: if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask else: masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps) targets, _, _ = cia.maybe_reshape_4d_to_3d(targets) mask = tf.less(masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 2) latents_dense = mask * targets + (1.0 - mask) * latents_dense latents_dense = tf.reshape(latents_dense, original_targets_shape) if hparams.decode_autoregressive: decoder_output = transformer_image_decoder( latents_dense, inputs, ed_attention_bias, hparams, name="decoder") else: decoder_output = latents_dense return decoder_output, losses, cache def iaf_flow(one_hot_assignments, scale_weights, scale_bias, num_codes, summary=True, name=None): """Performs a single IAF flow using scale and normalization transformations. Args: one_hot_assignments: Assignments Tensor with shape [num_samples, batch_size, latent_size, num_codes]. scale_weights: Tensor corresponding to lower triangular matrix used to autoregressively generate scale matrix from assignments. To ensure the lower-triangular matrix has length of latent_size, scale_weights should be a rank-one tensor with size latent_size * (latent_size + 1) / 2. scale_bias: Bias tensor to be added to scale tensor, with shape [latent_size, num_codes]. If scale weights are zero, initialize scale_bias to be log(exp(1.) / 2. - 1) so initial transformation is identity. num_codes: Number of codes in codebook. summary: Whether to save summaries. name: String used for name scope. Returns: flow_output: Transformed one-hot assignments. inverse_log_det_jacobian: Inverse log deteriminant of Jacobian corresponding to transformation. """ with tf.name_scope(name, default_name="iaf"): # Pad the one_hot_assignments by zeroing out the first latent dimension and # shifting the rest down by one (and removing the last dimension). padded_assignments = tf.pad( one_hot_assignments, [[0, 0], [0, 0], [1, 0], [0, 0]])[:, :, :-1, :] scale_bijector = tfp.distributions.bijectors.Affine( scale_tril=tfp.math.fill_triangular(scale_weights)) scale = scale_bijector.forward( tf.transpose(padded_assignments, [0, 1, 3, 2])) # Transpose the bijector output since it performs a batch matmul. scale = tf.transpose(scale, [0, 1, 3, 2]) scale = tf.nn.softplus(scale) scale = scale + tf.nn.softplus(scale_bias[tf.newaxis, tf.newaxis, ...]) # Don't need last dimension since the transformation keeps it constant. scale = scale[..., :-1] z = one_hot_assignments[..., :-1] unnormalized_probs = tf.concat([z * scale, one_hot_assignments[..., -1, tf.newaxis]], axis=-1) normalizer = tf.reduce_sum(unnormalized_probs, axis=-1) flow_output = unnormalized_probs / (normalizer[..., tf.newaxis]) inverse_log_det_jacobian = (-tf.reduce_sum(tf.log(scale), axis=-1) + num_codes * tf.log(normalizer)) if summary: tf.summary.histogram("iaf/scale", tf.reshape(scale, [-1])) tf.summary.histogram("iaf/inverse_log_det_jacobian", tf.reshape(inverse_log_det_jacobian, [-1])) return flow_output, inverse_log_det_jacobian