Python tensorflow.python.ops.variables.Variable() Examples

The following are 30 code examples of tensorflow.python.ops.variables.Variable(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.variables , or try the search function .
Example #1
Source File: graph_builder_test.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def testEmbeddingOp(self):
    graph = tf.Graph()
    with self.test_session(graph=graph):
      params = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                           tf.float32)

      var = variables.Variable([self.MakeSparseFeatures([1, 2], [1.0, 1.0]),
                                self.MakeSparseFeatures([], [])])
      var.initializer.run()
      embeddings = graph_builder.EmbeddingLookupFeatures(params, var,
                                                         True).eval()
      self.assertAllClose([[8.0, 10.0], [0.0, 0.0]], embeddings)

      var = variables.Variable([self.MakeSparseFeatures([], []),
                                self.MakeSparseFeatures([0, 2],
                                                        [0.5, 2.0])])
      var.initializer.run()
      embeddings = graph_builder.EmbeddingLookupFeatures(params, var,
                                                         True).eval()
      self.assertAllClose([[0.0, 0.0], [10.5, 13.0]], embeddings) 
Example #2
Source File: factorization_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def _prepare_gramian(self, factors, gramian):
    """Helper function to create ops to prepare/calculate gramian.

    Args:
      factors: Variable or list of Variable representing (sharded) factors.
        Used to compute the updated corresponding gramian value.
      gramian: Variable storing the gramian calculated from the factors.

    Returns:
      A op that updates the gramian with the calcuated value from the factors.
    """
    partial_gramians = []
    for f in factors:
      with ops.colocate_with(f):
        partial_gramians.append(math_ops.matmul(f, f, transpose_a=True))

    with ops.colocate_with(gramian):
      prep_gramian = state_ops.assign(gramian,
                                      math_ops.add_n(partial_gramians)).op

    return prep_gramian 
Example #3
Source File: variable_scope.py    From lambda-packs with MIT License 6 votes vote down vote up
def variable(initial_value=None,
             trainable=True,
             collections=None,
             validate_shape=True,
             caching_device=None,
             name=None,
             dtype=None):
  if get_variable_scope().use_resource:
    return resource_variable_ops.ResourceVariable(
        initial_value=initial_value, trainable=trainable,
        collections=collections, validate_shape=validate_shape,
        caching_device=caching_device, name=name, dtype=dtype)
  else:
    return variables.Variable(
        initial_value=initial_value, trainable=trainable,
        collections=collections, validate_shape=validate_shape,
        caching_device=caching_device, name=name, dtype=dtype) 
Example #4
Source File: imperative_graph.py    From lambda-packs with MIT License 6 votes vote down vote up
def record_variable_inits(self):
    """Context manager to record Variable initializations.

    Sets _in_variable_creation to True before a Variable is initialized.

    NOTE(keveman): This is used for recording the list of assign ops
    that are used to initialize variables. It relies on the fact that
    the constructor of Variable class creates exactly one assign op that is
    used for initializing the variable. Variable ops not created using the
    variables.Variable class are not added to _init_ops and hence not
    initialized automatically.

    """
    old_init = getattr(variables.Variable, '__init__')

    def record(*args, **kwargs):
      self._in_variable_creation = True
      old_init(*args, **kwargs)
      self._in_variable_creation = False

    setattr(variables.Variable, '__init__', record)
    yield
    setattr(variables.Variable, '__init__', old_init)
  # pylint: enable=g-doc-return-or-yield 
Example #5
Source File: optimizer.py    From lambda-packs with MIT License 6 votes vote down vote up
def _apply_sparse(self, grad, var):
    """Add ops to apply sparse gradients to `var`.

    The IndexedSlices object passed to `grad` in this function is by default
    pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
    indices (see its docstring for details). Optimizers which can tolerate or
    have correct special cases for duplicate sparse indices may override
    `_apply_sparse_duplicate_indices` instead of this function, avoiding that
    overhead.

    Args:
      grad: `IndexedSlices`, with no repeated indices.
      var: A `Variable` object.

    Return:
      An `Operation`.
    """
    raise NotImplementedError() 
Example #6
Source File: cli_shared.py    From lambda-packs with MIT License 6 votes vote down vote up
def _get_fetch_names(fetches):
  """Get a flattened list of the names in run() call fetches.

  Args:
    fetches: Fetches of the `Session.run()` call. It maybe a Tensor, an
      Operation or a Variable. It may also be nested lists, tuples or
      dicts. See doc of `Session.run()` for more details.

  Returns:
    (list of str) A flattened list of fetch names from `fetches`.
  """

  lines = []
  if isinstance(fetches, (list, tuple)):
    for fetch in fetches:
      lines.extend(_get_fetch_names(fetch))
  elif isinstance(fetches, dict):
    for key in fetches:
      lines.extend(_get_fetch_names(fetches[key]))
  else:
    # This ought to be a Tensor, an Operation or a Variable, for which the name
    # attribute should be available. (Bottom-out condition of the recursion.)
    lines.append(_get_fetch_name(fetches))

  return lines 
Example #7
Source File: utils.py    From lambda-packs with MIT License 6 votes vote down vote up
def constant_value(pred):
  """Return the bool value for `pred`, or None if `pred` had a dynamic value.

  Arguments:
    pred: A scalar, either a Python bool or a TensorFlow boolean variable
      or tensor.

  Returns:
    True or False if `pred` has a constant boolean value, None otherwise.

  Raises:
    TypeError is pred is not a Variable, Tensor or bool.
  """
  if isinstance(pred, bool):
    pred_value = pred
  elif isinstance(pred, variables.Variable):
    pred_value = None
  elif isinstance(pred, ops.Tensor):
    pred_value = tensor_util.constant_value(pred)
  else:
    raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.')
  return pred_value 
Example #8
Source File: optimizer.py    From lambda-packs with MIT License 6 votes vote down vote up
def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
                                         slot_name, op_name):
    """Find or create a slot for a variable, using an Initializer.

    Args:
      var: A `Variable` object.
      initializer: An `Initializer`.  The initial value of the slot.
      shape: Shape of the initial value of the slot.
      dtype: Type of the value of the slot.
      slot_name: Name for the slot.
      op_name: Name to use when scoping the Variable that
        needs to be created for  the slot.

    Returns:
      A `Variable` object.
    """
    named_slots = self._slot_dict(slot_name)
    if _var_key(var) not in named_slots:
      named_slots[_var_key(var)] = slot_creator.create_slot_with_initializer(
          var, initializer, shape, dtype, op_name)
    return named_slots[_var_key(var)] 
Example #9
Source File: optimizer.py    From lambda-packs with MIT License 6 votes vote down vote up
def _zeros_slot(self, var, slot_name, op_name):
    """Find or create a slot initialized with 0.0.

    Args:
      var: A `Variable` object.
      slot_name: Name for the slot.
      op_name: Name to use when scoping the Variable that
        needs to be created for  the slot.

    Returns:
      A `Variable` object.
    """
    named_slots = self._slot_dict(slot_name)
    if _var_key(var) not in named_slots:
      named_slots[_var_key(var)] = slot_creator.create_zeros_slot(var, op_name)
    return named_slots[_var_key(var)] 
Example #10
Source File: training_util.py    From lambda-packs with MIT License 6 votes vote down vote up
def assert_global_step(global_step_tensor):
  """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.

  Args:
    global_step_tensor: `Tensor` to test.
  """
  if not (isinstance(global_step_tensor, variables.Variable) or
          isinstance(global_step_tensor, ops.Tensor) or
          isinstance(global_step_tensor,
                     resource_variable_ops.ResourceVariable)):
    raise TypeError(
        'Existing "global_step" must be a Variable or Tensor: %s.' %
        global_step_tensor)

  if not global_step_tensor.dtype.base_dtype.is_integer:
    raise TypeError('Existing "global_step" does not have integer type: %s' %
                    global_step_tensor.dtype)

  if global_step_tensor.get_shape().ndims != 0:
    raise TypeError('Existing "global_step" is not scalar: %s' %
                    global_step_tensor.get_shape()) 
Example #11
Source File: training_util.py    From lambda-packs with MIT License 6 votes vote down vote up
def global_step(sess, global_step_tensor):
  """Small helper to get the global step.

  ```python
  # Creates a variable to hold the global_step.
  global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
  # Creates a session.
  sess = tf.Session()
  # Initializes the variable.
  print('global_step: %s' % tf.train.global_step(sess, global_step_tensor))

  global_step: 10
  ```

  Args:
    sess: A TensorFlow `Session` object.
    global_step_tensor:  `Tensor` or the `name` of the operation that contains
      the global step.

  Returns:
    The global step value.
  """
  return int(sess.run(global_step_tensor)) 
Example #12
Source File: checkpoint_utils.py    From lambda-packs with MIT License 6 votes vote down vote up
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
  """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  restore_op = io_ops.restore_v2(
      ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
  variable._initializer_op = state_ops.assign(variable, restore_op)  # pylint:disable=protected-access 
Example #13
Source File: session_debug_testlib.py    From lambda-packs with MIT License 6 votes vote down vote up
def testDebugCondWatchingWholeGraphWorks(self):
    with session.Session() as sess:
      x = variables.Variable(10.0, name="x")
      y = variables.Variable(20.0, name="y")
      cond = control_flow_ops.cond(
          x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))

      sess.run(variables.global_variables_initializer())

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(run_options,
                              sess.graph,
                              debug_urls=self._debug_urls())
      run_metadata = config_pb2.RunMetadata()
      self.assertEqual(
          21, sess.run(cond, options=run_options, run_metadata=run_metadata))

      dump = debug_data.DebugDumpDir(
          self._dump_root, partition_graphs=run_metadata.partition_graphs)
      self.assertAllClose(
          [21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity")) 
Example #14
Source File: local_cli_wrapper_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def setUp(self):
    self._tmp_dir = tempfile.mktemp()

    self.v = variables.Variable(10.0, name="v")
    self.delta = constant_op.constant(1.0, name="delta")
    self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")

    self.ph = array_ops.placeholder(dtypes.float32, name="ph")
    self.xph = array_ops.transpose(self.ph, name="xph")
    self.m = constant_op.constant(
        [[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]], dtype=dtypes.float32, name="m")
    self.y = math_ops.matmul(self.m, self.xph, name="y")

    self.sess = session.Session()

    # Initialize variable.
    self.sess.run(self.v.initializer) 
Example #15
Source File: stepper_cli_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def setUp(self):
    self.a = variables.Variable(10.0, name="a")
    self.b = variables.Variable(20.0, name="b")

    self.c = math_ops.add(self.a, self.b, name="c")  # Should be 30.0.
    self.d = math_ops.subtract(self.a, self.c, name="d")  # Should be -20.0.
    self.e = math_ops.multiply(self.c, self.d, name="e")  # Should be -600.0.

    self.ph = array_ops.placeholder(dtypes.float32, shape=(2, 2), name="ph")
    self.f = math_ops.multiply(self.e, self.ph, name="f")

    self.opt = gradient_descent.GradientDescentOptimizer(0.1).minimize(
        self.e, name="opt")

    self.sess = session.Session()

    self.sess.run(self.a.initializer)
    self.sess.run(self.b.initializer) 
Example #16
Source File: stepper_cli_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
    output = cli.cont(["opt/update_a/ApplyGradientDescent"])

    # After cont() call on .../update_a/..., Variable a should have been marked
    # as dirty, whereas b should not have.
    output = cli.list_sorted_nodes([])
    node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                  stat_labels[node_names.index("a")])
    self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                     stat_labels[node_names.index("b")])

    output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"])

    # After cont() call on .../update_b/... with the -r flag, Variable b should
    # have been marked as dirty, whereas Variable a should not be because it
    # should have been restored.
    output = cli.list_sorted_nodes([])
    node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                  stat_labels[node_names.index("b")])
    self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                     stat_labels[node_names.index("a")]) 
Example #17
Source File: metrics.py    From seglink with GNU General Public License v3.0 6 votes vote down vote up
def _create_local(name, shape, collections=None, validate_shape=True,
                  dtype=tf.float32):
    """Creates a new local variable.
    Args:
        name: The name of the new or existing variable.
        shape: Shape of the new or existing variable.
        collections: A list of collection names to which the Variable will be added.
        validate_shape: Whether to validate the shape of the variable.
        dtype: Data type of the variables.
    Returns:
        The created variable.
    """
    # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
    collections = list(collections or [])
    collections += [ops.GraphKeys.LOCAL_VARIABLES]
    return variables.Variable(
            initial_value=array_ops.zeros(shape, dtype=dtype),
            name=name,
            trainable=False,
            collections=collections,
            validate_shape=validate_shape) 
Example #18
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def parse_op_and_node(line):
  """Parse a line containing an op node followed by a node name.

  For example, if the line is
    "  [Variable] hidden/weights",
  this function will return ("Variable", "hidden/weights")

  Args:
    line: The line to be parsed, as a str.

  Returns:
    Name of the parsed op type.
    Name of the parsed node.
  """

  op_type = line.strip().split(" ")[0].replace("[", "").replace("]", "")

  # Not using [-1], to tolerate any other items that might be present behind
  # the node name.
  node_name = line.strip().split(" ")[1]

  return op_type, node_name 
Example #19
Source File: cli_shared.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _get_fetch_names(fetches):
  """Get a flattened list of the names in run() call fetches.

  Args:
    fetches: Fetches of the `Session.run()` call. It maybe a Tensor, an
      Operation or a Variable. It may also be nested lists, tuples or
      dicts. See doc of `Session.run()` for more details.

  Returns:
    (list of str) A flattened list of fetch names from `fetches`.
  """

  lines = []
  if isinstance(fetches, (list, tuple)):
    for fetch in fetches:
      lines.extend(_get_fetch_names(fetch))
  elif isinstance(fetches, dict):
    for key in fetches:
      lines.extend(_get_fetch_names(fetches[key]))
  else:
    # This ought to be a Tensor, an Operation or a Variable, for which the name
    # attribute should be available. (Bottom-out condition of the recursion.)
    lines.append(_get_fetch_name(fetches))

  return lines 
Example #20
Source File: rev_block_lib.py    From tensornets with MIT License 6 votes vote down vote up
def _underlying_variable_ref(t):
  """Find the underlying variable ref.

  Traverses through Identity, ReadVariableOp, and Enter ops.
  Stops when op type has Variable or VarHandle in name.

  Args:
    t: a Tensor

  Returns:
    a Tensor that is a variable ref, or None on error.
  """
  while t.op.type in ["Identity", "ReadVariableOp", "Enter"]:
    t = t.op.inputs[0]

  op_type = t.op.type
  if "Variable" in op_type or "VarHandle" in op_type:
    return t
  else:
    return None 
Example #21
Source File: session_debug_testlib.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testDebugNumericSummaryOnUninitializedTensorGivesCorrectResult(self):
    with session.Session() as sess:
      a = variables.Variable(
          [42], dtype=np.float32, name="numeric_summary_uninit/a")

      run_metadata = config_pb2.RunMetadata()
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugNumericSummary"],
          debug_urls=self._debug_urls())

      sess.run(a.initializer, options=run_options, run_metadata=run_metadata)

      dump = debug_data.DebugDumpDir(
          self._dump_root, partition_graphs=run_metadata.partition_graphs)
      self.assertTrue(dump.loaded_partition_graphs())

      # DebugNumericSummary output should reflect the uninitialized state of
      # the watched tensor.
      numeric_summary = dump.get_tensors("numeric_summary_uninit/a", 0,
                                         "DebugNumericSummary")[0]
      self.assertAllClose([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                          numeric_summary[0:8])
      self.assertTrue(np.isinf(numeric_summary[8]))
      self.assertGreater(numeric_summary[8], 0.0)
      self.assertTrue(np.isinf(numeric_summary[9]))
      self.assertLess(numeric_summary[9], 0.0)
      self.assertTrue(np.isnan(numeric_summary[10]))
      self.assertTrue(np.isnan(numeric_summary[11])) 
Example #22
Source File: debug_utils_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testWatchGraph_opTypeBlacklist(self):
    debug_utils.watch_graph_with_blacklists(
        self._run_options,
        self._graph,
        debug_urls="file:///tmp/tfdbg_1",
        op_type_regex_blacklist="(Variable|Identity|Assign|Const)")

    node_names = self._verify_watches(
        self._run_options.debug_options.debug_tensor_watch_opts, 0,
        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
    self.assertEqual(sorted(["p1", "s"]), sorted(node_names)) 
Example #23
Source File: debug_utils_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testWatchGraph_opTypeWhitelist(self):
    debug_utils.watch_graph(
        self._run_options,
        self._graph,
        debug_urls="file:///tmp/tfdbg_1",
        op_type_regex_whitelist="(Variable|MatMul)")

    node_names = self._verify_watches(
        self._run_options.debug_options.debug_tensor_watch_opts, 0,
        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
    self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names)) 
Example #24
Source File: debug_utils_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def setUpClass(cls):
    cls._sess = session.Session()
    with cls._sess:
      cls._a_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
      cls._b_init_val = np.array([[2.0], [-1.0]])
      cls._c_val = np.array([[-4.0], [np.nan]])

      cls._a_init = constant_op.constant(
          cls._a_init_val, shape=[2, 2], name="a1_init")
      cls._b_init = constant_op.constant(
          cls._b_init_val, shape=[2, 1], name="b_init")

      cls._a = variables.Variable(cls._a_init, name="a1")
      cls._b = variables.Variable(cls._b_init, name="b")
      cls._c = constant_op.constant(cls._c_val, shape=[2, 1], name="c")

      # Matrix product of a and b.
      cls._p = math_ops.matmul(cls._a, cls._b, name="p1")

      # Sum of two vectors.
      cls._s = math_ops.add(cls._p, cls._c, name="s")

    cls._graph = cls._sess.graph

    # These are all the expected nodes in the graph:
    #   Two variables (a, b), each with four nodes (Variable, init, Assign,
    #       read).
    #   One constant (c).
    #   One add operation and one matmul operation.
    cls._expected_num_nodes = 4 * 2 + 1 + 1 + 1 
Example #25
Source File: utils.py    From lambda-packs with MIT License 5 votes vote down vote up
def constant_value(value_or_tensor_or_var, dtype=None):
  """Returns value if value_or_tensor_or_var has a constant value.

  Args:
    value_or_tensor_or_var: A value, a `Tensor` or a `Variable`.
    dtype: Optional `tf.dtype`, if set it would check it has the right
      dtype.

  Returns:
    The constant value or None if it not constant.

  Raises:
    ValueError: if value_or_tensor_or_var is None or the tensor_variable has the
    wrong dtype.
  """
  if value_or_tensor_or_var is None:
    raise ValueError('value_or_tensor_or_var cannot be None')
  value = value_or_tensor_or_var
  if isinstance(value_or_tensor_or_var, (ops.Tensor, variables.Variable)):
    if dtype and value_or_tensor_or_var.dtype != dtype:
      raise ValueError('It has the wrong type %s instead of %s' % (
          value_or_tensor_or_var.dtype, dtype))
    if isinstance(value_or_tensor_or_var, variables.Variable):
      value = None
    else:
      value = tensor_util.constant_value(value_or_tensor_or_var)
  return value 
Example #26
Source File: factorization_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def _transient_var(name):
    """Helper function to create a Variable."""
    return variable_scope.variable(
        1.0,
        trainable=False,
        collections=[ops.GraphKeys.LOCAL_VARIABLES],
        validate_shape=False,
        name=name) 
Example #27
Source File: subscribe.py    From lambda-packs with MIT License 5 votes vote down vote up
def _recursive_apply(tensors, apply_fn):
  """Helper method to recursively apply a function to structure of tensors.

  The structure of the tensors should take the form similar to fetches in
  `tf.Session` and includes single `Tensor`, `list`, nested `list`, `tuple`,
  `namedtuple`, or `dict`.

  Args:
    tensors: Single `Tensor`, `list`, nested `list, `tuple`,
      `namedtuple`, or `dict`.
    apply_fn: Function to apply to each `Tensor` and should return a `Tensor`.
  Returns:
    Returns the modified tensors with the same structure.
  Raises:
    `TypeError` if undefined type in the tensors structure.
  """
  tensors_type = type(tensors)
  if tensors_type is ops.Tensor:
    return apply_fn(tensors)
  elif tensors_type is variables.Variable:
    return apply_fn(tensors.value())
  elif isinstance(tensors, (list, tuple)):
    tensors = [_recursive_apply(t, apply_fn) for t in tensors]
    if tensors_type is list:
      return list(tensors)
    elif tensors_type is tuple:
      return tuple(tensors)
    return tensors_type(*tensors)  # collections.namedtuple
  elif tensors_type is dict:
    return dict([(k, _recursive_apply(v, apply_fn))
                 for k, v in tensors.items()])
  else:
    raise TypeError('_recursive_apply argument %r has invalid type %r' %
                    (tensors, tensors_type)) 
Example #28
Source File: slot_creator.py    From lambda-packs with MIT License 5 votes vote down vote up
def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
  """Create a slot initialized to 0 with same shape as the primary object.

  Args:
    primary: The primary `Variable` or `Tensor`.
    name: Name to use for the slot variable.
    dtype: Type of the slot variable.  Defaults to the type of `primary`.
    colocate_with_primary: Boolean.  If True the slot is located
      on the same device as `primary`.

  Returns:
    A `Variable` object.
  """
  if dtype is None:
    dtype = primary.dtype
  slot_shape = primary.get_shape()
  slot_shape = (slot_shape if slot_shape.is_fully_defined()
                else array_ops.shape(primary.initialized_value()))
  if slot_shape.is_fully_defined():
    initializer = init_ops.zeros_initializer(dtype)
    return create_slot_with_initializer(
        primary, initializer, slot_shape, dtype, name,
        colocate_with_primary=colocate_with_primary)
  else:
    val = array_ops.zeros(slot_shape, dtype=dtype)
    return create_slot(primary, val, name,
                       colocate_with_primary=colocate_with_primary) 
Example #29
Source File: checkpoint_utils.py    From lambda-packs with MIT License 5 votes vote down vote up
def _collect_partitioned_variable(name, all_vars):
  """Returns list of `tf.Variable` that comprise the partitioned variable."""
  if name + "/part_0" in all_vars:
    var = []
    i = 0
    while name + "/part_%d" % i in all_vars:
      var.append(all_vars[name + "/part_%d" % i])
      i += 1
    return var
  return None 
Example #30
Source File: checkpoint_utils.py    From lambda-packs with MIT License 5 votes vote down vote up
def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
                                      tensor_name):
  """Overrides initialization op of given variable or list of variables.

  Calls `_set_checkpoint_initializer` for each variable in the given list of
  variables.

  Args:
    variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.

  Raises:
    ValueError: if all objects in `variable_or_list` are not partitions of the
      same large variable.
  """
  if isinstance(variable_or_list, (list, tuple)):
    # A set of slices.
    slice_name = None
    for v in variable_or_list:
      slice_info = v._save_slice_info  # pylint:disable=protected-access
      if slice_name is None:
        slice_name = slice_info.full_name
      elif slice_name != slice_info.full_name:
        raise ValueError("Slices must all be from the same tensor: %s != %s" %
                         (slice_name, slice_info.full_name))
      _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec)
  else:
    _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")