Python tensorflow.tables_initializer() Examples

The following are 30 code examples of tensorflow.tables_initializer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: model_test.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def test_create_summaries_is_runnable(self):
    ocr_model = self.create_model()
    data = data_provider.InputEndpoints(
        images=self.fake_images,
        images_orig=self.fake_images,
        labels=self.fake_labels,
        labels_one_hot=slim.one_hot_encoding(self.fake_labels,
                                             self.num_char_classes))
    endpoints = ocr_model.create_base(
        images=self.fake_images, labels_one_hot=None)
    charset = create_fake_charset(self.num_char_classes)
    summaries = ocr_model.create_summaries(
        data, endpoints, charset, is_training=False)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      tf.tables_initializer().run()
      sess.run(summaries)  # just check it is runnable 
Example #2
Source File: estimator_test.py    From training_results_v0.5 with Apache License 2.0 6 votes vote down vote up
def testTrainInputFn(self):
    nmt_parser = argparse.ArgumentParser()
    nmt.add_arguments(nmt_parser)
    flags, _ = nmt_parser.parse_known_args()
    update_flags(flags, "input_fn_test")
    default_hparams = nmt.create_hparams(flags)
    hparams = nmt.extend_hparams(default_hparams)

    with self.test_session() as sess:
      input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN)
      outputs = input_fn({})
      sess.run(tf.tables_initializer())
      iterator = outputs.make_initializable_iterator()
      sess.run(iterator.initializer)
      features = sess.run(iterator.get_next())
      tf.logging.info("source: %s", features["source"])
      tf.logging.info("target_input: %s", features["target_input"])
      tf.logging.info("target_output: %s", features["target_output"])
      tf.logging.info("source_sequence_length: %s",
                      features["source_sequence_length"])
      tf.logging.info("target_sequence_length: %s",
                      features["target_sequence_length"]) 
Example #3
Source File: utils.py    From TransE-Knowledge-Graph-Embedding with MIT License 6 votes vote down vote up
def load_model(sess, ckpt):
    with sess.as_default():
        with sess.graph.as_default():
            init_ops = [tf.global_variables_initializer(),
                        tf.local_variables_initializer(), tf.tables_initializer()]
            sess.run(init_ops)
            # load saved model
            ckpt_path = tf.train.latest_checkpoint(ckpt)
            if ckpt_path:
                print("Loading saved model: " + ckpt_path)
            else:
                raise ValueError("No checkpoint found in {}".format(ckpt))
            # reader = tf.train.NewCheckpointReader(ckpt+'model.ckpt_0.876-580500')
            # variables = reader.get_variable_to_shape_map()
            # for v in variables:
            #     print(v)
            saver = tf.train.Saver()
            saver.restore(sess, ckpt_path) 
Example #4
Source File: test_case.py    From Person-Detection-and-Tracking with MIT License 6 votes vote down vote up
def execute_cpu(self, graph_fn, inputs):
    """Constructs the graph, executes it on CPU and returns the result.

    Args:
      graph_fn: a callable that constructs the tensorflow graph to test. The
        arguments of this function should correspond to `inputs`.
      inputs: a list of numpy arrays to feed input to the computation graph.

    Returns:
      A list of numpy arrays or a scalar returned from executing the tensorflow
      graph.
    """
    with self.test_session(graph=tf.Graph()) as sess:
      placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
      results = graph_fn(*placeholders)
      sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
                tf.local_variables_initializer()])
      materialized_results = sess.run(results, feed_dict=dict(zip(placeholders,
                                                                  inputs)))
      if (len(materialized_results) == 1
          and (isinstance(materialized_results, list)
               or isinstance(materialized_results, tuple))):
        materialized_results = materialized_results[0]
    return materialized_results 
Example #5
Source File: estimator_test.py    From training_results_v0.5 with Apache License 2.0 6 votes vote down vote up
def testTrainInputFn(self):
    nmt_parser = argparse.ArgumentParser()
    nmt.add_arguments(nmt_parser)
    flags, _ = nmt_parser.parse_known_args()
    update_flags(flags, "input_fn_test")
    default_hparams = nmt.create_hparams(flags)
    hparams = nmt.extend_hparams(default_hparams)

    with self.test_session() as sess:
      input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN)
      outputs = input_fn({})
      sess.run(tf.tables_initializer())
      iterator = outputs.make_initializable_iterator()
      sess.run(iterator.initializer)
      features = sess.run(iterator.get_next())
      tf.logging.info("source: %s", features["source"])
      tf.logging.info("target_input: %s", features["target_input"])
      tf.logging.info("target_output: %s", features["target_output"])
      tf.logging.info("source_sequence_length: %s",
                      features["source_sequence_length"])
      tf.logging.info("target_sequence_length: %s",
                      features["target_sequence_length"]) 
Example #6
Source File: model_test.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def test_create_summaries_is_runnable(self):
    ocr_model = self.create_model()
    data = data_provider.InputEndpoints(
        images=self.fake_images,
        images_orig=self.fake_images,
        labels=self.fake_labels,
        labels_one_hot=slim.one_hot_encoding(self.fake_labels,
                                             self.num_char_classes))
    endpoints = ocr_model.create_base(
        images=self.fake_images, labels_one_hot=None)
    charset = create_fake_charset(self.num_char_classes)
    summaries = ocr_model.create_summaries(
        data, endpoints, charset, is_training=False)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      tf.tables_initializer().run()
      sess.run(summaries)  # just check it is runnable 
Example #7
Source File: utils.py    From realmix with Apache License 2.0 6 votes vote down vote up
def make_set_filter_fn(elements):
    """Constructs a TensorFlow "set" data structure.

    Note that sets returned by this function are uninitialized. Initialize them
    by calling `sess.run(tf.tables_initializer())`

    Args:
        elements: A list of non-Tensor elements.

    Returns:
        A function that when called with a single tensor argument, returns
        a boolean tensor if the argument is in the set.
    """
    table = tf.contrib.lookup.HashTable(
        tf.contrib.lookup.KeyValueTensorInitializer(
            elements, tf.tile([1], [len(elements)])
        ),
        default_value=0,
    )

    return lambda x: tf.equal(table.lookup(tf.dtypes.cast(x, tf.int32)), 1) 
Example #8
Source File: test_case.py    From ros_people_object_detection_tensorflow with Apache License 2.0 6 votes vote down vote up
def execute_cpu(self, graph_fn, inputs):
    """Constructs the graph, executes it on CPU and returns the result.

    Args:
      graph_fn: a callable that constructs the tensorflow graph to test. The
        arguments of this function should correspond to `inputs`.
      inputs: a list of numpy arrays to feed input to the computation graph.

    Returns:
      A list of numpy arrays or a scalar returned from executing the tensorflow
      graph.
    """
    with self.test_session(graph=tf.Graph()) as sess:
      placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
      results = graph_fn(*placeholders)
      sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
                tf.local_variables_initializer()])
      materialized_results = sess.run(results, feed_dict=dict(zip(placeholders,
                                                                  inputs)))
      if (len(materialized_results) == 1
          and (isinstance(materialized_results, list)
               or isinstance(materialized_results, tuple))):
        materialized_results = materialized_results[0]
    return materialized_results 
Example #9
Source File: test_case.py    From Traffic-Rule-Violation-Detection-System with MIT License 6 votes vote down vote up
def execute_cpu(self, graph_fn, inputs):
    """Constructs the graph, executes it on CPU and returns the result.

    Args:
      graph_fn: a callable that constructs the tensorflow graph to test. The
        arguments of this function should correspond to `inputs`.
      inputs: a list of numpy arrays to feed input to the computation graph.

    Returns:
      A list of numpy arrays or a scalar returned from executing the tensorflow
      graph.
    """
    with self.test_session(graph=tf.Graph()) as sess:
      placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
      results = graph_fn(*placeholders)
      sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
                tf.local_variables_initializer()])
      materialized_results = sess.run(results, feed_dict=dict(zip(placeholders,
                                                                  inputs)))
      if len(materialized_results) == 1:
        materialized_results = materialized_results[0]
    return materialized_results 
Example #10
Source File: test_case.py    From Traffic-Rule-Violation-Detection-System with MIT License 6 votes vote down vote up
def execute_tpu(self, graph_fn, inputs):
    """Constructs the graph, executes it on TPU and returns the result.

    Args:
      graph_fn: a callable that constructs the tensorflow graph to test. The
        arguments of this function should correspond to `inputs`.
      inputs: a list of numpy arrays to feed input to the computation graph.

    Returns:
      A list of numpy arrays or a scalar returned from executing the tensorflow
      graph.
    """
    with self.test_session(graph=tf.Graph()) as sess:
      placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
      tpu_computation = tpu.rewrite(graph_fn, placeholders)
      sess.run(tpu.initialize_system())
      sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
                tf.local_variables_initializer()])
      materialized_results = sess.run(tpu_computation,
                                      feed_dict=dict(zip(placeholders, inputs)))
      sess.run(tpu.shutdown_system())
      if len(materialized_results) == 1:
        materialized_results = materialized_results[0]
    return materialized_results 
Example #11
Source File: model_helper.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def create_or_load_model(model, model_dir, session, name):
  """Create translation model and initialize or load parameters in session."""
  latest_ckpt = tf.train.latest_checkpoint(model_dir)
  if latest_ckpt:
    model = load_model(model, latest_ckpt, session, name)
  else:
    start_time = time.time()
    session.run(tf.global_variables_initializer())
    session.run(tf.tables_initializer())
    utils.print_out("  created %s model with fresh parameters, time %.2fs" %
                    (name, time.time() - start_time))

  global_step = model.global_step.eval(session=session)
  return model, global_step 
Example #12
Source File: estimator.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def _get_tgt_sos_eos_id(hparams):
  with tf.Session() as sess:
    _, tgt_vocab_table = vocab_utils.create_vocab_tables(
        hparams.src_vocab_file, hparams.tgt_vocab_file, hparams.share_vocab)
    tgt_sos_id = tf.cast(
        tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32)
    tgt_eos_id = tf.cast(
        tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32)
    sess.run(tf.tables_initializer())
    tgt_sos_id = sess.run(tgt_sos_id, {})
    tgt_eos_id = sess.run(tgt_eos_id, {})
    return tgt_sos_id, tgt_eos_id 
Example #13
Source File: iterator_utils_test.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def testGetInferIterator(self):
    src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c c a", "c a", "d", "f e a g"]))
    hparams = tf.contrib.training.HParams(
        random_seed=3,
        eos="eos",
        sos="sos")
    batch_size = 2
    src_max_len = 3
    dataset = iterator_utils.get_infer_iterator(
        src_dataset=src_dataset,
        src_vocab_table=src_vocab_table,
        batch_size=batch_size,
        eos=hparams.eos,
        src_max_len=src_max_len)
    table_initializer = tf.tables_initializer()
    iterator = dataset.make_initializable_iterator()
    get_next = iterator.get_next()
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer)
      features = sess.run(get_next)

      self.assertAllEqual(
          [
              [2, 2, 0],  # c c a
              [2, 0, 3]
          ],  # c a eos
          features["source"])
      self.assertAllEqual([3, 2], features["source_sequence_length"]) 
Example #14
Source File: estimator.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def _get_tgt_sos_eos_id(hparams):
  with tf.Session() as sess:
    _, tgt_vocab_table = vocab_utils.create_vocab_tables(
        hparams.src_vocab_file, hparams.tgt_vocab_file, hparams.share_vocab)
    tgt_sos_id = tf.cast(
        tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32)
    tgt_eos_id = tf.cast(
        tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32)
    sess.run(tf.tables_initializer())
    tgt_sos_id = sess.run(tgt_sos_id, {})
    tgt_eos_id = sess.run(tgt_eos_id, {})
    return tgt_sos_id, tgt_eos_id 
Example #15
Source File: estimator.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def _convert_ids_to_strings(tgt_vocab_file, ids):
  """Convert prediction ids to words."""
  with tf.Session() as sess:
    reverse_target_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)
    sess.run(tf.tables_initializer())
    translations = sess.run(
        reverse_target_vocab_table.lookup(
            tf.to_int64(tf.convert_to_tensor(np.asarray(ids)))))
  return translations 
Example #16
Source File: process.py    From DeepRNN with MIT License 5 votes vote down vote up
def load_model(model, ckpt_path, session, name):
    """Load model from a checkpoint."""
    try:
        model.saver.restore(session, ckpt_path)
    except tf.errors.NotFoundError as e:
        utils.print_out("Can't load checkpoint")
        utils.print_out("%s" % str(e))
    # session.run(tf.tables_initializer())  ## why table still need to be initialized even model loaded??
    utils.print_out("  loaded %s model parameters from %s" % (name, ckpt_path))
    return model 
Example #17
Source File: iterator_utils_test.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def testGetInferIterator(self):
    src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c c a", "c a", "d", "f e a g"]))
    hparams = tf.contrib.training.HParams(
        random_seed=3,
        eos="eos",
        sos="sos")
    batch_size = 2
    src_max_len = 3
    dataset = iterator_utils.get_infer_iterator(
        src_dataset=src_dataset,
        src_vocab_table=src_vocab_table,
        batch_size=batch_size,
        eos=hparams.eos,
        src_max_len=src_max_len)
    table_initializer = tf.tables_initializer()
    iterator = dataset.make_initializable_iterator()
    get_next = iterator.get_next()
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer)
      features = sess.run(get_next)

      self.assertAllEqual(
          [
              [2, 2, 0],  # c c a
              [2, 0, 3]
          ],  # c a eos
          features["source"])
      self.assertAllEqual([3, 2], features["source_sequence_length"]) 
Example #18
Source File: iterator_utils_test.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def testGetInferIterator(self):
    src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c c a", "c a", "d", "f e a g"]))
    hparams = tf.contrib.training.HParams(
        random_seed=3,
        eos="eos",
        sos="sos")
    batch_size = 2
    dataset = iterator_utils.get_infer_iterator(
        src_dataset=src_dataset,
        src_vocab_table=src_vocab_table,
        batch_size=batch_size,
        eos=hparams.eos)
    table_initializer = tf.tables_initializer()
    iterator = dataset.make_initializable_iterator()
    get_next = iterator.get_next()
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer)
      features = sess.run(get_next)

      self.assertAllEqual(
          [
              [2, 2, 0],  # c c a
              [2, 0, 3]
          ],  # c a eos
          features["source"])
      self.assertAllEqual([3, 2], features["source_sequence_length"]) 
Example #19
Source File: vocab.py    From DualRL with MIT License 5 votes vote down vote up
def test_vocab():
    import tensorflow as tf
    import numpy as np
    import os
    from common_options import load_common_arguments

    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    # Load global vocab
    args = load_common_arguments()
    global_vocab, global_vocab_size = load_vocab(args.global_vocab_file)

    vocab, vocab_size = load_vocab_dict(args.global_vocab_file)

    assert global_vocab_size == vocab_size

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())
        i = 0
        ks = vocab.keys()
        vs = vocab.values()

        v1 = sess.run(global_vocab.lookup(tf.convert_to_tensor(ks)))
        for i in range(len(vs)):
            assert vs[i] == v1[i] 
Example #20
Source File: tf_example_decoder_test.py    From Traffic-Rule-Violation-Detection-System with MIT License 5 votes vote down vote up
def testDecodeObjectLabelNoText(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    bbox_classes = [1, 2]
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': self._BytesFeature(encoded_jpeg),
        'image/format': self._BytesFeature('jpeg'),
        'image/object/class/label': self._Int64Feature(bbox_classes),
    })).SerializeToString()
    label_map_string = """
      item {
        id:1
        name:'cat'
      }
      item {
        id:2
        name:'dog'
      }
    """
    label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
    with tf.gfile.Open(label_map_path, 'wb') as f:
      f.write(label_map_string)

    example_decoder = tf_example_decoder.TfExampleDecoder(
        label_map_proto_file=label_map_path)
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((tensor_dict[
        fields.InputDataFields.groundtruth_classes].get_shape().as_list()),
                        [None])

    init = tf.tables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      tensor_dict = sess.run(tensor_dict)

    self.assertAllEqual(bbox_classes,
                        tensor_dict[fields.InputDataFields.groundtruth_classes]) 
Example #21
Source File: model_test.py    From yolo_v2 with Apache License 2.0 5 votes vote down vote up
def test_predicted_text_has_correct_shape_w_charset(self):
    charset = create_fake_charset(self.num_char_classes)
    ocr_model = self.create_model(charset=charset)

    with self.test_session() as sess:
      endpoints_tf = ocr_model.create_base(
          images=self.fake_images, labels_one_hot=None)

      sess.run(tf.global_variables_initializer())
      tf.tables_initializer().run()
      endpoints = sess.run(endpoints_tf)

      self.assertEqual(endpoints.predicted_text.shape, (self.batch_size,))
      self.assertEqual(len(endpoints.predicted_text[0]), self.seq_length) 
Example #22
Source File: process.py    From DeepRNN with MIT License 5 votes vote down vote up
def create_or_load_model(model, model_dir, session, name):
    """Create translation model and initialize or load parameters in session."""
    latest_ckpt = tf.train.latest_checkpoint(model_dir)
    if latest_ckpt:
        model._replace(model=load_model(model.model, latest_ckpt, session, name))
        utils.print_out("checkpoint found, load checkpoint\n %s" % latest_ckpt)
    else:
        utils.print_out("  checkpoint not found in %s" % (model_dir))
        utils.print_out("  created %s model with fresh parameters" % (name))
        session.run(tf.global_variables_initializer())
        # session.run(tf.tables_initializer())
    global_step = model.model.global_step.eval(session=session)
    epoch_num = model.model.epoch_num.eval(session=session)
    return model, global_step, epoch_num 
Example #23
Source File: train.py    From TransE-Knowledge-Graph-Embedding with MIT License 5 votes vote down vote up
def train():
    # Training
    with tf.Session() as sess:
        init_ops = [tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()]
        sess.run(init_ops)
        writer = tf.summary.FileWriter("summary", sess.graph)  # graph

        for epoch in range(FLAGS.max_epoch):
            sess.run(iterator.initializer)
            model.train(sess)
            if not os.path.exists(FLAGS.model_dir):
                os.mkdir(FLAGS.model_dir)
            save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
            model.save(sess, save_path)

            print('-----Start training-----')
            epoch_loss = 0.0
            step = 0
            while True:
                try:
                    batch_loss, _, summary = model.train(sess)
                    epoch_loss += batch_loss
                    step += 1
                    writer.add_summary(summary)
                except tf.errors.OutOfRangeError:
                    print('-----Finish training an epoch avg epoch loss={}-----'.format(epoch_loss / step))
                    break
                # show train batch metrics
                if step % FLAGS.stats_per_steps == 0:
                    time_str = datetime.datetime.now().isoformat()
                    print('{}\tepoch {:2d}\tstep {:3d}\ttrain loss={:.6f}'.format(
                        time_str, epoch + 1, step, batch_loss))

            if (epoch+1) % FLAGS.save_per_epochs == 0:
                if not os.path.exists(FLAGS.model_dir):
                    os.mkdir(FLAGS.model_dir)
                save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
                model.save(sess, save_path)
                print("Epoch {}, saved checkpoint to {}".format(epoch+1, save_path)) 
Example #24
Source File: tfr2wav.py    From vqvae-speech with MIT License 5 votes vote down vote up
def main(_):
  tf.gfile.MkDir(args.output_dir)

  data = ByteWavWholeReader(
    speaker_list=txt2list(args.speaker_list),
    filenames=tf.gfile.Glob(args.file_pattern),
    num_epoch=1)

  XNOM = data.f[0]
  XWAV = tf.expand_dims(mu_law_decode(data.x[0, :]), -1)
  XBIN = tf.contrib.ffmpeg.encode_audio(XWAV, 'wav', 16000)

  sess_config = tf.ConfigProto(
    allow_soft_placement=True,
    gpu_options=tf.GPUOptions(allow_growth=True))
  with tf.Session(config=sess_config) as sess:
    sess.run(tf.tables_initializer())
    sess.run(data.iterator.initializer)
    csv = open('vctk.csv', 'w')
    counter = 1
    while True:
      try:
        fetch = {'xbin': XBIN, 'xwav': XWAV, 'wav_name': XNOM}  
        result = sess.run(fetch)
        wav_name = result['wav_name'].decode('utf8')
        print('\rFile {:05d}: Processing {}'.format(counter, wav_name), end='')
        csv.write('{}, {:d}\n'.format(wav_name, len(result['xwav'])))
        filename = os.path.join(args.output_dir, wav_name) + '.wav'
        with open(filename, 'wb') as fp:
          fp.write(result['xbin'])
        counter += 1
      except tf.errors.OutOfRangeError:
        print('\nEpoch complete')
        break
    print()
    csv.close() 
Example #25
Source File: model_helper.py    From nslt with Apache License 2.0 5 votes vote down vote up
def create_or_load_model(model, model_dir, session, name):
    """Create translation model and initialize or load parameters in session."""
    latest_ckpt = tf.train.latest_checkpoint(model_dir)
    if latest_ckpt:
        model = load_model(model, latest_ckpt, session, name)
    else:
        start_time = time.time()
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        utils.print_out("  created %s model with fresh parameters, time %.2fs" % (name, time.time() - start_time))

    global_step = model.global_step.eval(session=session)
    return model, global_step 
Example #26
Source File: model_helper.py    From nslt with Apache License 2.0 5 votes vote down vote up
def load_model(model, ckpt, session, name):
    start_time = time.time()
    model.saver.restore(session, ckpt)
    session.run(tf.tables_initializer())
    utils.print_out("  loaded %s model parameters from %s, time %.2fs" % (name, ckpt, time.time() - start_time))
    return model 
Example #27
Source File: nmt.py    From DualRL with MIT License 5 votes vote down vote up
def create_model(sess, args, src_vocab_size, tgt_vocab_size, src_vocab_rev, tgt_vocab_rev, mode=constants.TRAIN,
                 reuse=None, load_pretrained_model=False, direction="", model_save_dir=None):
    sess.run(tf.tables_initializer())

    with tf.variable_scope(constants.NMT_VAR_SCOPE + direction, reuse=reuse):
        with tf.variable_scope("src"):
            src_emb = tf.get_variable("embedding", shape=[src_vocab_size, args.emb_dim])
        with tf.variable_scope("dst"):
            tgt_emb = tf.get_variable("embedding", shape=[tgt_vocab_size, args.emb_dim])

        model = NMT(mode, args.__dict__, src_vocab_size, tgt_vocab_size, src_emb, tgt_emb,
                    src_vocab_rev, tgt_vocab_rev, direction)

    if load_pretrained_model:
        if model_save_dir is None:
            model_save_dir = args.nmt_model_save_dir
            if direction not in model_save_dir:
                if direction[::-1] in model_save_dir:
                    model_save_dir = re.sub(direction[::-1], direction, model_save_dir)
                else:
                    model_save_dir = os.path.join(model_save_dir, direction)
        print(model_save_dir)
        try:
            print("Loading nmt model from", model_save_dir)
            model.saver.restore(sess, model_save_dir)
        except Exception as e:
            print("Error! Loading nmt model from", model_save_dir)
            print("Again! Loading nmt model from", tf.train.latest_checkpoint(model_save_dir))
            model.saver.restore(sess, tf.train.latest_checkpoint(model_save_dir))
    else:
        if reuse is None:
            print("Creating model with new parameters.")
            sess.run(tf.global_variables_initializer())
        else:
            print("Reuse parameters.")
    return model 
Example #28
Source File: dataset_util_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def test_make_initializable_iterator_with_hashTable(self):
    keys = [1, 0, -1]
    dataset = tf.data.Dataset.from_tensor_slices([[1, 2, -1, 5]])
    table = tf.contrib.lookup.HashTable(
        initializer=tf.contrib.lookup.KeyValueTensorInitializer(
            keys=keys,
            values=list(reversed(keys))),
        default_value=100)
    dataset = dataset.map(table.lookup)
    data = dataset_util.make_initializable_iterator(dataset).get_next()
    init = tf.tables_initializer()

    with self.test_session() as sess:
      sess.run(init)
      self.assertAllEqual(sess.run(data), [-1, 100, 1, 100]) 
Example #29
Source File: tf_example_decoder_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def testDecodeObjectLabelWithMapping(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    bbox_classes_text = ['cat', 'dog']
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/encoded':
                    self._BytesFeature(encoded_jpeg),
                'image/format':
                    self._BytesFeature('jpeg'),
                'image/object/class/text':
                    self._BytesFeature(bbox_classes_text),
            })).SerializeToString()

    label_map_string = """
      item {
        id:3
        name:'cat'
      }
      item {
        id:1
        name:'dog'
      }
    """
    label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
    with tf.gfile.Open(label_map_path, 'wb') as f:
      f.write(label_map_string)
    example_decoder = tf_example_decoder.TfExampleDecoder(
        label_map_proto_file=label_map_path)
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes]
                         .get_shape().as_list()), [None])

    with self.test_session() as sess:
      sess.run(tf.tables_initializer())
      tensor_dict = sess.run(tensor_dict)

    self.assertAllEqual([3, 1],
                        tensor_dict[fields.InputDataFields.groundtruth_classes]) 
Example #30
Source File: tf_example_decoder_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def testDecodeObjectLabelUnrecognizedName(self):
    image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
    encoded_jpeg = self._EncodeImage(image_tensor)
    bbox_classes_text = ['cat', 'cheetah']
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/encoded':
                    self._BytesFeature(encoded_jpeg),
                'image/format':
                    self._BytesFeature('jpeg'),
                'image/object/class/text':
                    self._BytesFeature(bbox_classes_text),
            })).SerializeToString()

    label_map_string = """
      item {
        id:2
        name:'cat'
      }
      item {
        id:1
        name:'dog'
      }
    """
    label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
    with tf.gfile.Open(label_map_path, 'wb') as f:
      f.write(label_map_string)
    example_decoder = tf_example_decoder.TfExampleDecoder(
        label_map_proto_file=label_map_path)
    tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

    self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes]
                         .get_shape().as_list()), [None])

    with self.test_session() as sess:
      sess.run(tf.tables_initializer())
      tensor_dict = sess.run(tensor_dict)

    self.assertAllEqual([2, -1],
                        tensor_dict[fields.InputDataFields.groundtruth_classes])