Python absl.app.run() Examples
The following are 30
code examples of absl.app.run().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
absl.app
, or try the search function
.
Example #1
Source File: create_kitti_crop_dataset.py From lingvo with Apache License 2.0 | 7 votes |
def process(self, value): self._create_graph() elem_str = value.SerializeToString() b, bucket = self._sess.run([self._filtered_data, self._bucket], feed_dict={self._elem: elem_str}) if bucket > input_extractor.BUCKET_UPPER_BOUND: return b = py_utils.NestedMap(b) # Flatten the batch. flatten = b.FlattenItems() if not flatten: return num_boxes = b.bboxes_3d.shape[0] # For each box, get the pointcloud and write it as an example. for bbox_id in range(num_boxes): tf_example = self._ToTFExampleProto(b, bbox_id) yield tf_example
Example #2
Source File: ncf_main.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def run_graph(master, graph_spec, epoch): """Run graph_spec.graph with master.""" tf.logging.info("Running graph for epoch {}...".format(epoch)) with tf.Session(master, graph_spec.graph) as sess: tf.logging.info("Initializing system for epoch {}...".format(epoch)) sess.run(tpu.initialize_system( embedding_config=graph_spec.embedding.config_proto)) tf.logging.info("Running before hook for epoch {}...".format(epoch)) graph_spec.hook_before(sess, epoch) tf.logging.info("Running infeed for epoch {}...".format(epoch)) infeed_thread_fn = graph_spec.get_infeed_thread_fn(sess) infeed_thread = threading.Thread(target=infeed_thread_fn) tf.logging.info("Staring infeed thread...") infeed_thread.start() tf.logging.info("Running TPU loop for epoch {}...".format(epoch)) graph_spec.run_tpu_loop(sess, epoch) tf.logging.info("Joining infeed thread...") infeed_thread.join() tf.logging.info("Running after hook for epoch {}...".format(epoch)) graph_spec.hook_after(sess, epoch)
Example #3
Source File: run_inference.py From ffn with Apache License 2.0 | 6 votes |
def main(unused_argv): request = inference_flags.request_from_flags() if not gfile.Exists(request.segmentation_output_dir): gfile.MakeDirs(request.segmentation_output_dir) bbox = bounding_box_pb2.BoundingBox() text_format.Parse(FLAGS.bounding_box, bbox) runner = inference.Runner() runner.start(request) runner.run((bbox.start.z, bbox.start.y, bbox.start.x), (bbox.size.z, bbox.size.y, bbox.size.x)) counter_path = os.path.join(request.segmentation_output_dir, 'counters.txt') if not gfile.Exists(counter_path): runner.counters.dump(counter_path)
Example #4
Source File: train.py From ffn with Apache License 2.0 | 6 votes |
def run_training_step(sess, model, fetch_summary, feed_dict): """Runs one training step for a single FFN FOV.""" ops_to_run = [model.train_op, model.global_step, model.logits] if fetch_summary is not None: ops_to_run.append(fetch_summary) results = sess.run(ops_to_run, feed_dict) step, prediction = results[1:3] if fetch_summary is not None: summ = results[-1] else: summ = None return prediction, step, summ
Example #5
Source File: gng_impl.py From loaner with Apache License 2.0 | 6 votes |
def run(self): """Runs the Grab n Go manager.""" try: while True: utils.clear_screen() utils.write('Which of the following actions would you like to take?\n') for opt in self._options.values(): utils.write('Action: {!r}\nDescription: {}\n'.format( opt.name, opt.description)) action = utils.prompt_enum( '', accepted_values=list(self._options.keys()), case_sensitive=False).strip().lower() callback = self._options[action].callback if callback is None: break self = callback() finally: utils.write( 'Done managing Grab n Go for Cloud Project {!r}.'.format( self._config.project))
Example #6
Source File: inspect_dataset.py From mixmatch with Apache License 2.0 | 6 votes |
def main(argv): del argv utils.setup_tf() nbatch = FLAGS.samples // FLAGS.batch dataset = data.DATASETS[FLAGS.dataset]() groups = [('labeled', dataset.train_labeled), ('unlabeled', dataset.train_unlabeled), ('test', dataset.test.repeat())] groups = [(name, ds.batch(FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next()) for name, ds in groups] with tf.train.MonitoredSession() as sess: for group, train_data in groups: stats = np.zeros(dataset.nclass, np.int32) minmax = [], [] for _ in trange(nbatch, leave=False, unit='img', unit_scale=FLAGS.batch, desc=group): v = sess.run(train_data)['label'] for u in v: stats[u] += 1 minmax[0].append(v.min()) minmax[1].append(v.max()) print(group) print(' Label range', min(minmax[0]), max(minmax[1])) print(' Stats', ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())]))
Example #7
Source File: gen_structure_test_case.py From moonlight with Apache License 2.0 | 6 votes |
def main(argv): pages = argv[1:] assert pages, 'Pass one or more PNG files' omr = engine.OMREngine() for i, filename in enumerate(pages): escaped_filename = re.sub(r'([\'\\])', r'\\\0', filename) page = omr.run(filename).page[0] # TODO(ringw): Use a real templating system (e.g. jinja or mako). if i > 0: print('') print(' def test%s_structure(self):' % _sanitized_basename(filename)) print(' page = engine.OMREngine().run(') print(' \'%s\').page[0]' % escaped_filename) print(' self.assertEqual(len(page.system), %d)' % len(page.system)) for i, system in enumerate(page.system): print('') print(' self.assertEqual(len(page.system[%d].staff), %d)' % (i, len(system.staff))) print(' self.assertEqual(len(page.system[%d].bar), %d)' % (i, len(system.bar)))
Example #8
Source File: tf_cnn_benchmarks.py From benchmarks with Apache License 2.0 | 6 votes |
def main(positional_arguments): # Command-line arguments like '--distortions False' are equivalent to # '--distortions=True False', where False is a positional argument. To prevent # this from silently running with distortions, we do not allow positional # arguments. assert len(positional_arguments) >= 1 if len(positional_arguments) > 1: raise ValueError('Received unknown positional arguments: %s' % positional_arguments[1:]) params = benchmark_cnn.make_params_from_flags() with mlperf.mlperf_logger(absl_flags.FLAGS.ml_perf_compliance_logging, params.model): params = benchmark_cnn.setup(params) bench = benchmark_cnn.BenchmarkCNN(params) tfversion = cnn_util.tensorflow_version_tuple() log_fn('TensorFlow: %i.%i' % (tfversion[0], tfversion[1])) bench.print_info() bench.run()
Example #9
Source File: reinitialize_argfile.py From lottery-ticket-hypothesis with Apache License 2.0 | 6 votes |
def main(argv): del argv # Unused. line_format = ('--masks={masks} --output_dir={output_dir}') name = FLAGS.experiment for trial in range(1, 21): for level in range(0, 31): for run in range(1, 11): masks = paths.masks(constants.run(trial, level)) output = constants.run(trial, level, name, run) result = line_format.format(masks=masks, output_dir=output) if FLAGS.experiment in ('reuse', 'reuse_sign'): result += (' --initialization_distribution=' + constants.initialization(level)) if FLAGS.experiment == 'reuse_sign': presets = paths.initial(constants.run(trial, level)) result += ' --same_sign={}'.format(presets) print(result)
Example #10
Source File: config.py From trax with Apache License 2.0 | 6 votes |
def config_with_absl(self): # Run this before calling `app.run(main)` etc import absl.flags as absl_FLAGS from absl import app, flags as absl_flags self.use_absl = True self.absl_flags = absl_flags absl_defs = { bool: absl_flags.DEFINE_bool, int: absl_flags.DEFINE_integer, str: absl_flags.DEFINE_string, 'enum': absl_flags.DEFINE_enum } for name, val in self.values.items(): flag_type, meta_args, meta_kwargs = self.meta[name] absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) app.call_after_init(lambda: self.complete_absl_config(absl_flags))
Example #11
Source File: imagenet_to_gcs.py From tpu_models with Apache License 2.0 | 6 votes |
def __init__(self): # Create a single Session to run all image coding calls. self._sess = tf.Session() # Initializes function that converts PNG to JPEG data. self._png_data = tf.placeholder(dtype=tf.string) image = tf.image.decode_png(self._png_data, channels=3) self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) # Initializes function that converts CMYK JPEG data to RGB JPEG data. self._cmyk_data = tf.placeholder(dtype=tf.string) image = tf.image.decode_jpeg(self._cmyk_data, channels=0) self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100) # Initializes function that decodes RGB JPEG data. self._decode_jpeg_data = tf.placeholder(dtype=tf.string) self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
Example #12
Source File: cyclic_bag_log_reg.py From federated with Apache License 2.0 | 6 votes |
def create_model(self): """Creates a TF model and returns ops necessary to run training/eval.""" features = tf.compat.v1.placeholder(tf.float32, [None, self.input_dim]) labels = tf.compat.v1.placeholder(tf.float32, [None, self.num_classes]) w = tf.Variable(tf.random.normal(shape=[self.input_dim, self.num_classes])) b = tf.Variable(tf.random.normal(shape=[self.num_classes])) pred = tf.nn.softmax(tf.matmul(features, w) + b) loss = tf.reduce_mean(-tf.reduce_sum(labels * tf.math.log(pred), axis=1)) train_op = self.optimizer.minimize( loss=loss, global_step=tf.train.get_or_create_global_step()) correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(labels, 1)) eval_metric_op = tf.count_nonzero(correct_pred) return features, labels, train_op, loss, eval_metric_op
Example #13
Source File: cyclic_bag_log_reg.py From federated with Apache License 2.0 | 6 votes |
def log_config(logger): """Logs the configuration of this run, so it can be used in the analysis phase.""" logger.log('== Configuration ==') logger.log('task_id=%d' % FLAGS.task_id) logger.log('lr=%f' % FLAGS.lr) logger.log('vocab_size=%s' % FLAGS.vocab_size) logger.log('batch_size=%s' % FLAGS.batch_size) logger.log('bow_limit=%s' % FLAGS.bow_limit) logger.log('training_data=%s' % FLAGS.training_data) logger.log('test_data=%s' % FLAGS.test_data) logger.log('num_groups=%d' % FLAGS.num_groups) logger.log('num_days=%d' % FLAGS.num_days) logger.log('num_train_examples_per_day=%d' % FLAGS.num_train_examples_per_day) logger.log('mode=%s' % FLAGS.mode) logger.log('bias=%f' % FLAGS.bias) logger.log('replica=%d' % FLAGS.replica)
Example #14
Source File: imagenet_to_gcs.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def __init__(self): # Create a single Session to run all image coding calls. self._sess = tf.Session() # Initializes function that converts PNG to JPEG data. self._png_data = tf.placeholder(dtype=tf.string) image = tf.image.decode_png(self._png_data, channels=3) self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) # Initializes function that converts CMYK JPEG data to RGB JPEG data. self._cmyk_data = tf.placeholder(dtype=tf.string) image = tf.image.decode_jpeg(self._cmyk_data, channels=0) self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100) # Initializes function that decodes RGB JPEG data. self._decode_jpeg_data = tf.placeholder(dtype=tf.string) self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
Example #15
Source File: inspect_dataset.py From realmix with Apache License 2.0 | 6 votes |
def main(argv): del argv utils.setup_tf() nbatch = FLAGS.samples // FLAGS.batch dataset = data.DATASETS[FLAGS.dataset]() groups = [('labeled', dataset.train_labeled), ('unlabeled', dataset.train_unlabeled), ('test', dataset.test.repeat())] groups = [(name, ds.batch(FLAGS.batch).prefetch(16).make_one_shot_iterator().get_next()) for name, ds in groups] with tf.train.MonitoredSession() as sess: for group, train_data in groups: stats = np.zeros(dataset.nclass, np.int32) minmax = [], [] for _ in trange(nbatch, leave=False, unit='img', unit_scale=FLAGS.batch, desc=group): v = sess.run(train_data)['label'] for u in v: stats[u] += 1 minmax[0].append(v.min()) minmax[1].append(v.max()) print(group) print(' Label range', min(minmax[0]), max(minmax[1])) print(' Stats', ' '.join(['%.2f' % (100 * x) for x in (stats / stats.max())]))
Example #16
Source File: play.py From pysc2 with Apache License 2.0 | 5 votes |
def entry_point(): # Needed so setup.py scripts work. app.run(main)
Example #17
Source File: nsfw_scratch.py From nsfw with Apache License 2.0 | 5 votes |
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, parse_record_fn, num_epochs=1, num_gpus=None, examples_per_epoch=None, dtype=tf.float32): dataset = dataset.prefetch(buffer_size=batch_size) if is_training: dataset = dataset.shuffle(buffer_size=shuffle_buffer) dataset = dataset.repeat(num_epochs) if is_training and num_gpus and examples_per_epoch: total_examples = num_epochs * examples_per_epoch # Force the number of batches to be divisible by the number of devices. # This prevents some devices from receiving batches while others do not, # which can lead to a lockup. This case will soon be handled directly by # distribution strategies, at which point this .take() operation will no # longer be needed. total_batches = total_examples // batch_size // num_gpus * num_gpus dataset.take(total_batches * batch_size) # Parse the raw records into images and labels. Testing has shown that setting # num_parallel_batches > 1 produces no improvement in throughput, since # batch_size is almost always much greater than the number of CPU cores. dataset = dataset.apply( tf.contrib.data.map_and_batch( lambda value: parse_record_fn(value, is_training), batch_size=batch_size, num_parallel_batches=1, drop_remainder=False)) # Operations between the final prefetch and the get_next call to the iterator # will happen synchronously during run time. We prefetch here again to # background all of the above processing work and keep it out of the # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE # allows DistributionStrategies to adjust how many batches to fetch based # on how many devices are present. dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) return dataset
Example #18
Source File: nsfw_main_finetune.py From nsfw with Apache License 2.0 | 5 votes |
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, parse_record_fn, num_epochs=1, num_gpus=None, examples_per_epoch=None, dtype=tf.float32): dataset = dataset.prefetch(buffer_size=batch_size) if is_training: dataset = dataset.shuffle(buffer_size=shuffle_buffer) dataset = dataset.repeat(num_epochs) if is_training and num_gpus and examples_per_epoch: total_examples = num_epochs * examples_per_epoch # Force the number of batches to be divisible by the number of devices. # This prevents some devices from receiving batches while others do not, # which can lead to a lockup. This case will soon be handled directly by # distribution strategies, at which point this .take() operation will no # longer be needed. total_batches = total_examples // batch_size // num_gpus * num_gpus dataset.take(total_batches * batch_size) # Parse the raw records into images and labels. Testing has shown that setting # num_parallel_batches > 1 produces no improvement in throughput, since # batch_size is almost always much greater than the number of CPU cores. dataset = dataset.apply( tf.contrib.data.map_and_batch( lambda value: parse_record_fn(value, is_training, dtype), batch_size=batch_size, num_parallel_batches=1, drop_remainder=False)) # Operations between the final prefetch and the get_next call to the iterator # will happen synchronously during run time. We prefetch here again to # background all of the above processing work and keep it out of the # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE # allows DistributionStrategies to adjust how many batches to fetch based # on how many devices are present. dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) return dataset
Example #19
Source File: export_checkpoints.py From albert with Apache License 2.0 | 5 votes |
def build_model(sess): """Module function.""" input_ids = tf.placeholder(tf.int32, [None, None], "input_ids") input_mask = tf.placeholder(tf.int32, [None, None], "input_mask") segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids") mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions") albert_config_path = os.path.join( FLAGS.albert_directory, "albert_config.json") albert_config = modeling.AlbertConfig.from_json_file(albert_config_path) model = modeling.AlbertModel( config=albert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=False) get_mlm_logits(model.get_sequence_output(), albert_config, mlm_positions, model.get_embedding_table()) get_sentence_order_logits(model.get_pooled_output(), albert_config) checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint(tvars, checkpoint_path) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) tf.train.init_from_checkpoint(checkpoint_path, assignment_map) init = tf.global_variables_initializer() sess.run(init) return sess
Example #20
Source File: beam_reshuffle.py From exoplanet-ml with Apache License 2.0 | 5 votes |
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") def pipeline(root): """Beam pipeline for preprocessing open images.""" assert FLAGS.input_file_patterns assert FLAGS.output_dir assert FLAGS.output_name assert FLAGS.num_shards # Create Pipeline. tfrecords = [] for i, file_pattern in enumerate(FLAGS.input_file_patterns.split(",")): logging.info("Reading TFRecords from %s", file_pattern) stage_name = "read_tfrecords_{}".format(i) tfrecords.append(root | stage_name >> beam.io.tfrecordio.ReadFromTFRecord( file_pattern, coder=beam.coders.ProtoCoder(tf.train.Example))) # pylint: disable=expression-not-assigned (tfrecords | "flatten" >> beam.Flatten() | "count_labels" >> beam.ParDo(CountLabelsDoFn()) | "reshuffle" >> beam.Reshuffle() | "write_tfrecord" >> beam.io.tfrecordio.WriteToTFRecord( os.path.join(FLAGS.output_dir, FLAGS.output_name), coder=beam.coders.ProtoCoder(tf.train.Example), num_shards=FLAGS.num_shards)) # pylint: enable=expression-not-assigned pipeline.run() logging.info("Processing complete.")
Example #21
Source File: imagenet_to_gcs.py From tpu_models with Apache License 2.0 | 5 votes |
def decode_jpeg(self, image_data): image = self._sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data}) assert len(image.shape) == 3 assert image.shape[2] == 3 return image
Example #22
Source File: imagenet_to_gcs.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def decode_jpeg(self, image_data): image = self._sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data}) assert len(image.shape) == 3 assert image.shape[2] == 3 return image
Example #23
Source File: imagenet_to_gcs.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def cmyk_to_rgb(self, image_data): return self._sess.run(self._cmyk_to_rgb, feed_dict={self._cmyk_data: image_data})
Example #24
Source File: imagenet_to_gcs.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def png_to_jpeg(self, image_data): return self._sess.run(self._png_to_jpeg, feed_dict={self._png_data: image_data})
Example #25
Source File: tf_hub.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def _check_shapes_of_restored_variables(session, variables_to_restore): """Raises TypeError if restored variables have unexpected shapes.""" num_errors = 0 for variable_name, variable in variables_to_restore.items(): graph_shape = variable.value().shape # Values are big, but tf.shape(..) whould echo graph_shape if fully defined. checkpoint_shape = session.run(variable.value()).shape if not graph_shape.is_compatible_with(checkpoint_shape): tf.logging.error('Shape mismatch for variable %s: ' 'graph expects %s but checkpoint has %s' % (variable_name, graph_shape, checkpoint_shape)) num_errors += 1 if num_errors: raise TypeError( 'Shape mismatch for %d variables, see error log for list.' % num_errors)
Example #26
Source File: onsets_frames_transcription_realtime.py From magenta with Apache License 2.0 | 5 votes |
def console_entry_point(): app.run(main)
Example #27
Source File: deploy_impl.py From loaner with Apache License 2.0 | 5 votes |
def DeployWebApp(self): """Bundle then deploy (or run locally) the web application.""" self._BundleWebApp() if self.on_local: print('Run locally...') else: cmds = [ 'gcloud', 'app', 'deploy', '--no-promote', '--project={}'.format( self.project_id), '--version={}'.format(self.version)] for yaml_filename in self._yaml_files: cmds.append(self._GetYamlFile(yaml_filename)) logging.info( 'Deploying to the Google Cloud project: %s using gcloud...', self.project_id) _ExecuteCommand(cmds) if self.on_google_cloud_shell: self._CleanWebAppBackend()
Example #28
Source File: run.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def main(_): run_lib.run()
Example #29
Source File: ncf_main.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def build_hooks(mode, embedding, params, train_record_dir): """Build `hook_before` and `hook_after` for `graph_spec`.""" saver = tf.train.Saver() if mode == tpu_embedding.TRAINING: def hook_before(sess, epoch): if epoch == 0: sess.run(tf.global_variables_initializer()) else: saver.restore(sess, "{}/model.ckpt.{}".format( params["model_dir"], epoch-1)) sess.run(embedding.init_ops) def hook_after(sess, epoch): sess.run(embedding.retrieve_parameters_ops) ckpt_path = saver.save(sess, "{}/model.ckpt.{}".format( params["model_dir"], epoch)) tf.logging.info("Model saved in path: {}." .format(ckpt_path)) # must delete; otherwise the first epoch's data will always be used. tf.gfile.DeleteRecursively(train_record_dir) else: def hook_before(sess, epoch): saver.restore(sess, "{}/model.ckpt.{}".format( params["model_dir"], epoch)) sess.run(embedding.init_ops) def hook_after(sess, epoch): del sess, epoch return hook_before, hook_after
Example #30
Source File: run.py From tpu_models with Apache License 2.0 | 5 votes |
def main(_): run_lib.run()