Python absl.app.run() Examples

The following are 30 code examples of absl.app.run(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module absl.app , or try the search function .
Example #1
Source File: create_kitti_crop_dataset.py    From lingvo with Apache License 2.0 7 votes vote down vote up
def process(self, value):
    self._create_graph()
    elem_str = value.SerializeToString()

    b, bucket = self._sess.run([self._filtered_data, self._bucket],
                               feed_dict={self._elem: elem_str})
    if bucket > input_extractor.BUCKET_UPPER_BOUND:
      return
    b = py_utils.NestedMap(b)

    # Flatten the batch.
    flatten = b.FlattenItems()
    if not flatten:
      return

    num_boxes = b.bboxes_3d.shape[0]

    # For each box, get the pointcloud and write it as an example.
    for bbox_id in range(num_boxes):
      tf_example = self._ToTFExampleProto(b, bbox_id)
      yield tf_example 
Example #2
Source File: ncf_main.py    From training_results_v0.5 with Apache License 2.0 6 votes vote down vote up
def run_graph(master, graph_spec, epoch):
  """Run graph_spec.graph with master."""
  tf.logging.info("Running graph for epoch {}...".format(epoch))
  with tf.Session(master, graph_spec.graph) as sess:
    tf.logging.info("Initializing system for epoch {}...".format(epoch))
    sess.run(tpu.initialize_system(
        embedding_config=graph_spec.embedding.config_proto))

    tf.logging.info("Running before hook for epoch {}...".format(epoch))
    graph_spec.hook_before(sess, epoch)

    tf.logging.info("Running infeed for epoch {}...".format(epoch))
    infeed_thread_fn = graph_spec.get_infeed_thread_fn(sess)
    infeed_thread = threading.Thread(target=infeed_thread_fn)
    tf.logging.info("Staring infeed thread...")
    infeed_thread.start()

    tf.logging.info("Running TPU loop for epoch {}...".format(epoch))
    graph_spec.run_tpu_loop(sess, epoch)

    tf.logging.info("Joining infeed thread...")
    infeed_thread.join()

    tf.logging.info("Running after hook for epoch {}...".format(epoch))
    graph_spec.hook_after(sess, epoch) 
Example #3
Source File: run_inference.py    From ffn with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  request = inference_flags.request_from_flags()

  if not gfile.Exists(request.segmentation_output_dir):
    gfile.MakeDirs(request.segmentation_output_dir)

  bbox = bounding_box_pb2.BoundingBox()
  text_format.Parse(FLAGS.bounding_box, bbox)

  runner = inference.Runner()
  runner.start(request)
  runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
             (bbox.size.z, bbox.size.y, bbox.size.x))

  counter_path = os.path.join(request.segmentation_output_dir, 'counters.txt')
  if not gfile.Exists(counter_path):
    runner.counters.dump(counter_path) 
Example #4
Source File: train.py    From ffn with Apache License 2.0 6 votes vote down vote up
def run_training_step(sess, model, fetch_summary, feed_dict):
  """Runs one training step for a single FFN FOV."""
  ops_to_run = [model.train_op, model.global_step, model.logits]

  if fetch_summary is not None:
    ops_to_run.append(fetch_summary)

  results = sess.run(ops_to_run, feed_dict)
  step, prediction = results[1:3]

  if fetch_summary is not None:
    summ = results[-1]
  else:
    summ = None

  return prediction, step, summ 
Example #5
Source File: gng_impl.py    From loaner with Apache License 2.0 6 votes vote down vote up
def run(self):
    """Runs the Grab n Go manager."""
    try:
      while True:
        utils.clear_screen()
        utils.write('Which of the following actions would you like to take?\n')
        for opt in self._options.values():
          utils.write('Action: {!r}\nDescription: {}\n'.format(
              opt.name, opt.description))
        action = utils.prompt_enum(
            '', accepted_values=list(self._options.keys()),
            case_sensitive=False).strip().lower()
        callback = self._options[action].callback
        if callback is None:
          break
        self = callback()
    finally:
      utils.write(
          'Done managing Grab n Go for Cloud Project {!r}.'.format(
              self._config.project)) 
Example #6
Source File: inspect_dataset.py    From mixmatch with Apache License 2.0 6 votes vote down vote up
def main(argv):
    del argv
    utils.setup_tf()
    nbatch = FLAGS.samples // FLAGS.batch
    dataset = data.DATASETS[FLAGS.dataset]()
    groups = [('labeled', dataset.train_labeled),
              ('unlabeled', dataset.train_unlabeled),
              ('test', dataset.test.repeat())]
    groups = [(name, ds.batch(FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next())
              for name, ds in groups]
    with tf.train.MonitoredSession() as sess:
        for group, train_data in groups:
            stats = np.zeros(dataset.nclass, np.int32)
            minmax = [], []
            for _ in trange(nbatch, leave=False, unit='img', unit_scale=FLAGS.batch, desc=group):
                v = sess.run(train_data)['label']
                for u in v:
                    stats[u] += 1
                minmax[0].append(v.min())
                minmax[1].append(v.max())
            print(group)
            print('  Label range', min(minmax[0]), max(minmax[1]))
            print('  Stats', ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())])) 
Example #7
Source File: gen_structure_test_case.py    From moonlight with Apache License 2.0 6 votes vote down vote up
def main(argv):
  pages = argv[1:]
  assert pages, 'Pass one or more PNG files'
  omr = engine.OMREngine()
  for i, filename in enumerate(pages):
    escaped_filename = re.sub(r'([\'\\])', r'\\\0', filename)
    page = omr.run(filename).page[0]
    # TODO(ringw): Use a real templating system (e.g. jinja or mako).
    if i > 0:
      print('')
    print('  def test%s_structure(self):' % _sanitized_basename(filename))
    print('    page = engine.OMREngine().run(')
    print('        \'%s\').page[0]' % escaped_filename)
    print('    self.assertEqual(len(page.system), %d)' % len(page.system))
    for i, system in enumerate(page.system):
      print('')
      print('    self.assertEqual(len(page.system[%d].staff), %d)' %
            (i, len(system.staff)))
      print('    self.assertEqual(len(page.system[%d].bar), %d)' %
            (i, len(system.bar))) 
Example #8
Source File: tf_cnn_benchmarks.py    From benchmarks with Apache License 2.0 6 votes vote down vote up
def main(positional_arguments):
  # Command-line arguments like '--distortions False' are equivalent to
  # '--distortions=True False', where False is a positional argument. To prevent
  # this from silently running with distortions, we do not allow positional
  # arguments.
  assert len(positional_arguments) >= 1
  if len(positional_arguments) > 1:
    raise ValueError('Received unknown positional arguments: %s'
                     % positional_arguments[1:])

  params = benchmark_cnn.make_params_from_flags()
  with mlperf.mlperf_logger(absl_flags.FLAGS.ml_perf_compliance_logging,
                            params.model):
    params = benchmark_cnn.setup(params)
    bench = benchmark_cnn.BenchmarkCNN(params)

    tfversion = cnn_util.tensorflow_version_tuple()
    log_fn('TensorFlow:  %i.%i' % (tfversion[0], tfversion[1]))

    bench.print_info()
    bench.run() 
Example #9
Source File: reinitialize_argfile.py    From lottery-ticket-hypothesis with Apache License 2.0 6 votes vote down vote up
def main(argv):
  del argv  # Unused.
  line_format = ('--masks={masks} --output_dir={output_dir}')
  name = FLAGS.experiment

  for trial in range(1, 21):
    for level in range(0, 31):
      for run in range(1, 11):
        masks = paths.masks(constants.run(trial, level))
        output = constants.run(trial, level, name, run)
        result = line_format.format(masks=masks, output_dir=output)

        if FLAGS.experiment in ('reuse', 'reuse_sign'):
          result += (' --initialization_distribution=' +
                     constants.initialization(level))

        if FLAGS.experiment == 'reuse_sign':
          presets = paths.initial(constants.run(trial, level))
          result += ' --same_sign={}'.format(presets)

        print(result) 
Example #10
Source File: config.py    From trax with Apache License 2.0 6 votes vote down vote up
def config_with_absl(self):
    # Run this before calling `app.run(main)` etc
    import absl.flags as absl_FLAGS
    from absl import app, flags as absl_flags

    self.use_absl = True
    self.absl_flags = absl_flags
    absl_defs = { bool: absl_flags.DEFINE_bool,
                  int:  absl_flags.DEFINE_integer,
                  str:  absl_flags.DEFINE_string,
                  'enum': absl_flags.DEFINE_enum }

    for name, val in self.values.items():
      flag_type, meta_args, meta_kwargs = self.meta[name]
      absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)

    app.call_after_init(lambda: self.complete_absl_config(absl_flags)) 
Example #11
Source File: imagenet_to_gcs.py    From tpu_models with Apache License 2.0 6 votes vote down vote up
def __init__(self):
    # Create a single Session to run all image coding calls.
    self._sess = tf.Session()

    # Initializes function that converts PNG to JPEG data.
    self._png_data = tf.placeholder(dtype=tf.string)
    image = tf.image.decode_png(self._png_data, channels=3)
    self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)

    # Initializes function that converts CMYK JPEG data to RGB JPEG data.
    self._cmyk_data = tf.placeholder(dtype=tf.string)
    image = tf.image.decode_jpeg(self._cmyk_data, channels=0)
    self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100)

    # Initializes function that decodes RGB JPEG data.
    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 
Example #12
Source File: cyclic_bag_log_reg.py    From federated with Apache License 2.0 6 votes vote down vote up
def create_model(self):
    """Creates a TF model and returns ops necessary to run training/eval."""
    features = tf.compat.v1.placeholder(tf.float32, [None, self.input_dim])
    labels = tf.compat.v1.placeholder(tf.float32, [None, self.num_classes])

    w = tf.Variable(tf.random.normal(shape=[self.input_dim, self.num_classes]))
    b = tf.Variable(tf.random.normal(shape=[self.num_classes]))

    pred = tf.nn.softmax(tf.matmul(features, w) + b)

    loss = tf.reduce_mean(-tf.reduce_sum(labels * tf.math.log(pred), axis=1))
    train_op = self.optimizer.minimize(
        loss=loss, global_step=tf.train.get_or_create_global_step())

    correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(labels, 1))
    eval_metric_op = tf.count_nonzero(correct_pred)

    return features, labels, train_op, loss, eval_metric_op 
Example #13
Source File: cyclic_bag_log_reg.py    From federated with Apache License 2.0 6 votes vote down vote up
def log_config(logger):
  """Logs the configuration of this run, so it can be used in the analysis phase."""
  logger.log('== Configuration ==')
  logger.log('task_id=%d' % FLAGS.task_id)
  logger.log('lr=%f' % FLAGS.lr)
  logger.log('vocab_size=%s' % FLAGS.vocab_size)
  logger.log('batch_size=%s' % FLAGS.batch_size)
  logger.log('bow_limit=%s' % FLAGS.bow_limit)
  logger.log('training_data=%s' % FLAGS.training_data)
  logger.log('test_data=%s' % FLAGS.test_data)
  logger.log('num_groups=%d' % FLAGS.num_groups)
  logger.log('num_days=%d' % FLAGS.num_days)
  logger.log('num_train_examples_per_day=%d' % FLAGS.num_train_examples_per_day)
  logger.log('mode=%s' % FLAGS.mode)
  logger.log('bias=%f' % FLAGS.bias)
  logger.log('replica=%d' % FLAGS.replica) 
Example #14
Source File: imagenet_to_gcs.py    From training_results_v0.5 with Apache License 2.0 6 votes vote down vote up
def __init__(self):
    # Create a single Session to run all image coding calls.
    self._sess = tf.Session()

    # Initializes function that converts PNG to JPEG data.
    self._png_data = tf.placeholder(dtype=tf.string)
    image = tf.image.decode_png(self._png_data, channels=3)
    self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)

    # Initializes function that converts CMYK JPEG data to RGB JPEG data.
    self._cmyk_data = tf.placeholder(dtype=tf.string)
    image = tf.image.decode_jpeg(self._cmyk_data, channels=0)
    self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100)

    # Initializes function that decodes RGB JPEG data.
    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 
Example #15
Source File: inspect_dataset.py    From realmix with Apache License 2.0 6 votes vote down vote up
def main(argv):
    del argv
    utils.setup_tf()
    nbatch = FLAGS.samples // FLAGS.batch
    dataset = data.DATASETS[FLAGS.dataset]()
    groups = [('labeled', dataset.train_labeled),
              ('unlabeled', dataset.train_unlabeled),
              ('test', dataset.test.repeat())]
    groups = [(name, ds.batch(FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next())
              for name, ds in groups]
    with tf.train.MonitoredSession() as sess:
        for group, train_data in groups:
            stats = np.zeros(dataset.nclass, np.int32)
            minmax = [], []
            for _ in trange(nbatch, leave=False, unit='img', unit_scale=FLAGS.batch, desc=group):
                v = sess.run(train_data)['label']
                for u in v:
                    stats[u] += 1
                minmax[0].append(v.min())
                minmax[1].append(v.max())
            print(group)
            print('  Label range', min(minmax[0]), max(minmax[1]))
            print('  Stats', ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())])) 
Example #16
Source File: play.py    From pysc2 with Apache License 2.0 5 votes vote down vote up
def entry_point():  # Needed so setup.py scripts work.
  app.run(main) 
Example #17
Source File: nsfw_scratch.py    From nsfw with Apache License 2.0 5 votes vote down vote up
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
                           parse_record_fn, num_epochs=1, num_gpus=None,
                           examples_per_epoch=None, dtype=tf.float32):
    dataset = dataset.prefetch(buffer_size=batch_size)
    if is_training:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer)

    dataset = dataset.repeat(num_epochs)

    if is_training and num_gpus and examples_per_epoch:
        total_examples = num_epochs * examples_per_epoch
        # Force the number of batches to be divisible by the number of devices.
        # This prevents some devices from receiving batches while others do not,
        # which can lead to a lockup. This case will soon be handled directly by
        # distribution strategies, at which point this .take() operation will no
        # longer be needed.
        total_batches = total_examples // batch_size // num_gpus * num_gpus
        dataset.take(total_batches * batch_size)

    # Parse the raw records into images and labels. Testing has shown that setting
    # num_parallel_batches > 1 produces no improvement in throughput, since
    # batch_size is almost always much greater than the number of CPU cores.
    dataset = dataset.apply(
        tf.contrib.data.map_and_batch(
            lambda value: parse_record_fn(value, is_training),
            batch_size=batch_size,
            num_parallel_batches=1,
            drop_remainder=False))

    # Operations between the final prefetch and the get_next call to the iterator
    # will happen synchronously during run time. We prefetch here again to
    # background all of the above processing work and keep it out of the
    # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
    # allows DistributionStrategies to adjust how many batches to fetch based
    # on how many devices are present.
    dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)

    return dataset 
Example #18
Source File: nsfw_main_finetune.py    From nsfw with Apache License 2.0 5 votes vote down vote up
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
                           parse_record_fn, num_epochs=1, num_gpus=None,
                           examples_per_epoch=None, dtype=tf.float32):
    dataset = dataset.prefetch(buffer_size=batch_size)
    if is_training:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer)

    dataset = dataset.repeat(num_epochs)

    if is_training and num_gpus and examples_per_epoch:
        total_examples = num_epochs * examples_per_epoch
        # Force the number of batches to be divisible by the number of devices.
        # This prevents some devices from receiving batches while others do not,
        # which can lead to a lockup. This case will soon be handled directly by
        # distribution strategies, at which point this .take() operation will no
        # longer be needed.
        total_batches = total_examples // batch_size // num_gpus * num_gpus
        dataset.take(total_batches * batch_size)

    # Parse the raw records into images and labels. Testing has shown that setting
    # num_parallel_batches > 1 produces no improvement in throughput, since
    # batch_size is almost always much greater than the number of CPU cores.
    dataset = dataset.apply(
        tf.contrib.data.map_and_batch(
            lambda value: parse_record_fn(value, is_training, dtype),
            batch_size=batch_size,
            num_parallel_batches=1,
            drop_remainder=False))

    # Operations between the final prefetch and the get_next call to the iterator
    # will happen synchronously during run time. We prefetch here again to
    # background all of the above processing work and keep it out of the
    # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
    # allows DistributionStrategies to adjust how many batches to fetch based
    # on how many devices are present.
    dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)

    return dataset 
Example #19
Source File: export_checkpoints.py    From albert with Apache License 2.0 5 votes vote down vote up
def build_model(sess):
  """Module function."""
  input_ids = tf.placeholder(tf.int32, [None, None], "input_ids")
  input_mask = tf.placeholder(tf.int32, [None, None], "input_mask")
  segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids")
  mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions")

  albert_config_path = os.path.join(
      FLAGS.albert_directory, "albert_config.json")
  albert_config = modeling.AlbertConfig.from_json_file(albert_config_path)
  model = modeling.AlbertModel(
      config=albert_config,
      is_training=False,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=segment_ids,
      use_one_hot_embeddings=False)

  get_mlm_logits(model.get_sequence_output(), albert_config,
                 mlm_positions, model.get_embedding_table())
  get_sentence_order_logits(model.get_pooled_output(), albert_config)

  checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name)
  tvars = tf.trainable_variables()
  (assignment_map, initialized_variable_names
  ) = modeling.get_assignment_map_from_checkpoint(tvars, checkpoint_path)

  tf.logging.info("**** Trainable Variables ****")
  for var in tvars:
    init_string = ""
    if var.name in initialized_variable_names:
      init_string = ", *INIT_FROM_CKPT*"
    tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                    init_string)
  tf.train.init_from_checkpoint(checkpoint_path, assignment_map)
  init = tf.global_variables_initializer()
  sess.run(init)
  return sess 
Example #20
Source File: beam_reshuffle.py    From exoplanet-ml with Apache License 2.0 5 votes vote down vote up
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  def pipeline(root):
    """Beam pipeline for preprocessing open images."""
    assert FLAGS.input_file_patterns
    assert FLAGS.output_dir
    assert FLAGS.output_name
    assert FLAGS.num_shards

    # Create Pipeline.
    tfrecords = []
    for i, file_pattern in enumerate(FLAGS.input_file_patterns.split(",")):
      logging.info("Reading TFRecords from %s", file_pattern)
      stage_name = "read_tfrecords_{}".format(i)
      tfrecords.append(root | stage_name >> beam.io.tfrecordio.ReadFromTFRecord(
          file_pattern, coder=beam.coders.ProtoCoder(tf.train.Example)))

    # pylint: disable=expression-not-assigned
    (tfrecords
     | "flatten" >> beam.Flatten()
     | "count_labels" >> beam.ParDo(CountLabelsDoFn())
     | "reshuffle" >> beam.Reshuffle()
     | "write_tfrecord" >> beam.io.tfrecordio.WriteToTFRecord(
         os.path.join(FLAGS.output_dir, FLAGS.output_name),
         coder=beam.coders.ProtoCoder(tf.train.Example),
         num_shards=FLAGS.num_shards))
    # pylint: enable=expression-not-assigned

  pipeline.run()
  logging.info("Processing complete.") 
Example #21
Source File: imagenet_to_gcs.py    From tpu_models with Apache License 2.0 5 votes vote down vote up
def decode_jpeg(self, image_data):
    image = self._sess.run(self._decode_jpeg,
                           feed_dict={self._decode_jpeg_data: image_data})
    assert len(image.shape) == 3
    assert image.shape[2] == 3
    return image 
Example #22
Source File: imagenet_to_gcs.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def decode_jpeg(self, image_data):
    image = self._sess.run(self._decode_jpeg,
                           feed_dict={self._decode_jpeg_data: image_data})
    assert len(image.shape) == 3
    assert image.shape[2] == 3
    return image 
Example #23
Source File: imagenet_to_gcs.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def cmyk_to_rgb(self, image_data):
    return self._sess.run(self._cmyk_to_rgb,
                          feed_dict={self._cmyk_data: image_data}) 
Example #24
Source File: imagenet_to_gcs.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def png_to_jpeg(self, image_data):
    return self._sess.run(self._png_to_jpeg,
                          feed_dict={self._png_data: image_data}) 
Example #25
Source File: tf_hub.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def _check_shapes_of_restored_variables(session, variables_to_restore):
  """Raises TypeError if restored variables have unexpected shapes."""
  num_errors = 0
  for variable_name, variable in variables_to_restore.items():
    graph_shape = variable.value().shape
    # Values are big, but tf.shape(..) whould echo graph_shape if fully defined.
    checkpoint_shape = session.run(variable.value()).shape
    if not graph_shape.is_compatible_with(checkpoint_shape):
      tf.logging.error('Shape mismatch for variable %s: '
                       'graph expects %s but checkpoint has %s' %
                       (variable_name, graph_shape, checkpoint_shape))
      num_errors += 1
  if num_errors:
    raise TypeError(
        'Shape mismatch for %d variables, see error log for list.' % num_errors) 
Example #26
Source File: onsets_frames_transcription_realtime.py    From magenta with Apache License 2.0 5 votes vote down vote up
def console_entry_point():
  app.run(main) 
Example #27
Source File: deploy_impl.py    From loaner with Apache License 2.0 5 votes vote down vote up
def DeployWebApp(self):
    """Bundle then deploy (or run locally) the web application."""
    self._BundleWebApp()

    if self.on_local:
      print('Run locally...')
    else:
      cmds = [
          'gcloud', 'app', 'deploy', '--no-promote', '--project={}'.format(
              self.project_id), '--version={}'.format(self.version)]
      for yaml_filename in self._yaml_files:
        cmds.append(self._GetYamlFile(yaml_filename))
      logging.info(
          'Deploying to the Google Cloud project: %s using gcloud...',
          self.project_id)
      _ExecuteCommand(cmds)

    if self.on_google_cloud_shell:
      self._CleanWebAppBackend() 
Example #28
Source File: run.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def main(_):
  run_lib.run() 
Example #29
Source File: ncf_main.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def build_hooks(mode, embedding, params, train_record_dir):
  """Build `hook_before` and `hook_after` for `graph_spec`."""
  saver = tf.train.Saver()
  if mode == tpu_embedding.TRAINING:
    def hook_before(sess, epoch):
      if epoch == 0:
        sess.run(tf.global_variables_initializer())
      else:
        saver.restore(sess,
                      "{}/model.ckpt.{}".format(
                          params["model_dir"], epoch-1))
      sess.run(embedding.init_ops)

    def hook_after(sess, epoch):
      sess.run(embedding.retrieve_parameters_ops)
      ckpt_path = saver.save(sess,
                             "{}/model.ckpt.{}".format(
                                 params["model_dir"], epoch))
      tf.logging.info("Model saved in path: {}."
                      .format(ckpt_path))
      # must delete; otherwise the first epoch's data will always be used.
      tf.gfile.DeleteRecursively(train_record_dir)
  else:
    def hook_before(sess, epoch):
      saver.restore(sess,
                    "{}/model.ckpt.{}".format(
                        params["model_dir"], epoch))
      sess.run(embedding.init_ops)

    def hook_after(sess, epoch):
      del sess, epoch

  return hook_before, hook_after 
Example #30
Source File: run.py    From tpu_models with Apache License 2.0 5 votes vote down vote up
def main(_):
  run_lib.run()