Python tensorflow.executing_eagerly() Examples

The following are 30 code examples of tensorflow.executing_eagerly(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: expert_utils.py    From BERT with Apache License 2.0 6 votes vote down vote up
def remove(self, x):
    """Remove padding from the given tensor.

    Args:
      x (tf.Tensor): of shape [dim_origin,...]

    Returns:
      a tensor of shape [dim_compressed,...] with dim_compressed <= dim_origin
    """
    with tf.name_scope("pad_reduce/remove"):
      x_shape = x.get_shape().as_list()
      x = tf.gather_nd(
          x,
          indices=self.nonpad_ids,
      )
      if not tf.executing_eagerly():
        # This is a hack but for some reason, gather_nd return a tensor of
        # undefined shape, so the shape is set up manually
        x.set_shape([None] + x_shape[1:])
    return x 
Example #2
Source File: shape_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_is_broadcast_compatible(self, shape_x, shape_y, broadcastable):
    """Checks if the is_broadcast_compatible function works as expected."""
    if tf.executing_eagerly():
      if (shape_x is None or shape_y is None or None in shape_x or
          None in shape_y):
        return
      shape_x = tf.compat.v1.placeholder_with_default(
          tf.zeros(shape_x, dtype=tf.float32), shape=shape_x).shape
      shape_y = tf.compat.v1.placeholder_with_default(
          tf.zeros(shape_y, dtype=tf.float32), shape=shape_y).shape
    else:
      shape_x = tf.compat.v1.placeholder(shape=shape_x, dtype=tf.float32).shape
      shape_y = tf.compat.v1.placeholder(shape=shape_y, dtype=tf.float32).shape

    self.assertEqual(
        shape.is_broadcast_compatible(shape_x, shape_y), broadcastable) 
Example #3
Source File: shape_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_get_broadcasted_shape(self, shape_x, shape_y, broadcasted_shape):
    """Checks if the get_broadcasted_shape function works as expected."""
    if tf.executing_eagerly():
      if (shape_x is None or shape_y is None or None in shape_x or
          None in shape_y):
        return
      shape_x = tf.compat.v1.placeholder_with_default(
          tf.zeros(shape_x, dtype=tf.float32), shape=shape_x).shape
      shape_y = tf.compat.v1.placeholder_with_default(
          tf.zeros(shape_y, dtype=tf.float32), shape=shape_y).shape
    else:
      shape_x = tf.compat.v1.placeholder(shape=shape_x, dtype=tf.float32).shape
      shape_y = tf.compat.v1.placeholder(shape=shape_y, dtype=tf.float32).shape

    self.assertAllEqual(
        shape.get_broadcasted_shape(shape_x, shape_y), broadcasted_shape) 
Example #4
Source File: shape_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_compare_batch_dimensions_raises_exceptions(self, error_msg,
                                                      tensor_shapes, last_axes,
                                                      broadcast_compatible):
    """Tests that compare_batch_dimensions raises expected exceptions."""
    if not tensor_shapes:
      tensors = 0
    else:
      if all(shape.is_static(tensor_shape) for tensor_shape in tensor_shapes):
        tensors = [tf.ones(tensor_shape) for tensor_shape in tensor_shapes]
      else:
        # Dynamic shapes are not supported in eager mode.
        if tf.executing_eagerly():
          return
        tensors = [
            tf.compat.v1.placeholder(shape=tensor_shape, dtype=tf.float32)
            for tensor_shape in tensor_shapes
        ]
    self.assert_exception_is_raised(
        shape.compare_batch_dimensions,
        error_msg,
        shapes=[],
        tensors=tensors,
        last_axes=last_axes,
        broadcast_compatible=broadcast_compatible) 
Example #5
Source File: shape_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_compare_batch_dimensions_raises_no_exceptions(
      self, tensor_shapes, last_axes, broadcast_compatible, initial_axes):
    """Tests that compare_batch_dimensions works for various inputs."""
    if all(shape.is_static(tensor_shape) for tensor_shape in tensor_shapes):
      tensors = [tf.ones(tensor_shape) for tensor_shape in tensor_shapes]
    else:
      # Dynamic shapes are not supported in eager mode.
      if tf.executing_eagerly():
        return
      tensors = [
          tf.compat.v1.placeholder(shape=tensor_shape, dtype=tf.float32)
          for tensor_shape in tensor_shapes
      ]
    self.assert_exception_is_not_raised(
        shape.compare_batch_dimensions,
        shapes=[],
        tensors=tensors,
        last_axes=last_axes,
        broadcast_compatible=broadcast_compatible,
        initial_axes=initial_axes) 
Example #6
Source File: test_case.py    From graphics with Apache License 2.0 6 votes vote down vote up
def assert_jacobian_is_finite_fn(self, f, x):
    """Tests that the Jacobian only contains valid values.

    The analytical gradients and numerical ones are expected to differ at points
    where f(x) is not smooth. This function can be used to check that the
    analytical gradient is not 'NaN' nor 'Inf'.

    Args:
      f: the function.
      x: A list of arguments for the function
    """
    if tf.executing_eagerly():
      theoretical_gradient, _ = tf.compat.v2.test.compute_gradient(f, x)
    else:
      with self.cached_session():
        theoretical_gradient, _ = tf.compat.v2.test.compute_gradient(f, x)
    self.assertNotIn(
        True, [
            np.isnan(element).any() or np.isinf(element).any()
            for element in theoretical_gradient
        ],
        msg="nan or inf elements found in theoretical jacobian.") 
Example #7
Source File: test_case.py    From graphics with Apache License 2.0 6 votes vote down vote up
def assert_jacobian_is_correct_fn(self, f, x, atol=1e-6, delta=1e-6):
    """Tests that the gradient error of y=f(x) is small.

    Args:
      f: the function.
      x: A list of arguments for the function
      atol: Maximum absolute tolerance in gradient error.
      delta: The amount of perturbation.
    """
    # pylint: disable=no-value-for-parameter
    if tf.executing_eagerly():
      max_error = _max_error(*tf.test.compute_gradient(f, x, delta))
    else:
      with self.cached_session():
        max_error = _max_error(*tf.test.compute_gradient(f, x, delta))
    # pylint: enable=no-value-for-parameter
    self.assertLessEqual(max_error, atol) 
Example #8
Source File: test_case.py    From graphics with Apache License 2.0 6 votes vote down vote up
def assert_jacobian_is_correct(self, x, x_init, y, atol=1e-6, delta=1e-6):
    """Tests that the gradient error of y=f(x) is small.

    Args:
      x: A tensor.
      x_init: A numpy array containing the values at which to estimate the
        gradients of y.
      y: A tensor.
      atol: Maximum absolute tolerance in gradient error.
      delta: The amount of perturbation.
    """
    warnings.warn((
        "assert_jacobian_is_correct is deprecated and might get "
        "removed in a future version please use assert_jacobian_is_correct_fn"),
                  DeprecationWarning)
    if tf.executing_eagerly():
      self.skipTest(reason="Graph mode only test")
    max_error, _, _ = self._compute_gradient_error(x, y, x_init, delta)
    self.assertLessEqual(max_error, atol) 
Example #9
Source File: helpers_test.py    From tensorflow_constrained_optimization with Apache License 2.0 6 votes vote down vote up
def test_get_num_columns_of_2d_tensor(self):
    """Tests the "get_num_columns_of_2d_tensor" function."""
    self.assertFalse(tf.executing_eagerly())

    # Trying to get the number of columns from a non-tensor should fail.
    with self.assertRaises(TypeError):
      _ = helpers.get_num_columns_of_2d_tensor([[1, 2], [3, 4]])

    # Trying to get the number of columns from a rank-1 tensor should fail.
    tensor = tf.convert_to_tensor([1, 2, 3, 4])
    with self.assertRaises(ValueError):
      _ = helpers.get_num_columns_of_2d_tensor(tensor)

    # Make sure that we successfully get the number of columns.
    tensor = tf.convert_to_tensor([[1, 2, 3]])
    self.assertEqual(3, helpers.get_num_columns_of_2d_tensor(tensor)) 
Example #10
Source File: graph_convolution_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_dynamic_graph_convolution_keras_layer_duplicate_features(
      self, num_vertices, in_channels, out_channels):
    """Tests convolution when all vertex features are identical."""
    if not tf.executing_eagerly():
      return
    data = np.random.uniform(size=(1, in_channels))
    data = np.tile(data, (num_vertices, 1))
    # Results should be independent of 'neighbors'.
    neighbors = np.maximum(np.random.randint(
        0, 2, size=(num_vertices, num_vertices)), np.eye(num_vertices))
    neighbors = _dense_to_sparse(neighbors)
    layer = gc_layer.DynamicGraphConvolutionKerasLayer(
        num_output_channels=out_channels,
        reduction="max")

    output = layer(inputs=[data, neighbors], sizes=None)

    output_tile = tf.tile(output[:1, :], (num_vertices, 1))

    self.assertAllEqual(output, output_tile) 
Example #11
Source File: graph_convolution_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_dynamic_graph_convolution_keras_layer_exception_not_raised_shapes(
      self, batch_size, num_vertices, in_channels, out_channels, reduction):
    """Check if the convolution parameters and output have correct shapes."""
    if not tf.executing_eagerly():
      return
    data, neighbors = _dummy_data(batch_size, num_vertices, in_channels)
    layer = gc_layer.DynamicGraphConvolutionKerasLayer(
        num_output_channels=out_channels,
        reduction=reduction)

    try:
      output = layer(inputs=[data, neighbors], sizes=None)
    except Exception as e:  # pylint: disable=broad-except
      self.fail("Exception raised: %s" % str(e))

    self.assertAllEqual((batch_size, num_vertices, out_channels), output.shape) 
Example #12
Source File: axis_angle_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_inverse_jacobian_random(self):
    """Test the Jacobian of the inverse function."""
    x_axis_init, x_angle_init = test_helpers.generate_random_test_axis_angle()

    if tf.executing_eagerly():
      # Because axis is returned as is, gradient calculation fails in graph mode
      # but not in eager mode. This is a side effect of having a graph rather
      # than a problem of the function.
      with self.subTest("axis"):
        self.assert_jacobian_is_correct_fn(
            lambda x: axis_angle.inverse(1.0 * x, x_angle_init)[0],
            [x_axis_init])

    with self.subTest("angle"):
      self.assert_jacobian_is_correct_fn(
          lambda x: axis_angle.inverse(x_axis_init, x)[1], [x_angle_init]) 
Example #13
Source File: rasterizer_op_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_invalid_variable_inputs(self, error_msg, variable_names,
                                   variable_kinds, variable_values, error_eager,
                                   error_graph_mode):
    height = 1
    width = 1
    empty_shader_code = "#version 460\n void main() { }\n"
    if tf.executing_eagerly():
      error = error_eager
    else:
      error = error_graph_mode
    with self.assertRaisesRegexp(error, error_msg):
      self.evaluate(
          rasterizer.rasterize(
              num_points=0,
              variable_names=variable_names,
              variable_kinds=variable_kinds,
              variable_values=variable_values,
              output_resolution=(width, height),
              vertex_shader=empty_shader_code,
              geometry_shader=empty_shader_code,
              fragment_shader=empty_shader_code)) 
Example #14
Source File: levenberg_marquardt_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_minimize_ill_conditioned_not_raised(self):
    """Optimizing an ill conditioned problem should not raise an exception."""
    if not tf.executing_eagerly():
      return

    def f1(x, y):
      return x * y * 10000.0

    def f2(x, y):
      return x * y * 0.0001

    x = (1.,)
    y = (1.,)
    try:
      self.evaluate(
          levenberg_marquardt.minimize(
              residuals=(f1, f2),
              variables=(x, y),
              max_iterations=1,
              regularizer=1e-20))
    except Exception as e:  # pylint: disable=broad-except
      self.fail("Exception raised: %s" % str(e)) 
Example #15
Source File: common_layers.py    From BERT with Apache License 2.0 6 votes vote down vote up
def summarize_video(video, prefix, max_outputs=1):
  """Summarize the video using image summaries starting with prefix."""
  video_shape = shape_list(video)
  if len(video_shape) != 5:
    raise ValueError("Assuming videos given as tensors in the format "
                     "[batch, time, height, width, channels] but got one "
                     "of shape: %s" % str(video_shape))
  if tf.executing_eagerly():
    return
  if video.get_shape().as_list()[1] is None:
    tf.summary.image(
        "%s_last_frame" % prefix,
        tf.cast(video[:, -1, :, :, :], tf.uint8),
        max_outputs=max_outputs)
  else:
    for k in range(video_shape[1]):
      tf.summary.image(
          "%s_frame_%d" % (prefix, k),
          tf.cast(video[:, k, :, :, :], tf.uint8),
          max_outputs=max_outputs) 
Example #16
Source File: test_utils_test.py    From BERT with Apache License 2.0 6 votes vote down vote up
def test_run_in_graph_and_eager_modes(self):
    l = []
    def inc(self, with_brackets):
      del self  # self argument is required by run_in_graph_and_eager_modes.
      mode = "eager" if tf.executing_eagerly() else "graph"
      with_brackets = "with_brackets" if with_brackets else "without_brackets"
      l.append((with_brackets, mode))

    f = test_utils.run_in_graph_and_eager_modes(inc)
    f(self, with_brackets=False)
    f = test_utils.run_in_graph_and_eager_modes()(inc)
    f(self, with_brackets=True)

    self.assertEqual(len(l), 4)
    self.assertEqual(set(l), {
        ("with_brackets", "graph"),
        ("with_brackets", "eager"),
        ("without_brackets", "graph"),
        ("without_brackets", "eager"),
    }) 
Example #17
Source File: test_utils_test.py    From BERT with Apache License 2.0 6 votes vote down vote up
def test_run_in_graph_and_eager_modes_setup_in_same_mode(self):
    modes = []
    mode_name = lambda: "eager" if tf.executing_eagerly() else "graph"

    class ExampleTest(tf.test.TestCase):

      def runTest(self):
        pass

      def setUp(self):
        modes.append("setup_" + mode_name())

      @test_utils.run_in_graph_and_eager_modes
      def testBody(self):
        modes.append("run_" + mode_name())

    e = ExampleTest()
    e.setUp()
    e.testBody()

    self.assertEqual(modes[0:2], ["setup_eager", "run_eager"])
    self.assertEqual(modes[2:], ["setup_graph", "run_graph"]) 
Example #18
Source File: model.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_params(self):
        """
        Provides access to the model's parameters.
        :return: A list of all Variables defining the model parameters.
        """
        # Catch eager execution and assert function overload.
        try:
            if tf.executing_eagerly():
                raise NotImplementedError("For Eager execution - get_params "
                                          "must be overridden.")
        except AttributeError:
            pass

        # For Graoh based execution
        scope_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                       self.scope)
        return scope_vars 
Example #19
Source File: graph_convolution_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_dynamic_graph_convolution_keras_layer_zero_kernel(
      self, batch_size, num_vertices, in_channels, out_channels, reduction):
    """Tests convolution with an all-zeros kernel."""
    if not tf.executing_eagerly():
      return
    data, neighbors = _dummy_data(batch_size, num_vertices, in_channels)
    data = np.random.uniform(size=data.shape).astype(np.float32)
    layer = gc_layer.DynamicGraphConvolutionKerasLayer(
        num_output_channels=out_channels,
        reduction=reduction,
        use_bias=False,
        kernel_initializer=tf.compat.v1.keras.initializers.zeros())

    output = layer(inputs=[data, neighbors], sizes=None)

    self.assertAllEqual(
        output,
        np.zeros(shape=(batch_size, num_vertices, out_channels),
                 dtype=np.float32)) 
Example #20
Source File: multistep_optimizer.py    From BERT with Apache License 2.0 5 votes vote down vote up
def _get_iter_variable(self):
    graph = (
        None if tf.executing_eagerly() else tf.get_default_graph())
    return self._get_non_slot_variable("iter", graph=graph) 
Example #21
Source File: simplepose_coco.py    From imgclsmob with MIT License 5 votes vote down vote up
def call(self, x, training=None):
        x = self.backbone(x, training=training)
        heatmap = self.decoder(x, training=training)
        if self.return_heatmap or not tf.executing_eagerly():
            return heatmap
        else:
            keypoints = self.heatmap_max_det(heatmap)
            return keypoints 
Example #22
Source File: alphapose_coco.py    From imgclsmob with MIT License 5 votes vote down vote up
def call(self, x, training=None):
        x = self.backbone(x, training=training)
        heatmap = self.decoder(x, training=training)
        if self.return_heatmap or not tf.executing_eagerly():
            return heatmap
        else:
            keypoints = self.heatmap_max_det(heatmap)
            return keypoints 
Example #23
Source File: centernet.py    From imgclsmob with MIT License 5 votes vote down vote up
def call(self, x, training=None):
        x = self.backbone(x, training=training)
        x = self.decoder(x, training=training)
        if not self.return_heatmap or not tf.executing_eagerly():
            x = self.heatmap_max_det(x)
        return x 
Example #24
Source File: test_case.py    From graphics with Apache License 2.0 5 votes vote down vote up
def assert_jacobian_is_finite(self, x, x_init, y):
    """Tests that the Jacobian only contains valid values.

    The analytical gradients and numerical ones are expected to differ at points
    where y is not smooth. This function can be used to check that the
    analytical gradient is not NaN nor Inf.

    Args:
      x: A tensor.
      x_init: A numpy array containing the values at which to estimate the
        gradients of y.
      y: A tensor.
    """
    warnings.warn((
        "assert_jacobian_is_finite is deprecated and might get "
        "removed in a future version please use assert_jacobian_is_finite_fn"),
                  DeprecationWarning)
    if tf.executing_eagerly():
      self.skipTest(reason="Graph mode only test")
    x_shape = x.shape.as_list()
    y_shape = y.shape.as_list()
    with tf.compat.v1.Session():
      gradient = tf.compat.v1.test.compute_gradient(
          x, x_shape, y, y_shape, x_init_value=x_init)
      theoretical_gradient = gradient[0][0]
      self.assertFalse(
          np.isnan(theoretical_gradient).any() or
          np.isinf(theoretical_gradient).any()) 
Example #25
Source File: modalities.py    From BERT with Apache License 2.0 5 votes vote down vote up
def image_bottom(x, model_hparams, vocab_size):
  del model_hparams, vocab_size  # unused arg
  with tf.variable_scope("image_modality"):
    if not tf.executing_eagerly():
      tf.summary.image(
          "inputs", common_layers.tpu_safe_image_summary(x), max_outputs=2)
    return tf.to_float(x) 
Example #26
Source File: modalities.py    From BERT with Apache License 2.0 5 votes vote down vote up
def image_targets_bottom(x, model_hparams, vocab_size):
  """Bottom transformation for target images."""
  pixel_embedding_size = 64
  inputs = x
  with tf.variable_scope("image_modality"):
    if not tf.executing_eagerly():
      tf.summary.image(
          "targets_bottom",
          common_layers.tpu_safe_image_summary(inputs),
          max_outputs=1)
    inputs_shape = common_layers.shape_list(inputs)
    if len(inputs_shape) != 4:
      raise ValueError("Assuming images given as int tensors in the format "
                       "[batch, height, width, channels] (256 values).")
    # We embed each of 256=vocab_size possible pixel values.
    embedding_var = tf.get_variable(
        "pixel_embedding",
        [vocab_size, pixel_embedding_size])
    hot_inputs = tf.one_hot(tf.to_int32(inputs), vocab_size)
    hot_inputs = tf.reshape(hot_inputs, [-1, vocab_size])
    embedded = tf.matmul(hot_inputs, embedding_var)
    # Let's now merge all channels that were embedded into a single vector.
    merged_size = pixel_embedding_size * inputs_shape[3]
    embedded = tf.reshape(embedded, inputs_shape[:3] + [merged_size])
    merged = tf.layers.dense(
        embedded,
        model_hparams.hidden_size,
        name="merge_pixel_embedded_channels")
    return merged 
Example #27
Source File: common_layers.py    From BERT with Apache License 2.0 5 votes vote down vote up
def embedding(x,
              vocab_size,
              dense_size,
              name=None,
              reuse=None,
              multiplier=1.0,
              symbol_dropout_rate=0.0,
              embedding_var=None,
              dtype=tf.float32):
  """Embed x of type int64 into dense vectors, reducing to max 4 dimensions."""
  with tf.variable_scope(
      name, default_name="embedding", values=[x], reuse=reuse, dtype=dtype):
    if embedding_var is None:
      embedding_var = tf.get_variable("kernel", [vocab_size, dense_size])
    # On the backwards pass, we want to convert the gradient from
    # an indexed-slices to a regular tensor before sending it back to the
    # parameter server. This avoids excess computation on the parameter server.
    if not tf.executing_eagerly():
      embedding_var = convert_gradient_to_tensor(embedding_var)
    x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate)
    emb_x = gather(embedding_var, x, dtype)
    if multiplier != 1.0:
      emb_x *= multiplier
    static_shape = emb_x.shape.as_list()
    if len(static_shape) < 5:
      return emb_x
    assert len(static_shape) == 5
    # If we had an extra channel dimension, assume it's 1, i.e. shape[3] == 1.
    return tf.squeeze(emb_x, 3) 
Example #28
Source File: common_layers.py    From BERT with Apache License 2.0 5 votes vote down vote up
def reshape_like_all_dims(a, b):
  """Reshapes a to match the shape of b."""
  ret = tf.reshape(a, tf.shape(b))
  if not tf.executing_eagerly():
    ret.set_shape(b.get_shape())
  return ret 
Example #29
Source File: common_layers.py    From BERT with Apache License 2.0 5 votes vote down vote up
def reshape_like(a, b):
  """Reshapes a to match the shape of b in all but the last dimension."""
  ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0))
  if not tf.executing_eagerly():
    ret.set_shape(b.get_shape().as_list()[:-1] + a.get_shape().as_list()[-1:])
  return ret 
Example #30
Source File: inputs.py    From BERT with Apache License 2.0 5 votes vote down vote up
def _train_and_eval_dataset_v1(problem_name, data_dir):
  """Return train and evaluation datasets, feature info and supervised keys."""
  with tf.device("cpu:0"):
    problem = t2t_problems.problem(problem_name)
    train_dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, data_dir)
    train_dataset = train_dataset.map(_select_features)
    eval_dataset = problem.dataset(tf.estimator.ModeKeys.EVAL, data_dir)
    eval_dataset = eval_dataset.map(_select_features)
    hparams = problem.get_hparams()
    # We take a few training examples to guess the shapes.
    input_shapes, target_shapes, examples = [], [], []
    if tf.executing_eagerly():
      for example in _eager_dataset_iterator(train_dataset.take(3)):
        examples.append(example)
    else:
      example_tensor = train_dataset.make_one_shot_iterator().get_next()
      sess = tf.Session()
      example1 = sess.run(example_tensor)
      example2 = sess.run(example_tensor)
      example3 = sess.run(example_tensor)
      examples = [example1, example2, example3]
  # We use "inputs" as input except for purely auto-regressive tasks like
  # language models where "targets" are used as input_key.
  input_key = "inputs" if "inputs" in examples[0] else "targets"
  supervised_keys = ([input_key], ["targets"])
  for example in examples:
    input_shapes.append(list(example[input_key].shape))
    target_shapes.append(list(example["targets"].shape))
  input_vocab_size = hparams.vocab_size[input_key]
  target_vocab_size = hparams.vocab_size["targets"]
  input_dtype = examples[0][input_key].dtype
  target_dtype = examples[0]["targets"].dtype
  input_info = _make_info(input_shapes, input_vocab_size, input_dtype)
  target_info = _make_info(target_shapes, target_vocab_size, target_dtype)
  info = {input_key: input_info, "targets": target_info}
  return train_dataset, eval_dataset, info, supervised_keys