Python absl.flags.DEFINE_integer() Examples

The following are 30 code examples of absl.flags.DEFINE_integer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module absl.flags , or try the search function .
Example #1
Source File: flags.py    From benchmarks with Apache License 2.0 7 votes vote down vote up
def define_flags(specs=None):
  """Define a command line flag for each ParamSpec in flags.param_specs."""
  specs = specs or param_specs
  define_flag = {
      'boolean': absl_flags.DEFINE_boolean,
      'float': absl_flags.DEFINE_float,
      'integer': absl_flags.DEFINE_integer,
      'string': absl_flags.DEFINE_string,
      'enum': absl_flags.DEFINE_enum,
      'list': absl_flags.DEFINE_list
  }
  for name, param_spec in six.iteritems(specs):
    if param_spec.flag_type not in define_flag:
      raise ValueError('Unknown flag_type %s' % param_spec.flag_type)
    else:
      define_flag[param_spec.flag_type](name, param_spec.default_value,
                                        help=param_spec.description,
                                        **param_spec.kwargs) 
Example #2
Source File: lamb_flags.py    From lamb with Apache License 2.0 6 votes vote down vote up
def _define_flags(options):
  for option in _filter_options(options):
    name, type_, default_ = option[:3]
    if type_ == 'boolean':
      flags.DEFINE_boolean(name, default_, '')
    elif type_ == 'integer':
      flags.DEFINE_integer(name, default_, '')
    elif type_ == 'float':
      flags.DEFINE_float(name, default_, '')
    elif type_ == 'string':
      flags.DEFINE_string(name, default_, '')
    elif type_ == 'list_of_ints':
      flags.DEFINE_string(name, default_, '')
    else:
      assert 'Unexpected option type %s' % type_


# Define command-line flags for all options (unless `internal`). 
Example #3
Source File: app.py    From clgen with GNU General Public License v3.0 6 votes vote down vote up
def DEFINE_integer(
  name: str,
  default: Optional[int],
  help: str,
  required: bool = False,
  lower_bound: Optional[int] = None,
  upper_bound: Optional[int] = None,
  validator: Callable[[int], bool] = None,
):
  """Registers a flag whose value must be an integer."""
  absl_flags.DEFINE_integer(
    name,
    default,
    help,
    module_name=get_calling_module_name(),
    lower_bound=lower_bound,
    upper_bound=upper_bound,
  )
  if required:
    absl_flags.mark_flag_as_required(name)
  if validator:
    RegisterFlagValidator(name, validator) 
Example #4
Source File: argparse_flags_test.py    From abseil-py with Apache License 2.0 6 votes vote down vote up
def setUp(self):
    self._absl_flags = flags.FlagValues()
    flags.DEFINE_bool(
        'absl_bool', None, 'help for --absl_bool.',
        short_name='b', flag_values=self._absl_flags)
    # Add a boolean flag that starts with "no", to verify it can correctly
    # handle the "no" prefixes in boolean flags.
    flags.DEFINE_bool(
        'notice', None, 'help for --notice.',
        flag_values=self._absl_flags)
    flags.DEFINE_string(
        'absl_string', 'default', 'help for --absl_string=%.',
        short_name='s', flag_values=self._absl_flags)
    flags.DEFINE_integer(
        'absl_integer', 1, 'help for --absl_integer.',
        flag_values=self._absl_flags)
    flags.DEFINE_float(
        'absl_float', 1, 'help for --absl_integer.',
        flag_values=self._absl_flags)
    flags.DEFINE_enum(
        'absl_enum', 'apple', ['apple', 'orange'], 'help for --absl_enum.',
        flag_values=self._absl_flags) 
Example #5
Source File: module_bar.py    From abseil-py with Apache License 2.0 6 votes vote down vote up
def define_flags(flag_values=FLAGS):
  """Defines some flags.

  Args:
    flag_values: The FlagValues object we want to register the flags
      with.
  """
  # The 'tmod_bar_' prefix (short for 'test_module_bar') ensures there
  # is no name clash with the existing flags.
  flags.DEFINE_boolean('tmod_bar_x', True, 'Boolean flag.',
                       flag_values=flag_values)
  flags.DEFINE_string('tmod_bar_y', 'default', 'String flag.',
                      flag_values=flag_values)
  flags.DEFINE_boolean('tmod_bar_z', False,
                       'Another boolean flag from module bar.',
                       flag_values=flag_values)
  flags.DEFINE_integer('tmod_bar_t', 4, 'Sample int flag.',
                       flag_values=flag_values)
  flags.DEFINE_integer('tmod_bar_u', 5, 'Sample int flag.',
                       flag_values=flag_values)
  flags.DEFINE_integer('tmod_bar_v', 6, 'Sample int flag.',
                       flag_values=flag_values) 
Example #6
Source File: config.py    From trax with Apache License 2.0 6 votes vote down vote up
def config_with_absl(self):
    # Run this before calling `app.run(main)` etc
    import absl.flags as absl_FLAGS
    from absl import app, flags as absl_flags

    self.use_absl = True
    self.absl_flags = absl_flags
    absl_defs = { bool: absl_flags.DEFINE_bool,
                  int:  absl_flags.DEFINE_integer,
                  str:  absl_flags.DEFINE_string,
                  'enum': absl_flags.DEFINE_enum }

    for name, val in self.values.items():
      flag_type, meta_args, meta_kwargs = self.meta[name]
      absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)

    app.call_after_init(lambda: self.complete_absl_config(absl_flags)) 
Example #7
Source File: flag_util_test.py    From PerfKitBenchmarker with Apache License 2.0 6 votes vote down vote up
def testReadAndWrite(self):
    flag_values = flags.FlagValues()
    flags.DEFINE_integer('test_flag', 0, 'Test flag.', flag_values=flag_values)
    flag_values([sys.argv[0]])
    flag_values_overrides = {}
    flag_values_overrides['test_flag'] = 1
    self.assertFlagState(flag_values, 0, False)
    self.assertEqual(flag_values_overrides['test_flag'], 1)
    with flag_util.OverrideFlags(flag_values, flag_values_overrides):
      self.assertFlagState(flag_values, 1, True)
      self.assertEqual(flag_values_overrides['test_flag'], 1)
    self.assertFlagState(flag_values, 0, False)
    self.assertEqual(flag_values_overrides['test_flag'], 1)
    flag_values.test_flag = 3
    self.assertFlagState(flag_values, 3, False)
    self.assertEqual(flag_values_overrides['test_flag'], 1) 
Example #8
Source File: flags_helpxml_test.py    From abseil-py with Apache License 2.0 6 votes vote down vote up
def test_flag_help_in_xml_int_with_bounds(self):
    flags.DEFINE_integer('nb_iters', 17, 'An integer flag',
                         lower_bound=5, upper_bound=27,
                         flag_values=self.fv)
    expected_output = (
        '<flag>\n'
        '  <key>yes</key>\n'
        '  <file>module.name</file>\n'
        '  <name>nb_iters</name>\n'
        '  <meaning>An integer flag</meaning>\n'
        '  <default>17</default>\n'
        '  <current>17</current>\n'
        '  <type>int</type>\n'
        '  <lower_bound>5</lower_bound>\n'
        '  <upper_bound>27</upper_bound>\n'
        '</flag>\n')
    self._check_flag_help_in_xml('nb_iters', 'module.name', expected_output,
                                 is_key=True) 
Example #9
Source File: flags.py    From dlcookbook-dlbs with Apache License 2.0 6 votes vote down vote up
def define_flags():
  """Define a command line flag for each ParamSpec in flags.param_specs."""
  define_flag = {
      'boolean': absl_flags.DEFINE_boolean,
      'float': absl_flags.DEFINE_float,
      'integer': absl_flags.DEFINE_integer,
      'string': absl_flags.DEFINE_string,
      'enum': absl_flags.DEFINE_enum,
      'list': absl_flags.DEFINE_list
  }
  for name, param_spec in six.iteritems(param_specs):
    if param_spec.flag_type not in define_flag:
      raise ValueError('Unknown flag_type %s' % param_spec.flag_type)
    else:
      define_flag[param_spec.flag_type](name, param_spec.default_value,
                                        help=param_spec.description,
                                        **param_spec.kwargs) 
Example #10
Source File: mnist_eager.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def define_mnist_eager_flags():
  """Defined flags and defaults for MNIST in eager mode."""
  flags_core.define_base_eager()
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)

  flags.DEFINE_integer(
      name='log_interval', short_name='li', default=10,
      help=flags_core.help_wrap('batches between logging training status'))

  flags.DEFINE_string(
      name='output_dir', short_name='od', default=None,
      help=flags_core.help_wrap('Directory to write TensorBoard summaries'))

  flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
                     help=flags_core.help_wrap('Learning rate.'))

  flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
                     help=flags_core.help_wrap('SGD momentum.'))

  flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
                    help=flags_core.help_wrap(
                        'disables GPU usage even if a GPU is available'))

  flags_core.set_defaults(
      data_dir='/tmp/tensorflow/mnist/input_data',
      model_dir='/tmp/tensorflow/mnist/checkpoints/',
      batch_size=100,
      train_epochs=10,
  ) 
Example #11
Source File: train_higgs.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def define_train_higgs_flags():
  """Add tree related flags as well as training/eval configuration."""
  flags_core.define_base(clean=False, stop_threshold=False, batch_size=False,
                         num_gpu=False)
  flags_core.define_benchmark()
  flags.adopt_module_key_flags(flags_core)

  flags.DEFINE_integer(
      name="train_start", default=0,
      help=help_wrap("Start index of train examples within the data."))
  flags.DEFINE_integer(
      name="train_count", default=1000000,
      help=help_wrap("Number of train examples within the data."))
  flags.DEFINE_integer(
      name="eval_start", default=10000000,
      help=help_wrap("Start index of eval examples within the data."))
  flags.DEFINE_integer(
      name="eval_count", default=1000000,
      help=help_wrap("Number of eval examples within the data."))

  flags.DEFINE_integer(
      "n_trees", default=100, help=help_wrap("Number of trees to build."))
  flags.DEFINE_integer(
      "max_depth", default=6, help=help_wrap("Maximum depths of each tree."))
  flags.DEFINE_float(
      "learning_rate", default=0.1,
      help=help_wrap("The learning rate."))

  flags_core.set_defaults(data_dir="/tmp/higgs_data",
                          model_dir="/tmp/higgs_model") 
Example #12
Source File: _device.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def define_device(tpu=True):
  """Register device specific flags.
  Args:
    tpu: Create flags to specify TPU operation.
  Returns:
    A list of flags for core.py to marks as key flags.
  """

  key_flags = []

  if tpu:
    flags.DEFINE_string(
        name="tpu", default=None,
        help=help_wrap(
            "The Cloud TPU to use for training. This should be either the name "
            "used when creating the Cloud TPU, or a "
            "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
            "CPU of the local instance instead. (Good for debugging.)"))
    key_flags.append("tpu")

    flags.DEFINE_string(
        name="tpu_zone", default=None,
        help=help_wrap(
            "[Optional] GCE zone where the Cloud TPU is located in. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_string(
        name="tpu_gcp_project", default=None,
        help=help_wrap(
            "[Optional] Project name for the Cloud TPU-enabled project. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_integer(name="num_tpu_shards", default=8,
                         help=help_wrap("Number of shards (TPU chips)."))

  return key_flags 
Example #13
Source File: train_higgs.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def define_train_higgs_flags():
  """Add tree related flags as well as training/eval configuration."""
  flags_core.define_base(clean=False, stop_threshold=False, batch_size=False,
                         num_gpu=False)
  flags_core.define_benchmark()
  flags.adopt_module_key_flags(flags_core)

  flags.DEFINE_integer(
      name="train_start", default=0,
      help=help_wrap("Start index of train examples within the data."))
  flags.DEFINE_integer(
      name="train_count", default=1000000,
      help=help_wrap("Number of train examples within the data."))
  flags.DEFINE_integer(
      name="eval_start", default=10000000,
      help=help_wrap("Start index of eval examples within the data."))
  flags.DEFINE_integer(
      name="eval_count", default=1000000,
      help=help_wrap("Number of eval examples within the data."))

  flags.DEFINE_integer(
      "n_trees", default=100, help=help_wrap("Number of trees to build."))
  flags.DEFINE_integer(
      "max_depth", default=6, help=help_wrap("Maximum depths of each tree."))
  flags.DEFINE_float(
      "learning_rate", default=0.1,
      help=help_wrap("The learning rate."))

  flags_core.set_defaults(data_dir="/tmp/higgs_data",
                          model_dir="/tmp/higgs_model") 
Example #14
Source File: flag_util_test.py    From PerfKitBenchmarker with Apache License 2.0 5 votes vote down vote up
def testFlagChangesAreNotReflectedInConfigDict(self):
    flag_values = flags.FlagValues()
    flags.DEFINE_integer('test_flag', 0, 'Test flag.', flag_values=flag_values)
    flag_values([sys.argv[0]])
    flag_values_overrides = {}
    flag_values_overrides['test_flag'] = 1
    self.assertFlagState(flag_values, 0, False)
    self.assertEqual(flag_values_overrides['test_flag'], 1)
    with flag_util.OverrideFlags(flag_values, flag_values_overrides):
      self.assertFlagState(flag_values, 1, True)
      flag_values.test_flag = 2
      self.assertFlagState(flag_values, 2, True)
      self.assertEqual(flag_values_overrides['test_flag'], 1) 
Example #15
Source File: common.py    From models with Apache License 2.0 5 votes vote down vote up
def define_pruning_flags():
  """Define flags for pruning methods."""
  flags.DEFINE_string('pruning_method', None,
                      'Pruning method.'
                      'None (no pruning) or polynomial_decay.')
  flags.DEFINE_float('pruning_initial_sparsity', 0.0,
                     'Initial sparsity for pruning.')
  flags.DEFINE_float('pruning_final_sparsity', 0.5,
                     'Final sparsity for pruning.')
  flags.DEFINE_integer('pruning_begin_step', 0,
                       'Begin step for pruning.')
  flags.DEFINE_integer('pruning_end_step', 100000,
                       'End step for pruning.')
  flags.DEFINE_integer('pruning_frequency', 100,
                       'Frequency for pruning.') 
Example #16
Source File: classifier_trainer.py    From models with Apache License 2.0 5 votes vote down vote up
def define_classifier_flags():
  """Defines common flags for image classification."""
  hyperparams_flags.initialize_common_flags()
  flags.DEFINE_string(
      'data_dir',
      default=None,
      help='The location of the input data.')
  flags.DEFINE_string(
      'mode',
      default=None,
      help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
  flags.DEFINE_bool(
      'run_eagerly',
      default=None,
      help='Use eager execution and disable autograph for debugging.')
  flags.DEFINE_string(
      'model_type',
      default=None,
      help='The type of the model, e.g. EfficientNet, etc.')
  flags.DEFINE_string(
      'dataset',
      default=None,
      help='The name of the dataset, e.g. ImageNet, etc.')
  flags.DEFINE_integer(
      'log_steps',
      default=100,
      help='The interval of steps between logging of batch level stats.') 
Example #17
Source File: keras_common.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def define_keras_flags():
  flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
  flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
  flags.DEFINE_integer(
      name='train_steps', default=None,
      help='The number of steps to run for training. If it is larger than '
      '# batches per epoch, then use # batches per epoch. When this flag is '
      'set, only one epoch is going to run for training.')
  flags.DEFINE_integer(
      name='log_steps', default=100,
      help='For every log_steps, we log the timing information such as '
      'examples per second. Besides, for every log_steps, we store the '
      'timestamp of a batch end.') 
Example #18
Source File: _device.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def define_device(tpu=True):
  """Register device specific flags.
  Args:
    tpu: Create flags to specify TPU operation.
  Returns:
    A list of flags for core.py to marks as key flags.
  """

  key_flags = []

  if tpu:
    flags.DEFINE_string(
        name="tpu", default=None,
        help=help_wrap(
            "The Cloud TPU to use for training. This should be either the name "
            "used when creating the Cloud TPU, or a "
            "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
            "CPU of the local instance instead. (Good for debugging.)"))
    key_flags.append("tpu")

    flags.DEFINE_string(
        name="tpu_zone", default=None,
        help=help_wrap(
            "[Optional] GCE zone where the Cloud TPU is located in. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_string(
        name="tpu_gcp_project", default=None,
        help=help_wrap(
            "[Optional] Project name for the Cloud TPU-enabled project. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_integer(name="num_tpu_shards", default=8,
                         help=help_wrap("Number of shards (TPU chips)."))

  return key_flags 
Example #19
Source File: mnist_eager.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def define_mnist_eager_flags():
  """Defined flags and defaults for MNIST in eager mode."""
  flags_core.define_base_eager()
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)

  flags.DEFINE_integer(
      name='log_interval', short_name='li', default=10,
      help=flags_core.help_wrap('batches between logging training status'))

  flags.DEFINE_string(
      name='output_dir', short_name='od', default=None,
      help=flags_core.help_wrap('Directory to write TensorBoard summaries'))

  flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
                     help=flags_core.help_wrap('Learning rate.'))

  flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
                     help=flags_core.help_wrap('SGD momentum.'))

  flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
                    help=flags_core.help_wrap(
                        'disables GPU usage even if a GPU is available'))

  flags_core.set_defaults(
      data_dir='/tmp/tensorflow/mnist/input_data',
      model_dir='/tmp/tensorflow/mnist/checkpoints/',
      batch_size=100,
      train_epochs=10,
  ) 
Example #20
Source File: mnist_eager.py    From models with Apache License 2.0 5 votes vote down vote up
def define_mnist_eager_flags():
  """Defined flags and defaults for MNIST in eager mode."""
  flags_core.define_base_eager()
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)

  flags.DEFINE_integer(
      name='log_interval', short_name='li', default=10,
      help=flags_core.help_wrap('batches between logging training status'))

  flags.DEFINE_string(
      name='output_dir', short_name='od', default=None,
      help=flags_core.help_wrap('Directory to write TensorBoard summaries'))

  flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
                     help=flags_core.help_wrap('Learning rate.'))

  flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
                     help=flags_core.help_wrap('SGD momentum.'))

  flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
                    help=flags_core.help_wrap(
                        'disables GPU usage even if a GPU is available'))

  flags_core.set_defaults(
      data_dir='/tmp/tensorflow/mnist/input_data',
      model_dir='/tmp/tensorflow/mnist/checkpoints/',
      batch_size=100,
      train_epochs=10,
  ) 
Example #21
Source File: _distribution.py    From models with Apache License 2.0 5 votes vote down vote up
def define_distribution(worker_hosts=True, task_index=True):
  """Register distributed execution flags.

  Args:
    worker_hosts: Create a flag for specifying comma-separated list of workers.
    task_index: Create a flag for specifying index of task.

  Returns:
    A list of flags for core.py to marks as key flags.
  """
  key_flags = []

  if worker_hosts:
    flags.DEFINE_string(
        name='worker_hosts', default=None,
        help=help_wrap(
            'Comma-separated list of worker ip:port pairs for running '
            'multi-worker models with DistributionStrategy.  The user would '
            'start the program on each host with identical value for this '
            'flag.'))

  if task_index:
    flags.DEFINE_integer(
        name='task_index', default=-1,
        help=help_wrap('If multi-worker training, the task_index of this '
                       'worker.'))

  return key_flags 
Example #22
Source File: flags.py    From dlcookbook-dlbs with Apache License 2.0 5 votes vote down vote up
def DEFINE_integer(name, default, help, lower_bound=None, upper_bound=None):  # pylint: disable=invalid-name,redefined-builtin
  kwargs = {'lower_bound': lower_bound, 'upper_bound': upper_bound}
  param_specs[name] = ParamSpec('integer', default, help, kwargs) 
Example #23
Source File: train_higgs.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 5 votes vote down vote up
def define_train_higgs_flags():
  """Add tree related flags as well as training/eval configuration."""
  flags_core.define_base(clean=False, stop_threshold=False, batch_size=False,
                         num_gpu=False)
  flags_core.define_benchmark()
  flags.adopt_module_key_flags(flags_core)

  flags.DEFINE_integer(
      name="train_start", default=0,
      help=help_wrap("Start index of train examples within the data."))
  flags.DEFINE_integer(
      name="train_count", default=1000000,
      help=help_wrap("Number of train examples within the data."))
  flags.DEFINE_integer(
      name="eval_start", default=10000000,
      help=help_wrap("Start index of eval examples within the data."))
  flags.DEFINE_integer(
      name="eval_count", default=1000000,
      help=help_wrap("Number of eval examples within the data."))

  flags.DEFINE_integer(
      "n_trees", default=100, help=help_wrap("Number of trees to build."))
  flags.DEFINE_integer(
      "max_depth", default=6, help=help_wrap("Maximum depths of each tree."))
  flags.DEFINE_float(
      "learning_rate", default=0.1,
      help=help_wrap("The learning rate."))

  flags_core.set_defaults(data_dir="/tmp/higgs_data",
                          model_dir="/tmp/higgs_model") 
Example #24
Source File: mnist_eager.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 5 votes vote down vote up
def define_mnist_eager_flags():
  """Defined flags and defaults for MNIST in eager mode."""
  flags_core.define_base_eager(clean=True, train_epochs=True)
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)

  flags.DEFINE_integer(
      name='log_interval', short_name='li', default=10,
      help=flags_core.help_wrap('batches between logging training status'))

  flags.DEFINE_string(
      name='output_dir', short_name='od', default=None,
      help=flags_core.help_wrap('Directory to write TensorBoard summaries'))

  flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
                     help=flags_core.help_wrap('Learning rate.'))

  flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
                     help=flags_core.help_wrap('SGD momentum.'))

  flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
                    help=flags_core.help_wrap(
                        'disables GPU usage even if a GPU is available'))

  flags_core.set_defaults(
      data_dir='/tmp/tensorflow/mnist/input_data',
      model_dir='/tmp/tensorflow/mnist/checkpoints/',
      batch_size=100,
      train_epochs=10,
  ) 
Example #25
Source File: _distribution.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 5 votes vote down vote up
def define_distribution(worker_hosts=True, task_index=True):
  """Register distributed execution flags.

  Args:
    worker_hosts: Create a flag for specifying comma-separated list of workers.
    task_index: Create a flag for specifying index of task.

  Returns:
    A list of flags for core.py to marks as key flags.
  """
  key_flags = []

  if worker_hosts:
    flags.DEFINE_string(
        name='worker_hosts', default=None,
        help=help_wrap(
            'Comma-separated list of worker ip:port pairs for running '
            'multi-worker models with DistributionStrategy.  The user would '
            'start the program on each host with identical value for this '
            'flag.'))

  if task_index:
    flags.DEFINE_integer(
        name='task_index', default=-1,
        help=help_wrap('If multi-worker training, the task_index of this '
                       'worker.'))

  return key_flags 
Example #26
Source File: _device.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 5 votes vote down vote up
def define_device(tpu=True):
  """Register device specific flags.
  Args:
    tpu: Create flags to specify TPU operation.
  Returns:
    A list of flags for core.py to marks as key flags.
  """

  key_flags = []

  if tpu:
    flags.DEFINE_string(
        name="tpu", default=None,
        help=help_wrap(
            "The Cloud TPU to use for training. This should be either the name "
            "used when creating the Cloud TPU, or a "
            "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
            "CPU of the local instance instead. (Good for debugging.)"))
    key_flags.append("tpu")

    flags.DEFINE_string(
        name="tpu_zone", default=None,
        help=help_wrap(
            "[Optional] GCE zone where the Cloud TPU is located in. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_string(
        name="tpu_gcp_project", default=None,
        help=help_wrap(
            "[Optional] Project name for the Cloud TPU-enabled project. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_integer(name="num_tpu_shards", default=8,
                         help=help_wrap("Number of shards (TPU chips)."))

  return key_flags 
Example #27
Source File: shakespeare_main.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 5 votes vote down vote up
def define_flags():
  """Define the flags for the Shakespeare character LSTM."""
  flags_core.define_base(data_dir=False,
                         clean=False,
                         train_epochs=True,
                         epochs_between_evals=False,
                         stop_threshold=False,
                         hooks=False,
                         export_dir=False,
                         run_eagerly=True)

  flags_core.define_performance(num_parallel_calls=False,
                                inter_op=False,
                                intra_op=False,
                                synthetic_data=False,
                                max_train_steps=False,
                                dtype=False,
                                enable_xla=True,
                                force_v2_in_keras_compile=True)

  flags_core.set_defaults(train_epochs=43,
                          batch_size=64)

  flags.DEFINE_boolean(name='enable_eager', default=True, help='Enable eager?')
  flags.DEFINE_boolean(
      name='train', default=True,
      help='If true trains the model.')
  flags.DEFINE_string(
      name='predict_context', default=None,
      help='If set, makes a prediction with the given context.')
  flags.DEFINE_integer(
      name='predict_length', default=1000,
      help='Length of the predicted text including the context.')
  flags.DEFINE_integer(
      name='log_steps', default=100,
      help='For every log_steps, we log the timing information such as '
      'examples per second.')
  flags.DEFINE_string(
      name='training_data', default=None,
      help='Path to file containing the training data.')
  flags.DEFINE_boolean(name='cudnn', default=True, help='Use CuDNN LSTM.') 
Example #28
Source File: mnist_eager.py    From models with Apache License 2.0 5 votes vote down vote up
def define_mnist_eager_flags():
  """Defined flags and defaults for MNIST in eager mode."""
  flags_core.define_base_eager()
  flags_core.define_image()
  flags.adopt_module_key_flags(flags_core)

  flags.DEFINE_integer(
      name='log_interval', short_name='li', default=10,
      help=flags_core.help_wrap('batches between logging training status'))

  flags.DEFINE_string(
      name='output_dir', short_name='od', default=None,
      help=flags_core.help_wrap('Directory to write TensorBoard summaries'))

  flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
                     help=flags_core.help_wrap('Learning rate.'))

  flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
                     help=flags_core.help_wrap('SGD momentum.'))

  flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
                    help=flags_core.help_wrap(
                        'disables GPU usage even if a GPU is available'))

  flags_core.set_defaults(
      data_dir='/tmp/tensorflow/mnist/input_data',
      model_dir='/tmp/tensorflow/mnist/checkpoints/',
      batch_size=100,
      train_epochs=10,
  ) 
Example #29
Source File: _device.py    From models with Apache License 2.0 5 votes vote down vote up
def define_device(tpu=True):
  """Register device specific flags.
  Args:
    tpu: Create flags to specify TPU operation.
  Returns:
    A list of flags for core.py to marks as key flags.
  """

  key_flags = []

  if tpu:
    flags.DEFINE_string(
        name="tpu", default=None,
        help=help_wrap(
            "The Cloud TPU to use for training. This should be either the name "
            "used when creating the Cloud TPU, or a "
            "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
            "CPU of the local instance instead. (Good for debugging.)"))
    key_flags.append("tpu")

    flags.DEFINE_string(
        name="tpu_zone", default=None,
        help=help_wrap(
            "[Optional] GCE zone where the Cloud TPU is located in. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_string(
        name="tpu_gcp_project", default=None,
        help=help_wrap(
            "[Optional] Project name for the Cloud TPU-enabled project. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_integer(name="num_tpu_shards", default=8,
                         help=help_wrap("Number of shards (TPU chips)."))

  return key_flags 
Example #30
Source File: _device.py    From models with Apache License 2.0 5 votes vote down vote up
def define_device(tpu=True):
  """Register device specific flags.
  Args:
    tpu: Create flags to specify TPU operation.
  Returns:
    A list of flags for core.py to marks as key flags.
  """

  key_flags = []

  if tpu:
    flags.DEFINE_string(
        name="tpu", default=None,
        help=help_wrap(
            "The Cloud TPU to use for training. This should be either the name "
            "used when creating the Cloud TPU, or a "
            "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
            "CPU of the local instance instead. (Good for debugging.)"))
    key_flags.append("tpu")

    flags.DEFINE_string(
        name="tpu_zone", default=None,
        help=help_wrap(
            "[Optional] GCE zone where the Cloud TPU is located in. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_string(
        name="tpu_gcp_project", default=None,
        help=help_wrap(
            "[Optional] Project name for the Cloud TPU-enabled project. If not "
            "specified, we will attempt to automatically detect the GCE "
            "project from metadata."))

    flags.DEFINE_integer(name="num_tpu_shards", default=8,
                         help=help_wrap("Number of shards (TPU chips)."))

  return key_flags