# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the Evolved Transformer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import evolved_transformer
from tensor2tensor.models import transformer

import tensorflow.compat.v1 as tf

BATCH_SIZE = 3
INPUT_LENGTH = 5
TARGET_LENGTH = 7
VOCAB_SIZE = 10
DECODE_LENGTH = 3


def print_vars(all_vars=None):
  """Print info about a list of variables."""
  if not all_vars:
    all_vars = tf.trainable_variables()
  tf.logging.info("Format: <name>, <shape>, <(soft) device placement>")
  for var in all_vars:
    tf.logging.info("  %s, %s, %s" %
                    (var.name, str(var.get_shape()), var.op.device))


def get_var(name):
  """Get trainable variable by name."""
  variables = [var for var in tf.trainable_variables() if var.name == name]
  if len(variables) == 1:
    return variables[0]
  raise ValueError("`name` must match exactly one variable. '%s' matched %d" %
                   (name, len(variables)))


def get_vars(names):
  """Get trainable variables by name."""
  return [get_var(name) for name in names]


def assert_with_message(assert_method, a, b, message):
  try:
    assert_method(a, b)
  except AssertionError as e:
    tf.logging.error(message)
    raise e


def get_model(hparams, has_input=True, num_decoder_layers=1):
  hparams.layer_prepostprocess_dropout = 0.0
  hparams.hidden_size = 4
  hparams.num_heads = 1
  hparams.num_encoder_layers = 1
  hparams.num_decoder_layers = num_decoder_layers

  p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE,
                                                   hparams)
  if not has_input:
    del p_hparams.modality["inputs"]
  hparams.problem_hparams = p_hparams

  inputs = np.random.randint(VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1))
  targets = np.random.randint(
      VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1))
  features = {
      "targets": tf.constant(targets, dtype=tf.int32, name="targets"),
      "target_space_id": tf.constant(1, dtype=tf.int32),
  }
  if has_input:
    features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs")

  return (evolved_transformer.EvolvedTransformer(hparams,
                                                 tf.estimator.ModeKeys.TRAIN,
                                                 p_hparams), features)


class EvolvedTransformerTest(tf.test.TestCase):

  def testEvolvedTransformer(self):
    model, features = get_model(hparams=transformer.transformer_tiny())
    logits, _ = model(features)
    with self.test_session() as session:
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE))

  def testSlowVsFast(self):
    tf.set_random_seed(1234)
    model, features = get_model(transformer.transformer_tiny())

    decode_length = DECODE_LENGTH

    out_logits, _ = model(features)
    out_logits = tf.squeeze(out_logits, axis=[2, 3])
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
        labels=tf.reshape(features["targets"], [-1]))
    loss = tf.reduce_mean(loss)
    apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)

    with self.test_session():
      tf.global_variables_initializer().run()
      for _ in range(10):
        apply_grad.run()

    model.set_mode(tf.estimator.ModeKeys.PREDICT)

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      greedy_result = model._slow_greedy_infer(features,
                                               decode_length)["outputs"]
      greedy_result = tf.squeeze(greedy_result, axis=[2, 3])

      fast_result = model._greedy_infer(features, decode_length)["outputs"]

    with self.test_session():
      greedy_res = greedy_result.eval()
      fast_res = fast_result.eval()

    self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length))
    self.assertAllClose(greedy_res, fast_res)

  def testSlowVsFastNoInput(self):
    model, features = get_model(transformer.transformer_tiny(), has_input=False)

    decode_length = DECODE_LENGTH

    out_logits, _ = model(features)
    out_logits = tf.squeeze(out_logits, axis=[2, 3])
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
        labels=tf.reshape(features["targets"], [-1]))
    loss = tf.reduce_mean(loss)
    apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)

    with self.test_session():
      tf.global_variables_initializer().run()
      for _ in range(10):
        apply_grad.run()

    model.set_mode(tf.estimator.ModeKeys.PREDICT)

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      slow_result = model._slow_greedy_infer(features, decode_length)["outputs"]
      slow_result = tf.squeeze(slow_result, axis=[2, 3])

      fast_result = model._greedy_infer(features, decode_length)["outputs"]

    with self.test_session():
      slow_res = slow_result.eval()
      fast_res = fast_result.eval()

    self.assertEqual(slow_res.shape, (BATCH_SIZE, decode_length))
    self.assertAllClose(slow_res, fast_res)

  def testBeamVsFast(self):
    model, features = get_model(transformer.transformer_tiny())

    decode_length = DECODE_LENGTH

    out_logits, _ = model(features)
    out_logits = tf.squeeze(out_logits, axis=[2, 3])
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
        labels=tf.reshape(features["targets"], [-1]))
    loss = tf.reduce_mean(loss)
    apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)

    with self.test_session():
      tf.global_variables_initializer().run()
      for _ in range(10):
        apply_grad.run()

    model.set_mode(tf.estimator.ModeKeys.PREDICT)

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      beam_result = model._beam_decode_slow(
          features, decode_length, beam_size=4, top_beams=1,
          alpha=1.0)["outputs"]

      fast_result = model._beam_decode(
          features, decode_length, beam_size=4, top_beams=1,
          alpha=1.0)["outputs"]

    with self.test_session():
      beam_res = beam_result.eval()
      fast_res = fast_result.eval()

    self.assertAllClose(beam_res, fast_res)

  def _create_greedy_infer_model(self):
    """Creates model for greedy inference testing.

    Returns:
      model: A t2t model.
      features: An map of string to tensor.
    """
    model, features = get_model(transformer.transformer_tiny())

    out_logits, _ = model(features)
    out_logits = tf.squeeze(out_logits, axis=[2, 3])
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
        labels=tf.reshape(features["targets"], [-1]))
    loss = tf.reduce_mean(loss)
    apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)

    with self.test_session():
      tf.global_variables_initializer().run()
      for _ in range(10):
        apply_grad.run()

    model.set_mode(tf.estimator.ModeKeys.PREDICT)

    return model, features

  def testGreedySlowTPUVsNonTPU(self):
    decode_length = DECODE_LENGTH

    model, features = self._create_greedy_infer_model()

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      slow_result_non_tpu = model._slow_greedy_infer(features,
                                                     decode_length)["outputs"]
      slow_result_non_tpu = tf.squeeze(slow_result_non_tpu, axis=[2, 3])

      slow_result_tpu = model._slow_greedy_infer_tpu(features,
                                                     decode_length)["outputs"]
      slow_result_tpu = tf.squeeze(slow_result_tpu, axis=[2, 3])

    with self.test_session():
      slow_non_tpu_res = slow_result_non_tpu.eval()
      slow_tpu_res = slow_result_tpu.eval()

    self.assertEqual(slow_tpu_res.shape,
                     (BATCH_SIZE, INPUT_LENGTH + decode_length))
    self.assertAllClose(slow_tpu_res, slow_non_tpu_res)

  def testGreedyFastTPUVsNonTPU(self):
    tf.set_random_seed(1234)
    decode_length = DECODE_LENGTH

    model, features = self._create_greedy_infer_model()

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      fast_result_non_tpu = model._greedy_infer(
          features, decode_length, use_tpu=False)["outputs"]

      fast_result_tpu = model._greedy_infer(
          features, decode_length, use_tpu=True)["outputs"]

    with self.test_session():
      fast_non_tpu_res = fast_result_non_tpu.eval()
      fast_tpu_res = fast_result_tpu.eval()

    self.assertEqual(fast_tpu_res.shape,
                     (BATCH_SIZE, INPUT_LENGTH + decode_length))
    self.assertAllClose(fast_tpu_res, fast_non_tpu_res)

  def testGreedyTPUSlowVsFast(self):
    tf.set_random_seed(1234)
    decode_length = DECODE_LENGTH

    model, features = self._create_greedy_infer_model()

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      slow_result = model._slow_greedy_infer_tpu(features,
                                                 decode_length)["outputs"]
      slow_result = tf.squeeze(slow_result, axis=[2, 3])

      fast_result = model._greedy_infer(
          features, decode_length, use_tpu=True)["outputs"]

    with self.test_session():
      slow_res = slow_result.eval()
      fast_res = fast_result.eval()

    self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length))
    self.assertAllClose(fast_res, slow_res)

  def testFrozenWeightsUnchangedByTraining(self):
    # Arrange.
    hparams = transformer.transformer_tiny()
    hparams.add_hparam("num_trainable_top_decoder_layers", 1)
    model, features = get_model(hparams, num_decoder_layers=3)
    out_logits, _ = model(features)
    out_logits = tf.squeeze(out_logits, axis=[2, 3])
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
        labels=tf.reshape(features["targets"], [-1]))
    loss = tf.reduce_mean(loss)
    apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)
    frozen_names = [
        "evolved_transformer/symbol_modality_10_4/shared/weights_0:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_1:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_2:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_3:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_4:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_5:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_6:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_7:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_8:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_9:0",
        "evolved_transformer/body/target_space_embedding/kernel:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/kernel:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/bias:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/kernel:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/dense/kernel:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/dense/bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/kernel:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/depthwise_kernel:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/pointwise_kernel:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/bias:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/dense/kernel:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/dense/bias:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/kernel:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/bias:0",
        "evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/bias:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/bias:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/bias:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/dense/kernel:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/dense/bias:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/kernel:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/bias:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/bias:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/bias:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/bias:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/dense/kernel:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/dense/bias:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/kernel:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/bias:0",
    ]
    train_names = [
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/bias:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/bias:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/bias:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/dense/kernel:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/dense/bias:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/kernel:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/bias:0",
        "evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/symbol_modality_10_4/softmax/weights_1:0",
        "evolved_transformer/symbol_modality_10_4/softmax/weights_2:0",
        "evolved_transformer/symbol_modality_10_4/softmax/weights_3:0",
        "evolved_transformer/symbol_modality_10_4/softmax/weights_4:0",
        "evolved_transformer/symbol_modality_10_4/softmax/weights_5:0",
        "evolved_transformer/symbol_modality_10_4/softmax/weights_6:0",
        "evolved_transformer/symbol_modality_10_4/softmax/weights_7:0",
        "evolved_transformer/symbol_modality_10_4/softmax/weights_8:0",
        "evolved_transformer/symbol_modality_10_4/softmax/weights_9:0",
    ]
    frozen_vars = get_vars(frozen_names)
    train_vars = get_vars(train_names)
    print_vars()

    # Act.
    with self.test_session() as session:
      tf.global_variables_initializer().run()
      frozen_values_before = session.run(frozen_vars)
      train_values_before = session.run(train_vars)
      for _ in range(10):  # Arbitrary number of training steps.
        apply_grad.run()
      frozen_values_after = session.run(frozen_vars)
      train_values_after = session.run(train_vars)

    # Assert.
    self.assertTrue(
        model._original_hparams.shared_embedding_and_softmax_weights)
    self.assertFalse(model.hparams.shared_embedding_and_softmax_weights)
    self.assertTrue(model.hparams.shared_embedding)
    for name, before, after in zip(frozen_names, frozen_values_before,
                                   frozen_values_after):
      assert_with_message(
          self.assertAllClose, before, after,
          "%s should be frozen, but changed after training." % name)
    for name, before, after in zip(train_names, train_values_before,
                                   train_values_after):
      assert_with_message(
          self.assertNotAllClose, before, after,
          "%s should be trainable, but did not change after training." % name)

  def testAllWeightsTrainableByDefault(self):
    # Arrange.
    model, features = get_model(
        transformer.transformer_tiny(), num_decoder_layers=3)
    out_logits, _ = model(features)
    out_logits = tf.squeeze(out_logits, axis=[2, 3])
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
        labels=tf.reshape(features["targets"], [-1]))
    loss = tf.reduce_mean(loss)
    apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)
    var_names = [
        "evolved_transformer/symbol_modality_10_4/shared/weights_0:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_1:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_2:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_3:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_4:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_5:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_6:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_7:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_8:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_9:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_10:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_11:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_12:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_13:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_14:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_15:0",
        "evolved_transformer/body/target_space_embedding/kernel:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/kernel:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/bias:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/kernel:0",
        "evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/dense/kernel:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/dense/bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/kernel:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/depthwise_kernel:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/pointwise_kernel:0",
        "evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/bias:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/dense/kernel:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/dense/bias:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/kernel:0",
        "evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/bias:0",
        "evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/bias:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/bias:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/bias:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/dense/kernel:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/dense/bias:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/kernel:0",
        "evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/bias:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/bias:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/bias:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/bias:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/dense/kernel:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/dense/bias:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/kernel:0",
        "evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/bias:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/bias:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/bias:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
        "evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/bias:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/q/kernel:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/k/kernel:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/v/kernel:0",
        "evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/dense/kernel:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/dense/bias:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/kernel:0",
        "evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/bias:0",
        "evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
        "evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
    ]
    variables = get_vars(var_names)
    print_vars()

    # Act.
    with self.test_session() as session:
      tf.global_variables_initializer().run()
      values_before = session.run(variables)
      for _ in range(10):  # Arbitrary number of training steps.
        apply_grad.run()
      values_after = session.run(variables)

    # Assert.
    self.assertTrue(
        model._original_hparams.shared_embedding_and_softmax_weights)
    self.assertTrue(model.hparams.shared_embedding_and_softmax_weights)
    self.assertFalse(model.hparams.shared_embedding)
    self.assertSameElements(var_names,
                            [var.name for var in tf.trainable_variables()])
    empty_vars = {
        "evolved_transformer/symbol_modality_10_4/shared/weights_10:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_11:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_12:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_13:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_14:0",
        "evolved_transformer/symbol_modality_10_4/shared/weights_15:0"
    }
    for name, before, after in zip(var_names, values_before, values_after):
      if name in empty_vars:
        self.assertEqual(before.size, after.size)
        self.assertEqual(before.size, 0)
      else:
        assert_with_message(
            self.assertNotAllClose, before, after,
            "%s should be trainable, but did not change after training." % name)


if __name__ == "__main__":
  tf.test.main()