# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Discretization bottlenecks used to train discrete latent variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial

from tensor2tensor.layers import common_layers

import tensorflow as tf

from tensorflow.python.training import moving_averages


def project_hidden(x, projection_tensors, hidden_size, num_blocks):
  """Project encoder hidden state into block_dim using projection tensors.

  Args:
    x: Encoder hidden state of shape [-1, hidden_size].
    projection_tensors: Projection tensors used to project the hidden state.
    hidden_size: Dimension of the latent space.
    num_blocks: Number of blocks in DVQ.

  Returns:
    Projected states of shape [-1, num_blocks, block_dim].
  """
  x = tf.reshape(x, shape=[1, -1, hidden_size])
  x_tiled = tf.reshape(
      tf.tile(x, multiples=[num_blocks, 1, 1]),
      shape=[num_blocks, -1, hidden_size])
  x_projected = tf.matmul(x_tiled, projection_tensors)
  x_projected = tf.transpose(x_projected, perm=[1, 0, 2])
  return x_projected


def slice_hidden(x, hidden_size, num_blocks):
  """Slice encoder hidden state into block_dim.

  Args:
    x: Encoder hidden state of shape [-1, hidden_size].
    hidden_size: Dimension of the latent space.
    num_blocks: Number of blocks in DVQ.

  Returns:
    Sliced states of shape [-1, num_blocks, block_dim].
  """
  block_dim = int(hidden_size // num_blocks)
  x_sliced = tf.reshape(x, shape=[-1, num_blocks, block_dim])
  return x_sliced


def nearest_neighbor(x,
                     means,
                     block_v_size,
                     random_top_k=1,
                     soft_em=False,
                     num_samples=1):
  """Find the nearest element in means to elements in x.

  Args:
    x: Batch of encoder continuous latent states sliced/projected into shape
      [-1, num_blocks, block_dim].
    means: Embedding table of shpae [num_blocks, block_v_size, block_dim].
    block_v_size: Number of table entries per block.
    random_top_k: Noisy top-k if this is bigger than 1 (Default: 1).
    soft_em: If True then use soft EM rather than hard EM (Default: False).
    num_samples: Number of samples to take in soft EM (Default: 1).

  Returns:
    Tensor with nearest element in mean encoded in one-hot notation
    and distances.
  """
  x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keep_dims=True)
  means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keep_dims=True)
  scalar_prod = tf.matmul(
      tf.transpose(x, perm=[1, 0, 2]), tf.transpose(means, perm=[0, 2, 1]))
  scalar_prod = tf.transpose(scalar_prod, perm=[1, 0, 2])
  dist = x_norm_sq + tf.transpose(
      means_norm_sq, perm=[2, 0, 1]) - 2 * scalar_prod

  # computing cluster probabilities
  if soft_em:
    num_blocks = common_layers.shape_list(dist)[1]
    nearest_idx = tf.stack(
        [
            tf.multinomial(-dist[:, i, :], num_samples=num_samples)
            for i in range(num_blocks)
        ],
        axis=1)
    nearest_hot = tf.one_hot(nearest_idx, depth=block_v_size)
    nearest_hot = tf.reduce_mean(nearest_hot, axis=-2)
  else:
    if random_top_k > 1:
      _, top_k_idx = tf.nn.top_k(-dist, k=random_top_k)
      nearest_idx = tf.gather(
          top_k_idx,
          tf.random_uniform(
              [1], minval=0, maxval=random_top_k - 1, dtype=tf.int32),
          axis=-1)
    else:
      nearest_idx = tf.argmax(-dist, axis=-1)
    nearest_hot = tf.one_hot(nearest_idx, block_v_size)
  return nearest_hot


def embedding_lookup(x,
                     means,
                     num_blocks,
                     block_v_size,
                     random_top_k=1,
                     soft_em=False,
                     num_samples=1):
  """Compute nearest neighbors and loss for training the embeddings via DVQ.

  Args:
    x: Batch of encoder continuous latent states sliced/projected into shape
      [-1, num_blocks, block_dim].
    means: Embedding table of shape [num_blocks, block_v_size, block_dim].
    num_blocks: Number of blocks in DVQ.
    block_v_size: Number of table entries per block.
    random_top_k: Noisy top-k if this is bigger than 1 (Default: 1).
    soft_em: If True then use soft EM rather than hard EM (Default: False).
    num_samples: Number of samples to use for soft EM (Default: 1).

  Returns:
    The nearest neighbor in one hot form, the nearest neighbor itself, the
    commitment loss, embedding training loss and distances.
  """
  x_means_hot = nearest_neighbor(
      x,
      means,
      block_v_size,
      random_top_k,
      soft_em=soft_em,
      num_samples=num_samples)
  x_means_hot_flat = tf.reshape(x_means_hot, [-1, num_blocks, block_v_size])
  x_means = tf.matmul(tf.transpose(x_means_hot_flat, perm=[1, 0, 2]), means)
  x_means = tf.transpose(x_means, [1, 0, 2])
  q_loss = tf.reduce_mean(tf.square((tf.stop_gradient(x) - x_means)))
  e_loss = tf.reduce_mean(tf.square(x - tf.stop_gradient(x_means)))
  return x_means_hot, x_means, q_loss, e_loss


def bit_to_int(x_bit, num_bits, base=2):
  """Turn x_bit representing numbers bitwise (lower-endian) to int tensor.

  Args:
    x_bit: Tensor containing numbers in a particular base to be converted to
      int.
    num_bits: Number of bits in the representation.
    base: Base of the representation.

  Returns:
    Integer representation of this number.
  """
  x_l = tf.stop_gradient(tf.to_int32(tf.reshape(x_bit, [-1, num_bits])))
  x_labels = []
  for i in range(num_bits):
    x_labels.append(x_l[:, i] * tf.to_int32(base)**tf.to_int32(i))
  res = sum(x_labels)
  return tf.to_int32(tf.reshape(res, common_layers.shape_list(x_bit)[:-1]))


def int_to_bit(x_int, num_bits, base=2):
  """Turn x_int representing numbers into a bitwise (lower-endian) tensor.

  Args:
    x_int: Tensor containing integer to be converted into base notation.
    num_bits: Number of bits in the representation.
    base: Base of the representation.

  Returns:
    Corresponding number expressed in base.
  """
  x_l = tf.to_int32(tf.expand_dims(x_int, axis=-1))
  x_labels = []
  for i in range(num_bits):
    x_labels.append(
        tf.floormod(
            tf.floordiv(tf.to_int32(x_l),
                        tf.to_int32(base)**i), tf.to_int32(base)))
  res = tf.concat(x_labels, axis=-1)
  return tf.to_float(res)


def int_to_bit_embed(x_int, num_bits, embedding_size, base=2):
  """Turn x_int into a bitwise (lower-endian) tensor and embed densly."""
  shape = common_layers.shape_list(x_int)
  inputs = int_to_bit(x_int, num_bits, base=base)
  inputs = tf.reshape(inputs, shape[:-1] + [shape[-1] * 8])
  inputs = 2.0 * tf.to_float(inputs) - 1.0  # Move from 0/1 to -1/1.
  return tf.layers.dense(inputs, embedding_size, name="int_to_bit_embed")


def embed(x,
          hidden_size,
          z_size,
          filter_size,
          name,
          bottleneck_kind="dvq",
          soft_em=False,
          num_blocks=2,
          num_residuals=1,
          block_v_size=None,
          means=None):
  """Embedding function that takes discrete latent and returns embedding.

  Args:
    x: Input to the discretization bottleneck.
    hidden_size: Dimension of the latent state.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    filter_size: Filter size to be used for the embedding function.
    name: Name for the bottleneck scope.
    bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
      semhash, gumbel-softmax (Default: dvq).
    soft_em: If True then it uses a multi-sample version of EM (Default: False).
    num_blocks: Number of blocks in DVQ (Default: 2).
    num_residuals: Number of residuals (Default: 1).
    block_v_size: Number of embedding entries per block (Default: None).
    means: The embedding table for dvq (Default: None).

  Returns:
    Continuous embedding to be passed on to the decoder.

  Raises:
    ValueError: For unknown or missing arguments.
  """
  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    if bottleneck_kind == "semhash":
      c = int_to_bit(x, z_size)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      h1 = h1a + h1b
    elif bottleneck_kind == "gumbel-softmax":
      hot = tf.one_hot(x, 2**z_size)
      h1 = tf.layers.dense(hot, hidden_size, name="dae_dense")
    elif bottleneck_kind == "dvq":
      if block_v_size is None:
        raise ValueError("Bottleneck kind is dvq but block_v_size is None.")

      if soft_em:
        assert num_residuals == 1
        x_hot_flat = tf.reshape(x, shape=[-1, num_blocks, block_v_size])
        h1 = tf.matmul(tf.transpose(x_hot_flat, perm=[1, 0, 2]), means[0])
        h1 = tf.transpose(h1, perm=[1, 0, 2])
        new_shape = common_layers.shape_list(x)
        new_shape[-1] = hidden_size
        h1 = tf.reshape(h1, shape=new_shape)
      else:
        shape_x = common_layers.shape_list(x)
        x_flat = tf.reshape(x, [-1, 1])
        c = int_to_bit(x_flat, num_bits=z_size, base=2)
        shape = common_layers.shape_list(c)
        new_shape = shape
        new_shape[-1] = num_residuals
        new_shape.append(num_blocks)
        new_shape.append(int(z_size / (num_residuals * num_blocks)))
        c = tf.to_int32(tf.reshape(c, shape=new_shape))
        h1_shape = shape_x
        h1_shape.append(hidden_size)
        h1 = tf.zeros(dtype=tf.float32, shape=h1_shape)
        for i in range(num_residuals):
          c_residual = bit_to_int(
              c[:, :, i, :, :],
              num_bits=int(z_size / (num_residuals * num_blocks)),
              base=2)
          c_hot = tf.one_hot(c_residual, depth=block_v_size, axis=-1)
          c_hot_flat = tf.reshape(c_hot, shape=[-1, num_blocks, block_v_size])
          h1_residual = tf.matmul(
              tf.transpose(c_hot_flat, perm=[1, 0, 2]), means[i])
          h1_residual = tf.transpose(h1_residual, perm=[1, 0, 2])
          h1_residual = tf.reshape(h1_residual, shape=h1_shape)
          h1 += h1_residual
    elif bottleneck_kind == "rounding":
      h1 = x
    else:
      raise ValueError("Unknown bottleneck kind.")

    return h1


def vae(x, name, z_size):
  """Simple variational autoencoder without discretization.

  Args:
    x: Input to the discretization bottleneck.
    name: Name for the bottleneck scope.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.

  Returns:
    Embedding function, latent, loss, mu and log_simga.
  """
  with tf.variable_scope(name):
    mu = tf.layers.dense(x, z_size, name="mu")
    log_sigma = tf.layers.dense(x, z_size, name="log_sigma")
    shape = common_layers.shape_list(x)
    epsilon = tf.random_normal([shape[0], shape[1], 1, z_size])
    z = mu + tf.exp(log_sigma / 2) * epsilon
    kl = 0.5 * tf.reduce_mean(
        tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma, axis=-1)
    free_bits = z_size // 4
    kl_loss = tf.reduce_mean(tf.maximum(kl - free_bits, 0.0))
    return z, kl_loss, mu, log_sigma


def top_k_softmax(x, k):
  """Calculate softmax(x), select top-k and rescale to sum to 1.

  Args:
    x: Input to softmax over.
    k: Number of top-k to select.

  Returns:
    softmax(x) and maximum item.
  """
  x = tf.nn.softmax(x)
  top_x, _ = tf.nn.top_k(x, k=k + 1)
  min_top = tf.reduce_min(top_x, axis=-1, keep_dims=True)
  x = tf.nn.relu((x - min_top) + 1e-12)
  x /= tf.reduce_sum(x, axis=-1, keep_dims=True)
  return x, tf.reduce_max(top_x, axis=-1)


def gumbel_sample(shape):
  """Sample from the Gumbel distribution, protect from overflows.

  Args:
    shape: Shape of Gumbel samples.

  Returns:
    Noise drawn from Gumbel distribution.
  """
  uniform_samples = tf.random_uniform(shape, minval=0.00001, maxval=0.99998)
  return -tf.log(-tf.log(uniform_samples))


def gumbel_softmax(x,
                   name,
                   z_size,
                   mode,
                   softmax_k=0,
                   kl_warmup_steps=150000,
                   summary=True):
  """Gumbel softmax discretization bottleneck.

  Args:
    x: Input to the discretization bottleneck.
    name: Name for the bottleneck scope.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding function, discrete code and loss.
  """
  with tf.variable_scope(name):
    m = tf.layers.dense(x, 2**z_size, name="mask")
    if softmax_k > 0:
      m, kl = top_k_softmax(m, softmax_k)
      return m, m, 1.0 - tf.reduce_mean(kl)
    logsm = tf.nn.log_softmax(m)

    # Gumbel-softmax sample.
    gumbel_samples = gumbel_sample(common_layers.shape_list(m))
    steps = kl_warmup_steps
    gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
    temperature = 1.2 - common_layers.inverse_lin_decay(steps)

    # 10% of the time keep reasonably high temperature to keep learning.
    temperature = tf.cond(
        tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
        lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
    s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
    m = tf.nn.softmax(m)
    kl = -tf.reduce_max(logsm, axis=-1)

    if summary:
      tf.summary.histogram("max-log", tf.reshape(kl, [-1]))

    # Calculate the argmax and construct hot vectors.
    maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
    maxvhot = tf.stop_gradient(tf.one_hot(maxvec, 2**z_size))

    # Add losses that prevent too few being used.
    distrib = tf.reshape(logsm, [-1, 2**z_size]) * maxvhot
    d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
    d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
    d_dev = -tf.reduce_mean(d_variance)
    ret = s

    if mode != tf.contrib.learn.ModeKeys.TRAIN:
      ret = tf.reshape(maxvhot, common_layers.shape_list(s))  # Just hot @eval.
    return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002


def discrete_bottleneck(x,
                        hidden_size,
                        z_size,
                        filter_size,
                        name,
                        mode=None,
                        startup_steps=50000,
                        bottleneck_kind="dvq",
                        num_blocks=2,
                        num_residuals=1,
                        reshape_method="slice",
                        projection_tensors=None,
                        means=None,
                        beta=0.25,
                        noise_dev=1.,
                        decay=0.999,
                        discrete_mix=0.5,
                        random_top_k=1,
                        soft_em=False,
                        num_samples=1,
                        epsilon=1e-5,
                        softmax_k=0,
                        kl_warmup_steps=150000,
                        ema=True,
                        ema_count=None,
                        ema_means=None,
                        summary=True):
  """Discretization bottleneck for latent variables.

  Args:
    x: Input to the discretization bottleneck.
    hidden_size: Dimension of the latent state.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    filter_size: Filter size to be used for the embedding function.
    name: Name for the bottleneck scope.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    startup_steps: Number of steps after which latent predictor is trained
      (Default: 50000).
    bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
      semhash, gumbel-softmax (Default: dvq).
    num_blocks: Number of blocks to use for decomposed vector
      quantization (Default: 2).
    num_residuals: Number of residual units used to compute nearest
      neighbors (Default: 1).
    reshape_method: Method to reshape for DVQ (Default: slice).
    projection_tensors: If the reshape method is project, then these are the
      tensors used to project (Default: None).
    means: The embedding table for dvq (Default: None).
    beta: Beta factor for the DVQ loss (Default: 0.25).
    noise_dev: Stddev for noise added for semhash (Default: 0).
    decay: Decay factor for the exponential moving average (Default: 0.999).
    discrete_mix: Factor for mixing discrete and non-discrete input for semhash
      (Default: 0.5).
    random_top_k: Noisy top-k for DVQ (Default: 1).
    soft_em: If True then use soft EM rather than hard EM (Default: False).
    num_samples: Number of samples for soft EM (Default: 1).
    epsilon: Epsilon parameter for DVQ (Default: 1e-5).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    ema: If True update embeddings using exponential moving averages (Default:
      True).
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to (Default: None).
    ema_means: Exponentially averaged version of the embeddings (Default: None).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding to pass to the decoder, discrete latent, loss, and the embedding
    function.

  Raises:
    ValueError: If projection_tensors is None for reshape_method project, or
    ema_count or ema_means is None if we are using ema, or unknown args.
  """
  block_v_size = None
  if bottleneck_kind == "dvq":
    # Define the dvq parameters
    assert means is not None

    # Check block dimensions add up
    if hidden_size % num_blocks != 0:
      raise ValueError("num_blocks does not divide hidden size")

    if z_size % num_residuals != 0:
      raise ValueError("num_residuals does not divide embedding table size")

    z_size_per_residual = int(z_size / num_residuals)

    if z_size_per_residual % num_blocks != 0:
      raise ValueError("num_blocks does not divide embedding table size")

    block_v_size = 2**(z_size_per_residual / num_blocks)
    block_v_size = int(block_v_size)

    # Set the reshape method corresponding to projections or slices
    if reshape_method == "slice":
      reshape_fn = partial(
          slice_hidden, hidden_size=hidden_size, num_blocks=num_blocks)
    elif reshape_method == "project":
      if projection_tensors is None:
        raise ValueError(
            "Projection tensors is None for reshape_method project")
      reshape_fn = partial(
          project_hidden,
          projection_tensors=projection_tensors,
          hidden_size=hidden_size,
          num_blocks=num_blocks)
    else:
      raise ValueError("Unknown reshape_method")

    # Check if the ema settings make sense
    if ema:
      if ema_count is None:
        raise ValueError("ema_count is None but ema is True")
      if ema_means is None:
        raise ValueError("ema_means is None but ema is True")

  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    l = tf.constant(0.0)
    if bottleneck_kind == "dense":
      c = tf.layers.dense(x, z_size, name="vcc")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    elif bottleneck_kind == "vae":
      c, l, _, _ = vae(x, z_size, "vae")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    elif bottleneck_kind == "semhash":
      c = tf.layers.dense(x, z_size, name="vcc")
      y_clean = common_layers.saturating_sigmoid(c)
      if summary:
        tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
      if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
        noise = tf.truncated_normal(
            common_layers.shape_list(c), mean=0.0, stddev=noise_dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(startup_steps * 2)
      pd *= discrete_mix
      pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.where(
          tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd),
          y_discrete, y)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, z_size)
    elif bottleneck_kind == "gumbel-softmax":
      _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k,
                                 kl_warmup_steps, summary)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hidden_size, name="dae_dense")
    elif bottleneck_kind == "dvq":
      x_reshaped = reshape_fn(x)
      x_res = x_reshaped
      x_means_hot = []
      x_means = 0
      l = 0
      for i in range(num_residuals):
        x_means_hot_res, x_means_res, q_loss_res, e_loss_res = embedding_lookup(
            x_res, means[i], num_blocks, block_v_size, random_top_k, soft_em,
            num_samples)
        # Update the ema variables
        if ema:
          tf.logging.info("Using EMA with beta = {}".format(beta))
          updated_ema_count_res = moving_averages.assign_moving_average(
              ema_count[i],
              tf.reduce_sum(
                  tf.reshape(
                      x_means_hot_res, shape=[-1, num_blocks, block_v_size]),
                  axis=0),
              decay,
              zero_debias=False)

          dw = tf.matmul(
              tf.transpose(x_means_hot_res, perm=[1, 2, 0]),
              tf.transpose(x_res, perm=[1, 0, 2]))

          updated_ema_means_res = moving_averages.assign_moving_average(
              ema_means[i], dw, decay, zero_debias=False)
          n = tf.reduce_sum(updated_ema_count_res, axis=-1, keep_dims=True)
          updated_ema_count_res = ((updated_ema_count_res + epsilon) /
                                   (n + 2**z_size * epsilon) * n)
          # pylint: disable=g-no-augmented-assignment
          updated_ema_means_res = updated_ema_means_res / tf.expand_dims(
              updated_ema_count_res, axis=-1)
          # pylint: enable=g-no-augmented-assignment

          with tf.control_dependencies([e_loss_res]):
            update_means_res = tf.assign(means[i], updated_ema_means_res)
            with tf.control_dependencies([update_means_res]):
              l += beta * e_loss_res
        else:
          l += q_loss_res + beta * e_loss_res

        # Update the residuals
        x_res -= x_means_res
        x_means += x_means_res
        x_means_hot.append(x_means_hot_res)

      # Get the discrete latent representation
      x_means_hot = tf.stack(x_means_hot, axis=1)
      x_means_idx = tf.argmax(x_means_hot, axis=-1)

      # Get the binary representation
      x_means_bits = int_to_bit(
          x_means_idx,
          num_bits=int(z_size / (num_residuals * num_blocks)),
          base=2)
      shape = common_layers.shape_list(x_means_bits)
      new_shape = shape[:-2]
      new_shape[-1] = z_size
      x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
      c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2)

      # Adjust shape of c
      shape_x = common_layers.shape_list(x)
      new_shape = shape_x[:-1]
      c = tf.reshape(c, new_shape)

      # If we are doing soft EM then c is x_means_hot
      if soft_em:
        c = x_means_hot
        new_shape.append(block_v_size)
        c = tf.reshape(c, new_shape)

      x_means = tf.reshape(x_means, shape_x)
      x_reshaped = tf.reshape(x_reshaped, shape_x)
      h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped)
    else:
      raise ValueError("Unknown discretization method.")

    res = h1

    embed_fn = partial(
        embed,
        hidden_size=hidden_size,
        z_size=z_size,
        filter_size=filter_size,
        name=name,
        bottleneck_kind=bottleneck_kind,
        soft_em=soft_em,
        num_blocks=num_blocks,
        num_residuals=num_residuals,
        block_v_size=block_v_size,
        means=means)
    return res, c, l, embed_fn


# New API for discretization bottlenecks:
# * Each method is separate and provides 2 functions:
# * The [method]_bottleneck function returns discretized state.
# * The [method]_unbottleneck function moves from discretized state to dense.


def get_vq_bottleneck(bottleneck_size, hidden_size):
  """Get lookup table for VQ bottleneck."""
  with tf.variable_scope("vq", reuse=tf.AUTO_REUSE):
    means = tf.get_variable(
        name="means",
        shape=[bottleneck_size, hidden_size],
        initializer=tf.uniform_unit_scaling_initializer())

    ema_count = tf.get_variable(
        name="ema_count",
        shape=[bottleneck_size],
        initializer=tf.constant_initializer(0),
        trainable=False)

    with tf.colocate_with(means):
      ema_means = tf.get_variable(
          name="ema_means",
          initializer=means.initialized_value(),
          trainable=False)

  return means, ema_means, ema_count


def vq_nearest_neighbor(x, means, soft_em=False, num_samples=10):
  """Find the nearest element in means to elements in x."""
  bottleneck_size = common_layers.shape_list(means)[0]
  x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
  means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True)
  scalar_prod = tf.matmul(x, means, transpose_b=True)
  dist = x_norm_sq + tf.transpose(means_norm_sq) - 2 * scalar_prod
  if soft_em:
    x_means_idx = tf.multinomial(-dist, num_samples=num_samples)
    x_means_hot = tf.one_hot(
        x_means_idx, depth=common_layers.shape_list(means)[0])
    x_means_hot = tf.reduce_mean(x_means_hot, axis=1)
  else:
    x_means_idx = tf.argmax(-dist, axis=-1)
    x_means_hot = tf.one_hot(x_means_idx, bottleneck_size)
  x_means_hot_flat = tf.reshape(x_means_hot, [-1, bottleneck_size])
  x_means = tf.matmul(x_means_hot_flat, means)
  e_loss = tf.reduce_mean(tf.square(x - tf.stop_gradient(x_means)))
  return x_means_hot, e_loss


def vq_discrete_bottleneck(x,
                           bottleneck_bits,
                           beta=0.25,
                           decay=0.999,
                           epsilon=1e-5,
                           soft_em=False,
                           num_samples=10):
  """Simple vector quantized discrete bottleneck."""
  bottleneck_size = 2**bottleneck_bits
  x_shape = common_layers.shape_list(x)
  hidden_size = x_shape[-1]
  means, ema_means, ema_count = get_vq_bottleneck(bottleneck_size, hidden_size)
  x = tf.reshape(x, [-1, hidden_size])
  x_means_hot, e_loss = vq_nearest_neighbor(
      x, means, soft_em=soft_em, num_samples=num_samples)

  # Update the ema variables
  updated_ema_count = moving_averages.assign_moving_average(
      ema_count,
      tf.reduce_sum(
          tf.reshape(x_means_hot, shape=[-1, bottleneck_size]), axis=0),
      decay,
      zero_debias=False)

  dw = tf.matmul(x_means_hot, x, transpose_a=True)
  updated_ema_means = tf.identity(moving_averages.assign_moving_average(
      ema_means, dw, decay, zero_debias=False))
  n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
  updated_ema_count = (
      (updated_ema_count + epsilon) / (n + bottleneck_size * epsilon) * n)
  updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)
  with tf.control_dependencies([e_loss]):
    update_means = means.assign(updated_ema_means)
    with tf.control_dependencies([update_means]):
      loss = beta * e_loss

  d = tf.reshape(x_means_hot, x_shape[:-1] + [bottleneck_size])
  return d, loss


def vq_discrete_unbottleneck(x, hidden_size):
  """Simple undiscretization from vector quantized representation."""
  x_shape = common_layers.shape_list(x)
  x = tf.to_float(x)
  bottleneck_size = common_layers.shape_list(x)[-1]
  means, _, _ = get_vq_bottleneck(bottleneck_size, hidden_size)
  result = tf.matmul(tf.reshape(x, [-1, x_shape[-1]]), means)
  return tf.reshape(result, x_shape[:-1] + [hidden_size])


def gumbel_softmax_discrete_bottleneck(x,
                                       bottleneck_bits,
                                       beta=0.25,
                                       decay=0.999,
                                       epsilon=1e-5,
                                       startup_steps=15000,
                                       hard=False,
                                       summary=True):
  """VQ-VAE using Gumbel-Softmax.

  Different from `gumbel_softmax()` function as
  this function calculates the KL by using the discrete entropy
  instead of taking the argmax, and it also uses an exponential moving average
  to update the codebook while the `gumbel_softmax()` function includes no
  codebook update.

  Args:
    x: A `float`-like `Tensor` containing the latent vectors to be compared to
      the codebook, whose squared difference is used as the Gumbel-Softmax
      logits.
    bottleneck_bits: An `int` that sets the size of the bottleneck in `log_2`.
    beta: Beta factor for commitment loss (Default: 0.25).
    decay: Decay factor for exponential moving average (Default: 0.999).
    epsilon: Small value to avoid dividing by zero in EMA update
      (Default: 1e-5).
    startup_steps: Number of steps for KL warmup (Default: 25000).
    hard: When `True`, we use hard Gumbel-Softmax samples and force
      discrete latents by taking the argmax. When `False`, we use soft samples,
      which we treat as codebook weights (Default: False).
    summary: When `True`, we save histogram summaries of the KL term (Default:
      True).

  Returns:
    x_means_assignments: A `float`-like `Tensor` containing the codebook
      assignments. When `hard == True`, this is one-hot, containing the arg-max
      of the Gumbel-Softmax samples (and we use the straightthrough gradient).
      Otherwise, it contains the Gumbel-Softmax samples exactly, which are
      values from the `(K-1)`-simplex where `K` is the bottleneck size.
    loss: The loss, which is the sum of the KL between the Gumbel-Softmax and
      the uniform prior and the commitment loss multiplied by the beta factor.
      We approximate the KL by using the entropy of a categorical distribution
      instead of the Gumbel Softmax.

  """
  bottleneck_size = 2**bottleneck_bits
  x_shape = common_layers.shape_list(x)
  hidden_size = x_shape[-1]
  means, ema_means, ema_count = get_vq_bottleneck(bottleneck_size, hidden_size)
  x = tf.reshape(x, [-1, hidden_size])

  bottleneck_size = common_layers.shape_list(means)[0]
  x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
  means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True)
  scalar_prod = tf.matmul(x, means, transpose_b=True)
  dist = x_norm_sq + tf.transpose(means_norm_sq) - 2 * scalar_prod

  class_probs = tf.nn.softmax(dist)
  log_class_probs = tf.nn.log_softmax(dist)
  gumbel_samples = gumbel_sample(common_layers.shape_list(dist))
  gumbel_samples *= common_layers.inverse_exp_decay(startup_steps // 5) * 0.5
  temperature = 1.2 - common_layers.inverse_lin_decay(startup_steps)

  # 10% of the time keep reasonably high temperature to keep learning.
  temperature = tf.cond(
      tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
      lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
  gumbel_softmax_samples = tf.nn.softmax(
      (log_class_probs + gumbel_samples) / temperature)

  # Calculate KL between q and a uniform prior.
  kl = tf.reduce_sum(class_probs * (log_class_probs -
                                    tf.log(1.0/bottleneck_size)), -1)
  if summary:
    tf.summary.histogram("KL", tf.reshape(kl, [-1]))

  # Straight-through gradient estimation when we're using hard assignments.
  if hard:
    x_means_idx = tf.reshape(tf.argmax(gumbel_softmax_samples, axis=-1), [-1])
    x_means_hot = tf.one_hot(x_means_idx, bottleneck_size)
    x_means_assignments = gumbel_softmax_samples + tf.stop_gradient(
        x_means_hot - gumbel_softmax_samples)
  else:
    x_means_assignments = gumbel_softmax_samples
  x_means_assignments_flat = tf.reshape(
      x_means_assignments, [-1, bottleneck_size])
  x_means = tf.matmul(x_means_assignments_flat, means)
  commitment_loss = tf.reduce_mean(tf.square(x - tf.stop_gradient(x_means)))

  # Update the ema variables.
  updated_ema_count = moving_averages.assign_moving_average(
      ema_count,
      tf.reduce_sum(
          tf.reshape(x_means_assignments, shape=[-1, bottleneck_size]), axis=0),
      decay,
      zero_debias=False)

  dw = tf.matmul(x_means_assignments, x, transpose_a=True)
  updated_ema_means = tf.identity(moving_averages.assign_moving_average(
      ema_means, dw, decay, zero_debias=False))
  n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
  updated_ema_count = (
      (updated_ema_count + epsilon) / (n + bottleneck_size * epsilon) * n)
  updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)
  with tf.control_dependencies([commitment_loss]):
    update_means = means.assign(updated_ema_means)
    with tf.control_dependencies([update_means]):
      loss = beta * commitment_loss

  # Add KL loss.
  loss += tf.reduce_mean(kl)

  x_means_assignments = tf.reshape(
      x_means_assignments, x_shape[:-1] + [bottleneck_size])
  return x_means_assignments, loss


def tanh_discrete_bottleneck(x, bottleneck_bits, bottleneck_noise,
                             discretize_warmup_steps, mode):
  """Simple discretization through tanh, flip bottleneck_noise many bits."""
  x = tf.tanh(tf.layers.dense(x, bottleneck_bits,
                              name="tanh_discrete_bottleneck"))
  d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
  if mode == tf.estimator.ModeKeys.TRAIN:
    noise = tf.random_uniform(common_layers.shape_list(x))
    noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
    d *= noise
  d = common_layers.mix(d, x, discretize_warmup_steps,
                        mode == tf.estimator.ModeKeys.TRAIN)
  return d, 0.0


def tanh_discrete_unbottleneck(x, hidden_size):
  """Simple un-discretization from tanh."""
  x = tf.layers.dense(x, hidden_size, name="tanh_discrete_unbottleneck")
  return x


def isemhash_bottleneck(x, bottleneck_bits, bottleneck_noise,
                        discretize_warmup_steps, mode,
                        isemhash_noise_dev=0.5, isemhash_mix_prob=0.5):
  """Improved semantic hashing bottleneck."""
  with tf.variable_scope("isemhash_bottleneck"):
    x = tf.layers.dense(x, bottleneck_bits, name="dense")
    y = common_layers.saturating_sigmoid(x)
    if isemhash_noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
      noise = tf.truncated_normal(
          common_layers.shape_list(x), mean=0.0, stddev=isemhash_noise_dev)
      y = common_layers.saturating_sigmoid(x + noise)
    d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y)
    d = 2.0 * d - 1.0  # Move from [0, 1] to [-1, 1].
    if mode == tf.estimator.ModeKeys.TRAIN:  # Flip some bits.
      noise = tf.random_uniform(common_layers.shape_list(x))
      noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
      d *= noise
      d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps,
                            mode == tf.estimator.ModeKeys.TRAIN,
                            max_prob=isemhash_mix_prob)
    return d, 0.0


def isemhash_unbottleneck(x, hidden_size, isemhash_filter_size_multiplier=1.0):
  """Improved semantic hashing un-bottleneck."""
  filter_size = int(hidden_size * isemhash_filter_size_multiplier)
  x = 0.5 * (x - 1.0)  # Move from [-1, 1] to [0, 1].
  with tf.variable_scope("isemhash_unbottleneck"):
    h1a = tf.layers.dense(x, filter_size, name="hidden1a")
    h1b = tf.layers.dense(1.0 - x, filter_size, name="hidden1b")
    h2 = tf.layers.dense(tf.nn.relu(h1a + h1b), filter_size, name="hidden2")
    return tf.layers.dense(tf.nn.relu(h2), hidden_size, name="final")


def parametrized_bottleneck(x, hparams):
  """Meta-function calling all the above bottlenecks with hparams."""
  if hparams.bottleneck_kind == "tanh_discrete":
    return tanh_discrete_bottleneck(
        x, hparams.bottleneck_bits, hparams.bottleneck_noise * 0.5,
        hparams.discretize_warmup_steps, hparams.mode)
  if hparams.bottleneck_kind == "isemhash":
    return isemhash_bottleneck(
        x, hparams.bottleneck_bits, hparams.bottleneck_noise * 0.5,
        hparams.discretize_warmup_steps, hparams.mode,
        hparams.isemhash_noise_dev, hparams.isemhash_mix_prob)
  if hparams.bottleneck_kind == "vq":
    return vq_discrete_bottleneck(x, hparams.bottleneck_bits, hparams.vq_beta,
                                  hparams.vq_decay, hparams.vq_epsilon)
  if hparams.bottleneck_kind == "em":
    return vq_discrete_bottleneck(
        x,
        hparams.bottleneck_bits,
        hparams.vq_beta,
        hparams.vq_decay,
        hparams.vq_epsilon,
        soft_em=True,
        num_samples=hparams.vq_num_samples)
  if hparams.bottleneck_kind == "gumbel_softmax":
    return gumbel_softmax_discrete_bottleneck(x, hparams.bottleneck_bits,
                                              hparams.vq_beta, hparams.vq_decay,
                                              hparams.vq_epsilon,
                                              hparams.startup_steps, hard=False,
                                              summary=True)

  raise ValueError("Unsupported hparams.bottleneck_kind %s"
                   % hparams.bottleneck_kind)


def parametrized_unbottleneck(x, hidden_size, hparams):
  """Meta-function calling all the above un-bottlenecks with hparams."""
  if hparams.bottleneck_kind == "tanh_discrete":
    return tanh_discrete_unbottleneck(x, hidden_size)
  if hparams.bottleneck_kind == "isemhash":
    return isemhash_unbottleneck(
        x, hidden_size, hparams.isemhash_filter_size_multiplier)
  if hparams.bottleneck_kind in ["vq", "em", "gumbel_softmax"]:
    return vq_discrete_unbottleneck(x, hidden_size)
  raise ValueError("Unsupported hparams.bottleneck_kind %s"
                   % hparams.bottleneck_kind)