# coding=utf-8
# Copyright 2020 The Mesh TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Layers implemented in Mesh TensorFlow."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import gin

from mesh_tensorflow import ops_with_redefined_builtins as mtf

import tensorflow.compat.v1 as tf


@gin.configurable
def unit_scaling_convention(value=False):
  """Turn this on with gin to enable the unit-scaling convention.

  TODO(noam): turn this comment into a position paper and post to arxiv

  Under the unit-scaling convention, all weights are initialized with unit
  variance, and the outputs of most contractions (matmul/einsum operations) are
  divided by the square-root of the sizes of the contracting dimensions.

  This differs from the typical inverse-square-root weight-initalization
  convention often attributed to
  http://proceedings.mlr.press/v9/glorot10a.html
  in which weights are typically initialized according to a distribution with
  mean zero and standard-deviation equal to the inverse-square-root of the
  contracting dimension(s).

  Under both conventions, the purpose of the inverse-square-root scaling is so
  that activations in a layer should be scaled similarly to the activations in
  the previous layer.  (Typically, models are initialized so that activations in
  all layers should have RMS=O(1)).

  The difference between the two conventions is whether this scaling happens in
  the parameters (their way), or as an explicit multiplier on the activations
  (our way).

  In our opinion, parameter-scaling (their way) has three main disadvantages:

  1. Optimizers need to be aware of differently-scaled parameters.  This is
  because the learning-rates of adaptive optimizers represent target step-sizes
  for the parameters.  The desired step size for a parameter logically depends
  on the scale of the parameter itself, and so one typically needs to lower the
  learning-rate when the layers get bigger and the parameters get consequently
  smaller.  Under the unit-scaling convention, this is unnecessary, since all
  parameters are on the same unit scale.

  2. It is often unwieldy from an engineering standpoint to communicate to both
  the variable initializers and to the optimizer what the scale of the variable
  should be.  Typically, the variable initializer guesses this by inferring from
  the dimension order which dimension of the variable might represent
  contracting dimensions.  This is highly error-prone.

  3. Sometimes contractions happen without being associated with parameters, as
  in neural attention.  It may be important here too to divide by the square
  root of the contracting dimensions, in order to maintain activation scale.
  See the discussion in section 3.2.1 of https://arxiv.org/abs/1706.03762
  Being in the habit of scaling the outputs of contractions in this way makes
  it more likely to remember to do the same thing in these circumstances.

  Note: When switching to the unit-scaling convention, it is probably necessary
  to raise the learning rate, since larger parameters need larger updates.  An
  exception is when using Adafactor, which by default scales the updates
  relative to the scale of the current parameter values.

  Args:
    value: a boolean
  Returns:
    a boolean
  """
  return value


def us_einsum(xs, *args, **kwargs):
  """Einsum with optional unit-scaling convention.

  If the unit-scaling convention is enabled, then divide the output by
  the square-root of the product of the contracting dimensions.

  Args:
    xs: a list of mtf.Tensor
    *args: arguments to mtf.einsum
    **kwargs: keyword arguments to mtf.einsum
  Returns:
    a mtf.Tensor
  """
  y = mtf.einsum(xs, *args, **kwargs)
  if unit_scaling_convention():
    all_input_dims = set(sum([x.shape.dims for x in xs], []))
    reduced_dims = [d for d in all_input_dims if d not in y.shape.dims]
    y *= mtf.Shape(reduced_dims).size ** -0.5
  return y


def dense(x,
          new_dims,
          reduced_dims=None,
          expert_dims=None,
          use_bias=True,
          activation=None,
          master_dtype=tf.float32,
          slice_dtype=tf.float32,
          variable_dtype=None,
          kernel_initializer=None,
          kernel_weights=None,
          name=None):
  """Dense layer doing (kernel*x + bias) computation.

  Args:
    x: a mtf.Tensor of shape [..., reduced_dims].
    new_dims: a list of mtf.Dimension.
    reduced_dims: a list of mtf.Dimensions of x to be reduced.
      If omitted (deprecated interface), we reduce the last dimension.
    expert_dims: an optional list of mtf.Dimension which represent different
      experts. Different experts get different weights.
    use_bias: a boolean, whether to add bias.
    activation: an optional function from mtf.Tensor to mtf.Tensor
    master_dtype: a tf.dtype (deprecated - use variable_dtype)
    slice_dtype: a tf.dtype (deprecated - use variable_dtype)
    variable_dtype: a mtf.VariableDType
    kernel_initializer: an initializer for kernel variable.
    kernel_weights: mtf.Tensor weights matrix to use for dense computation
    name: a string used for tf.variable_scope.

  Returns:
    a mtf.Tensor of shape [..., new_dims].
  """
  if not isinstance(new_dims, list):
    new_dims = [new_dims]

  if variable_dtype is None:
    variable_dtype = mtf.VariableDType(master_dtype, slice_dtype, x.dtype)

  if expert_dims is None:
    expert_dims = []
  if reduced_dims is None:
    tf.logging.warning(
        "Deprecation warning - it is recommended to pass reduced_dims "
        "explicitly to mtf.layers.dense() so as not to depend on dimension "
        "order. To silence this warning, explicitly pass "
        "reduced_dims=x.shape.dims[-1:] (in scope %s)"
        %  tf.get_variable_scope().name)
    reduced_dims = x.shape.dims[-1:]
  # if any reduced dims have the same names as new dims, first change these
  #  dimension names in the input so as to avoid name conflict in the weight
  #  matrix.
  reduced_dims = reduced_dims[:]
  for i in range(len(reduced_dims)):
    if reduced_dims[i] in new_dims:
      original_name = reduced_dims[i].name
      tmp_name = "_" + original_name
      reduced_dims[i] = mtf.Dimension(tmp_name, reduced_dims[i].size)
      x = mtf.rename_dimension(x, original_name, tmp_name)
  output_shape = mtf.Shape([d for d in x.shape.dims if d not in reduced_dims] +
                           new_dims)
  if not kernel_weights:
    kernel_weights = get_dense_kernel_weights(x, new_dims, reduced_dims,
                                              expert_dims, kernel_initializer,
                                              name, variable_dtype,
                                              master_dtype, slice_dtype)

  with tf.variable_scope(name, default_name="dense"):
    y = us_einsum([x, kernel_weights], output_shape)
    if use_bias:
      b = mtf.get_variable(
          x.mesh,
          "bias",
          mtf.Shape(expert_dims + new_dims),
          initializer=tf.zeros_initializer(),
          dtype=variable_dtype)
      y += b
    if activation is not None:
      y = activation(y)
    return y


def get_dense_kernel_weights(x,
                             new_dims,
                             reduced_dims,
                             expert_dims,
                             kernel_initializer,
                             name=None,
                             variable_dtype=None,
                             master_dtype=tf.float32,
                             slice_dtype=tf.float32):
  """Create w matrix variable.

  Args:
    x: a mtf.Tensor.
    new_dims: a list of mtf.Dimension.
    reduced_dims: a list of mtf.Dimensions of x to be reduced.
    expert_dims: an optional list of mtf.Dimension which represent different
      experts. Different experts get different weights.
    kernel_initializer: an initializer for kernel variable.
    name: a string used for tf.variable_scope.
    variable_dtype: a mtf.VariableDType
    master_dtype: a tf.dtype (deprecated - use variable_dtype)
    slice_dtype: a tf.dtype (deprecated - use variable_dtype)

  Returns:
    a mtf.Tensor.
  """
  if variable_dtype is None:
    variable_dtype = mtf.VariableDType(master_dtype, slice_dtype, x.dtype)
  w_shape = mtf.Shape(expert_dims + reduced_dims + new_dims)

  with tf.variable_scope(name, default_name="dense"):
    if kernel_initializer is None:
      kernel_initializer = VarianceScalingInitializer()
    if isinstance(kernel_initializer, DenseInitializer):
      kernel_initializer = kernel_initializer(reduced_dims, new_dims)
    w = mtf.get_variable(
        x.mesh,
        "kernel",
        w_shape,
        initializer=kernel_initializer,
        dtype=variable_dtype)
    w = mtf.cast(w, x.dtype)
  return w


def dense_product(x,
                  reduced_dims,
                  new_dims,
                  activation_functions=None,
                  name="dense_product",
                  **kwargs):
  """Component-wise product of multiple dense layers.

  e.g. if activation_functions=["linear", "sigmoid"], then this implements
  Gated Linear Units https://arxiv.org/pdf/1612.08083.pdf

  Args:
    x: a Tensor
    reduced_dims: a list of Dimensions.
    new_dims: a list of Dimensions.
    activation_functions: a list of activation functions (or a singleton)
      Each can be a either:
        - a callable function from Tensor to Tensor
        - a string function name from namespace mtf)
        - None or "linear", meaning no activation function
    name: an optional string
    **kwargs: additional kwargs for mtf.layers.dense()
  """
  if not isinstance(activation_functions, list):
    activation_functions = [activation_functions]
  num_factors = len(activation_functions)
  factors = []
  for i, activation in enumerate(activation_functions):
    if activation == "linear":
      activation = None
    elif isinstance(activation, str):
      activation = getattr(mtf, activation)
    factors.append(
        dense(x,
              reduced_dims=reduced_dims,
              new_dims=new_dims,
              activation=activation,
              name="%s_%d" % (name, i) if num_factors > 1 else name,
              **kwargs))
  return functools.reduce(mtf.multiply, factors)


class DenseInitializer(object):
  """Initializer that can be passed to dense().

  The __call__ function takes reduced_dims and new_dims and returns a
  tf initializer class.
  """

  def __call__(self, reduced_dims, new_dims):
    raise NotImplementedError("not implemented")


@gin.configurable
class VarianceScalingInitializer(DenseInitializer):
  """Initializer capable of adapting its scale to the shape of weights.

  With `distribution="normal"`, samples are drawn from a truncated normal
  distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:

    1.0 if unit_scaling_convention() is turned on
    otherwise:
      number of input units in the weight tensor, if mode = "fan_in"
      number of output units, if mode = "fan_out"
      average of the numbers of input and output units, if mode = "fan_avg"

  With `distribution="uniform"`,
  samples are drawn from a uniform distribution
  within [-limit, limit], with `limit = sqrt(3 * scale / n)`.

  # Arguments
      scale: Scaling factor (positive float).
      mode: One of "fan_in", "fan_out", "fan_avg".
      distribution: Random distribution to use. One of "normal", "uniform".
      seed: A Python integer. Used to seed the random generator.
  """

  def __init__(self, scale=1.0,
               mode="fan_in",
               distribution="normal"):
    self.scale = scale
    self.mode = mode.lower()
    self.distribution = distribution.lower()

  def __call__(self, reduced_dims, new_dims):
    fan_in = mtf.list_product(d.size for d in reduced_dims)
    fan_out = mtf.list_product(d.size for d in new_dims)
    scale = self.scale
    if self.mode == "fan_in":
      if not unit_scaling_convention():
        scale /= max(1., fan_in)
    elif self.mode == "fan_out":
      if unit_scaling_convention():
        raise ValueError("Unit scaling convention only works with \"fan_in\"")
      scale /= max(1., fan_out)
    elif self.mode == "fan_avg":
      if unit_scaling_convention():
        raise ValueError("Unit scaling convention only works with \"fan_in\"")
      scale /= max(1., float(fan_in + fan_out) / 2)
    else:
      raise ValueError(
          "Invalid `mode` argument: "
          "expected on of {\"fan_in\", \"fan_out\", \"fan_avg\"} "
          "but got %s" % (self.mode,))
    stddev = scale ** 0.5
    if self.distribution == "normal":
      return tf.truncated_normal_initializer(stddev=stddev)
    elif self.distribution == "uniform":
      limit = stddev * 3. ** 0.5
      return tf.random_uniform_initializer(minval=-limit, maxval=limit)
    else:
      raise ValueError("Invalid `distribution` argument: "
                       "expected one of {\"normal\", \"uniform\"} "
                       "but got %s" % (self.distribution,))


def conv1d(x, output_dim, filter_size=3, stride=1, **kw_args):
  """1D Convolution.

  Args:
    x: a mtf.Tensor of format NWC.
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    filter_size: a positive integer, the filter width.
    stride: a positive integer, the stride.
    **kw_args: optional keyword arguments to mtf.layers.conv2d.

  Returns:
    a mtf.Tensor of format NWO, where O is the output dimension.
  """
  fake_height_dim = mtf.Dimension("fake_height", 1)
  x = mtf.reshape(
      x, mtf.Shape(x.shape.dims[:-2] + [fake_height_dim] + x.shape.dims[-2:]))
  output = conv2d(
      x,
      output_dim,
      filter_size=(1, filter_size),
      strides=(1, stride),
      **kw_args)
  return mtf.reshape(
      output,
      mtf.Shape([
          d for d in x.shape.dims
          if d != fake_height_dim and d != x.shape.dims[-1]
      ] + [output_dim]))


def _depthwise_conv1d_hack(x,
                           depth_dim,
                           length_dim,
                           min_relative_pos=-1,
                           max_relative_pos=1,
                           name=None,
                           use_bias=True,
                           initializer_scale=1.0,
                           kernel_depth_weights=None):
  """Hacky version of a 1d depthwise convolution.

  Args:
    x: a mtf.Tensor
    depth_dim: mtf.Dimension,
    length_dim: mtf.Dimension,
    min_relative_pos: int, min relative position,
    max_relative_pos: int, max relative position,
    name: str, variable_scope name,
    use_bias: Bool, whether to use bias,
    initializer_scale: int, initalizer scale,
    kernel_depth_weights: an optional list of kernel weight tensors. The list
    contains one element for each relative position in the kernel. Each element
    has a width equal to the depth over which the separable conv operation is
    being "separated"

  Returns:
    an mtf.Tensor
  """

  ret = 0
  kernel_size = max_relative_pos - min_relative_pos + 1

  with tf.variable_scope(name, default_name="depthwise_conv_hack"):
    for i in range(kernel_size):
      relative_pos = min_relative_pos + i
      shifted_input = mtf.shift(x, -relative_pos, length_dim, wrap=False)
      ret += dense(
          shifted_input,
          new_dims=[],
          reduced_dims=[],
          expert_dims=[depth_dim],
          kernel_weights=kernel_depth_weights[i]
          if kernel_depth_weights else None,
          name="depthwise_dense_%d" % i,
          use_bias=use_bias and (i == 0),
          kernel_initializer=VarianceScalingInitializer(
              scale=initializer_scale / kernel_size))

  return ret


def separable_conv1d(x,
                     output_dim,
                     min_relative_pos=-1,
                     max_relative_pos=1,
                     depthwise_filter_initializer_scale=1.0,
                     pointwise_filter_initializer_scale=1.0,
                     name=None,
                     use_bias=True,
                     kernel_depth_weights=None):
  """1-D convolution with separable filters.

  The filter size will be `max_relative_pos - min_relative_pos + 1`.

  Args:
    x: a mtf.Tensor of format NWC.
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    min_relative_pos: an integer, the inclusive minimum relative positive of the
      depthwise filter, where a relative position of zero means the left end of
      the filter aligns with the left end of the input.
    max_relative_pos: an integer, the inclusive maximum relative position of the
      depthwise filter, where a relative position of zero means the right end of
      the filter aligns with the right end of the input.
    depthwise_filter_initializer_scale: a positive float, the scale of the
      initializer for the depthwise filter.
    pointwise_filter_initializer_scale: a positive float, the scale of the
      initializer for the pointwise filter.
    name: a string used for tf.variable_scope.
    use_bias: a bool, whether to use bias in the convolutions.
    kernel_depth_weights: an optional list of kernel weight tensors. The list
    contains one element for each relative position in the kernel. Each element
    has a width equal to the dimension over which the separable conv operation
    is being "separated"

  Returns:
    a mtf.Tensor of format NWO, where O is the output dimension.
  """
  depth_dim = x.shape.dims[-1]
  length_dim = x.shape.dims[-2]
  with tf.variable_scope(name, default_name="separable_conv1d"):
    depthwise = _depthwise_conv1d_hack(
        x,
        depth_dim=depth_dim,
        length_dim=length_dim,
        min_relative_pos=min_relative_pos,
        max_relative_pos=max_relative_pos,
        use_bias=use_bias,
        initializer_scale=depthwise_filter_initializer_scale,
        kernel_depth_weights=kernel_depth_weights)
    return dense(
        depthwise,
        new_dims=[output_dim],
        reduced_dims=[depth_dim],
        name="pointwise_dense",
        use_bias=use_bias,
        kernel_initializer=VarianceScalingInitializer(
            scale=pointwise_filter_initializer_scale))


def conv2d(x, output_dim, filter_size=(3, 3),
           strides=(1, 1), padding="SAME", filter_initializer=None,
           variable_dtype=None, name=None):
  """2D Convolution.

  Args:
    x: a mtf.Tensor of format NHWC.
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    filter_size: a list or tuple in format [filter_height, filter_width].
    strides: a list or tuple in format [stride_height, stride_width].
    padding: either "SAME" or "VALID".
    filter_initializer: the initializer for tf.get_variable.
    variable_dtype: a mtf.VariableDType
    name: a string used for tf.variable_scope.

  Returns:
    a mtf.Tensor.
  """
  fh_dim = mtf.Dimension("fh", filter_size[0])
  fw_dim = mtf.Dimension("fw", filter_size[1])
  input_dim = x.shape[-1]
  with tf.variable_scope(name, default_name="conv2d"):
    if variable_dtype is None:
      variable_dtype = mtf.VariableDType(activation_dtype=x.dtype)
    conv_filter = mtf.get_variable(
        x.mesh, "kernel", [fh_dim, fw_dim, input_dim, output_dim],
        initializer=filter_initializer, dtype=variable_dtype)
    # Pad stride in batch and channel dimensions.
    strides = [1] + list(strides) + [1]

    return mtf.Conv2dOperation(x, conv_filter, strides, padding).outputs[0]


def conv2d_with_blocks(
    x, output_dim, filter_size=(3, 3),
    strides=(1, 1), padding="SAME",
    h_blocks_dim=None, w_blocks_dim=None, filter_initializer=None,
    variable_dtype=None, name=None):
  """2D Convolution with spatial partitioning.

  Spatial partitioning is implemented by decomposing the image into blocks.
  Block dimensions represented as h_blocks_dim and w_blocks_dim can be split
  along the mesh axis. If split, then we do a halo exchange where each block
  receives the part of the image from its left and right neighbors necessary to
  do the convolution. Exchange can involve complete or partial blocks depending
  on the filter height and width.

  Currently, only "SAME" padding with dilation rate of 1 is supported.

  Args:
    x: a Tensor of shape
        [batch, h_blocks_dim, w_blocks_dim, h_dim, w_dim, in_channels_dim]
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    filter_size: a list or tuple in format [filter_height, filter_width].
    strides: a list or tuple in format [stride_height, stride_width].
    padding: string, "SAME". The type of padding algorithm to use.
        "Valid" is not currently supported.
    h_blocks_dim: Dimension representing number of height blocks.
    w_blocks_dim: Dimension representing number of witdh blocks.
    filter_initializer: the initializer for tf.get_variable.
    variable_dtype: a mtf.VariableDType
    name: a name for the operation (optional).

  Returns:
    A Tensor of shape
      [batch, h_blocks_dim, w_blocks_dim, h_dim, w_dim, out_channels_dim]
  """
  # If h_blocks_dim and w_blocks_dim are not split, directly call conv2d.
  if h_blocks_dim is None and w_blocks_dim is None:
    return conv2d(x, output_dim,
                  filter_size, strides, padding, filter_initializer,
                  variable_dtype, name)

  assert filter_size[0] % 2 == 1
  assert filter_size[1] % 2 == 1

  # Padding 'VALID' is not supported yet.
  if padding != "SAME":
    raise NotImplementedError("conv2d_with_blocks requires padding=SAME")

  # Halo exchange for h_blocks and w_blocks.
  h_dim, w_dim = x.shape.dims[-3:-1]
  for blocks_dim, block_size_dim, halo_size in [
      (h_blocks_dim, h_dim, filter_size[0] // 2),
      (w_blocks_dim, w_dim, filter_size[1] // 2)]:
    if halo_size > 0:
      if blocks_dim is not None:
        x = mtf.halo_exchange(x, blocks_dim, block_size_dim, halo_size)
      else:
        x = mtf.pad(x, [halo_size, halo_size], block_size_dim.name)
  return conv2d(x, output_dim,
                filter_size, strides, "VALID", filter_initializer,
                variable_dtype, name)


def conv2d_transpose(x, output_dim,
                     filter_size=(2, 2), strides=(2, 2),
                     padding="SAME", filter_initializer=None,
                     variable_dtype=None, name=None):
  """2D Transposed Convolution.

  Args:
    x: a mtf.Tensor of format NHWC.
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    filter_size: a list or tuple in format
        [filter_height, filter_width]. Only filter_size of (2, 2) is tested.
    strides: a list or tuple in format
        [stride_height, stride_width]. Only strides of (2, 2) is tested.
    padding: either "SAME" or "VALID".
    filter_initializer: the initializer for tf.get_variable.
    variable_dtype: a mtf.VariableDType
    name: a string used for tf.variable_scope.

  Returns:
    a mtf.Tensor.
  """
  fh_dim = mtf.Dimension("fh", filter_size[0])
  fw_dim = mtf.Dimension("fw", filter_size[1])
  input_dim = x.shape[-1]
  with tf.variable_scope(name, default_name="conv2d_transpose"):
    if variable_dtype is None:
      variable_dtype = mtf.VariableDType(activation_dtype=x.dtype)
    conv_filter = mtf.get_variable(
        x.mesh, "kernel", [fh_dim, fw_dim, output_dim, input_dim],
        initializer=filter_initializer, dtype=variable_dtype)
    # Pad stride in batch and channel dimensions.
    strides = [1] + list(strides) + [1]

    return mtf.Conv2dTransposeOperation(
        x, conv_filter, strides, padding).outputs[0]


def conv2d_transpose_with_blocks(
    x, output_dim, filter_size=(2, 2),
    strides=(2, 2), padding="SAME",
    h_blocks_dim=None, w_blocks_dim=None, filter_initializer=None,
    variable_dtype=None, name=None):
  """2D Transposed Convolution with spatial partitioning.

  Spatial partitioning is implemented by decomposing the image into blocks.
  Block dimensions represented as h_blocks_dim and w_blocks_dim can be split
  along the mesh axis. If split, then we do a halo exchange where each block
  receives the part of the image from its left and right neighbors necessary to
  do the convolution. Exchange can involve complete or partial blocks depending
  on the filter depth and height.

  Currently, only "SAME" padding with dilation rate of 1 is supported. Only
  splitting along the depth and height dimensions are supported.

  Args:
    x: a Tensor of shape
        [batch, h_blocks_dim, w_blocks_dim, h_dim, w_dim, in_channel_dim]
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    filter_size: a list or tuple in format
        [filter_height, filter_width]. Only filter_size of (2, 2) is tested.
    strides: a list or tuple in format
        [stride_height, stride_width]. Only strides of (2, 2) is tested.
    padding: string, "SAME". The type of padding algorithm to use.
        "Valid" is not currently supported.
    h_blocks_dim: Dimension representing number of height blocks.
    w_blocks_dim: Dimension representing number of width blocks.
    filter_initializer: the initializer for tf.get_variable.
    variable_dtype: a mtf.VariableDType
    name: a name for the operation (optional).

  Returns:
    A Tensor of shape
      [batch, h_blocks_dim, w_blocks_dim, h_dim, w_dim, out_channels_dim]
  """
  # If h_blocks_dim and w_blocks_dim are not split, directly call conv2d_trans.
  if h_blocks_dim is None and w_blocks_dim is None:
    return conv2d_transpose(
        x, output_dim, filter_size, strides, padding, filter_initializer,
        variable_dtype, name)

  # Now only supports even-sized filters.
  assert filter_size[0] % 2 == 0
  assert filter_size[1] % 2 == 0

  # Padding 'VALID' is not supported yet.
  if padding != "SAME":
    raise NotImplementedError(
        "conv2d_transpose_with_blocks requires padding=SAME")

  # Halo exchange for h_blocks and w_blocks.
  # TODO(lehou): figure out the halo_size in general cases.
  h_dim, w_dim = x.shape.dims[-3:-1]
  for blocks_dim, block_size_dim, halo_size in [
      (h_blocks_dim, h_dim, filter_size[0] // 2 - 1),
      (w_blocks_dim, w_dim, filter_size[1] // 2 - 1)]:
    if halo_size > 0:
      if blocks_dim is not None:
        x = mtf.halo_exchange(x, blocks_dim, block_size_dim, halo_size)
      else:
        x = mtf.pad(x, [halo_size, halo_size], block_size_dim.name)

  return conv2d_transpose(
      x, output_dim, filter_size, strides, "VALID", filter_initializer,
      variable_dtype, name)


def conv3d(x, output_dim, filter_size=(3, 3, 3),
           strides=(1, 1, 1), padding="SAME",
           filter_initializer=None,
           variable_dtype=None, name=None):
  """3D Convolution.

  Args:
    x: a mtf.Tensor of format NDHWC.
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    filter_size: a list or tuple in format
        [filter_depth, filter_height, filter_width].
    strides: a list or tuple in format
        [stride_depth, stride_height, stride_width].
    padding: either "SAME" or "VALID".
    filter_initializer: the initializer for tf.get_variable.
    variable_dtype: a mtf.VariableDType
    name: a string used for tf.variable_scope.

  Returns:
    a mtf.Tensor.
  """
  fd_dim = mtf.Dimension("fd", filter_size[0])
  fh_dim = mtf.Dimension("fh", filter_size[1])
  fw_dim = mtf.Dimension("fw", filter_size[2])
  input_dim = x.shape[-1]
  with tf.variable_scope(name, default_name="conv3d"):
    if variable_dtype is None:
      variable_dtype = mtf.VariableDType(activation_dtype=x.dtype)
    conv_filter = mtf.get_variable(
        x.mesh, "kernel", [fd_dim, fh_dim, fw_dim, input_dim, output_dim],
        initializer=filter_initializer, dtype=variable_dtype)
    # Pad stride in batch and channel dimensions.
    strides = [1] + list(strides) + [1]

    return mtf.Conv3dOperation(x, conv_filter, strides, padding).outputs[0]


def conv3d_with_blocks(
    x, output_dim, filter_size=(3, 3, 3),
    strides=(1, 1, 1), padding="SAME",
    d_blocks_dim=None, h_blocks_dim=None, filter_initializer=None,
    variable_dtype=None, name=None):
  """3D Convolution with spatial partitioning.

  Spatial partitioning is implemented by decomposing the image into blocks.
  Block dimensions represented as d_blocks_dim and h_blocks_dim can be split
  along the mesh axis. If split, then we do a halo exchange where each block
  receives the part of the image from its left and right neighbors necessary to
  do the convolution. Exchange can involve complete or partial blocks depending
  on the filter depth and height.

  Currently, only "SAME" padding with dilation rate of 1 is supported. Only
  splitting along the depth and height dimensions are supported.

  Args:
    x: a Tensor of shape
        [batch, d_blocks_dim, h_blocks_dim, d_dim, h_dim, w_dim, in_channel_dim]
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    filter_size: a list or tuple in format
        [filter_depth, filter_height, filter_width].
    strides: a list or tuple in format
        [stride_depth, stride_height, stride_width].
    padding: string, "SAME". The type of padding algorithm to use.
        "Valid" is not currently supported.
    d_blocks_dim: Dimension representing number of depth blocks.
    h_blocks_dim: Dimension representing number of height blocks.
    filter_initializer: the initializer for tf.get_variable.
    variable_dtype: a mtf.VariableDType
    name: a name for the operation (optional).

  Returns:
    A Tensor of shape
      [batch, d_blocks_dim, h_blocks_dim, w_blocks_dim,
       d_dim, h_dim, w_dim, out_channels_dim]
  """
  # If d_blocks_dim and h_blocks_dim are not split, directly call conv3d.
  if d_blocks_dim is None and h_blocks_dim is None:
    return conv3d(x, output_dim,
                  filter_size, strides, padding, filter_initializer,
                  variable_dtype, name)

  assert filter_size[0] % 2 == 1
  assert filter_size[1] % 2 == 1
  assert filter_size[2] % 2 == 1

  # Padding 'VALID' is not supported yet.
  if padding != "SAME":
    raise NotImplementedError("conv3d_with_blocks requires padding=SAME")

  # Halo exchange for d_blocks and h_blocks.
  d_dim, h_dim, w_dim = x.shape.dims[-4:-1]
  for blocks_dim, block_size_dim, halo_size in [
      (d_blocks_dim, d_dim, filter_size[0] // 2),
      (h_blocks_dim, h_dim, filter_size[1] // 2)]:
    if halo_size > 0:
      if blocks_dim is not None:
        x = mtf.halo_exchange(x, blocks_dim, block_size_dim, halo_size)
      else:
        x = mtf.pad(x, [halo_size, halo_size], block_size_dim.name)

  # Pad w dimension with zeros.
  x = mtf.pad(x, [filter_size[2] // 2, filter_size[2] // 2],
              dim_name=w_dim.name, name="conv3d_pad_w_dim")
  return conv3d(x, output_dim,
                filter_size, strides, "VALID", filter_initializer,
                variable_dtype, name)


def conv3d_transpose(x, output_dim,
                     filter_size=(2, 2, 2), strides=(2, 2, 2),
                     padding="SAME", filter_initializer=None,
                     variable_dtype=None, name=None):
  """3D Transposed Convolution.

  Args:
    x: a mtf.Tensor of format NDHWC.
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    filter_size: a list or tuple in format
        [filter_depth, filter_height, filter_width].
        Only filter_size of (2, 2, 2) is tested.
    strides: a list or tuple in format
        [stride_depth, stride_height, stride_width].
        Only strides of (2, 2, 2) is tested.
    padding: either "SAME" or "VALID".
    filter_initializer: the initializer for tf.get_variable.
    variable_dtype: a mtf.VariableDType
    name: a string used for tf.variable_scope.

  Returns:
    a mtf.Tensor.
  """
  fd_dim = mtf.Dimension("fd", filter_size[0])
  fh_dim = mtf.Dimension("fh", filter_size[1])
  fw_dim = mtf.Dimension("fw", filter_size[2])
  input_dim = x.shape[-1]
  with tf.variable_scope(name, default_name="conv3d_transpose"):
    if variable_dtype is None:
      variable_dtype = mtf.VariableDType(activation_dtype=x.dtype)
    conv_filter = mtf.get_variable(
        x.mesh, "kernel", [fd_dim, fh_dim, fw_dim, output_dim, input_dim],
        initializer=filter_initializer, dtype=variable_dtype)
    # Pad stride in batch and channel dimensions.
    strides = [1] + list(strides) + [1]

    return mtf.Conv3dTransposeOperation(
        x, conv_filter, strides, padding).outputs[0]


def conv3d_transpose_with_blocks(
    x, output_dim, filter_size=(2, 2, 2),
    strides=(2, 2, 2), padding="SAME",
    d_blocks_dim=None, h_blocks_dim=None, filter_initializer=None,
    variable_dtype=None, name=None):
  """3D Transposed Convolution with spatial partitioning.

  Spatial partitioning is implemented by decomposing the image into blocks.
  Block dimensions represented as d_blocks_dim and h_blocks_dim can be split
  along the mesh axis. If split, then we do a halo exchange where each block
  receives the part of the image from its left and right neighbors necessary to
  do the convolution. Exchange can involve complete or partial blocks depending
  on the filter depth and height.

  Currently, only "SAME" padding with dilation rate of 1 is supported. Only
  splitting along the depth and height dimensions are supported.

  Args:
    x: a Tensor of shape
        [batch, d_blocks_dim, h_blocks_dim, d_dim, h_dim, w_dim, in_channel_dim]
    output_dim: a mtf.Dimension, indicating the output channel dimension.
    filter_size: a list or tuple in format
        [filter_depth, filter_height, filter_width].
        Only filter_size of (2, 2, 2) is tested.
    strides: a list or tuple in format
        [stride_depth, stride_height, stride_width].
        Only strides of (2, 2, 2) is tested.
    padding: string, "SAME". The type of padding algorithm to use.
        "Valid" is not currently supported.
    d_blocks_dim: Dimension representing number of depth blocks.
    h_blocks_dim: Dimension representing number of height blocks.
    filter_initializer: the initializer for tf.get_variable.
    variable_dtype: a mtf.VariableDType
    name: a name for the operation (optional).

  Returns:
    A Tensor of shape
      [batch, d_blocks_dim, h_blocks_dim, w_blocks_dim,
       d_dim, h_dim, w_dim, out_channels_dim]
  """
  # If d_blocks_dim and h_blocks_dim are not split, directly call conv3d_trans.
  if d_blocks_dim is None and h_blocks_dim is None:
    return conv3d_transpose(
        x, output_dim, filter_size, strides, padding, filter_initializer,
        variable_dtype, name)

  # Now only supports even-sized filters.
  assert filter_size[0] % 2 == 0
  assert filter_size[1] % 2 == 0
  assert filter_size[2] % 2 == 0

  # Padding 'VALID' is not supported yet.
  if padding != "SAME":
    raise NotImplementedError(
        "conv3d_transpose_with_blocks requires padding=SAME")

  # Halo exchange for d_blocks and h_blocks.
  # TODO(lehou): figure out the halo_size in general cases.
  d_dim, h_dim, w_dim = x.shape.dims[-4:-1]
  for blocks_dim, block_size_dim, halo_size in [
      (d_blocks_dim, d_dim, filter_size[0] // 2 - 1),
      (h_blocks_dim, h_dim, filter_size[1] // 2 - 1)]:
    if halo_size > 0:
      if blocks_dim is not None:
        x = mtf.halo_exchange(x, blocks_dim, block_size_dim, halo_size)
      else:
        x = mtf.pad(x, [halo_size, halo_size], block_size_dim.name)

  # Pad w dimension with zeros.
  x = mtf.pad(x, [filter_size[2] // 2 - 1, filter_size[2] // 2 - 1],
              dim_name=w_dim.name, name="conv3d_trans_pad_w_dim")
  return conv3d_transpose(
      x, output_dim, filter_size, strides, "VALID", filter_initializer,
      variable_dtype, name)


def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
  """Layer normalization over dimension dim.

  Args:
    x: a mtf.Tensor whose shape contains dim.
    dim: a mtf.Dimension
    epsilon: a floating point number
    name: a string used for tf.variable_scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
  with tf.variable_scope(name + "/layer_norm"):
    scale = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    bias = mtf.get_variable(
        x.mesh,
        "layer_norm_bias",
        mtf.Shape([dim]),
        initializer=tf.zeros_initializer(),
        activation_dtype=x.dtype)
    reduced_shape = x.shape - dim
    mean = mtf.reduce_mean(x, output_shape=reduced_shape)
    variance = mtf.reduce_mean(mtf.square(x - mean), output_shape=reduced_shape)
    norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)
    return norm_x * scale + bias


def batch_norm(x, is_training, momentum, epsilon=1e-9,
               dims_idx_start=0, dims_idx_end=-1,
               init_zero=False, name=None):
  """Batch normalization.

  Args:
    x: a mtf.Tensor whose shape contains [batch_dim, ..., dim]
    is_training: a boolean, whether mode is training.
    momentum: a floating point number, specifying batch norm decay value.
    epsilon: a floating point number.
    dims_idx_start: an integer. Dimension with indices in
      [dims_idx_start, dims_idx_end - 1] will be normalized.
    dims_idx_end: an integer. Dimension with indices in
      [dims_idx_start, dims_idx_end - 1] will be normalized.
    init_zero: a boolean, whether to initialize scale with 0's or 1's.
    name: a string used for tf.variable_scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
  with tf.variable_scope(name, default_name="batch_norm", values=[x]):
    if init_zero:
      gamma_initializer = tf.zeros_initializer()
    else:
      gamma_initializer = tf.ones_initializer()

    norm_dim = x.shape.dims[dims_idx_start:dims_idx_end]
    reduced_shape = x.shape - norm_dim

    scale = mtf.get_variable(
        x.mesh,
        "batch_norm_scale",
        reduced_shape,
        initializer=gamma_initializer,
        activation_dtype=x.dtype)
    bias = mtf.get_variable(
        x.mesh,
        "batch_norm_bias",
        reduced_shape,
        initializer=tf.zeros_initializer(),
        activation_dtype=x.dtype)

    moving_mean = mtf.get_variable(
        x.mesh, "bn_moving_mean", reduced_shape,
        initializer=tf.random_normal_initializer(stddev=1.0),
        activation_dtype=x.dtype,
        trainable=False)
    moving_variance = mtf.get_variable(
        x.mesh, "bn_moving_variance",
        reduced_shape, initializer=tf.ones_initializer(),
        activation_dtype=x.dtype,
        trainable=False)

    # At training time, calculate mean and variance and normalize across batch
    # dim.
    if is_training:
      mean = mtf.reduce_mean(x, output_shape=reduced_shape)
      variance = mtf.reduce_mean(
          mtf.square(x - mean), output_shape=reduced_shape)

      norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)

      # Update running mean and running variance.
      # TODO(lehou): do not return update_ops; handle them inside MTF.
      bn_stats_update_ops = []
      bn_stats_update_ops.append(mtf.assign(
          moving_mean, momentum * moving_mean + (1 - momentum) * mean,
          name="{}/bn_mean_update".format(name)))
      bn_stats_update_ops.append(mtf.assign(
          moving_variance,
          momentum * moving_variance + (1 - momentum) * variance,
          name="{}/bn_var_update".format(name)))
    else:
      # At eval and test time, use the running mean and variance.
      norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon)
      bn_stats_update_ops = []

    return (norm_x * scale) + bias, bn_stats_update_ops


def softmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
  """Per-example softmax loss.

  `logits` is a Tensor with floating-point dtype, containing the predicted
  relative log probabilities of the classes.

  Either hard targets or soft targets are supported.

  In the case of hard targets, `targets` is a Tensor with integer dtype whose
  values are in the range [0, vocab_dim.size).  `targets` should have the same
  set of dimensions as `logits`, but without `vocab_dim`.

  In the case of soft targets, `targets` is a Tensor with floating point dtype
  and the same dimensions as `logits.  Reducing `targets` along `vocab_dim`
  should result in all ones.

  if z_loss is nonzero, we add a loss equal to z_loss*log(z)^2, where z is the
  partition function.  Example value: z_loss=1e-4.  Two uses of z_loss are:
  - To keep the logits from drifting too far from zero, which can cause
     unacceptable roundoff errors in bfloat16.
  - To encourage the logits to be normalized log-probabilities.

  Args:
    logits: a mtf.Tensor whose shape contains vocab_dim
    targets: a mtf.Tensor representing hard or soft targets (see comments)
    vocab_dim: a mtf.Dimension
    z_loss: a float

  Returns:
    a mtf.Tensor whose shape is equal to logits.shape - vocab_dim

  Raises:
    ValueError: if the shapes do not match.
  """
  if targets.dtype.is_integer:
    # hard targets
    if (set(targets.shape.dims)
        != set(logits.shape.dims).difference([vocab_dim])):
      raise ValueError(
          "softmax_cross_entropy_with_logits with hard targets "
          "dims in targets=%s should be dims in logits=%s other than "
          "vocab_dim=%s" % (targets, logits, vocab_dim))
    targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype)
  elif set(targets.shape.dims) != set(logits.shape.dims):
    raise ValueError(
        "softmax_cross_entropy_with_logits with soft targets "
        "dims in targets=%s should be dims in logits=%s"% (targets, logits))
  if vocab_dim not in logits.shape.dims:
    raise ValueError("vocab_dim must be in logits.shape.dims")
  log_z = mtf.reduce_logsumexp(logits, vocab_dim)
  log_softmax = logits - log_z
  loss = mtf.negative(
      mtf.reduce_sum(log_softmax * targets, reduced_dim=vocab_dim))
  if z_loss != 0:
    loss += z_loss * mtf.square(log_z)
  return loss


def sigmoid_cross_entropy_with_logits(logits, targets):
  """Sigmoid cross-entropy loss.

  Args:
    logits: a mtf.Tensor
    targets: a mtf.Tensor with the same shape as logits

  Returns:
    a mtf.Tensor whose shape is equal to logits.shape

  Raises:
    ValueError: if the shapes do not match.
  """
  if logits.shape != targets.shape:
    raise ValueError(
        "logits shape must equal targets shape"
        "logits=%s targets=%s" % (logits.to_string, targets.to_string))
  x = logits
  z = targets
  return mtf.relu(x) - x * z + mtf.log(1 + mtf.exp(-mtf.abs(x)))


def weights_nonzero(targets, dtype=tf.float32):
  def my_fn(x):
    return tf.cast(tf.not_equal(x, 0), dtype)
  return mtf.cwise(my_fn, [targets], output_dtype=dtype, name="weights_nonzero")


def dense_relu_dense(x,
                     hidden_channels,
                     dropout=0.0,
                     dropout_broadcast_dims=None,
                     master_dtype=tf.float32,
                     slice_dtype=tf.float32, name=None):
  """Hidden layer with ReLU activation followed by linear projection.

  The output has the same number of channels as the input.

  Args:
    x: a mtf.Tensor
    hidden_channels: a mtf.Dimension - channels in the hidden layer
    dropout: an optional float
    dropout_broadcast_dims: an optional list of mtf.Dimension
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    name: an optional string

  Returns:
    a mtf.Tensor with the same shape as x.
  """
  with tf.variable_scope(name, default_name="dense_relu_dense"):
    io_channels = x.shape.dims[-1]
    h = dense(x, hidden_channels,
              use_bias=False, activation=mtf.relu,
              master_dtype=master_dtype, slice_dtype=slice_dtype, name="wi")
    if dropout != 0.0:
      h = mtf.dropout(h, 1.0 - dropout,
                      noise_shape=h.shape - dropout_broadcast_dims)
    return dense(h, io_channels, use_bias=False, activation=None,
                 master_dtype=master_dtype, slice_dtype=slice_dtype,
                 name="wo")


def local_1d_halo_exchange(k, v, num_w_blocks, w_dim, mask_right):
  """Halo exchange for keys and values for Local 1D attention."""
  if num_w_blocks is not None:
    if mask_right:
      k = mtf.left_halo_exchange(k, num_w_blocks, w_dim, w_dim.size)
      v = mtf.left_halo_exchange(v, num_w_blocks, w_dim, w_dim.size)
    else:
      k = mtf.halo_exchange(k, num_w_blocks, w_dim, w_dim.size)
      v = mtf.halo_exchange(v, num_w_blocks, w_dim, w_dim.size)
  else:
    if mask_right:
      k = mtf.pad(k, [w_dim, None], w_dim.name)
      v = mtf.pad(v, [w_dim, None], w_dim.name)
    else:
      k = mtf.pad(k, [w_dim, w_dim], w_dim.name)
      v = mtf.pad(v, [w_dim, w_dim], w_dim.name)
  return k, v


def local_self_attention_spatial_blocks(
    query_antecedent,
    kv_channels,
    heads,
    memory_w_dim=None,
    mask_right=False,
    master_dtype=tf.float32,
    slice_dtype=tf.float32,
    name=None):
  """Attention to the source position and a neighborhood to the left or right.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape
      [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]
      must have the same size as query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    memory_w_dim: mtf Dimension, for the memory width block.
    mask_right: bool, flag specifying whether we mask out attention to the right
      for the decoder.
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    name: an optional string.

  Returns:
    a Tensor of shape
        [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
  with tf.variable_scope(
      name, default_name="multihead_attention",
      values=[query_antecedent]):

    w_dim, io_channels = query_antecedent.shape.dims[-2:]
    batch, num_w_blocks = query_antecedent.shape.dims[:2]
    wq, wk, wv, wo = multihead_attention_vars(
        query_antecedent.mesh, heads, io_channels, kv_channels,
        master_dtype, slice_dtype, query_antecedent.dtype)

    # Rename dimensions for the memory height and width.
    memory_antecedent = mtf.rename_dimension(
        query_antecedent, w_dim.name, "memory_" + w_dim.name)
    memory_w_dim = memory_antecedent.shape.dims[-2]

    # Call einsum over the query and memory to get query q, keys k and values v.
    q = mtf.einsum(
        [query_antecedent, wq],
        mtf.Shape([batch, heads, num_w_blocks, w_dim, kv_channels]))
    k = mtf.einsum(
        [memory_antecedent, wk],
        mtf.Shape([batch, heads, num_w_blocks, memory_w_dim, kv_channels]))
    v = mtf.einsum(
        [memory_antecedent, wv],
        mtf.Shape([batch, heads, num_w_blocks, memory_w_dim, kv_channels]))

    # Halo exchange for memory blocks.
    k, v = local_1d_halo_exchange(k, v, num_w_blocks, memory_w_dim, mask_right)

    # Calculate the causal mask to avoid peeking into the future. We compute
    # this once and reuse it for all blocks since the block_size is known.
    mask = None
    if mask_right:
      mask = attention_bias_local_block(
          query_antecedent.mesh, w_dim, memory_w_dim)

    output = dot_product_attention(q, k, v, mask=mask)

    return mtf.einsum(
        [output, wo], mtf.Shape([batch, num_w_blocks, w_dim, io_channels]))


def masked_local_attention_1d(x,
                              kv_channels,
                              heads,
                              window_size=128,
                              master_dtype=tf.float32,
                              slice_dtype=tf.float32,
                              length_per_split=None,
                              return_kv=None,
                              params=None,
                              name=None):
  """Attention to the source position and a neighborhood to the left of it.

  Attention for a given query position p can only see memory positions
  in the range (p - window_size, p].

  Args:
    x: a mtf.Tensor with shape batch_dims + [length, io_channels]
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    window_size: an integer
    master_dtype: a tf.dtype (deprecated - use params arg)
    slice_dtype: a tf.dtype (deprecated - use params arg)
    length_per_split: an optional integer indicating the part of the length
      dimension per processor.  You can omit if the length dimension is not
      split.
    return_kv: an optional list onto which to append the computed k and v.
    params: an optional quadruple of Tensors (see multihead_attention_params())
    name: an optional string.

  Returns:
    a Tensor with the same shape as x

  Raises:
    ValueError: if channels or depth don't match.
  """
  with tf.variable_scope(
      name, default_name="masked_local_attention_1d", values=[x]):

    batch_dims = x.shape.dims[:-2]
    length, io_channels = x.shape.dims[-2:]
    if params is None:
      wq, wk, wv, wo = multihead_attention_vars(
          x.mesh, heads, io_channels, kv_channels,
          master_dtype, slice_dtype, x.dtype)
    else:
      wq, wk, wv, wo = params

    # Get query q, keys k and values v.
    qkv_shape = mtf.Shape(batch_dims + [heads, length, kv_channels])
    q = mtf.einsum([x, wq], qkv_shape)
    k = mtf.einsum([x, wk], qkv_shape)
    v = mtf.einsum([x, wv], qkv_shape)
    if return_kv is not None:
      return_kv.extend([k, v])

    # Choose a suitable block size.
    # We choose the greatest divisor of length_per_split less than or equal
    # to max(window_size, 128)
    if length_per_split is None:
      length_per_split = length.size
    block_length = max(window_size, 128)
    while length_per_split % block_length != 0:
      block_length -= 1

    query_block_length = mtf.Dimension("query_block_length", block_length)
    memory_block_length = mtf.Dimension("memory_block_length", block_length)
    # The num_blocks dimension gets the same name as the length dimension,
    # so it will be split in the same way.
    num_blocks = mtf.Dimension(length.name, length.size // block_length)
    q_shape = batch_dims + [heads, num_blocks, query_block_length, kv_channels]
    kv_shape = batch_dims + [
        heads, num_blocks, memory_block_length, kv_channels]
    q = mtf.reshape(q, q_shape)
    k = mtf.reshape(k, kv_shape)
    v = mtf.reshape(v, kv_shape)
    # augment the keys and values for each block with keys and values for
    # the previous window_size timesteps.
    k = mtf.left_halo_exchange(k, num_blocks, memory_block_length, window_size)
    v = mtf.left_halo_exchange(v, num_blocks, memory_block_length, window_size)
    padded_memory_block_length = mtf.Dimension(
        "memory_block_length", window_size + block_length)
    mpos = mtf.range(x.mesh, padded_memory_block_length, tf.float32)
    qpos = mtf.range(x.mesh, query_block_length, tf.float32) + window_size
    # prevent looking forward
    mask = mtf.cast(mtf.greater(mpos, qpos), x.dtype) * -1e9
    # prevent looking >=block_length timesteps backward
    mask += mtf.cast(mtf.less_equal(mpos, qpos - block_length), x.dtype) * -1e9
    # Note: The first window_size-1 positions can see back into pre-time
    # where all the keys and values are zero.  We could mask this out, but we
    # don't.
    o = dot_product_attention(q, k, v, mask=mask)
    o = mtf.reshape(o, batch_dims + [heads, length, kv_channels])
    return mtf.einsum([o, wo], mtf.Shape(batch_dims + [length, io_channels]))


def masked_local_attention_1d_incremental(x,
                                          prev_k,
                                          prev_v,
                                          step_num,
                                          master_dtype=None,
                                          slice_dtype=None,
                                          params=None,
                                          name=None):
  """Incremental local self-attention (one decode step).

  Incremental version of masked_local_attention_1d()

  Args:
    x: a mtf.Tensor with shape [batch..., io_channels]
    prev_k: mtf.Tensor with shape
       [batch..., heads, window_length, kv_channels]
    prev_v: mtf.Tensor with shape
       [batch..., heads, window_length, kv_channels]
    step_num: mtf Scalar with dtype tf.int32
    master_dtype: a tf.dtype (deprecated)
    slice_dtype: a tf.dtype (deprecated)
    params: a quadruple of Tensors (see multihead_attention_params())
    name: an optional string.

  Returns:
    y: A mtf.Tensor with shape [batch..., io_channels]
    new_k: mtf.Tensor with shape
       [batch..., heads, window_length, kv_channels]
    new_v: mtf.Tensor with shape
       [batch..., heads, window_length, kv_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
  batch_dims = x.shape.dims[:-1]
  io_channels = x.shape.dims[-1]
  heads, window_length, kv_channels = prev_k.shape.dims[-3:]
  with tf.variable_scope(name, default_name="masked_local_attention_1d"):
    if params is None:
      wq, wk, wv, wo = multihead_attention_vars(
          x.mesh, heads, io_channels, kv_channels,
          master_dtype, slice_dtype, x.dtype)
    else:
      wq, wk, wv, wo = params
    q = mtf.einsum([x, wq], mtf.Shape(batch_dims + [heads, kv_channels]))
    k = mtf.einsum([x, wk], mtf.Shape(batch_dims + [heads, kv_channels]))
    v = mtf.einsum([x, wv], mtf.Shape(batch_dims + [heads, kv_channels]))
    current_position = mtf.equal(
        mtf.range(x.mesh, window_length, dtype=tf.int32),
        mtf.mod(step_num, window_length.size))
    k = mtf.where(current_position, k, prev_k, output_shape=prev_k.shape)
    v = mtf.where(current_position, v, prev_v, output_shape=prev_v.shape)
    o = dot_product_attention(q, k, v, mask=None)
    y = mtf.einsum([o, wo], x.shape)
    return y, k, v


def local_2d_halo_exchange(k, v, num_h_blocks, h_dim,
                           num_w_blocks, w_dim, mask_right):
  """Halo exchange for keys and values for Local 2D attention."""
  for blocks_dim, block_size_dim, halo_size in [
      (num_h_blocks, h_dim, h_dim.size),
      (num_w_blocks, w_dim, w_dim.size)]:
    # shape of k is [num_h_blocks, num_w_blocks, h_dim, w_dim, kv_channels]
    if halo_size > 0:
      if blocks_dim is not None:
        if mask_right:
          k = mtf.left_halo_exchange(k, blocks_dim, block_size_dim, halo_size)
          v = mtf.left_halo_exchange(v, blocks_dim, block_size_dim, halo_size)
        else:
          k = mtf.halo_exchange(k, blocks_dim, block_size_dim, halo_size)
          v = mtf.halo_exchange(v, blocks_dim, block_size_dim, halo_size)
      else:
        if mask_right:
          k = mtf.pad(k, [halo_size, None], block_size_dim.name)
          v = mtf.pad(v, [halo_size, None], block_size_dim.name)
        else:
          k = mtf.pad(k, [halo_size, halo_size], block_size_dim.name)
          v = mtf.pad(v, [halo_size, halo_size], block_size_dim.name)
  return k, v


def local_2d_self_attention_spatial_blocks(query_antecedent,
                                           kv_channels,
                                           heads,
                                           memory_h_dim=None,
                                           memory_w_dim=None,
                                           mask_right=False,
                                           master_dtype=tf.float32,
                                           slice_dtype=tf.float32,
                                           name=None):
  """Attention to the source position and a neighborhood to the left or right.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape [batch, num_h_blocks,
      num_w_blocks, h_dim, w_dim, io_channels] must have the same size as
      query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    memory_h_dim: mtf Dimension, for the memory height block.
    memory_w_dim: mtf Dimension, for the memory width block.
    mask_right: bool, flag specifying whether we mask out attention to the right
      for the decoder.
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    name: an optional string.

  Returns:
    a Tensor of shape
        [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
  with tf.variable_scope(
      name, default_name="multihead_attention", values=[query_antecedent]):

    h_dim, w_dim, io_channels = query_antecedent.shape.dims[-3:]
    batch, num_h_blocks, num_w_blocks = query_antecedent.shape.dims[:3]
    wq, wk, wv, wo = multihead_attention_vars(
        query_antecedent.mesh, heads, io_channels, kv_channels,
        master_dtype, slice_dtype, query_antecedent.dtype)

    # Rename dimensions for the memory height and width.
    memory_antecedent = mtf.rename_dimension(query_antecedent, h_dim.name,
                                             "memory_" + h_dim.name)
    memory_antecedent = mtf.rename_dimension(memory_antecedent, w_dim.name,
                                             "memory_" + w_dim.name)
    memory_h_dim, memory_w_dim = memory_antecedent.shape.dims[-3:-1]

    # Call einsum over the query and memory to get query q, keys k and values v.
    q = mtf.einsum([query_antecedent, wq],
                   mtf.Shape([
                       batch, heads, num_h_blocks, num_w_blocks, h_dim, w_dim,
                       kv_channels
                   ]))
    k = mtf.einsum([memory_antecedent, wk],
                   mtf.Shape([batch, heads, num_h_blocks, num_w_blocks,
                              memory_h_dim, memory_w_dim, kv_channels]))
    v = mtf.einsum([memory_antecedent, wv],
                   mtf.Shape([batch, heads, num_h_blocks, num_w_blocks,
                              memory_h_dim, memory_w_dim, kv_channels]))

    # Halo exchange for memory blocks.
    k, v = local_2d_halo_exchange(k, v, num_h_blocks, memory_h_dim,
                                  num_w_blocks, memory_w_dim, mask_right)

    # Calculate the causal mask to avoid peeking into the future. We compute
    # this once and reuse it for all blocks since the block_size is known.
    mask = None
    if mask_right:
      mask = attention_bias_local_2d_block(query_antecedent.mesh, h_dim, w_dim,
                                           memory_h_dim, memory_w_dim)

    output = dot_product_attention(q, k, v, mask=mask)

    return mtf.einsum(
        [output, wo],
        mtf.Shape(
            [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]))


def rename_length_to_memory_length(
    x, length_name="length", memory_length_name="memory_length"):
  return mtf.rename_dimension(x, length_name, memory_length_name)


def multihead_attention_vars(
    mesh, heads, io_channels, kv_channels,
    master_dtype, slice_dtype, activation_dtype):
  """Deprecated version of multihead_attention_params with combine=True."""
  return multihead_attention_params(
      mesh, heads, io_channels, kv_channels,
      mtf.VariableDType(master_dtype, slice_dtype, activation_dtype),
      combine=True)


def multihead_attention_params(mesh, heads, io_channels, kv_channels,
                               variable_dtype, combine=False):
  """Create Parameters for Multihead Attention.

  If the combine flag is set to True, then we create only one variable
  which stacks together all of the parameters.  Otherwise, we create four
  separate variables.

  Args:
    mesh: a Mesh
    heads: a Dimension
    io_channels: a Dimension
    kv_channels: a Dimension
    variable_dtype: a mtf.VariableDType
    combine: a boolean

  Returns:
    wq: a Tensor with shape [heads, io_channels, kv_channels]
    wk: a Tensor with shape [heads, io_channels, kv_channels]
    wv: a Tensor with shape [heads, io_channels, kv_channels]
    wo: a Tensor with shape [heads, io_channels, kv_channels]
  """
  qkvo = mtf.Dimension("qkvo", 4)
  qk_stddev = (io_channels.size ** -0.5) * (kv_channels.size ** -0.25)
  v_stddev = io_channels.size ** -0.5
  # TODO(noam): should be: o_stddev = (kv_channels.size * heads.size) ** -0.5
  #   verify that this still works and change it.
  o_stddev = (io_channels.size * heads.size) ** -0.5
  if combine:
    def qkvo_initializer(shape,
                         dtype=None,
                         partition_info=None,
                         verify_shape=None):
      del partition_info, verify_shape
      return tf.random_normal(shape, dtype=dtype) * tf.reshape(
          tf.cast([qk_stddev, qk_stddev, v_stddev, o_stddev],
                  dtype or tf.float32), [4, 1, 1, 1])
    var = mtf.get_variable(
        mesh, "qkvo", mtf.Shape([qkvo, heads, io_channels, kv_channels]),
        initializer=qkvo_initializer, dtype=variable_dtype)
    return mtf.unstack(var, qkvo)
  else:
    return [mtf.get_variable(  # pylint: disable=g-complex-comprehension
        mesh, name, mtf.Shape([heads, io_channels, kv_channels]),
        initializer=tf.random_normal_initializer(stddev=stddev),
        dtype=variable_dtype) for name, stddev in zip(
            ["q", "k", "v", "o"],
            [qk_stddev, qk_stddev, v_stddev, o_stddev])]


def dot_product_attention(q,
                          k,
                          v,
                          mask,
                          dropout=0.0,
                          dropout_broadcast_dims=None,
                          extra_logit=None):
  """Dot-product attention.

  Args:
    q: Tensor with shape [...., length_q, depth_k]. Typically leading dimensions
      are [batch, heads].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    mask: mask Tensor (see attention_mask())
    dropout: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor

  Returns:
    Tensor with shape [..., length_q, depth_v].
  """
  length_kv = k.shape.dims[-2]
  logits_shape = mtf.Shape(q.shape.dims[:-1] + [length_kv])
  logits = mtf.einsum([q, k], logits_shape)
  if mask is not None:
    logits += mask
  weights = mtf.softmax(logits, length_kv, extra_logit=extra_logit)
  if dropout != 0.0:
    weights = mtf.dropout(
        weights, 1.0 - dropout,
        noise_shape=weights.shape - dropout_broadcast_dims)
  depth_v = v.shape.dims[-1]
  outputs_shape = mtf.Shape(q.shape.dims[:-1] + [depth_v])
  outputs = mtf.einsum([weights, v], outputs_shape)
  return outputs


def multihead_attention(query_antecedent,
                        memory_antecedent,
                        mask,
                        kv_channels,
                        heads,
                        dropout=0.0,
                        dropout_broadcast_dims=None,
                        master_dtype=tf.float32,
                        slice_dtype=tf.float32,
                        name="multihead_attention"):
  """Multihead scaled-dot-product attention with input/output transformations.

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  Args:
    query_antecedent: a mtf.Tensor with shape
      [<batch_dims>, query_length, io_channels]
    memory_antecedent: a mtf.Tensor with shape
      [batch, memory_length, io_channels] (optional)
    mask: mask Tensor (see attention_mask())
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    dropout: a floating point value
    dropout_broadcast_dims: an optional list of mtf.Dimension
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    name: an optional string.

  Returns:
    A mtf.Tensor with shape [batch, query_length, io_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
  batch_dims = query_antecedent.shape.dims[:-2]
  query_length, io_channels = query_antecedent.shape.dims[-2:]
  with tf.variable_scope(name,
                         default_name="multihead_attention",
                         values=[query_antecedent, memory_antecedent]):
    wq, wk, wv, wo = multihead_attention_vars(
        query_antecedent.mesh, heads, io_channels, kv_channels,
        master_dtype, slice_dtype, query_antecedent.dtype)
    if memory_antecedent is None:
      memory_antecedent = rename_length_to_memory_length(
          query_antecedent, query_length.name)
    memory_batch_dims = memory_antecedent.shape.dims[:-2]
    memory_length, memory_channels = memory_antecedent.shape.dims[-2:]
    if memory_batch_dims != batch_dims:
      raise ValueError("memory batch must equal query batch")
    if memory_channels != io_channels:
      raise ValueError("memory channels must equal query channels")
    q = mtf.einsum(
        [query_antecedent, wq],
        mtf.Shape(batch_dims + [heads, query_length, kv_channels]))
    k = mtf.einsum(
        [memory_antecedent, wk],
        mtf.Shape(batch_dims + [heads, memory_length, kv_channels]))
    v = mtf.einsum(
        [memory_antecedent, wv],
        mtf.Shape(batch_dims + [heads, memory_length, kv_channels]))
    o = dot_product_attention(
        q, k, v, mask, dropout, dropout_broadcast_dims)
    return mtf.einsum(
        [o, wo], mtf.Shape(batch_dims + [query_length, io_channels]))


def multihead_self_attention_incremental(query_antecedent,
                                         prev_k,
                                         prev_v,
                                         step_num,
                                         master_dtype,
                                         slice_dtype,
                                         name="multihead_attention"):
  """Incremental self-attention (one decode step).

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  Args:
    query_antecedent: a mtf.Tensor with shape [batch..., io_channels]
    prev_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    prev_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    step_num: mtf Scalar with dtype tf.int32
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    name: an optional string.

  Returns:
    y: A mtf.Tensor with shape [batch..., io_channels]
    new_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]
    new_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
  batch_dims = query_antecedent.shape.dims[:-1]
  io_channels = query_antecedent.shape.dims[-1]
  heads, memory_length, kv_channels = prev_k.shape.dims[-3:]
  with tf.variable_scope(name, default_name="multihead_attention"):
    wq, wk, wv, wo = multihead_attention_vars(
        query_antecedent.mesh, heads, io_channels, kv_channels,
        master_dtype, slice_dtype, query_antecedent.dtype)
    memory_antecedent = query_antecedent
    q = mtf.einsum(
        [query_antecedent, wq],
        mtf.Shape(batch_dims + [heads, kv_channels]))
    k = mtf.einsum(
        [memory_antecedent, wk],
        mtf.Shape(batch_dims + [heads, kv_channels]))
    v = mtf.einsum(
        [memory_antecedent, wv],
        mtf.Shape(batch_dims + [heads, kv_channels]))
    k = prev_k + mtf.multiply(
        k, mtf.one_hot(step_num, memory_length, dtype=prev_k.dtype),
        output_shape=prev_k.shape)
    v = prev_v + mtf.multiply(
        v, mtf.one_hot(step_num, memory_length, dtype=prev_v.dtype),
        output_shape=prev_v.shape)

    mask = mtf.cast(
        mtf.greater(mtf.range(
            query_antecedent.mesh, memory_length, dtype=tf.int32), step_num),
        q.dtype) * -1e9
    o = dot_product_attention(q, k, v, mask)
    y = mtf.einsum([o, wo], query_antecedent.shape)
    return y, k, v


def multihead_encdec_attention_incremental(query_antecedent,
                                           wq, wo, k, v,
                                           mask,
                                           name="multihead_attention"):
  """Incremental attention over encoder (one decode step).

  In order to use only one variable containing the four weight matrices
  packed together, we insist that the query and memory antecedents have the
  same dimensionality (io_channels) and that the keys and values have the
  same dimensionality (kv_channels).

  memory_dims is a subset of query_dims

  Args:
    query_antecedent: a mtf.Tensor with shape query_dims + [io_channels]
    wq: a mtf.Tensor with shape [heads, io_channels, kv_channels]
    wo: a mtf.Tensor with shape [heads, io_channels, kv_channels]
    k: memory_dims + [heads, memory_length, kv_channels]
    v: memory_dims + [heads, memory_length, kv_channels]
    mask: mask Tensor (see attention_mask())
    name: an optional string.

  Returns:
    A mtf.Tensor with shape [batch, qlen, io_channels]
  """
  heads, _, kv_channels = k.shape.dims[-3:]
  query_dims = query_antecedent.shape.dims[:-1]
  with tf.variable_scope(name, default_name="multihead_attention"):
    q = mtf.einsum(
        [query_antecedent, wq],
        mtf.Shape(query_dims + [heads, kv_channels]))
    o = dot_product_attention(q, k, v, mask)
    return mtf.einsum([o, wo], query_antecedent.shape)


def attention_mask_ignore_padding(inputs, dtype=tf.float32):
  """Bias for encoder-decoder attention.

  Args:
    inputs: a mtf.Tensor with shape [..., length_dim]
    dtype: a tf.dtype

  Returns:
    a mtf.Tensor with shape [..., memory_length_dim]
  """
  inputs = rename_length_to_memory_length(inputs)
  return mtf.cast(mtf.equal(inputs, 0), dtype) * -1e9


def attention_mask_autoregressive(query_pos, dtype=tf.float32):
  """Bias for self-attention where attention to the right is disallowed.

  Args:
    query_pos: a mtf.Tensor with shape [..., length_dim]
    dtype: a tf.dtype

  Returns:
    a mtf.Tensor with shape [..., length_dim, memory_length_dim]
  """
  memory_pos = rename_length_to_memory_length(query_pos)
  return mtf.cast(mtf.less(query_pos, memory_pos), dtype) * -1e9


def attention_mask_same_segment(
    query_segment, memory_segment=None, dtype=tf.float32):
  """Bias for attention where attention between segments is disallowed.

  Args:
    query_segment: a mtf.Tensor with shape [..., length_dim]
    memory_segment: a mtf.Tensor with shape [..., memory_length_dim]
    dtype: a tf.dtype

  Returns:
    a mtf.Tensor with shape [..., length_dim, memory_length_dim]
  """
  memory_segment = rename_length_to_memory_length(
      memory_segment or query_segment)
  return mtf.cast(mtf.not_equal(query_segment, memory_segment), dtype) * -1e9


def attention_bias_local_block(mesh, block_length, memory_length,
                               dtype=tf.int32):
  """Bias for attention for local blocks where attention to right is disallowed.

  Create the bias matrix by using two separate masks, one for the memory part
  which doesn't overlap with the query and second which interacts with the query
  and should be disallowed to look to the right of the current query position.

  Args:
    mesh: a MeshTensorflow object
    block_length: a mtf.Dimension
    memory_length: a mtf.Dimension
    dtype: a tf.dtype

  Returns:
    a mtf.Tensor with shape [block_length, memory_length]
  """
  memory_length = mtf.Dimension(memory_length.name, block_length.size)
  memory_mask = mtf.zeros(mesh, [block_length, memory_length], dtype=dtype)

  mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype),
                           mtf.range(mesh, memory_length, dtype=dtype)),
                  dtype=dtype)
  mask = mtf.cast(
      mtf.concat([memory_mask, mask], memory_length.name),
      dtype=tf.float32) * -1e9
  return mask


def attention_bias_local_2d_block(mesh,
                                  h_dim,
                                  w_dim,
                                  memory_h_dim,
                                  memory_w_dim,
                                  dtype=tf.int32):
  """Bias for attention for local blocks where attention to right is disallowed.

  Create the bias matrix by using two separate masks, one for the memory part
  which doesn't overlap with the query and second which interacts with the query
  and should be disallowed to look to the right of the current query position.

  Args:
    mesh: a MeshTensorflow object
    h_dim: a mtf.Dimension
    w_dim: a mtf.Dimension
    memory_h_dim: a mtf.Dimension
    memory_w_dim: a mtf.Dimension
    dtype: a tf.dtype

  Returns:
    a mtf.Tensor with shape [block_length, memory_length]
  """
  memory_height = mtf.Dimension(memory_h_dim.name, h_dim.size)
  memory_width = mtf.Dimension(memory_w_dim.name, w_dim.size)
  mask_top_visible = mtf.zeros(mesh, [h_dim, memory_height], dtype=dtype)
  mask_left_visible = mtf.zeros(mesh, [w_dim, memory_width], dtype=dtype)
  mask_query = mtf.greater(
      mtf.range(mesh, memory_height, dtype=tf.int32),
      mtf.range(mesh, memory_width, dtype=dtype))
  width_mask = mtf.concat([mask_left_visible, mask_query], memory_width.name)
  mask = mtf.cast(
      mtf.concat([mask_top_visible, width_mask], memory_height.name),
      dtype=tf.float32) * -1e9
  return mask


def multiplicative_jitter(x, epsilon=1e-2):
  """Multiply values by a random number between 1-epsilon and 1+epsilon.

  Makes models more resilient to rounding errors introduced by bfloat16.
  This seems particularly important for logits.

  Args:
    x: a mtf.Tensor
    epsilon: a floating point value

  Returns:
    a mtf.Tensor with the same type and shape as x.
  """
  if epsilon == 0:
    return x
  return x * mtf.random_uniform(
      x.mesh, x.shape, minval=1.0 - epsilon, maxval=1.0+epsilon, dtype=x.dtype)


def multihead_self_attention_memory_compressed(x,
                                               mask_right,
                                               compression_factor,
                                               kv_channels,
                                               heads,
                                               dropout=0.0,
                                               dropout_broadcast_dims=None,
                                               master_dtype=tf.float32,
                                               slice_dtype=tf.float32,
                                               name="multihead_attention"):
  """Memory-compressed self-attention.

  The memory is first average-pooled (strided) to make it shorter by
  a factor of compression_factor.

  Args:
    x: a mtf.Tensor with shape
      [<batch_dims>, query_length, io_channels]
    mask_right: a boolean
    compression_factor: an integer
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    dropout: a floating point value
    dropout_broadcast_dims: an optional list of mtf.Dimension
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    name: an optional string.

  Returns:
    A mtf.Tensor with shape [batch, query_length, io_channels]

  Raises:
    ValueError: if the dimensions do not match.
  """
  batch_dims = x.shape.dims[:-2]
  length, io_channels = x.shape.dims[-2:]
  with tf.variable_scope(name,
                         default_name="compressed_attention",
                         values=[x]):
    wq, wk, wv, wo = multihead_attention_vars(
        x.mesh, heads, io_channels, kv_channels,
        master_dtype, slice_dtype, x.dtype)
    memory_antecedent = compress_mean(x, length, compression_factor)
    memory_antecedent = rename_length_to_memory_length(memory_antecedent)
    memory_length = memory_antecedent.shape.dims[-2]
    q = mtf.einsum(
        [x, wq],
        mtf.Shape(batch_dims + [heads, length, kv_channels]))
    k = mtf.einsum(
        [memory_antecedent, wk],
        mtf.Shape(batch_dims + [heads, memory_length, kv_channels]))
    v = mtf.einsum(
        [memory_antecedent, wv],
        mtf.Shape(batch_dims + [heads, memory_length, kv_channels]))
    if mask_right:
      query_pos = mtf.range(x.mesh, length, dtype=tf.int32)
      memory_pos = (
          mtf.range(x.mesh, memory_length, dtype=tf.int32) * compression_factor
          + (compression_factor - 1))
      mask = mtf.cast(mtf.greater(memory_pos, query_pos), x.dtype) * -1e9
    else:
      mask = None
    o = dot_product_attention(
        q, k, v, mask, dropout, dropout_broadcast_dims, extra_logit=0.0)
    return mtf.einsum(
        [o, wo], mtf.Shape(batch_dims + [length, io_channels]))


def compress_mean(x, dim, compression_factor):
  """Compress by taking group means.

  Args:
    x: a Tensor
    dim: a dimension in x.shape
    compression_factor: an integer

  Returns:
    a Tensor
  """
  dims = x.shape.dims
  pos = dims.index(dim)
  compressed_dim = mtf.Dimension(dim.name, dim.size // compression_factor)
  compression_factor_dim = mtf.Dimension(
      "compression_factor", compression_factor)
  new_shape = (
      dims[:pos] + [compressed_dim, compression_factor_dim] + dims[pos + 1:])
  x = mtf.reshape(x, new_shape)
  x = mtf.reduce_mean(x, reduced_dim=compression_factor_dim)
  return x


def embedding_weights(mesh,
                      vocab_dim,
                      output_dim,
                      variable_dtype,
                      name="embedding",
                      ensemble_dim=None,
                      initializer=None):
  """Embedding weights."""
  if not ensemble_dim:
    ensemble_dim = []
  elif not isinstance(ensemble_dim, list):
    ensemble_dim = [ensemble_dim]
  shape = mtf.Shape(ensemble_dim) + [vocab_dim, output_dim]
  if initializer is None:
    initializer = tf.random_normal_initializer()
  ret = mtf.get_variable(
      mesh, name, shape, dtype=variable_dtype, initializer=initializer)
  return ret


def embedding(indices, vocab_dim, output_dim, variable_dtype, name="embedding"):
  """Embedding layer."""
  weights = embedding_weights(
      indices.mesh, vocab_dim, output_dim, variable_dtype, name)
  return mtf.gather(weights, indices, vocab_dim)


def max_pool2d(x, ksize=(2, 2), name="max_pool2d"):
  """2D max pooling.

  Pooling is applied on the HW dimensions. We assume the dimensions of x is
  [NHWC]. There can be multiple batch dimensions, e.g., [10, 4, 4, 10, 10, 3].
  Currently we only support unoverlapping pooling: strides == ksize. Also the
  input HW dimensions must be divisible by ksize.

  Args:
    x: a Tensor
    ksize: kernel size. A list or tuple
    name: an optional string

  Returns:
    a Tensor
  """
  return x if tuple(ksize) == (1, 1) else mtf.PoolOperation(
      x, ksize, strides=ksize, pool_fn_string="MAX_2D", name=name).outputs[0]


def max_pool3d(x, ksize=(2, 2, 2), name="max_pool3d"):
  """3D max pooling.

  Pooling is applied on the DHW dimensions. We assume the dimensions of x is
  [NDHWC]. There can be multiple batch dimensions, e.g.,
  [10, 4, 4, 10, 10, 10, 3].
  Currently we only support unoverlapping pooling: strides == ksize. Also the
  input DHW dimensions must be divisible by ksize.

  Args:
    x: a Tensor
    ksize: kernel size. A list or tuple
    name: an optional string

  Returns:
    a Tensor
  """
  return x if tuple(ksize) == (1, 1, 1) else mtf.PoolOperation(
      x, ksize, strides=ksize, pool_fn_string="MAX_3D", name=name).outputs[0]


def avg_pool2d(x, ksize=(2, 2), name="avg_pool2d"):
  """2D average pooling.

  Pooling is applied on the HW dimensions. We assume the dimensions of x is
  [NHWC]. There can be multiple batch dimensions, e.g., [10, 4, 4, 10, 10, 3].
  Currently we only support unoverlapping pooling: strides == ksize. Also the
  input HW dimensions must be divisible by ksize.

  Args:
    x: a Tensor
    ksize: kernel size. A list or tuple
    name: an optional string

  Returns:
    a Tensor
  """
  return x if tuple(ksize) == (1, 1) else mtf.PoolOperation(
      x, ksize, strides=ksize, pool_fn_string="AVG_2D", name=name).outputs[0]


def avg_pool3d(x, ksize=(2, 2, 2), name="avg_pool3d"):
  """3D average pooling.

  Pooling is applied on the DHW dimensions. We assume the dimensions of x is
  [NDHWC]. There can be multiple batch dimensions, e.g.,
  [10, 4, 4, 10, 10, 10, 3].
  Currently we only support unoverlapping pooling: strides == ksize. Also the
  input DHW dimensions must be divisible by ksize.

  Args:
    x: a Tensor
    ksize: kernel size. A list or tuple
    name: an optional string

  Returns:
    a Tensor
  """
  return x if tuple(ksize) == (1, 1, 1) else mtf.PoolOperation(
      x, ksize, strides=ksize, pool_fn_string="AVG_3D", name=name).outputs[0]


def _reversible_half_residual_grad(
    explicit_inputs, all_inputs, forward_operations, outputs, output_grads):
  """Backpropagation function for a revnet."""
  x1, _, x2, _ = explicit_inputs
  extra_inputs = all_inputs[len(explicit_inputs):]
  _, _, y1, _ = outputs
  dy2, dy2_backwards, dy1, dy1_backwards = output_grads
  # last operation should be an addition to produce y1
  if not isinstance(forward_operations[-1], mtf.AddOperation):
    raise ValueError("expected an addition here")
  f_ops = forward_operations[:-1]
  orig_fx2 = f_ops[-1].outputs[0]
  orig_x2 = x2
  if dy2_backwards is not None:
    x2 = dy2_backwards
  if dy1_backwards is not None:
    y1 = dy1_backwards
  graph = all_inputs[0].graph
  f_again_ops, mapping = graph.clone_operations(f_ops, {orig_x2: x2})
  fx2 = mapping[orig_fx2]
  x1 = y1 - fx2
  grads = mtf.gradients(ys=[fx2], xs=[x2] + extra_inputs, grad_ys=[dy1],
                        operations=f_again_ops)
  dx2 = dy2 + grads[0]
  extra_inputs_grads = grads[1:]
  dx1 = dy1
  return [dx1, x1, dx2, x2] + extra_inputs_grads


def _half_residual_and_swap(x1, x1_backwards, x2, x2_backwards, f=None):
  return x2, x2_backwards, x1 + f(x2), x1_backwards


def reversible_half_residual_and_swap(x1,
                                      x1_backwards,
                                      x2,
                                      x2_backwards,
                                      f,
                                      recompute_grads=True):
  """Building block of a revnet.

  https://arxiv.org/abs/1707.04585

  All the inputs and output Tensors have the same shape and dtype.

  The forward computation is:
    y1 = x1 + f(x2)
    y2 = x2

  The x1_backwards and x2_backwards tensors are used by backpropagation.
  None should be passed for the first layer, then the outputs of each layer
  should be passed to the next.

  Example usage:
  x1, x1_backwards, x2, x2_backwards = x, None, x, None
  for f in my_functions:
    x1, x1_backwards, x2, x2_backwards = mtf.layers.reversible_half_residual(
      x1, x1_backwards, x2, x2_backwards)
  y = (x1 + x2) / 2

  Args:
    x1: a Tensor
    x1_backwards: a Tensor or None
    x2: a Tensor
    x2_backwards: a Tensor or None
    f: a function from Tensor to Tensor
    recompute_grads: a boolean
  Returns:
    y2: a Tensor
    y2_backwards: a Tensor
    y1: a Tensor
    y1_backwards: a Tensor
  """
  if recompute_grads:
    if x1_backwards is None:
      x1_backwards = mtf.zeros_like(x1)
    if x2_backwards is None:
      x2_backwards = mtf.zeros_like(x2)
    return mtf.custom_gradient(
        functools.partial(_half_residual_and_swap, f=f),
        _reversible_half_residual_grad,
        [x1, x1_backwards, x2, x2_backwards])
  else:
    return _half_residual_and_swap(x1, x1_backwards, x2, x2_backwards, f)