# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Modalities, which specify a feature's domain.

T2TModel applies a default transformation to each feature according to its
modality. Override them by specifying a model's
hparams.{bottom,loss,top,weights_fn}.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import range  # pylint: disable=redefined-builtin

from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_audio
from tensor2tensor.layers import common_image_attention as cia
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import common_video
from tensor2tensor.layers import discretization

import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp


class ModalityType(object):
  """Types of modalities."""

  AUDIO = "audio"
  AUDIO_SPECTRAL = "audio_spectral"
  CLASS_LABEL = "class_label"
  CTC_SYMBOL = "ctc_symbol"  # symbol with CTC loss
  GENERIC_L2_LOSS = "generic_l2"  # identity modality with L2 loss
  IDENTITY = "identity"  # identity top and bottom
  IDENTITY_SYMBOL = "identity_symbol"  # symbol with identity top and bottom
  IMAGE = "image"
  # images using channel compression for generation
  IMAGE_CHANNEL_BOTTOM_IDENTITY = "image_channel_bottom_identity"
  # images using channel compression for generation
  IMAGE_CHANNEL_COMPRESS = "image_channel_compress"
  IMAGE_CHANNEL_EMBEDDINGS_BOTTOM = "image_channel_embeddings_bottom"
  MULTI_LABEL = "multi_label"
  ONE_HOT_CLASS_LABEL = "one_hot_class_label"
  REAL = "real"  # real vectors
  REAL_L2_LOSS = "real_l2"  # real vectors with L2 as loss
  # real vectors with log Poisson regression loss
  REAL_LOG_POISSON_LOSS = "real_log_poisson"
  SIGMOID_CLASS_LABEL = "sigmoid_class_label"  # sigmoid cross-entropy loss
  # sigmoid cross-entropy applied on max-pooling over timesteps
  SIGMOID_MAX_POOLING_CLASS_LABEL = "sigmoid_max_pooling_class_label"
  # softmax cross-entropy applied on average-pooling over timesteps
  SOFTMAX_AVERAGE_POOLING_CLASS_LABEL = "softmax_average_pooling_class_label"
  # softmax cross-entropy applied on last-timestep encoding
  SOFTMAX_LAST_TIMESTEP_CLASS_LABEL = "softmax_last_timestep_class_label"
  # softmax cross-entropy applied on max-pooling over timesteps
  SOFTMAX_MAX_POOLING_CLASS_LABEL = "softmax_max_pooling_class_label"
  SPEECH_RECOGNITION = "speech_recognition"
  SYMBOL = "symbol"
  SYMBOL_WEIGHTS_ALL = "symbol_weights_all"  # symbol for features w/o 0-padding
  SYMBOL_ONE_HOT = "symbol_one_hot"  # symbol with one hot as embeddings
  VIDEO = "video"
  VIDEO_BITWISE = "video_bitwise"  # video where bottom embeds pixels bitwise
  VIDEO_IDENTITY = "video_identity"  # video with identity top and bottom
  VIDEO_L1 = "video_l1"  # video with L2 loss
  VIDEO_L2 = "video_l2"  # video with L1 loss
  # video with L1 loss and raw input (sequences of frames)
  VIDEO_L1_RAW = "video_l1_raw"
  # video with L2 loss and raw input (sequences of frames)
  VIDEO_L2_RAW = "video_l2_raw"
  # video with pixel noise on input during training
  VIDEO_PIXEL_NOISE = "video_pixel_noise"

  @staticmethod
  def get_choices():
    return [
        ModalityType.AUDIO,
        ModalityType.AUDIO_SPECTRAL,
        ModalityType.CLASS_LABEL,
        ModalityType.CTC_SYMBOL,
        ModalityType.GENERIC_L2_LOSS,
        ModalityType.IDENTITY,
        ModalityType.IDENTITY_SYMBOL,
        ModalityType.IMAGE,
        ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
        ModalityType.IMAGE_CHANNEL_COMPRESS,
        ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM,
        ModalityType.MULTI_LABEL,
        ModalityType.ONE_HOT_CLASS_LABEL,
        ModalityType.REAL,
        ModalityType.REAL_L2_LOSS,
        ModalityType.REAL_LOG_POISSON_LOSS,
        ModalityType.SIGMOID_CLASS_LABEL,
        ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL,
        ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL,
        ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL,
        ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL,
        ModalityType.SPEECH_RECOGNITION,
        ModalityType.SYMBOL,
        ModalityType.SYMBOL_ONE_HOT,
        ModalityType.SYMBOL_WEIGHTS_ALL,
        ModalityType.VIDEO,
        ModalityType.VIDEO_BITWISE,
        ModalityType.VIDEO_IDENTITY,
        ModalityType.VIDEO_L1,
        ModalityType.VIDEO_L2,
        ModalityType.VIDEO_L1_RAW,
        ModalityType.VIDEO_L2_RAW,
        ModalityType.VIDEO_PIXEL_NOISE,
    ]


# Bottom transformations, applied to all features


def audio_bottom(x, model_hparams, vocab_size):
  """Transform input from data space to model space.

  Args:
    x: A Tensor with shape [batch, ...]
    model_hparams: HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.

  Returns:
    body_input: A Tensor with shape [batch, ?, ?,
      model_hparams.hidden_size].
  """
  del vocab_size  # unused arg
  inputs = x
  with tf.variable_scope("audio_modality"):
    # TODO(aidangomez): Will need to sort out a better audio pipeline
    def xnet_resblock(x, filters, res_relu, name):
      """Xception block."""
      with tf.variable_scope(name):
        # Typically audio samples are >100k samples in length and have a width
        # of 2 or 4. Mono audio has a single channel while stereo has 2.
        y = common_layers.separable_conv_block(
            x,
            filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
            first_relu=True,
            padding="SAME",
            force2d=True,
            name="sep_conv_block")
        y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 2))
        return y + common_layers.conv_block(
            x,
            filters, [((1, 1), (1, 1))],
            padding="SAME",
            strides=(2, 2),
            first_relu=res_relu,
            force2d=True,
            name="res_conv0")

    x = tf.to_float(inputs) / 255.
    x.set_shape([None, None, None, 1])
    for i in range(model_hparams.audio_compression):
      x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i)
    return xnet_resblock(x,
                         model_hparams.hidden_size,
                         False,
                         "compress_block_final")


def audio_spectral_bottom(x, model_hparams, vocab_size):
  """Transform input from data space to model space.

  Args:
    x: A Tensor with shape [batch, ...]
    model_hparams: HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.

  Returns:
    body_input: A Tensor with shape [batch, ?, ?,
      model_hparams.hidden_size].
  """
  del vocab_size  # unused arg
  inputs = x
  with tf.variable_scope("audio_spectral_modality"):
    # TODO(aidangomez): Will need to sort out a better audio pipeline
    def xnet_resblock(x, filters, res_relu, name):
      """Xception-like block."""
      with tf.variable_scope(name):
        # We only stride along the length dimension to preserve the spectral
        # bins (which are tiny in dimensionality relative to length)
        y = common_layers.separable_conv_block(
            x,
            filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
            first_relu=True,
            padding="SAME",
            force2d=True,
            name="sep_conv_block")
        y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1))
        return y + common_layers.conv_block(
            x,
            filters, [((1, 1), (1, 1))],
            padding="SAME",
            strides=(2, 1),
            first_relu=res_relu,
            force2d=True,
            name="res_conv0")

    # Bitcast back from int32
    x = tf.bitcast(inputs, tf.float32)
    x.set_shape([None, None, None, 1])
    for i in range(model_hparams.audio_compression):
      x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i)
    return xnet_resblock(x,
                         model_hparams.hidden_size,
                         False,
                         "compress_block_final")


def class_label_bottom(x, model_hparams, vocab_size):
  with tf.variable_scope("class_label_modality_%d_%d" % (
      vocab_size, model_hparams.hidden_size)):
    multiplier = 1.0
    if model_hparams.multiply_embedding_mode == "sqrt_depth":
      multiplier = model_hparams.hidden_size**0.5
    return common_layers.embedding(x,
                                   vocab_size,
                                   model_hparams.hidden_size,
                                   multiplier=multiplier)


def class_label_targets_bottom(x, model_hparams, vocab_size):
  with tf.variable_scope("class_label_modality_%d_%d" % (
      vocab_size, model_hparams.hidden_size)):
    return tf.zeros([common_layers.shape_list(x)[0],
                     1,
                     1,
                     model_hparams.hidden_size])


def identity_bottom(x, model_hparams, vocab_size):
  del model_hparams, vocab_size  # unused arg
  return tf.to_float(x)


def image_bottom(x, model_hparams, vocab_size):
  del model_hparams, vocab_size  # unused arg
  with tf.variable_scope("image_modality"):
    if not tf.executing_eagerly():
      tf.summary.image(
          "inputs", common_layers.tpu_safe_image_summary(x), max_outputs=2)
    return tf.to_float(x)


def image_targets_bottom(x, model_hparams, vocab_size):
  """Bottom transformation for target images."""
  pixel_embedding_size = 64
  inputs = x
  with tf.variable_scope("image_modality"):
    if not tf.executing_eagerly():
      tf.summary.image(
          "targets_bottom",
          common_layers.tpu_safe_image_summary(inputs),
          max_outputs=1)
    inputs_shape = common_layers.shape_list(inputs)
    if len(inputs_shape) != 4:
      raise ValueError("Assuming images given as int tensors in the format "
                       "[batch, height, width, channels] (256 values).")
    # We embed each of 256=vocab_size possible pixel values.
    embedding_var = tf.get_variable(
        "pixel_embedding",
        [vocab_size, pixel_embedding_size])
    hot_inputs = tf.one_hot(tf.to_int32(inputs), vocab_size)
    hot_inputs = tf.reshape(hot_inputs, [-1, vocab_size])
    embedded = tf.matmul(hot_inputs, embedding_var)
    # Let's now merge all channels that were embedded into a single vector.
    merged_size = pixel_embedding_size * inputs_shape[3]
    embedded = tf.reshape(embedded, inputs_shape[:3] + [merged_size])
    merged = tf.layers.dense(
        embedded,
        model_hparams.hidden_size,
        name="merge_pixel_embedded_channels")
    return merged


def _image_channel_compress_bottom(inputs, model_hparams, name="bottom"):
  """Compresses channel-wise input pixels into whole pixel representions.

  Perform conversion of RGB pixel values to a real number in the range -1 to
  1. This combines pixel channels to form a representation of shape
  [img_len, img_len].

  Args:
    inputs: Tensor representing RGB pixel intensities as integers, of shape
      [batch, img_len, img_len, channels].
    model_hparams: HParams, model hyperparmeters.
    name: string, scope.

  Returns:
    body_input: Tensor of shape
      [batch, img_len, img_len, model_hparams.hidden_size].
  """
  num_channels = 3
  with tf.variable_scope(name):
    inputs = tf.to_float(inputs)
    hp = model_hparams
    if hp.mode != tf.estimator.ModeKeys.PREDICT:
      tf.summary.image(
          "inputs",
          common_layers.tpu_safe_image_summary(inputs),
          max_outputs=2)
    inputs = common_layers.convert_rgb_to_symmetric_real(inputs)

    # Reshape inputs to apply convolutions across [img_len, img_len*channels].
    inputs_shape = common_layers.shape_list(inputs)
    inputs = tf.reshape(
        inputs, [-1, inputs_shape[1], inputs_shape[2] * inputs_shape[3], 1])

    # Compress RGB intensities for each pixel using a convolution.
    outputs = tf.layers.conv2d(
        inputs,
        model_hparams.hidden_size,
        kernel_size=(1, num_channels),
        padding="VALID",
        strides=(1, num_channels),
        activation=tf.nn.relu,
        name="conv_input")
    return outputs


def image_channel_compress_bottom(x, model_hparams, vocab_size):
  del vocab_size  # unused arg
  return _image_channel_compress_bottom(x, model_hparams, "input_bottom")


def image_channel_compress_targets_bottom(x, model_hparams, vocab_size):
  del vocab_size  # unused arg
  return _image_channel_compress_bottom(x, model_hparams, "output_bottom")


def image_channel_embeddings_bottom(x, model_hparams, vocab_size):
  """Bottom transformation for image targets."""
  del vocab_size  # unused arg
  inputs = tf.to_int32(x)
  io_depth = model_hparams.num_channels
  tshape = common_layers.shape_list(inputs)
  hidden_size = model_hparams.hidden_size
  target_embeddings = cia.get_channel_embeddings(
      io_depth, inputs, hidden_size, "input_bottom")
  return tf.reshape(target_embeddings,
                    [tshape[0], tshape[1], tshape[2] * io_depth, hidden_size])


def make_targets_bottom(bottom):
  def targets_bottom(x, model_hparams, vocab_size):
    with tf.variable_scope("targets_bottom"):
      return bottom(x, model_hparams, vocab_size)
  return targets_bottom


def real_bottom(x, model_hparams, vocab_size):
  del vocab_size  # unused arg
  with tf.variable_scope("real"):
    return tf.layers.dense(
        tf.to_float(x), model_hparams.hidden_size, name="bottom")


def speech_recognition_bottom(x, model_hparams, vocab_size):
  """Use batchnorm instead of CMVN and shorten the stft with strided convs.

  Args:
    x: float32 tensor with shape [batch_size, len, 1, freqs * channels]
    model_hparams: HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.

  Returns:
    float32 tensor with shape [batch_size, shorter_len, 1, hidden_size]
  """
  del vocab_size  # unused arg
  inputs = x
  p = model_hparams

  num_mel_bins = p.audio_num_mel_bins
  num_channels = 3 if p.audio_add_delta_deltas else 1

  with tf.variable_scope("speech_recognition_modality"):
    if p.audio_preproc_in_bottom:
      # Compute filterbanks
      with tf.variable_scope("fbanks"):
        waveforms = tf.squeeze(inputs, [2, 3])
        mel_fbanks = common_audio.compute_mel_filterbank_features(
            waveforms,
            sample_rate=p.audio_sample_rate,
            dither=p.audio_dither,
            preemphasis=p.audio_preemphasis,
            frame_length=p.audio_frame_length,
            frame_step=p.audio_frame_step,
            lower_edge_hertz=p.audio_lower_edge_hertz,
            upper_edge_hertz=p.audio_upper_edge_hertz,
            num_mel_bins=p.audio_num_mel_bins,
            apply_mask=True)
        if p.audio_add_delta_deltas:
          mel_fbanks = common_audio.add_delta_deltas(mel_fbanks)
        x = tf.reshape(mel_fbanks,
                       common_layers.shape_list(mel_fbanks)[:2] +
                       [num_mel_bins, num_channels])

        nonpadding_mask = 1. - common_attention.embedding_to_padding(x)
        num_of_nonpadding_elements = tf.reduce_sum(
            nonpadding_mask) * num_mel_bins * num_channels

        # This replaces CMVN estimation on data
        var_epsilon = 1e-09
        mean = tf.reduce_sum(
            x, axis=[1], keepdims=True) / num_of_nonpadding_elements
        variance = (num_of_nonpadding_elements * mean**2. -
                    2. * mean * tf.reduce_sum(x, axis=[1], keepdims=True) +
                    tf.reduce_sum(x**2, axis=[1], keepdims=True)
                   ) / num_of_nonpadding_elements
        x = (x - mean) * tf.rsqrt(variance + var_epsilon) * tf.expand_dims(
            nonpadding_mask, -1)
    else:
      x = inputs

    # The convention is that the models are flattened along the spatial,
    # dimensions, thus the speech preprocessor treats frequencies and
    # channels as image colors (last axis)
    x.set_shape([None, None, num_mel_bins, num_channels])

    # TODO(chorowski): how to specify bottom's hparams and avoid hardcoding?
    x = tf.pad(x, [[0, 0], [0, 8], [0, 0], [0, 0]])
    for _ in range(2):
      x = tf.layers.conv2d(
          x, 128, (3, 3), (2, 2), use_bias=False)
      x = common_layers.layer_norm(x)
      x = tf.nn.relu(x)

    xshape = common_layers.shape_list(x)
    # apply a conv that will remove all frequencies and at the same time
    # project the output into desired hidden_size
    x = tf.pad(x, [[0, 0], [0, 2], [0, 0], [0, 0]])
    x = tf.layers.conv2d(x, p.hidden_size, (3, xshape[2]), use_bias=False)

    assert common_layers.shape_list(x)[2] == 1
    x = common_layers.layer_norm(x)
    x = tf.nn.relu(x)
  return x


def get_weights(model_hparams, vocab_size, hidden_dim=None):
  """Create or get concatenated embedding or softmax variable.

  Args:
    model_hparams: HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.
    hidden_dim: dim of the variable. Defaults to _model_hparams' hidden_size

  Returns:
     a list of num_shards Tensors.
  """
  if hidden_dim is None:
    hidden_dim = model_hparams.hidden_size
  num_shards = model_hparams.symbol_modality_num_shards
  shards = []
  for i in range(num_shards):
    shard_size = (vocab_size // num_shards) + (
        1 if i < vocab_size % num_shards else 0)
    var_name = "weights_%d" % i
    shards.append(
        tf.get_variable(
            var_name, [shard_size, hidden_dim],
            initializer=tf.random_normal_initializer(0.0, hidden_dim**-0.5)))
  if num_shards == 1:
    ret = shards[0]
  else:
    ret = tf.concat(shards, 0)
  # Convert ret to tensor.
  if not tf.executing_eagerly():
    ret = common_layers.convert_gradient_to_tensor(ret)
  return ret


def _symbol_bottom_simple(x, model_hparams, vocab_size, name, reuse):
  """Bottom transformation for symbols."""
  with tf.variable_scope(name, reuse=reuse):
    # Ensure the inputs are 3-D
    if len(x.get_shape()) == 4:
      x = tf.squeeze(x, axis=3)
    while len(x.get_shape()) < 3:
      x = tf.expand_dims(x, axis=-1)

    var = get_weights(model_hparams, vocab_size)
    x = common_layers.dropout_no_scaling(
        x, 1.0 - model_hparams.symbol_dropout)
    ret = common_layers.gather(var, x)
    if model_hparams.multiply_embedding_mode == "sqrt_depth":
      ret *= model_hparams.hidden_size**0.5
    ret *= tf.expand_dims(
        common_layers.cast_like(tf.not_equal(x, 0), ret), -1)
    return ret


def symbol_bottom(x, model_hparams, vocab_size):
  if (model_hparams.shared_embedding_and_softmax_weights or
      model_hparams.get("shared_embedding")):
    return _symbol_bottom_simple(
        x, model_hparams, vocab_size, "shared", reuse=None)
  return _symbol_bottom_simple(
      x, model_hparams, vocab_size, "input_emb", reuse=None)


def symbol_targets_bottom(x, model_hparams, vocab_size):
  """Bottom transformation for target symbols."""
  if (model_hparams.shared_embedding_and_softmax_weights or
      model_hparams.get("shared_embedding")):
    try:
      return _symbol_bottom_simple(
          x, model_hparams, vocab_size, "shared", reuse=True)
    except ValueError:
      # perhaps there were no inputs, and this is a new variable.
      return _symbol_bottom_simple(
          x, model_hparams, vocab_size, "shared", reuse=None)
  else:
    return _symbol_bottom_simple(
        x, model_hparams, vocab_size, "target_emb", reuse=None)


def symbol_one_hot_bottom(x, model_hparams, vocab_size):
  del model_hparams  # unused arg
  return tf.one_hot(x, vocab_size)


def video_bottom(x, model_hparams, vocab_size):
  del model_hparams, vocab_size  # unused arg
  common_video.gif_summary("inputs", x, max_outputs=1)
  x = common_layers.standardize_images(x)
  return x


def video_targets_bottom(x, model_hparams, vocab_size):
  del model_hparams, vocab_size  # unused arg
  common_video.gif_summary("targets", x, max_outputs=1)
  x = common_layers.standardize_images(x)
  return x


def video_bitwise_bottom(x, model_hparams, vocab_size):
  """Bottom transformation for embedding video bitwise."""
  pixel_embedding_size = 64
  inputs = x
  with tf.variable_scope("video_modality_bitwise", reuse=tf.AUTO_REUSE):
    common_layers.summarize_video(inputs, "bottom")
    # Embed bitwise.
    assert vocab_size == 256
    embedded = discretization.int_to_bit_embed(inputs, 8,
                                               pixel_embedding_size)
    # Project.
    return tf.layers.dense(
        embedded,
        model_hparams.hidden_size,
        name="merge_pixel_embedded_frames")


def video_bitwise_targets_bottom(x, model_hparams, vocab_size):
  """Bottom transformation for embedding target video bitwise."""
  pixel_embedding_size = 64
  inputs = x
  with tf.variable_scope("video_modality_bitwise", reuse=tf.AUTO_REUSE):
    common_layers.summarize_video(inputs, "targets_bottom")
    # Embed bitwise.
    assert vocab_size == 256
    embedded = discretization.int_to_bit_embed(inputs, 8,
                                               pixel_embedding_size)
    # Transpose and project.
    transposed = common_layers.time_to_channels(embedded)
    return tf.layers.dense(
        transposed,
        model_hparams.hidden_size,
        name="merge_pixel_embedded_frames")


def video_identity_bottom(x, model_hparams, vocab_size):
  del model_hparams, vocab_size  # unused arg
  common_video.gif_summary("inputs", x, max_outputs=1)
  return x


def video_identity_targets_bottom(x, model_hparams, vocab_size):
  del model_hparams, vocab_size  # unused arg
  common_video.gif_summary("targets", x, max_outputs=1)
  return x


def video_pixel_noise_bottom(x, model_hparams, vocab_size):
  """Bottom transformation for video."""
  input_noise = getattr(model_hparams, "video_modality_input_noise", 0.25)
  inputs = x
  if model_hparams.mode == tf.estimator.ModeKeys.TRAIN:
    background = tfp.stats.percentile(inputs, 50., axis=[0, 1, 2, 3])
    input_shape = common_layers.shape_list(inputs)
    input_size = tf.reduce_prod(input_shape[:-1])
    input_mask = tf.multinomial(
        tf.log([[input_noise, 1.-input_noise]]), input_size)
    input_mask = tf.reshape(tf.cast(input_mask, tf.int32),
                            input_shape[:-1]+[1])
    inputs = inputs * input_mask + background * (1 - input_mask)
  return video_bottom(inputs, model_hparams, vocab_size)


def convert_rgb_to_real(prediction, targets):
  """Convert prediction and target from rgb to real."""
  prediction = tf.squeeze(prediction, axis=-1)
  prediction = common_layers.convert_rgb_to_real(prediction)
  targets = common_layers.convert_rgb_to_real(targets)
  return prediction, targets


def video_raw_bottom(x, model_hparams, vocab_size):
  del model_hparams, vocab_size  # unused arg
  common_video.gif_summary("inputs", x)
  return common_layers.convert_rgb_to_real(x)


def video_raw_targets_bottom(x, model_hparams, vocab_size):
  del model_hparams, vocab_size  # unused arg
  common_video.gif_summary("targets_bottom", x)
  return common_layers.convert_rgb_to_real(x)


# Loss transformations, applied to target features


def ctc_symbol_loss(top_out, targets, model_hparams, vocab_size, weight_fn):
  """Compute the CTC loss."""
  del model_hparams, vocab_size  # unused arg
  logits = top_out
  with tf.name_scope("ctc_loss", values=[logits, targets]):
    # For CTC we assume targets are 1d, [batch, length, 1, 1] here.
    targets_shape = targets.get_shape().as_list()
    assert len(targets_shape) == 4
    assert targets_shape[2] == 1
    assert targets_shape[3] == 1
    targets = tf.squeeze(targets, axis=[2, 3])
    logits = tf.squeeze(logits, axis=[2, 3])
    targets_mask = 1 - tf.to_int32(tf.equal(targets, 0))
    targets_lengths = tf.reduce_sum(targets_mask, axis=1)
    sparse_targets = tf.keras.backend.ctc_label_dense_to_sparse(
        targets, targets_lengths)
    xent = tf.nn.ctc_loss(
        sparse_targets,
        logits,
        targets_lengths,
        time_major=False,
        preprocess_collapse_repeated=False,
        ctc_merge_repeated=False)
    weights = weight_fn(targets)
    return tf.reduce_sum(xent), tf.reduce_sum(weights)


def generic_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
  """Compute loss numerator and denominator for one shard of output."""
  del vocab_size  # unused arg
  logits = top_out
  logits = common_attention.maybe_upcast(logits, hparams=model_hparams)
  cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.0)
  return common_layers.padded_cross_entropy(
      logits,
      targets,
      model_hparams.label_smoothing,
      cutoff=cutoff,
      weights_fn=weights_fn)


def generic_l2_loss(body_output,
                    targets,
                    model_hparams,
                    vocab_size,
                    weights_fn):
  del model_hparams, vocab_size, weights_fn  # unused arg
  loss = tf.squared_difference(body_output, tf.to_float(targets))
  return tf.reduce_mean(loss), tf.constant(1.0)


def multi_label_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
  """Average loss over the labels."""
  del vocab_size  # unused arg
  logits = top_out
  num_labels = tf.shape(targets)[1]
  logits = tf.tile(logits, [1, num_labels, 1, 1, 1])

  xent, weights = common_layers.padded_cross_entropy(
      logits,
      targets,
      model_hparams.label_smoothing,
      weights_fn=weights_fn,
      reduce_sum=False,
  )
  xent = tf.squeeze(xent, [2, 3])
  weights = tf.squeeze(weights, [2, 3])
  # average loss over all labels
  loss = tf.reduce_sum(xent, axis=1)
  weights = tf.reduce_sum(weights, axis=1)
  loss /= (weights + 1e-8)
  weights = tf.to_float(tf.greater(weights, 0.))

  return tf.reduce_sum(loss*weights), tf.reduce_sum(weights)


def one_hot_class_label_loss(top_out,
                             targets,
                             model_hparams,
                             vocab_size,
                             weights_fn):
  """Apply softmax cross-entropy between outputs and targets.

  Args:
    top_out: logits Tensor with shape [batch, ?, ?, num_classes]
    targets: one-hot encoding Tensor with shape [batch, ?, ?, num_classes]
    model_hparams: HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.
    weights_fn:

  Returns:
    loss_scale (cross-entropy), loss_denom
  """
  del model_hparams, vocab_size  # unused arg
  loss_scale = tf.losses.softmax_cross_entropy(
      onehot_labels=targets, logits=top_out)
  weights = weights_fn(targets)
  loss_denom = tf.reduce_sum(weights)
  return loss_scale, loss_denom


def real_l2_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
  del model_hparams, vocab_size  # unused arg
  predictions = top_out
  if (len(common_layers.shape_list(top_out)) != len(
      common_layers.shape_list(targets))):
    predictions = tf.squeeze(top_out, axis=[-1])
  with tf.name_scope("l2"):
    weights = weights_fn(targets)
    l2 = tf.pow(predictions - targets, 2)
    return tf.reduce_sum(l2 * weights), tf.reduce_sum(weights)


def real_log_poisson_loss(top_out,
                          targets,
                          model_hparams,
                          vocab_size,
                          weights_fn):
  """Poisson loss for real."""
  del model_hparams, vocab_size  # unused arg
  predictions = top_out
  if (len(common_layers.shape_list(top_out)) != len(
      common_layers.shape_list(targets))):
    predictions = tf.squeeze(top_out, axis=[-1])
  with tf.name_scope("log_possion"):
    weights = weights_fn(targets)
    lp_loss = tf.nn.log_poisson_loss(targets, predictions)
    return tf.reduce_sum(lp_loss * weights), tf.reduce_sum(weights)


def sigmoid_class_label_loss(top_out,
                             targets,
                             model_hparams,
                             vocab_size,
                             weights_fn):
  """Loss for class label."""
  # Expect inputs of size [batch-size, timesteps, 1, num-classes], where the
  # last dimension of num-classes represents logits for binary labels
  del model_hparams, vocab_size  # unused arg
  loss_scale = tf.losses.sigmoid_cross_entropy(
      multi_class_labels=targets, logits=top_out)
  weights = weights_fn(targets)
  loss_denom = tf.reduce_sum(weights)
  return loss_scale, loss_denom


def sigmoid_max_pooling_class_label_loss(top_out,
                                         targets,
                                         model_hparams,
                                         vocab_size,
                                         weights_fn):
  """Loss for class label."""
  # Expect inputs of size [batch-size, 1, 1, num-classes], where the
  # last dimension of num-classes represents logits for binary labels
  del model_hparams, vocab_size  # unused arg
  loss_scale = tf.losses.sigmoid_cross_entropy(
      multi_class_labels=targets, logits=top_out)
  weights = weights_fn(targets)
  loss_denom = tf.reduce_sum(weights)
  return loss_scale, loss_denom


def symbol_one_hot_loss(top_out,
                        targets,
                        model_hparams,
                        vocab_size,
                        weights_fn):
  del model_hparams, weights_fn  # unused arg
  labels = tf.one_hot(targets, vocab_size)
  loss = tf.nn.softmax_cross_entropy_with_logits(
      logits=top_out, labels=labels)
  return tf.reduce_mean(loss), tf.constant(1.0)


def video_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
  """Compute loss numerator and denominator for one shard of output."""
  del vocab_size  # unused arg
  logits = top_out
  logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:])
  targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
  cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.01)
  return common_layers.padded_cross_entropy(
      logits,
      targets,
      model_hparams.label_smoothing,
      cutoff=cutoff,
      weights_fn=weights_fn)


def video_identity_loss(top_out,
                        targets,
                        model_hparams,
                        vocab_size,
                        weights_fn):
  """Compute loss numerator and denominator for one shard of output."""
  del vocab_size  # unused arg
  # TODO(nikip): Try L2 loss
  logits = top_out
  logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:])
  targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
  cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.01)
  return common_layers.padded_cross_entropy(
      logits,
      targets,
      model_hparams.label_smoothing,
      cutoff=cutoff,
      weights_fn=weights_fn)


def video_l1_internal_loss(logits, targets, model_hparams):
  cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.2)
  return tf.nn.relu(tf.abs(logits - targets) - cutoff)


def video_l1_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
  """Compute loss numerator and denominator for one shard of output."""
  del vocab_size  # unused arg
  logits = top_out
  logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:-1])
  targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
  weights = weights_fn(targets)
  # Shift targets by 0.5 so later just casting to int gives the prediction.
  # So for int targets, say 0 and 7, we actually train to predict 0.5 and 7.5.
  # Later (in merics or infer) this is cast to int anyway. Also, we have no
  # loss beyond cutoff = 0.2 as these are already correct predictions.
  targets = tf.to_float(targets) + 0.5
  loss = video_l1_internal_loss(logits, targets, model_hparams)
  return tf.reduce_sum(loss * weights), tf.reduce_sum(weights)


def video_l2_internal_loss(logits, targets, model_hparams):
  cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.2)
  return tf.nn.relu(
      tf.squared_difference(logits, targets) - cutoff * cutoff)


def video_l2_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
  """Compute loss numerator and denominator for one shard of output."""
  del vocab_size  # unused arg
  logits = top_out
  logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:-1])
  targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
  weights = weights_fn(targets)
  # Shift targets by 0.5 so later just casting to int gives the prediction.
  # So for int targets, say 0 and 7, we actually train to predict 0.5 and 7.5.
  # Later (in merics or infer) this is cast to int anyway. Also, we have no
  # loss beyond cutoff = 0.2 as these are already correct predictions.
  targets = tf.to_float(targets) + 0.5
  loss = video_l2_internal_loss(logits, targets, model_hparams)
  return tf.reduce_sum(loss * weights), tf.reduce_sum(weights)


def video_l2_raw_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
  del model_hparams, vocab_size, weights_fn  # unused arg
  prediction, groundtruth = convert_rgb_to_real(top_out, targets)
  loss = tf.losses.mean_squared_error(prediction, groundtruth)
  return loss, tf.constant(1.0)


def video_l1_raw_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
  del model_hparams, vocab_size, weights_fn  # unused arg
  prediction, groundtruth = convert_rgb_to_real(top_out, targets)
  loss = tf.losses.absolute_difference(prediction, groundtruth)
  return loss, tf.constant(1.0)


# Top transformations, applied to target features


def is_pointwise(func):
  """Decorator for whether the function is pointwise.

  An example of a pointwise function is a linear layer followed by
  a softmax. Given a tensor [batch, length, height, depth] it operates
  only on the last axis, on every point in [batch, length, height] fully
  independently. In contrast, a classifier that first averages over length
  and height is not pointwise, as it depends on the whole field. It is useful
  to know if top functions are pointwise to speed up decoding in certain models.

  Args:
    func: Function to decorate.

  Returns:
    Original function with an attribute pointwise set to True.
  """
  func.pointwise = True
  return func


def class_label_top(body_output, targets, model_hparams, vocab_size):
  """Transform inputs from model space to target space.

  Average over inner dims and a linear layer to logits.

  Args:
    body_output: A Tensor with shape [batch, ?, ?, body_output_size].
    targets:
    model_hparams: HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.

  Returns:
    a Tensors, each with shape [batch_size, 1, 1, 1, vocab_size]
  """
  del targets  # unused arg
  with tf.variable_scope("class_label_modality_%d_%d" % (
      vocab_size, model_hparams.hidden_size)):
    x = body_output
    x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
    res = tf.layers.dense(x, vocab_size)
    return tf.expand_dims(res, 3)


def identity_top(body_output, targets, model_hparams, vocab_size):
  del targets, model_hparams, vocab_size  # unused arg
  return body_output


def image_top(body_output, targets, model_hparams, vocab_size):
  """Top transformation for images."""
  del targets  # unused arg
  # TODO(lukaszkaiser): is this a universal enough way to get channels?
  num_channels = model_hparams.problem.num_channels
  with tf.variable_scope("rgb_softmax"):
    body_output_shape = common_layers.shape_list(body_output)
    reshape_shape = body_output_shape[:3]
    reshape_shape.extend([num_channels, vocab_size])
    res = tf.layers.dense(body_output, vocab_size * num_channels)
    res = tf.reshape(res, reshape_shape)
    if not tf.get_variable_scope().reuse:
      res_argmax = tf.argmax(res, axis=-1)
      tf.summary.image(
          "result",
          common_layers.tpu_safe_image_summary(res_argmax),
          max_outputs=1)
    return res


def image_channel_compress_top(body_output, targets, model_hparams, vocab_size):
  """Transforms body output to return logits.

  Args:
    body_output: Tensor of shape [batch, img_len, img_len, depth].
    targets:
    model_hparams: HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.

  Returns:
    Tensor of shape [batch, img_len, img_len, channels, vocab_size].
  """
  del targets  # unused arg
  with tf.variable_scope("image_channel_compress_modality"):
    hidden_size = model_hparams.hidden_size
    img_len = model_hparams.img_len
    channels = 3  # RGB
    batch = common_layers.shape_list(body_output)[0]
    x = tf.layers.conv2d(
        body_output,
        hidden_size * channels,
        kernel_size=(1, 1),
        strides=(1, 1),
        padding="VALID",
        activation=tf.nn.relu,
        name="decompress_conv")
    x = tf.reshape(x, [batch, img_len, img_len * channels, hidden_size])
    x = common_layers.layer_preprocess(x, model_hparams)
    x = tf.layers.dense(x,
                        vocab_size,
                        use_bias=True,
                        activation=None,
                        name="output_conv")
    x = tf.reshape(
        x, [batch, img_len, img_len, channels, vocab_size])
    return x


def image_channel_embeddings_top(body_output,
                                 targets,
                                 model_hparams,
                                 vocab_size):
  """Top transformation for images."""
  del targets  # unused arg
  with tf.variable_scope("image_channel_embeddings_bottom"):
    img_len = model_hparams.img_len
    channels = model_hparams.num_channels
    x = tf.layers.dense(
        body_output, 256, use_bias=True, activation=None, name="output_conv")
    x = tf.reshape(x,
                   [-1, img_len, img_len, channels, vocab_size])
    return x


@is_pointwise
def real_top(body_output, targets, model_hparams, vocab_size):
  del targets, model_hparams  # unused arg
  with tf.variable_scope("real"):
    return tf.layers.dense(body_output, vocab_size, name="top")


def sigmoid_max_pooling_class_label_top(body_output,
                                        targets,
                                        model_hparams,
                                        vocab_size):
  """Transform inputs from model space to target space.

  Average over inner dims and a linear layer to logits.

  Args:
    body_output: A Tensor with shape [batch, timesteps, 1, body_output_size].
    targets:
    model_hparams: HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.

  Returns:
    a Tensors, each with shape [batch_size, 1, 1, vocab_size]
  """
  del targets  # unused arg
  with tf.variable_scope(
      "sigmoid_max_pooling_class_symbol_modality_%d_%d" % (
          vocab_size, model_hparams.hidden_size)):
    x = body_output
    x = tf.reduce_max(x, axis=1, keepdims=True)
    return tf.layers.dense(x, vocab_size)


def softmax_average_pooling_class_label_top(body_output,
                                            targets,
                                            model_hparams,
                                            vocab_size):
  """Loss for class label."""
  del targets  # unused arg
  with tf.variable_scope(
      "softmax_average_pooling_onehot_class_label_modality_%d_%d" % (
          vocab_size, model_hparams.hidden_size)):
    x = body_output
    x = tf.reduce_mean(x, axis=1, keepdims=True)
    return tf.layers.dense(x, vocab_size)


def softmax_last_timestep_class_label_top(body_output,
                                          targets,
                                          model_hparams,
                                          vocab_size):
  """Loss for class label."""
  del targets  # unused arg
  with tf.variable_scope(
      "softmax_last_timestep_onehot_class_label_modality_%d_%d" % (
          vocab_size, model_hparams.hidden_size)):
    x = body_output
    x = tf.expand_dims(x[:, -1], 1)  # Pick the last timestep
    return tf.layers.dense(x, vocab_size)


def softmax_max_pooling_class_label_top(body_output,
                                        targets,
                                        model_hparams,
                                        vocab_size):
  """Loss for class label."""
  del targets  # unused arg
  with tf.variable_scope(
      "softmax_max_pooling_onehot_class_label_modality_%d_%d" % (
          vocab_size, model_hparams.hidden_size)):
    x = body_output
    x = tf.reduce_max(x, axis=1, keepdims=True)
    return tf.layers.dense(x, vocab_size)


@is_pointwise
def symbol_top(body_output, targets, model_hparams, vocab_size):
  """Generate logits.

  Args:
    body_output: A Tensor with shape
      [batch, p0, p1, model_hparams.hidden_size].
    targets: Unused.
    model_hparams: HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.

  Returns:
    logits: A Tensor with shape  [batch, p0, p1, ?, vocab_size].
  """
  del targets  # unused arg
  if model_hparams.shared_embedding_and_softmax_weights:
    scope_name = "shared"
    reuse = tf.AUTO_REUSE
  else:
    scope_name = "softmax"
    reuse = False
  with tf.variable_scope(scope_name, reuse=reuse):
    body_output_shape = common_layers.shape_list(body_output)
    var = get_weights(model_hparams, vocab_size, body_output_shape[-1])
    if (model_hparams.factored_logits and
        model_hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # insert channels dimension
      body_output = tf.expand_dims(body_output, 3)
      return common_layers.FactoredTensor(body_output, var)
    else:
      body_output = tf.reshape(body_output, [-1, body_output_shape[-1]])
      logits = tf.matmul(body_output, var, transpose_b=True)
      return tf.reshape(logits,
                        body_output_shape[:-1] + [1, vocab_size])


@is_pointwise
def symbol_one_hot_top(body_output, targets, model_hparams, vocab_size):
  del targets, model_hparams, vocab_size  # unused arg
  return body_output


def video_top(body_output, targets, model_hparams, vocab_size):
  """Top transformation for video."""
  del targets  # unused arg
  num_channels = model_hparams.problem.num_channels
  shape = common_layers.shape_list(body_output)
  reshape_shape = shape[:-1] + [num_channels, vocab_size]
  res = tf.reshape(body_output, reshape_shape)
  # Calculate argmax so as to have a summary with the produced images.
  x = tf.argmax(tf.reshape(res, [-1, vocab_size]), axis=-1)
  x = tf.reshape(x, shape[:-1] + [num_channels])
  common_video.gif_summary("results", x, max_outputs=1)
  return res


def video_l1_top(body_output, targets, model_hparams, vocab_size):
  """Top transformation for video."""
  del targets, vocab_size  # unused arg
  num_channels = model_hparams.problem.num_channels
  num_frames = model_hparams.video_num_target_frames
  with tf.variable_scope("rgb"):
    body_output_shape = common_layers.shape_list(body_output)
    res = tf.layers.dense(body_output, num_channels * num_frames, name="cast")
    res = tf.reshape(res, body_output_shape[:3] + [num_channels, num_frames])
    res = tf.transpose(res, [0, 4, 1, 2, 3])  # Move frames next to batch.
    if not tf.get_variable_scope().reuse:
      res_argmax = res[:, -1, :, :, :]
      tf.summary.image(
          "result",
          common_layers.tpu_safe_image_summary(res_argmax),
          max_outputs=1)
    return tf.expand_dims(res, axis=-1)  # Add an axis like in perplexity.


def video_raw_top(body_output, targets, model_hparams, vocab_size):
  del targets, model_hparams, vocab_size  # unused arg
  frames = body_output
  if isinstance(body_output, list):
    frames = tf.stack(body_output, axis=1)
  rgb_frames = common_layers.convert_real_to_rgb(frames)
  common_video.gif_summary("body_output", rgb_frames)
  return tf.expand_dims(rgb_frames, axis=-1)


# Utility functions similar to tf.keras for default transformations


def get_bottom(modality_type, value=None):
  """Gets default bottom transformation; if none available, return value."""
  if modality_type == ModalityType.AUDIO:
    return audio_bottom
  elif modality_type == ModalityType.AUDIO_SPECTRAL:
    return audio_spectral_bottom
  elif modality_type in (ModalityType.CLASS_LABEL,
                         ModalityType.MULTI_LABEL,
                         ModalityType.ONE_HOT_CLASS_LABEL,
                         ModalityType.SIGMOID_CLASS_LABEL,
                         ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL,
                         ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL,
                         ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL,
                         ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL):
    return class_label_bottom
  elif modality_type in (ModalityType.CTC_SYMBOL,
                         ModalityType.SYMBOL,
                         ModalityType.SYMBOL_WEIGHTS_ALL):
    return symbol_bottom
  elif modality_type in (ModalityType.GENERIC_L2_LOSS,
                         ModalityType.IDENTITY,
                         ModalityType.IDENTITY_SYMBOL,
                         ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM):
    return identity_bottom
  elif modality_type == ModalityType.IMAGE:
    return image_bottom
  elif modality_type in (ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
                         ModalityType.IMAGE_CHANNEL_COMPRESS):
    return image_channel_compress_bottom
  elif modality_type in (ModalityType.REAL,
                         ModalityType.REAL_L2_LOSS,
                         ModalityType.REAL_LOG_POISSON_LOSS):
    return real_bottom
  elif modality_type == ModalityType.SPEECH_RECOGNITION:
    return speech_recognition_bottom
  elif modality_type == ModalityType.SYMBOL_ONE_HOT:
    return symbol_one_hot_bottom
  elif modality_type in (ModalityType.VIDEO,
                         ModalityType.VIDEO_L1,
                         ModalityType.VIDEO_L2):
    return video_bottom
  elif modality_type == ModalityType.VIDEO_BITWISE:
    return video_bitwise_bottom
  elif modality_type == ModalityType.VIDEO_IDENTITY:
    return video_identity_bottom
  elif modality_type in (ModalityType.VIDEO_L1_RAW,
                         ModalityType.VIDEO_L2_RAW):
    return video_raw_bottom
  elif modality_type == ModalityType.VIDEO_PIXEL_NOISE:
    return video_pixel_noise_bottom
  return value


def get_loss(modality_type, value=None):
  """Gets default loss transformation; if none available, return value."""
  if modality_type in (ModalityType.AUDIO,
                       ModalityType.AUDIO_SPECTRAL,
                       ModalityType.CLASS_LABEL,
                       ModalityType.IDENTITY,
                       ModalityType.IDENTITY_SYMBOL,
                       ModalityType.IMAGE,
                       ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
                       ModalityType.IMAGE_CHANNEL_COMPRESS,
                       ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM,
                       ModalityType.REAL,
                       ModalityType.SPEECH_RECOGNITION,
                       ModalityType.SYMBOL,
                       ModalityType.SYMBOL_WEIGHTS_ALL):
    return generic_loss
  elif modality_type == ModalityType.CTC_SYMBOL:
    return ctc_symbol_loss
  elif modality_type == ModalityType.GENERIC_L2_LOSS:
    return generic_l2_loss
  elif modality_type == ModalityType.MULTI_LABEL:
    return multi_label_loss
  elif modality_type in (ModalityType.ONE_HOT_CLASS_LABEL,
                         ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL,
                         ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL,
                         ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL):
    return one_hot_class_label_loss
  elif modality_type == ModalityType.REAL_L2_LOSS:
    return real_l2_loss
  elif modality_type == ModalityType.REAL_LOG_POISSON_LOSS:
    return real_log_poisson_loss
  elif modality_type == ModalityType.SIGMOID_CLASS_LABEL:
    return sigmoid_class_label_loss
  elif modality_type == ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL:
    return sigmoid_max_pooling_class_label_loss
  elif modality_type == ModalityType.SYMBOL_ONE_HOT:
    return symbol_one_hot_loss
  elif modality_type in (ModalityType.VIDEO,
                         ModalityType.VIDEO_BITWISE,
                         ModalityType.VIDEO_PIXEL_NOISE):
    return video_loss
  elif modality_type == ModalityType.VIDEO_IDENTITY:
    return video_identity_loss
  elif modality_type == ModalityType.VIDEO_L1:
    return video_l1_loss
  elif modality_type == ModalityType.VIDEO_L1_RAW:
    return video_l1_raw_loss
  elif modality_type == ModalityType.VIDEO_L2:
    return video_l2_loss
  elif modality_type == ModalityType.VIDEO_L2_RAW:
    return video_l2_raw_loss
  return value


def get_name(modality_type, value=None):
  """Gets default name for transformations; if none available, return value."""
  # For legacy reasons, modalities vary in their naming scheme. Future plans are
  # to remove any need for get_name. We do not recommend using it.
  if modality_type == ModalityType.AUDIO:
    return lambda model_hparams, vocab_size: "audio_modality"
  elif modality_type == ModalityType.AUDIO_SPECTRAL:
    return lambda model_hparams, vocab_size: "audio_spectral_modality"
  elif modality_type == ModalityType.GENERIC_L2_LOSS:
    return lambda model_hparams, vocab_size: "generic_l2_loss_modality"
  elif modality_type == ModalityType.IDENTITY:
    return lambda model_hparams, vocab_size: "identity_modality"
  elif modality_type == ModalityType.IMAGE:
    return lambda model_hparams, vocab_size: "image_modality"
  elif modality_type == ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY:
    return (lambda model_hparams, vocab_size:  # pylint: disable=g-long-lambda
            "image_channel_bottom_identity_modality")
  elif modality_type == ModalityType.IMAGE_CHANNEL_COMPRESS:
    return lambda model_hparams, vocab_size: "image_channel_compress_modality"
  elif modality_type == ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM:
    return lambda model_hparams, vocab_size: "image_channel_embeddings_bottom"
  elif modality_type == ModalityType.REAL:
    return lambda model_hparams, vocab_size: "real_modality"
  elif modality_type == ModalityType.REAL_L2_LOSS:
    return lambda model_hparams, vocab_size: "real_l2_loss_modality"
  elif modality_type == ModalityType.REAL_LOG_POISSON_LOSS:
    return lambda model_hparams, vocab_size: "real_log_poisson_loss_modality"
  elif modality_type == ModalityType.SPEECH_RECOGNITION:
    return lambda model_hparams, vocab_size: "speech_recognition_modality"
  elif modality_type == ModalityType.VIDEO:
    return lambda model_hparams, vocab_size: "video_modality"
  elif modality_type == ModalityType.VIDEO_BITWISE:
    return lambda model_hparams, vocab_size: "video_modality_bitwise"
  elif modality_type == ModalityType.VIDEO_IDENTITY:
    return lambda model_hparams, vocab_size: "video_modality_identity"
  elif modality_type == ModalityType.VIDEO_L1:
    return lambda model_hparams, vocab_size: "video_modality_l1"
  elif modality_type == ModalityType.VIDEO_L1_RAW:
    return lambda model_hparams, vocab_size: "video_modality_l1_raw"
  elif modality_type == ModalityType.VIDEO_L2:
    return lambda model_hparams, vocab_size: "video_modality_l2"
  elif modality_type == ModalityType.VIDEO_L2_RAW:
    return lambda model_hparams, vocab_size: "video_modality_l2_raw"
  elif modality_type == ModalityType.VIDEO_PIXEL_NOISE:
    return lambda model_hparams, vocab_size: "video_modality_pixel_noise"
  elif modality_type in (ModalityType.CLASS_LABEL,
                         ModalityType.MULTI_LABEL,
                         ModalityType.ONE_HOT_CLASS_LABEL):
    def name(model_hparams, vocab_size):
      return "class_label_modality_%d_%d" % (vocab_size,
                                             model_hparams.hidden_size)
    return name
  elif modality_type in (ModalityType.CTC_SYMBOL,
                         ModalityType.IDENTITY_SYMBOL,
                         ModalityType.SYMBOL,
                         ModalityType.SYMBOL_WEIGHTS_ALL,
                         ModalityType.SYMBOL_ONE_HOT):
    def name(model_hparams, vocab_size):
      return "symbol_modality_%d_%d" % (vocab_size, model_hparams.hidden_size)
    return name
  elif modality_type == ModalityType.SIGMOID_CLASS_LABEL:
    def name(model_hparams, vocab_size):
      return "sigmoid_class_symbol_modality_%d_%d" % (vocab_size,
                                                      model_hparams.hidden_size)
    return name
  elif modality_type == ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL:
    def name(model_hparams, vocab_size):
      return "sigmoid_max_pooling_class_symbol_modality_%d_%d" % (
          vocab_size, model_hparams.hidden_size)
    return name
  elif modality_type == ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL:
    def name(model_hparams, vocab_size):
      return "softmax_average_pooling_onehot_class_label_modality_%d_%d" % (
          vocab_size, model_hparams.hidden_size)
    return name
  elif modality_type == ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL:
    def name(model_hparams, vocab_size):
      return "softmax_last_timestep_onehot_class_label_modality_%d_%d" % (
          vocab_size, model_hparams.hidden_size)
    return name
  elif modality_type == ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL:
    def name(model_hparams, vocab_size):
      return "softmax_max_pooling_onehot_class_label_modality_%d_%d" % (
          vocab_size, model_hparams.hidden_size)
    return name
  return value


def get_targets_bottom(modality_type, value=None):
  """Gets default bottom transformation for targets; if none, return value."""
  if modality_type == ModalityType.AUDIO:
    return make_targets_bottom(audio_bottom)
  elif modality_type == ModalityType.AUDIO_SPECTRAL:
    return make_targets_bottom(audio_spectral_bottom)
  elif modality_type in (ModalityType.CLASS_LABEL,
                         ModalityType.MULTI_LABEL,
                         ModalityType.ONE_HOT_CLASS_LABEL,
                         ModalityType.SIGMOID_CLASS_LABEL,
                         ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL,
                         ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL,
                         ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL,
                         ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL):
    return class_label_targets_bottom
  elif modality_type in (ModalityType.CTC_SYMBOL,
                         ModalityType.SYMBOL,
                         ModalityType.SYMBOL_WEIGHTS_ALL):
    return symbol_targets_bottom
  elif modality_type in (ModalityType.GENERIC_L2_LOSS,
                         ModalityType.IDENTITY_SYMBOL):
    return identity_bottom
  elif modality_type == ModalityType.IDENTITY:
    return make_targets_bottom(identity_bottom)
  elif modality_type == ModalityType.IMAGE:
    return image_targets_bottom
  elif modality_type in (ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
                         ModalityType.IMAGE_CHANNEL_COMPRESS):
    return image_channel_compress_targets_bottom
  elif modality_type == ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM:
    return image_channel_embeddings_bottom
  elif modality_type in (ModalityType.REAL,
                         ModalityType.REAL_L2_LOSS,
                         ModalityType.REAL_LOG_POISSON_LOSS):
    return make_targets_bottom(real_bottom)
  elif modality_type == ModalityType.SPEECH_RECOGNITION:
    return make_targets_bottom(speech_recognition_bottom)
  elif modality_type == ModalityType.SYMBOL_ONE_HOT:
    return symbol_one_hot_bottom
  elif modality_type in (ModalityType.VIDEO,
                         ModalityType.VIDEO_L1,
                         ModalityType.VIDEO_L2):
    return video_targets_bottom
  elif modality_type == ModalityType.VIDEO_BITWISE:
    return video_bitwise_targets_bottom
  elif modality_type == ModalityType.VIDEO_IDENTITY:
    return video_identity_targets_bottom
  elif modality_type in (ModalityType.VIDEO_L1_RAW,
                         ModalityType.VIDEO_L2_RAW):
    return video_raw_targets_bottom
  elif modality_type == ModalityType.VIDEO_PIXEL_NOISE:
    return make_targets_bottom(video_pixel_noise_bottom)
  return value


def get_top(modality_type, value=None):
  """Gets default top transformation; if none available, return value."""
  if modality_type in (ModalityType.AUDIO,
                       ModalityType.AUDIO_SPECTRAL,
                       ModalityType.GENERIC_L2_LOSS,
                       ModalityType.IDENTITY,
                       ModalityType.IDENTITY_SYMBOL,
                       ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
                       ModalityType.SPEECH_RECOGNITION,
                       ModalityType.VIDEO_IDENTITY):
    return identity_top
  elif modality_type in (ModalityType.CLASS_LABEL,
                         ModalityType.MULTI_LABEL,
                         ModalityType.ONE_HOT_CLASS_LABEL,
                         ModalityType.SIGMOID_CLASS_LABEL):
    return class_label_top
  elif modality_type in (ModalityType.CTC_SYMBOL,
                         ModalityType.SYMBOL,
                         ModalityType.SYMBOL_WEIGHTS_ALL):
    return symbol_top
  elif modality_type == ModalityType.IMAGE:
    return image_top
  elif modality_type == ModalityType.IMAGE_CHANNEL_COMPRESS:
    return image_channel_compress_top
  elif modality_type == ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM:
    return image_channel_embeddings_top
  elif modality_type in (ModalityType.REAL,
                         ModalityType.REAL_L2_LOSS,
                         ModalityType.REAL_LOG_POISSON_LOSS):
    return real_top
  elif modality_type == ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL:
    return sigmoid_max_pooling_class_label_top
  elif modality_type == ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL:
    return softmax_average_pooling_class_label_top
  elif modality_type == ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL:
    return softmax_last_timestep_class_label_top
  elif modality_type == ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL:
    return softmax_max_pooling_class_label_top
  elif modality_type == ModalityType.SYMBOL_ONE_HOT:
    return symbol_one_hot_top
  elif modality_type in (ModalityType.VIDEO,
                         ModalityType.VIDEO_BITWISE,
                         ModalityType.VIDEO_PIXEL_NOISE):
    return video_top
  elif modality_type in (ModalityType.VIDEO_L1,
                         ModalityType.VIDEO_L2):
    return video_l1_top
  elif modality_type in (ModalityType.VIDEO_L1_RAW,
                         ModalityType.VIDEO_L2_RAW):
    return video_raw_top
  return value


def get_weights_fn(modality_type, value=None):
  """Gets default weights function; if none available, return value."""
  if modality_type in (ModalityType.CTC_SYMBOL,
                       ModalityType.IDENTITY_SYMBOL,
                       ModalityType.MULTI_LABEL,
                       ModalityType.SYMBOL,
                       ModalityType.SYMBOL_ONE_HOT):
    return common_layers.weights_nonzero
  elif modality_type in ModalityType.get_choices():
    return common_layers.weights_all
  return value