def test_loop_tuple_input():
    graph = tf.Graph()

    with graph.as_default():
        data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
        dname = 'data'
        data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
        split = tf.split(data, 2, axis=0)

        def body(x, y):
            return x + 2, y + 1

        start = tf.constant(0)
        def condition(x, y):
            return tf.less(y, 20)

        r = tf.while_loop(condition, body, loop_vars=[split[1], start])
        with tf.Session() as sess:
            tf_out =, feed_dict={ data_np})

    check_equal(graph, tf_out, {dname: data_np}) 
def tanh_discrete_bottleneck(x, bottleneck_bits, bottleneck_noise,
                             discretize_warmup_steps, mode):
  """Simple discretization through tanh, flip bottleneck_noise many bits."""
  x = tf.layers.dense(x, bottleneck_bits, name="tanh_discrete_bottleneck")
  d0 = tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x))) - 1.0
  if mode == tf.estimator.ModeKeys.TRAIN:
    x += tf.truncated_normal(
        common_layers.shape_list(x), mean=0.0, stddev=0.2)
  x = tf.tanh(x)
  d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
  if mode == tf.estimator.ModeKeys.TRAIN:
    noise = tf.random_uniform(common_layers.shape_list(x))
    noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
    d *= noise
  d = common_layers.mix(d, x, discretize_warmup_steps,
                        mode == tf.estimator.ModeKeys.TRAIN)
  return d, d0 
def compute_valid_mask(num_valid_elements, num_elements):
  """Computes mask of valid entries within padded context feature.

    num_valid_elements: A int32 Tensor of shape [batch_size].
    num_elements: An int32 Tensor.

    A boolean Tensor of the shape [batch_size, num_elements]. True means
      valid and False means invalid.
  batch_size = num_valid_elements.shape[0]
  element_idxs = tf.range(num_elements, dtype=tf.int32)
  batch_element_idxs = tf.tile(element_idxs[tf.newaxis, ...], [batch_size, 1])
  num_valid_elements = num_valid_elements[..., tf.newaxis]
  valid_mask = tf.less(batch_element_idxs, num_valid_elements)
  return valid_mask 
def unwrap(p, discont=np.pi, axis=-1):
  """Unwrap a cyclical phase tensor.

    p: Phase tensor.
    discont: Float, size of the cyclic discontinuity.
    axis: Axis of which to unwrap.

    unwrapped: Unwrapped tensor of same size as input.
  dd = diff(p, axis=axis)
  ddmod = tf.mod(dd + np.pi, 2.0 * np.pi) - np.pi
  idx = tf.logical_and(tf.equal(ddmod, -np.pi), tf.greater(dd, 0))
  ddmod = tf.where(idx, tf.ones_like(ddmod) * np.pi, ddmod)
  ph_correct = ddmod - dd
  idx = tf.less(tf.abs(dd), discont)
  ddmod = tf.where(idx, tf.zeros_like(ddmod), dd)
  ph_cumsum = tf.cumsum(ph_correct, axis=axis)

  shape = p.get_shape().as_list()
  shape[axis] = 1
  ph_cumsum = tf.concat([tf.zeros(shape, dtype=p.dtype), ph_cumsum], axis=axis)
  unwrapped = p + ph_cumsum
  return unwrapped 
def get_lr(global_step, base_lr,  # pylint: disable=missing-docstring
           decay_steps, lr_decay_factor, warmup_steps):

  warmup_lr = 0.0
  if warmup_steps > 0:
    warmup_lr = (tf.cast(global_step, tf.float32) * (base_lr / warmup_steps))

  if decay_steps:
    normal_lr = tf.train.piecewise_constant(
        [s for s in decay_steps],
        [base_lr * (lr_decay_factor ** i) for i in range(len(decay_steps) + 1)]
    normal_lr = base_lr

  lr = tf.cond(
      tf.less(global_step, tf.cast(warmup_steps, dtype=tf.dtypes.int64)),
      lambda: warmup_lr, lambda: normal_lr)

  return lr

# TODO(akolesnikov): add more logging 
def construct_latent_tower(self, images, time_axis):
    """Create the latent tower."""
    # No latent in the first phase
    first_phase = tf.less(
        self.get_iteration_num(), self.hparams.num_iterations_1st_stage)

    # use all frames by default but this allows more
    # predicted frames at inference time
    latent_num_frames = self.hparams.latent_num_frames"Creating latent tower with %d frames." % latent_num_frames)
    if latent_num_frames > 0:
      images = images[:, :latent_num_frames]

    return common_video.conv_latent_tower(
def _build(self, x, state):
    prev_keep_mask = state
    shape = tf.shape(x)
    noise = tf.random_uniform(shape, dtype=x.dtype)
    other_mask = tf.floor(self._keep_prob + noise)
    choice_noise = tf.random_uniform(shape, dtype=x.dtype)
    choice = tf.less(choice_noise, self._flip_prob)
    # KLUDGE(melisgl): The client has to pass the last keep_mask from
    # a batch to the next so the mask may end up next to some
    # recurrent cell state. This state is often zero at the beginning
    # and may be periodically zeroed (per example) during training.
    # While zeroing LSTM state is okay, zeroing the dropout mask is
    # not. So instead of forcing every client to deal with this common
    # (?) case, if an all zero mask is detected, then regenerate a
    # fresh mask. This is of course a major hack and won't help with
    # learnt initial states, for example.
    sum_ = tf.reduce_sum(prev_keep_mask, 1, keepdims=True)
    is_initializing = tf.equal(sum_, 0.0)

    self._keep_mask = tf.where(tf.logical_or(choice, is_initializing),
    self._time_step += 1
    return x * self._keep_mask / self._keep_prob * self._scaler 
def compute_thresholded_labels(labels, null_threshold=4):
  """Computes thresholded labels.

    labels: <int32> [batch_size, num_annotators]
    null_threshold: If number of null annotations is greater than or equal to
      this threshold, all annotations are set to null for this example.

    thresholded_labels: <int32> [batch_size, num_annotators]
  null_labels = tf.equal(labels, 0)

  # <int32> [batch_size]
  null_count = tf.reduce_sum(tf.to_int32(null_labels), 1)
  threshold_mask = tf.less(null_count, null_threshold)

  # <bool> [batch_size, num_annotators]
  threshold_mask = tf.tile(
      tf.expand_dims(threshold_mask, -1), [1, tf.shape(labels)[1]])

  # <bool> [batch_size, num_annotators]
  thresholded_labels = tf.where(
      threshold_mask, x=labels, y=tf.zeros_like(labels))
  return thresholded_labels 
def bottleneck(self, x):  # pylint: disable=arguments-differ
    hparams = self.hparams
    if hparams.unordered:
      return super(AutoencoderOrderedDiscrete, self).bottleneck(x)
    noise = hparams.bottleneck_noise
    hparams.bottleneck_noise = 0.0  # We'll add noise below.
    x, loss = discretization.parametrized_bottleneck(x, hparams)
    hparams.bottleneck_noise = noise
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      # We want a number p such that p^bottleneck_bits = 1 - noise.
      # So log(p) * bottleneck_bits = log(noise)
      log_p = tf.log1p(-float(noise) / 2) / float(hparams.bottleneck_bits)
      # Probabilities of flipping are p, p^2, p^3, ..., p^bottleneck_bits.
      noise_mask = 1.0 - tf.exp(tf.cumsum(tf.zeros_like(x) + log_p, axis=-1))
      # Having the no-noise mask, we can make noise just uniformly at random.
      ordered_noise = tf.random_uniform(tf.shape(x))
      # We want our noise to be 1s at the start and random {-1, 1} bits later.
      ordered_noise = tf.to_float(tf.less(noise_mask, ordered_noise))
      # Now we flip the bits of x on the noisy positions (ordered and normal).
      x *= 2.0 * ordered_noise - 1
    return x, loss 
def sample(self, features=None, shape=None):
    del features
    hp = self.hparams
    div_x = 2**hp.num_hidden_layers
    div_y = 1 if self.is1d else 2**hp.num_hidden_layers
    size = [
        hp.batch_size, hp.sample_height // div_x, hp.sample_width // div_y,
    size = size if shape is None else shape
    rand = tf.random_uniform(size)
    res = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
    # If you want to set some first bits to a fixed value, do this:
    # fixed = tf.zeros_like(rand) - 1.0
    # nbits = 3
    # res = tf.concat([fixed[:, :, :, :nbits], res[:, :, :, nbits:]], axis=-1)
    return res 
def test_loop_conditions():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(1)
        j = tf.constant(1)
        k = tf.constant(5)

        def c(i, j, k): return \
            tf.equal(tf.not_equal(tf.less(i + j, 10),
                                  tf.less(j * k, 100)),
                     tf.greater_equal(k, i + j))

        def b(i, j, k): return [i+j, j+k, k+1]
        r = tf.while_loop(c, b, loop_vars=[i, j, k])
        with tf.Session() as sess:
            tf_out =

    check_equal(graph, tf_out) 
def test_nested_loop():
    graph = tf.Graph()
    with graph.as_default():

        def body(x):
            def nest_body(c):
                return tf.multiply(c, 2)
            def cd(c): return tf.less(c, 10)
            c = tf.constant(2)
            res = tf.while_loop(cd, nest_body, loop_vars=[c])
            return tf.nn.relu(x + res)

        def condition(x):
            return tf.greater(x, 100)
        x = tf.constant(3)
        r = tf.while_loop(condition, body, loop_vars=[x])

        with tf.Session() as sess:
            tf_out =

    check_equal(graph, tf_out) 
def test_cond_fn_parameters():
    graph = tf.Graph()
    with graph.as_default():
        def fn1(x, y):
            return tf.multiply(5, 6)

        def fn2(x, y):
            return tf.add(3, 4)

        i = tf.constant(1)
        j = tf.constant(2)
        k = tf.constant(3)
        r = tf.cond(tf.less(i, j), lambda: fn1(i, k), lambda: fn2(j, k))

        with tf.Session() as sess:
            tf_out =, feed_dict={i: 1, j: 2, k: 3})

    check_equal(graph, tf_out) 
def mask_from_lengths(lengths, max_length=None, dtype=None, name=None):
  """Convert a length scalar to a vector of binary masks.

  This function will convert a vector of lengths to a matrix of binary masks.
  E.g. [2, 4, 3] will become [[1, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 0]]

    lengths: a d-dimensional vector of integers corresponding to lengths.
    max_length: an optional (default: None) scalar-like or 0-dimensional tensor
      indicating the maximum length of the masks. If not provided, the maximum
      length will be inferred from the lengths vector.
    dtype: the dtype of the returned mask, if specified. If None, the dtype of
      the lengths will be used.
    name: a name for the operation (optional).

    A d x max_length tensor of binary masks (int32).
  with tf.name_scope(name, 'mask_from_lengths'):
    dtype = lengths.dtype if dtype is None else dtype
    max_length = tf.reduce_max(lengths) if max_length is None else max_length
    indexes = tf.range(max_length, dtype=lengths.dtype)
    mask = tf.less(tf.expand_dims(indexes, 0), tf.expand_dims(lengths, 1))
    cast_mask = tf.cast(mask, dtype)
  return tf.stop_gradient(cast_mask) 
def _mix_tokens(p_sample, gold_targets, sampled_targets):
  """Interleave sampled and gold tokens randomly.

    p_sample: float in [0, 1]. Probability a token will come from
      'sampled_targets'. 0 means all-gold, 1 means all-sampled.
    gold_targets: Tensor. Gold token IDs.
    sampled_targets: Tensor. Sampled token IDs. Same shape as 'gold_targets'.

    Tensor of same shape as 'gold_targets' containing a mix of tokens from
    'gold_targets' and 'sampled_targets'.
  targets_shape = common_layers.shape_list(sampled_targets)
  return tf.where(
      tf.less(tf.random_uniform(targets_shape), p_sample),
      sampled_targets, gold_targets) 
def neural_gpu_body(inputs, hparams, name=None):
  """The core Neural GPU."""
  with tf.variable_scope(name, "neural_gpu"):

    def step(state, inp):  # pylint: disable=missing-docstring
      x = tf.nn.dropout(state, 1.0 - hparams.dropout)
      for layer in range(hparams.num_hidden_layers):
        x = common_layers.conv_gru(
            x, (hparams.kernel_height, hparams.kernel_width),
            name="cgru_%d" % layer)
      # Padding input is zeroed-out in the modality, we check this by summing.
      padding_inp = tf.less(tf.reduce_sum(tf.abs(inp), axis=[1, 2]), 0.00001)
      new_state = tf.where(padding_inp, state, x)  # No-op where inp is padding.
      return new_state

    return tf.foldl(
        tf.transpose(inputs, [1, 0, 2, 3]),
def test_cond_in_loop():
    graph = tf.Graph()
    with graph.as_default():
        def body(x):
            x = tf.constant(7)
            z = tf.constant(20)
            res = tf.cond(tf.less(x, 10), lambda: tf.add(
                10, 20), lambda: tf.square(10))
            return tf.multiply(res, x)

        x = tf.constant(21)
        def condition(x):
            return tf.less(x, 100)

        r = tf.while_loop(condition, body, loop_vars=[x])
        with tf.Session() as sess:
            tf_out =

    check_equal(graph, tf_out) 
def transformer_tall_finetune_tied():
  """Tied means fine-tune CNN/DM summarization as LM."""
  hparams = transformer_tall()
  hparams.multiproblem_max_input_length = 750
  hparams.multiproblem_max_target_length = 100
  hparams.multiproblem_schedule_max_examples = 0
  hparams.learning_rate_schedule = ("linear_warmup*constant*cosdecay")
  hparams.learning_rate_constant = 5e-5
  hparams.learning_rate_warmup_steps = 100
  # Set train steps to learning_rate_decay_steps or less
  hparams.learning_rate_decay_steps = 80000
  hparams.multiproblem_target_eval_only = True
  hparams.multiproblem_reweight_label_loss = True
  hparams.multiproblem_label_weight = 1.0
  hparams.optimizer = "true_adam"
  return hparams 
def test_loop_in_cond():
    graph = tf.Graph()
    with graph.as_default():
        def fn1(a, b):
            i = tf.constant(0)

            def cd(i): return tf.less(i, 10)

            def bd(i): return tf.add(i, 1)
            res = tf.while_loop(cd, bd, [i])
            return tf.multiply(tf.add(20, res), 10)

        def fn2(a, b):
            return tf.add(10, 20)

        x = tf.constant(7)
        y = tf.constant(20)
        z = tf.constant(10)
        pred = tf.less(x, y)
        r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

        with tf.Session() as sess:
            tf_out =, feed_dict={x: 1, y: 2, z: 3, pred: True})

    check_equal(graph, tf_out) 
def transformer_tall_pretrain_lm():
  """Hparams for transformer on LM pretraining (with 64k vocab)."""
  hparams = transformer_tall()
  hparams.learning_rate_constant = 2e-4
  hparams.learning_rate_schedule = ("linear_warmup*constant*cosdecay")
  hparams.optimizer = "adam_w"
  hparams.weight_decay = 0.01 * hparams.learning_rate_constant
  hparams.optimizer_adam_beta1 = 0.9
  hparams.optimizer_adam_beta2 = 0.999
  hparams.optimizer_adam_epsilon = 1e-8
  # Set max examples to something big when pretraining only the LM, definitely
  # something an order of magnitude bigger than number of train steps.
  hparams.multiproblem_schedule_max_examples = 5e8
  # Set train steps to learning_rate_decay_steps or less
  hparams.learning_rate_decay_steps = 5000000
  return hparams 
def intensity_shift(
    image, label, per_class_intensity_scale, per_class_intensity_shift):
  """Perturb intensity in lesion and non-lesion regions."""

  if per_class_intensity_scale < 0.000001 and (
      per_class_intensity_shift < 0.000001):
    return image

  # Randomly change (mostly increase) intensity of non-lesion region.
  per_class_noise = _truncated_normal(
      per_class_intensity_shift, per_class_intensity_scale)
  image = image + per_class_noise * (
      image * tf.cast(tf.greater(label, 1.5), tf.float32))

  # Randomly change (mostly decrease) intensity of lesion region.
  per_class_noise = _truncated_normal(
      -per_class_intensity_shift, per_class_intensity_scale)
  image = image + per_class_noise * (
      image * tf.cast(tf.less(label, 1.5), tf.float32))

  return image 
def test_vanilla_cond():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(1)
        j = tf.constant(4)

        def f1():
            return tf.multiply(1, 17)

        def f2():
            return tf.add(4, 23)
        r = tf.cond(tf.less(i, j), f1, f2)

    with tf.Session(graph=graph) as sess:
        tf_out =

    check_equal(graph, tf_out) 
def decode_tf(self, ids):
    """Decode in TensorFlow.

      ids: a 1d tf.Tensor with dtype tf.int32
      a tf Scalar with dtype tf.string
    ids = tf.where_v2(
        tf.less(ids, self.tokenizer.GetPieceSize()),
        ids, self.tokenizer.unk_id())

    return self.tf_tokenizer.detokenize(ids) 
def tf_apply_with_probability(p, fn, x):
  """Apply function `fn` to input `x` randomly `p` percent of the time."""
  return tf.cond(
      tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), p),
      lambda: fn(x),
      lambda: x) 
def adapt(self, original_inputs, adversarial_inputs, labels):
    """Runs binary search to find the first misclassified input."""
    batch_size = tf.shape(original_inputs)[0]
    binary_search_iterations = 10

    def cond(i, *_):
      return tf.less(i, binary_search_iterations)

    def get(m):
      m = tf.reshape(m, [batch_size] + [1] * (len(original_inputs.shape) - 1))
      return (adversarial_inputs - original_inputs) * m + original_inputs

    def is_attack_successful(m):
      logits = self._eval_fn(get(m))
      return self._success_fn(self._specification.evaluate(logits))

    def loop_body(i, lower, upper):
      m = (lower + upper) * .5
      success = is_attack_successful(m)
      new_lower = tf.where(success, lower, m)
      new_upper = tf.where(success, m, upper)
      return i + 1, new_lower, new_upper

    lower = tf.zeros(shape=[batch_size])
    upper = tf.ones(shape=[batch_size])
    _, lower, upper = tf.while_loop(
        loop_vars=[tf.constant(0.), lower, upper],
    # If lower is incorrectly classified, pick lower; otherwise pick upper.
    success = is_attack_successful(lower)
    return get(tf.where(success, lower, upper)) 
def _get_action_type(extended_indices, output_vocab_size, model_config):
  """Returns action_type tensor."""
  action_type = tf.constant(0, dtype=tf.int64)
  for action_type_range in _get_action_types_to_range(output_vocab_size,
    index_in_range = tf.logical_and(
        tf.greater_equal(extended_indices, action_type_range.start_index),
        tf.less(extended_indices, action_type_range.end_index))
    action_type += (
        tf.to_int64(index_in_range) * tf.constant(
            action_type_range.action_type, dtype=tf.int64))
  return action_type 
def _valid_indicator(feature_grid_y, feature_grid_x, true_feature_shapes):
  """Computes a indicator vector for valid indices.

  Computes an indicator vector which is true for points on feature map and
  false for points off feature map.

    feature_grid_y: An int32 tensor of shape [batch, num_boxes, size_y]
      containing y coordinate vector.
    feature_grid_x: An int32 tensor of shape [batch, num_boxes, size_x]
      containing x coordinate vector.
    true_feature_shapes: A int32 tensor of shape [batch, num_boxes, 2]
      containing valid height and width of feature maps. Feature maps are
      assumed to be aligned to the left top corner.

    indices: A 1D bool tensor indicating valid feature indices.
  height = tf.cast(true_feature_shapes[:, :, 0:1], dtype=feature_grid_y.dtype)
  width = tf.cast(true_feature_shapes[:, :, 1:2], dtype=feature_grid_x.dtype)
  valid_indicator = tf.logical_and(
          tf.logical_and(feature_grid_y >= 0, tf.less(feature_grid_y, height)),
          tf.logical_and(feature_grid_x >= 0, tf.less(feature_grid_x, width)),
  return tf.reshape(valid_indicator, [-1]) 
def test_nested_cond():
    graph = tf.Graph()
    with graph.as_default():
        def fn1(a, b):
            def nest_fn1():
                return tf.add(1, 2)

            def nest_fn2():
                return tf.subtract(10, 5)

            res = tf.cond(tf.less(1, 2), nest_fn1, nest_fn2)
            return tf.multiply(tf.add(87, res), 10)

        def fn2(a, b):
            return tf.add(10, 10)

        x = tf.constant(5)
        y = tf.constant(6)
        z = tf.constant(7)
        pred = tf.less(x, y)
        r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

        with tf.Session() as sess:
            tf_out =, feed_dict={x: 1, y: 2, z: 3, pred: True})

    check_equal(graph, tf_out) 
def test_vanilla_loop():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(0, name="while/constant")

        def c(i): return tf.less(i, 10)

        def b(i): return tf.add(i, 1)

        r = tf.while_loop(c, b, [i])

        with tf.Session() as sess:
            tf_out =

        check_equal(graph, tf_out) 
def test_callnode_loop_vars():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.add(tf.constant(0), 1)

        def c(i): return tf.less(i, 10)

        def b(i): return tf.add(i, 1)

        r = tf.while_loop(c, b, [i])

        with tf.Session() as sess:
            tf_out =

        check_equal(graph, tf_out)