Python tensorflow.get_collection_ref() Examples

The following are 30 code examples of tensorflow.get_collection_ref(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: ptb_word_lm.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def import_ops(self):
    """Imports ops from collections."""
    if self._is_training:
      self._train_op = tf.get_collection_ref("train_op")[0]
      self._lr = tf.get_collection_ref("lr")[0]
      self._new_lr = tf.get_collection_ref("new_lr")[0]
      self._lr_update = tf.get_collection_ref("lr_update")[0]
      rnn_params = tf.get_collection_ref("rnn_params")
      if self._cell and rnn_params:
        params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
            self._cell,
            self._cell.params_to_canonical,
            self._cell.canonical_to_params,
            rnn_params,
            base_variable_scope="Model/RNN")
        tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
    self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
    num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
    self._initial_state = util.import_state_tuples(
        self._initial_state, self._initial_state_name, num_replicas)
    self._final_state = util.import_state_tuples(
        self._final_state, self._final_state_name, num_replicas) 
Example #2
Source File: graph_search_test.py    From kfac with Apache License 2.0 6 votes vote down vote up
def test_tied_weights_untied_bias_registered_bias(self):
    """Tests that ambiguity in graph raises value error.

    Graph search will find several possible registrations for tensors.
    In this registering b_1 as a linked variable will result in an error
    because there will remain an ambiguity on the other branch of the graph.
    """
    with tf.Graph().as_default():
      tensor_dict = _build_model()

      layer_collection = lc.LayerCollection()
      layer_collection.register_squared_error_loss(tensor_dict['out_0'])
      layer_collection.register_squared_error_loss(tensor_dict['out_1'])

      layer_collection.define_linked_parameters((tensor_dict['b_1']))

      with self.assertRaises(gs.AmbiguousRegistrationError):
        gs.register_layers(layer_collection,
                           tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) 
Example #3
Source File: ptb_word_lm.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 6 votes vote down vote up
def import_ops(self):
    """Imports ops from collections."""
    if self._is_training:
      self._train_op = tf.get_collection_ref("train_op")[0]
      self._lr = tf.get_collection_ref("lr")[0]
      self._new_lr = tf.get_collection_ref("new_lr")[0]
      self._lr_update = tf.get_collection_ref("lr_update")[0]
      rnn_params = tf.get_collection_ref("rnn_params")
      if self._cell and rnn_params:
        params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
            self._cell,
            self._cell.params_to_canonical,
            self._cell.canonical_to_params,
            rnn_params,
            base_variable_scope="Model/RNN")
        tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
    self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
    num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
    self._initial_state = util.import_state_tuples(
        self._initial_state, self._initial_state_name, num_replicas)
    self._final_state = util.import_state_tuples(
        self._final_state, self._final_state_name, num_replicas) 
Example #4
Source File: ptb_word_lm.py    From object_detection_kitti with Apache License 2.0 6 votes vote down vote up
def import_ops(self):
    """Imports ops from collections."""
    if self._is_training:
      self._train_op = tf.get_collection_ref("train_op")[0]
      self._lr = tf.get_collection_ref("lr")[0]
      self._new_lr = tf.get_collection_ref("new_lr")[0]
      self._lr_update = tf.get_collection_ref("lr_update")[0]
      rnn_params = tf.get_collection_ref("rnn_params")
      if self._cell and rnn_params:
        params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
            self._cell,
            self._cell.params_to_canonical,
            self._cell.canonical_to_params,
            rnn_params,
            base_variable_scope="Model/RNN")
        tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
    self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
    num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
    self._initial_state = util.import_state_tuples(
        self._initial_state, self._initial_state_name, num_replicas)
    self._final_state = util.import_state_tuples(
        self._final_state, self._final_state_name, num_replicas) 
Example #5
Source File: block_base.py    From object_detection_kitti with Apache License 2.0 6 votes vote down vote up
def MarkAsNonTrainable(self):
    """Mark all the variables of this block as non-trainable.

    All the variables owned directly or indirectly (through subblocks) are
    marked as non trainable.

    This function along with CheckpointInitOp can be used to load a pretrained
    model that consists in only one part of the whole graph.
    """
    assert self._called

    all_variables = self.VariableList()
    collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
    for v in all_variables:
      if v in collection:
        collection.remove(v) 
Example #6
Source File: config.py    From tensorflow-tbcnn with MIT License 6 votes vote down vote up
def initialize_tbcnn_weights(clz):
        clz.initialize_embedding_weights()
        # Don't train We
        tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES).remove(clz.get('We'))

        clz.create_variable('Wcomb1', (hyper.word_dim, hyper.word_dim),
                            tf.constant_initializer(-.2, .2))
        clz.create_variable('Wcomb2', (hyper.word_dim, hyper.word_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('Wconvt', (hyper.word_dim, hyper.conv_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('Wconvl', (hyper.word_dim, hyper.conv_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('Wconvr', (hyper.word_dim, hyper.conv_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('Bconv', (hyper.conv_dim,),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('FC1/weight', (hyper.conv_dim, hyper.fc_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('FC1/bias', (hyper.fc_dim,),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('FC2/weight', (hyper.fc_dim, hyper.output_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('FC2/bias', (hyper.output_dim, ),
                            tf.random_uniform_initializer(-.2, .2)) 
Example #7
Source File: graph_search_test.py    From kfac with Apache License 2.0 6 votes vote down vote up
def mixed_usage_test(self):
    """Tests that graph search raises error on mixed types usage for tensors.

    Tensors can be reused in various locations in the tensorflow graph. This
    occurs regularly in the case of recurrent models or models with parallel
    graphs. However the tensors must be used for the same operation in each
    location or graph search should raise an error.
    """
    with tf.Graph().as_default():
      w = tf.get_variable('W', [10, 10])
      x = tf.placeholder(tf.float32, shape=(32, 10))
      y = tf.placeholder(tf.float32, shape=(32, 10, 10))

      out_0 = tf.matmul(x, w)  # pylint: disable=unused-variable
      out_1 = y + w  # pylint: disable=unused-variable

      layer_collection = lc.LayerCollection()

      with self.assertRaises(ValueError) as cm:
        gs.register_layers(layer_collection,
                           tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))

      self.assertIn('mixed record types', str(cm.exception)) 
Example #8
Source File: ptb_word_lm.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def import_ops(self):
    """Imports ops from collections."""
    if self._is_training:
      self._train_op = tf.get_collection_ref("train_op")[0]
      self._lr = tf.get_collection_ref("lr")[0]
      self._new_lr = tf.get_collection_ref("new_lr")[0]
      self._lr_update = tf.get_collection_ref("lr_update")[0]
      rnn_params = tf.get_collection_ref("rnn_params")
      if self._cell and rnn_params:
        params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
            self._cell,
            self._cell.params_to_canonical,
            self._cell.canonical_to_params,
            rnn_params,
            base_variable_scope="Model/RNN")
        tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
    self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
    num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
    self._initial_state = util.import_state_tuples(
        self._initial_state, self._initial_state_name, num_replicas)
    self._final_state = util.import_state_tuples(
        self._final_state, self._final_state_name, num_replicas) 
Example #9
Source File: bayesian_rnn.py    From BayesianRecurrentNN with MIT License 6 votes vote down vote up
def import_ops(self):
		"""Imports ops from collections."""
		if self._is_training:
			self._train_op = tf.get_collection_ref("train_op")[0]
			self._lr = tf.get_collection_ref("lr")[0]
			self._new_lr = tf.get_collection_ref("new_lr")[0]
			self._lr_update = tf.get_collection_ref("lr_update")[0]
			rnn_params = tf.get_collection_ref("rnn_params")
			if self._cell and rnn_params:
				params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
					self._cell,
					self._cell.params_to_canonical,
					self._cell.canonical_to_params,
					rnn_params,
					base_variable_scope="Model/RNN")
				tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
		self._cost = tf.get_collection_ref(tf_util.with_prefix(self._name, "cost"))[0]
		self._kl_div = tf.get_collection_ref(tf_util.with_prefix(self._name, "kl_div"))[0]
		num_replicas = 1
		self._initial_state = tf_util.import_state_tuples(
			self._initial_state, self._initial_state_name, num_replicas)
		self._final_state = tf_util.import_state_tuples(
			self._final_state, self._final_state_name, num_replicas) 
Example #10
Source File: ptb_word_lm.py    From object_detection_with_tensorflow with MIT License 6 votes vote down vote up
def import_ops(self):
    """Imports ops from collections."""
    if self._is_training:
      self._train_op = tf.get_collection_ref("train_op")[0]
      self._lr = tf.get_collection_ref("lr")[0]
      self._new_lr = tf.get_collection_ref("new_lr")[0]
      self._lr_update = tf.get_collection_ref("lr_update")[0]
      rnn_params = tf.get_collection_ref("rnn_params")
      if self._cell and rnn_params:
        params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
            self._cell,
            self._cell.params_to_canonical,
            self._cell.canonical_to_params,
            rnn_params,
            base_variable_scope="Model/RNN")
        tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
    self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
    num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
    self._initial_state = util.import_state_tuples(
        self._initial_state, self._initial_state_name, num_replicas)
    self._final_state = util.import_state_tuples(
        self._final_state, self._final_state_name, num_replicas) 
Example #11
Source File: ptb_word_lm.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def import_ops(self):
    """Imports ops from collections."""
    if self._is_training:
      self._train_op = tf.get_collection_ref("train_op")[0]
      self._lr = tf.get_collection_ref("lr")[0]
      self._new_lr = tf.get_collection_ref("new_lr")[0]
      self._lr_update = tf.get_collection_ref("lr_update")[0]
      rnn_params = tf.get_collection_ref("rnn_params")
      if self._cell and rnn_params:
        params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
            self._cell,
            self._cell.params_to_canonical,
            self._cell.canonical_to_params,
            rnn_params,
            base_variable_scope="Model/RNN")
        tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
    self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
    num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
    self._initial_state = util.import_state_tuples(
        self._initial_state, self._initial_state_name, num_replicas)
    self._final_state = util.import_state_tuples(
        self._final_state, self._final_state_name, num_replicas) 
Example #12
Source File: block_base.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def MarkAsNonTrainable(self):
    """Mark all the variables of this block as non-trainable.

    All the variables owned directly or indirectly (through subblocks) are
    marked as non trainable.

    This function along with CheckpointInitOp can be used to load a pretrained
    model that consists in only one part of the whole graph.
    """
    assert self._called

    all_variables = self.VariableList()
    collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
    for v in all_variables:
      if v in collection:
        collection.remove(v) 
Example #13
Source File: context.py    From texar with Apache License 2.0 5 votes vote down vote up
def global_mode():
    """Returns the Tensor of global mode.

    This is a placeholder with default value of
    :tf_main:`tf.estimator.ModeKeys.TRAIN <estimator/ModeKeys>`.

    Example:

        .. code-block:: python

            mode = session.run(global_mode())
            # mode == tf.estimator.ModeKeys.TRAIN

            mode = session.run(
                global_mode(),
                feed_dict={tf.global_mode(): tf.estimator.ModeKeys.PREDICT})
            # mode == tf.estimator.ModeKeys.PREDICT
    """
    mode = tf.get_collection_ref(_GLOBAL_MODE_KEY)
    if len(mode) < 1:
        # mode_tensor = tf.placeholder(tf.string, name="global_mode")
        mode_tensor = tf.placeholder_with_default(
            input=tf.estimator.ModeKeys.TRAIN,
            shape=(),
            name="global_mode")
        # mode_tensor = tf.constant(
        #    value=tf.estimator.ModeKeys.TRAIN,
        #    dtype=tf.string,
        #    name="global_mode")
        mode.append(mode_tensor)
    return mode[0] 
Example #14
Source File: common.py    From ternarynet with Apache License 2.0 5 votes vote down vote up
def restore_collection(backup):
    for k, v in six.iteritems(backup):
        del tf.get_collection_ref(k)[:]
        tf.get_collection_ref(k).extend(v) 
Example #15
Source File: util.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def import_state_tuples(state_tuples, name, num_replicas):
  restored = []
  for i in range(len(state_tuples) * num_replicas):
    c = tf.get_collection_ref(name)[2 * i + 0]
    h = tf.get_collection_ref(name)[2 * i + 1]
    restored.append(tf.contrib.rnn.LSTMStateTuple(c, h))
  return tuple(restored) 
Example #16
Source File: models.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def __init__(self,
               state_size,
               num_timesteps,
               sigma_min=1e-5,
               dtype=tf.float32,
               random_seed=None,
               graph_collection_name="R_TILDE_VARS"):
    self.dtype = dtype
    self.sigma_min = sigma_min
    initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
                    "b": tf.zeros_initializer}
    self.graph_collection_name=graph_collection_name

    def custom_getter(getter, *args, **kwargs):
      out = getter(*args, **kwargs)
      ref = tf.get_collection_ref(self.graph_collection_name)
      if out not in ref:
        ref.append(out)
      return out

    self.fns = [
        snt.Linear(output_size=2*state_size,
                   initializers=initializers,
                   name="r_tilde_%d" % t,
                   custom_getter=custom_getter)
        for t in xrange(num_timesteps)
    ] 
Example #17
Source File: util.py    From object_detection_with_tensorflow with MIT License 5 votes vote down vote up
def import_state_tuples(state_tuples, name, num_replicas):
  restored = []
  for i in range(len(state_tuples) * num_replicas):
    c = tf.get_collection_ref(name)[2 * i + 0]
    h = tf.get_collection_ref(name)[2 * i + 1]
    restored.append(tf.contrib.rnn.LSTMStateTuple(c, h))
  return tuple(restored) 
Example #18
Source File: tf_utils.py    From rltf with MIT License 5 votes vote down vote up
def normalize(x, training, momentum=0.0):
  """Normalize a tensor along the batch dimension. Normalization is done using the statistics of the
  current batch (in training mode) or based on running mean and variance (in inference mode).
  Args:
    x: tf.Tensor, shape.ndims == 2. Input tensor
    training: tf.Tensor or bool. Whether to return the output in training mode (normalized with
      statistics of the current batch) or in inference mode (normalized with moving statistics)
    momentum: float. Momentum for the moving average.
  """
  assert x.shape.ndims == 2

  kwargs = dict(axis=-1, center=False, scale=False, trainable=True, training=training, momentum=momentum)

  ops = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS)
  i   = len(ops)

  x = tf.layers.batch_normalization(x, **kwargs)

  # Get the batch norm update ops and remove them from the global list
  update_ops = ops[i:]
  del ops[i:]

  # Update the moving mean and variance before returning the output
  with tf.control_dependencies(update_ops):
    x = tf.identity(x)
  return x 
Example #19
Source File: common.py    From ternarynet with Apache License 2.0 5 votes vote down vote up
def clear_collection(keys):
    for k in keys:
        del tf.get_collection_ref(k)[:] 
Example #20
Source File: model_lib.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def filter_trainable_variables(trainable_scopes):
  """Keep only trainable variables which are prefixed with given scopes.

  Args:
    trainable_scopes: either list of trainable scopes or string with comma
      separated list of trainable scopes.

  This function removes all variables which are not prefixed with given
  trainable_scopes from collection of trainable variables.
  Useful during network fine tuning, when you only need to train subset of
  variables.
  """
  if not trainable_scopes:
    return
  if isinstance(trainable_scopes, six.string_types):
    trainable_scopes = [scope.strip() for scope in trainable_scopes.split(',')]
  trainable_scopes = {scope for scope in trainable_scopes if scope}
  if not trainable_scopes:
    return
  trainable_collection = tf.get_collection_ref(
      tf.GraphKeys.TRAINABLE_VARIABLES)
  non_trainable_vars = [
      v for v in trainable_collection
      if not any([v.op.name.startswith(s) for s in trainable_scopes])
  ]
  for v in non_trainable_vars:
    trainable_collection.remove(v) 
Example #21
Source File: common.py    From VDAIC2017 with MIT License 5 votes vote down vote up
def restore_collection(backup):
    for k, v in six.iteritems(backup):
        del tf.get_collection_ref(k)[:]
        tf.get_collection_ref(k).extend(v) 
Example #22
Source File: model_lib.py    From adversarial-logit-pairing-analysis with Apache License 2.0 5 votes vote down vote up
def filter_trainable_variables(trainable_scopes):
  """Keep only trainable variables which are prefixed with given scopes.

  Args:
    trainable_scopes: either list of trainable scopes or string with comma
      separated list of trainable scopes.

  This function removes all variables which are not prefixed with given
  trainable_scopes from collection of trainable variables.
  Useful during network fine tuning, when you only need to train subset of
  variables.
  """
  if not trainable_scopes:
    return
  if isinstance(trainable_scopes, six.string_types):
    trainable_scopes = [scope.strip() for scope in trainable_scopes.split(',')]
  trainable_scopes = {scope for scope in trainable_scopes if scope}
  if not trainable_scopes:
    return
  trainable_collection = tf.get_collection_ref(
      tf.GraphKeys.TRAINABLE_VARIABLES)
  non_trainable_vars = [
      v for v in trainable_collection
      if not any([v.op.name.startswith(s) for s in trainable_scopes])
  ]
  for v in non_trainable_vars:
    trainable_collection.remove(v) 
Example #23
Source File: common.py    From VDAIC2017 with MIT License 5 votes vote down vote up
def clear_collection(keys):
    for k in keys:
        del tf.get_collection_ref(k)[:] 
Example #24
Source File: tf_util.py    From BayesianRecurrentNN with MIT License 5 votes vote down vote up
def import_state_tuples(state_tuples, name, num_replicas):
  restored = []
  for i in range(len(state_tuples) * num_replicas):
    c = tf.get_collection_ref(name)[2 * i + 0]
    h = tf.get_collection_ref(name)[2 * i + 1]
    restored.append(tf.contrib.rnn.LSTMStateTuple(c, h))
  return tuple(restored) 
Example #25
Source File: model.py    From aapm_thoracic_challenge with MIT License 5 votes vote down vote up
def __init__(self, sess, checkpoint_dir, log_dir, training_paths, testing_paths, roi, im_size, nclass,
                 batch_size=1, layers=3, features_root=32, conv_size=3, dropout=0.5, testing_gt_available=True,
                 loss_type='cross_entropy', class_weights=None):
        self.sess = sess
        
        self.checkpoint_dir = checkpoint_dir
        self.log_dir = log_dir
        
        self.training_paths = training_paths
        self.testing_paths = testing_paths
        self.testing_gt_available = testing_gt_available
        
        self.nclass = nclass
        self.im_size = im_size
        self.roi = roi # (roi_order, roi_name)
        
        self.batch_size = batch_size
        self.layers = layers
        self.features_root = features_root
        self.conv_size = conv_size
        self.dropout = dropout
        self.loss_type = loss_type
        
        self.class_weights = class_weights
        
        self.build_model()
        
        self.saver = tf.train.Saver(tf.trainable_variables() + tf.get_collection_ref('bn_collections')) 
Example #26
Source File: common.py    From Distributed-BA3C with Apache License 2.0 5 votes vote down vote up
def clear_collection(keys):
    for k in keys:
        del tf.get_collection_ref(k)[:] 
Example #27
Source File: common.py    From Distributed-BA3C with Apache License 2.0 5 votes vote down vote up
def restore_collection(backup):
    for k, v in six.iteritems(backup):
        del tf.get_collection_ref(k)[:]
        tf.get_collection_ref(k).extend(v) 
Example #28
Source File: utils.py    From dynamic-training-bench with Mozilla Public License 2.0 5 votes vote down vote up
def variables_to_save(add_list=None):
    """Returns a list of variables to save.
    add_list variables are always added to the list
    Args:
        add_list: a list of variables
    Returns:
        list: list of tensors to save
    """
    if add_list is None:
        add_list = []
    return tf.trainable_variables() + tf.get_collection_ref(
        REQUIRED_NON_TRAINABLES) + add_list + training_process_variables() 
Example #29
Source File: transfer_elmo_model.py    From delta with Apache License 2.0 5 votes vote down vote up
def transfer_elmo_model(vocab_file, options_file, weight_file, token_embedding_file,
                        output_elmo_model):

  dump_token_embeddings(
      vocab_file, options_file, weight_file, token_embedding_file
  )
  logging.info("finish dump_token_embeddings")
  tf.reset_default_graph()

  with tf.Session(graph=tf.Graph()) as sess:
    bilm = BidirectionalLanguageModel(
      options_file,
      weight_file,
      use_character_inputs=False,
      embedding_weight_file=token_embedding_file
    )
    input_x = tf.placeholder(tf.int32, shape=[None, None],
                             name='input_x')
    train_embeddings_op = bilm(input_x)
    input_x_elmo_op = weight_layers(
      'output', train_embeddings_op, l2_coef=0.0
    )['weighted_op']
    input_x_elmo = tf.identity(input_x_elmo_op, name="input_x_elmo")
    logging.info("input_x_elmo shape: {}".format(input_x_elmo))
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.save(sess, output_elmo_model)
    logging.info("finish saving!")

    all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
    for v in all_variables:
      logging.info("variable name: {}".format(v.name)) 
Example #30
Source File: util.py    From vae-seq with Apache License 2.0 5 votes vote down vote up
def dynamic_hparam(key, value):
    """Returns a memoized, non-constant Tensor that allows feeding."""
    collection = tf.get_collection_ref("HPARAMS_" + key)
    if len(collection) > 1:
        raise ValueError("Dynamic hparams ollection should contain one item.")
    if not collection:
        with tf.name_scope(""):
            default_value = tf.convert_to_tensor(value, name=key + "_default")
            tensor = tf.placeholder_with_default(
                default_value,
                default_value.get_shape(),
                name=key)
            collection.append(tensor)
    return collection[0]