Python absl.flags.FLAGS Examples

The following are 30 code examples of absl.flags.FLAGS(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module absl.flags , or try the search function .
Example #1
Source File: ars_server.py    From soccer-matlab with BSD 2-Clause "Simplified" License 7 votes vote down vote up
def main(unused_argv):
  servers = []
  server_creds = loas2.loas2_server_credentials()
  port = FLAGS.port
  if not FLAGS.run_on_borg:
    port = 20000 + FLAGS.server_id
  server = grpc.server(
      futures.ThreadPoolExecutor(max_workers=10), ports=(port,))
  servicer = ars_evaluation_service.ParameterEvaluationServicer(
      FLAGS.config_name, worker_id=FLAGS.server_id)
  ars_evaluation_service_pb2_grpc.add_EvaluationServicer_to_server(
      servicer, server)
  server.add_secure_port("[::]:{}".format(port), server_creds)
  servers.append(server)
  server.start()
  print("Start server {}".format(FLAGS.server_id))

  # prevent the main thread from exiting
  try:
    while True:
      time.sleep(_ONE_DAY_IN_SECONDS)
  except KeyboardInterrupt:
    for server in servers:
      server.stop(0) 
Example #2
Source File: oracle_agent.py    From streetlearn with Apache License 2.0 6 votes vote down vote up
def main(argv):
  config = {'width': FLAGS.width,
            'height': FLAGS.height,
            'field_of_view': FLAGS.field_of_view,
            'graph_width': FLAGS.width,
            'graph_height': FLAGS.height,
            'graph_zoom': FLAGS.graph_zoom,
            'goal_timeout': FLAGS.frame_cap,
            'frame_cap': FLAGS.frame_cap,
            'full_graph': (FLAGS.start_pano == ''),
            'start_pano': FLAGS.start_pano,
            'min_graph_depth': FLAGS.graph_depth,
            'max_graph_depth': FLAGS.graph_depth,
            'proportion_of_panos_with_coins':
                FLAGS.proportion_of_panos_with_coins,
            'action_spec': 'streetlearn_fast_rotate',
            'observations': ['view_image', 'graph_image', 'yaw', 'pitch']}
  config = default_config.ApplyDefaults(config)
  game = courier_game.CourierGame(config)
  env = streetlearn.StreetLearn(FLAGS.dataset_path, config, game)
  env.reset()
  pygame.init()
  screen = pygame.display.set_mode((FLAGS.width, FLAGS.height * 2))
  loop(env, screen) 
Example #3
Source File: preview.py    From cwavegan with MIT License 6 votes vote down vote up
def noise_input_fn(params):
    """Input function for generating samples for PREDICT mode.

  Generates a single Tensor of fixed random noise. Use tf.data.Dataset to
  signal to the estimator when to terminate the generator returned by
  predict().

  Args:
    params: param `dict` passed by TPUEstimator.

  Returns:
    1-element `dict` containing the randomly generated noise.
  """

    # random noise
    np.random.seed(0)
    noise_dataset = tf.data.Dataset.from_tensors(tf.constant(
        np.random.randn(params['batch_size'], FLAGS.noise_dim), dtype=tf.float32))
    noise = noise_dataset.make_one_shot_iterator().get_next()
    return {'random_noise': noise}, None 
Example #4
Source File: optimizer.py    From ffn with Apache License 2.0 6 votes vote down vote up
def optimizer_from_flags():
  lr = FLAGS.learning_rate
  if FLAGS.optimizer == 'momentum':
    return tf.train.MomentumOptimizer(lr, FLAGS.momentum)
  elif FLAGS.optimizer == 'sgd':
    return tf.train.GradientDescentOptimizer(lr)
  elif FLAGS.optimizer == 'adagrad':
    return tf.train.AdagradOptimizer(lr)
  elif FLAGS.optimizer == 'adam':
    return tf.train.AdamOptimizer(learning_rate=lr,
                                  beta1=FLAGS.adam_beta1,
                                  beta2=FLAGS.adam_beta2,
                                  epsilon=FLAGS.epsilon)
  elif FLAGS.optimizer == 'rmsprop':
    return tf.train.RMSPropOptimizer(lr, FLAGS.rmsprop_decay,
                                     momentum=FLAGS.momentum,
                                     epsilon=FLAGS.epsilon)
  else:
    raise ValueError('Unknown optimizer: %s' % FLAGS.optimizer) 
Example #5
Source File: generate.py    From mathematics_dataset with Apache License 2.0 6 votes vote down vote up
def init_modules(train_split=False):
  """Inits the dicts containing functions for generating modules."""
  if filtered_modules:
    return  # already initialized

  all_modules = collections.OrderedDict([])
  if train_split:
    all_modules['train-easy'] = modules.train(_make_entropy_fn(0, 3))
    all_modules['train-medium'] = modules.train(_make_entropy_fn(1, 3))
    all_modules['train-hard'] = modules.train(_make_entropy_fn(2, 3))
  else:
    all_modules['train'] = modules.train(_make_entropy_fn(0, 1))

  all_modules['interpolate'] = modules.test()
  all_modules['extrapolate'] = modules.test_extra()

  counts['train'] = FLAGS.per_train_module
  counts['train-easy'] = FLAGS.per_train_module // 3
  counts['train-medium'] = FLAGS.per_train_module // 3
  counts['train-hard'] = FLAGS.per_train_module // 3
  counts['interpolate'] = FLAGS.per_test_module
  counts['extrapolate'] = FLAGS.per_test_module

  for regime_, modules_ in six.iteritems(all_modules):
    filtered_modules[regime_] = _filter_and_flatten(modules_) 
Example #6
Source File: run_inference.py    From ffn with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  request = inference_flags.request_from_flags()

  if not gfile.Exists(request.segmentation_output_dir):
    gfile.MakeDirs(request.segmentation_output_dir)

  bbox = bounding_box_pb2.BoundingBox()
  text_format.Parse(FLAGS.bounding_box, bbox)

  runner = inference.Runner()
  runner.start(request)
  runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
             (bbox.size.z, bbox.size.y, bbox.size.x))

  counter_path = os.path.join(request.segmentation_output_dir, 'counters.txt')
  if not gfile.Exists(counter_path):
    runner.counters.dump(counter_path) 
Example #7
Source File: model_compile.py    From turkish-morphology with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  # Below rewrite rule retrieval calls might throw IOError or
  # MorphotacticsCompilerError.
  lexicon = _get_lexicon_rules(FLAGS.lexicon_dir)
  morphotactics = _get_morphotactics_rules(FLAGS.morphotactics_dir)

  merged = _RewriteRuleSet()
  merged.rule.extend(lexicon.rule)
  merged.rule.extend(morphotactics.rule)
  _remove_duplicate_rules(merged)

  symbols_content = _symbols_table_file_content(merged)
  fst_content = _text_fst_file_content(merged)

  _make_output_directory(FLAGS.output_dir)
  symbols_path = os.path.join(FLAGS.output_dir, "complex_symbols.syms")
  _write_file(symbols_path, symbols_content)
  fst_path = os.path.join(FLAGS.output_dir, "morphotactics.txt")
  _write_file(fst_path, fst_content) 
Example #8
Source File: generate.py    From mathematics_dataset with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  """Prints Q&As from modules according to FLAGS.filter."""
  init_modules()

  text_wrapper = textwrap.TextWrapper(
      width=80, initial_indent=' ', subsequent_indent='  ')

  for regime, flat_modules in six.iteritems(filtered_modules):
    per_module = counts[regime]
    for module_name, module in six.iteritems(flat_modules):
      # These magic print constants make the header bold.
      print('\033[1m{}/{}\033[0m'.format(regime, module_name))
      num_dropped = 0
      for _ in range(per_module):
        problem, extra_dropped = sample_from_module(module)
        num_dropped += extra_dropped
        text = text_wrapper.fill(
            '{}  \033[92m{}\033[0m'.format(problem.question, problem.answer))
        print(text)
      if num_dropped > 0:
        logging.warning('Dropped %d examples', num_dropped) 
Example #9
Source File: autobuild.py    From glazier with Apache License 2.0 6 votes vote down vote up
def RunBuild(self):
    """Perform the build."""
    title.set_title()
    self._build_info.BeyondCorp()

    task_list = self._SetupTaskList()

    if not os.path.exists(task_list):
      root_path = FLAGS.config_root_path or '/'
      try:
        b = builder.ConfigBuilder(self._build_info)
        b.Start(out_file=task_list, in_path=root_path)
      except builder.ConfigBuilderError as e:
        _LogFatal(str(e))

    try:
      r = runner.ConfigRunner(self._build_info)
      r.Start(task_list=task_list)
    except runner.ConfigRunnerError as e:
      _LogFatal(str(e)) 
Example #10
Source File: generate.py    From mathematics_dataset with Apache License 2.0 6 votes vote down vote up
def _filter_and_flatten(modules_):
  """Returns flattened dict, filtered according to FLAGS."""
  flat = collections.OrderedDict()

  def add(submodules, prefix=None):
    for key, module_or_function in six.iteritems(submodules):
      full_name = prefix + '__' + key if prefix is not None else key
      if isinstance(module_or_function, dict):
        add(module_or_function, full_name)
      else:
        if FLAGS.filter not in full_name:
          continue
        flat[full_name] = module_or_function

  add(modules_)

  # Make sure list of modules are in deterministic order. This is important when
  # generating across multiple machines.
  flat = collections.OrderedDict(
      [(key, flat[key]) for key in sorted(six.iterkeys(flat))])

  return flat 
Example #11
Source File: generate_to_file.py    From mathematics_dataset with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  generate.init_modules(FLAGS.train_split)

  output_dir = os.path.expanduser(FLAGS.output_dir)
  if os.path.exists(output_dir):
    logging.fatal('output dir %s already exists', output_dir)
  logging.info('Writing to %s', output_dir)
  os.makedirs(output_dir)

  for regime, flat_modules in six.iteritems(generate.filtered_modules):
    regime_dir = os.path.join(output_dir, regime)
    os.mkdir(regime_dir)
    per_module = generate.counts[regime]
    for module_name, module in six.iteritems(flat_modules):
      path = os.path.join(regime_dir, module_name + '.txt')
      with open(path, 'w') as text_file:
        for _ in range(per_module):
          problem, _ = generate.sample_from_module(module)
          text_file.write(str(problem.question) + '\n')
          text_file.write(str(problem.answer) + '\n')
      logging.info('Written %s', path) 
Example #12
Source File: fonts-subset-support.py    From gftools with Apache License 2.0 6 votes vote down vote up
def main(argv):
  if len(argv) != 2 or not os.path.isdir(argv[1]):
    sys.exit('Must have one argument, a directory containing font files.')

  sys.stderr = open(os.devnull, 'w')
  dirpath = argv[1]
  result = True
  files = []
  for font in fonts.Metadata(dirpath).fonts:
    files.append(os.path.join(dirpath, font.filename))
  for subset in fonts.Metadata(dirpath).subsets:
    if subset == 'menu':
      continue
    (file1, file2, diff_size)  = _LeastSimilarCoverage(files, subset)
    if diff_size > FLAGS.max_diff_cps:
      print('%s coverage for %s failed' % (dirpath, subset))
      print('Difference of codepoints between %s & %s is %d' % (
          file1, file2, diff_size))
      result = False

  if result:
    print('%s passed subset coverage' % (dirpath)) 
Example #13
Source File: gftools-compare-font.py    From gftools with Apache License 2.0 6 votes vote down vote up
def CompareDirs(font1, font2):
  """Compares fonts by assuming font1/2 are dirs containing METADATA.pb."""

  m1 = fonts.Metadata(font1)
  m2 = fonts.Metadata(font2)

  subsets_to_compare = fonts.UniqueSort(m1.subsets, m2.subsets)
  subsets_to_compare.remove('menu')
  subsets_to_compare.append('all')

  font_filename1 = os.path.join(font1, fonts.RegularWeight(m1))
  font_filename2 = os.path.join(font2, fonts.RegularWeight(m2))

  if FLAGS.diff_coverage:
    print('Subset Coverage Change (codepoints)')
    for subset in subsets_to_compare:
      DiffCoverage(font_filename1, font_filename2, subset)

  print(CompareSize(font_filename1, font_filename2)) 
Example #14
Source File: simplify_nq_data.py    From natural-questions with Apache License 2.0 6 votes vote down vote up
def main(_):
  """Runs `text_utils.simplify_nq_example` over all shards of a split.

  Prints simplified examples to a single gzipped file in the same directory
  as the input shards.
  """
  split = os.path.basename(FLAGS.data_dir)
  outpath = os.path.join(FLAGS.data_dir,
                         "simplified-nq-{}.jsonl.gz".format(split))
  with gzip.open(outpath, "wb") as fout:
    num_processed = 0
    start = time.time()
    for inpath in glob.glob(os.path.join(FLAGS.data_dir, "nq-*-??.jsonl.gz")):
      print("Processing {}".format(inpath))
      with gzip.open(inpath, "rb") as fin:
        for l in fin:
          utf8_in = l.decode("utf8", "strict")
          utf8_out = json.dumps(
              text_utils.simplify_nq_example(json.loads(utf8_in))) + u"\n"
          fout.write(utf8_out.encode("utf8"))
          num_processed += 1
          if not num_processed % 100:
            print("Processed {} examples in {}.".format(num_processed,
                                                        time.time() - start)) 
Example #15
Source File: scan_agent.py    From streetlearn with Apache License 2.0 6 votes vote down vote up
def main(argv):
  config = {'width': FLAGS.width,
            'height': FLAGS.height,
            'field_of_view': FLAGS.field_of_view,
            'graph_width': FLAGS.width,
            'graph_height': FLAGS.height,
            'graph_zoom': 1,
            'full_graph': True,
            'proportion_of_panos_with_coins': 0.0,
            'action_spec': 'streetlearn_fast_rotate',
            'observations': ['view_image', 'graph_image', 'yaw']}
  with open(FLAGS.list_pano_ids_yaws, 'r') as f:
    lines = f.readlines()
    pano_ids_yaws = [(line.split('\t')[0], float(line.split('\t')[1]))
                     for line in lines]
  config = default_config.ApplyDefaults(config)
  game = coin_game.CoinGame(config)
  env = streetlearn.StreetLearn(FLAGS.dataset_path, config, game)
  env.reset()
  pygame.init()
  screen = pygame.display.set_mode((FLAGS.width, FLAGS.height * 2))
  loop(env, screen, pano_ids_yaws) 
Example #16
Source File: scan_agent.py    From streetlearn with Apache License 2.0 6 votes vote down vote up
def loop(env, screen, pano_ids_yaws):
  """Main loop of the scan agent."""
  for (pano_id, yaw) in pano_ids_yaws:

    # Retrieve the observation at a specified pano ID and heading.
    logging.info('Retrieving view at pano ID %s and yaw %f', pano_id, yaw)
    observation = env.goto(pano_id, yaw)

    current_yaw = observation["yaw"]
    view_image = interleave(observation["view_image"],
                            FLAGS.width, FLAGS.height)
    graph_image = interleave(observation["graph_image"],
                             FLAGS.width, FLAGS.height)
    screen_buffer = np.concatenate((view_image, graph_image), axis=1)
    pygame.surfarray.blit_array(screen, screen_buffer)
    pygame.display.update()

    for event in pygame.event.get():
      if (event.type == pygame.QUIT or
          (event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE)):
        return
    if FLAGS.save_images:
      filename = 'scan_agent_{}_{}.bmp'.format(pano_id, yaw)
      pygame.image.save(screen, filename) 
Example #17
Source File: app_constants.py    From loaner with Apache License 2.0 6 votes vote down vote up
def _get_all_constants(module=__name__, func=None):
  """Returns a dictionary of all constants.

  This function will return all of the flags configured above as `Constant`
  objects. By default, the default value of the flag will be used.

  Args:
    module: str, the name of the module to get the constants from.
    func: Callable, a function that returns the value of each constant given the
        name of the flag.

  Returns:
    A dictionary of all key flags in this module represented as Constants,
        keyed by the name of the constant.
  """
  constants = {}

  for flag in FLAGS.get_key_flags_for_module(sys.modules[module]):
    value = FLAGS[flag.name].default
    if func:
      value = func(flag.name)
    constants[flag.name] = Constant(
        flag.name, flag.help, value, _PARSERS.get(flag.name))
  return constants 
Example #18
Source File: gng_impl.py    From loaner with Apache License 2.0 6 votes vote down vote up
def main(argv):
  del argv  # Unused.
  utils.clear_screen()
  utils.write('Welcome to the Grab n Go management script!\n')

  try:
    _Manager.new(
        FLAGS.config_file_path,
        FLAGS.prefer_gcs,
        project_key=FLAGS.project,
        version=FLAGS.app_version,
    ).run()
  except KeyboardInterrupt as err:
    logging.error('Manager received CTRL-C, exiting: %s', err)
    exit_code = 1
  else:
    exit_code = 0

  sys.exit(exit_code) 
Example #19
Source File: build_docs.py    From lattice with Apache License 2.0 6 votes vote down vote up
def main(_):
  private_map = {
      'tfl': ['python'],
      'tfl.aggregation_layer': ['Aggregation'],
      'tfl.categorical_calibration_layer': ['CategoricalCalibration'],
      'tfl.lattice_layer': ['Lattice'],
      'tfl.linear_layer': ['Linear'],
      'tfl.pwl_calibration_layer': ['PWLCalibration'],
      'tfl.parallel_combination_layer': ['ParallelCombination'],
      'tfl.rtl_layer': ['RTL'],
  }
  doc_generator = generate_lib.DocGenerator(
      root_title='TensorFlow Lattice 2.0',
      py_modules=[('tfl', tfl)],
      base_dir=os.path.dirname(tfl.__file__),
      code_url_prefix=FLAGS.code_url_prefix,
      search_hints=FLAGS.search_hints,
      site_path=FLAGS.site_path,
      private_map=private_map,
      callbacks=[local_definitions_filter])

  sys.exit(doc_generator.build(output_dir=FLAGS.output_dir)) 
Example #20
Source File: gftools-sanity-check.py    From gftools with Apache License 2.0 6 votes vote down vote up
def _SanityCheck(path):
  """Runs various sanity checks on the font family under path.

  Args:
    path: A directory containing a METADATA.pb file.
  Returns:
    A list of ResultMessageTuple's.
  """
  try:
    fonts.Metadata(path)
  except ValueError as e:
    return [_SadResult('Bad METADATA.pb: ' + e.message, path)]

  results = []
  if FLAGS.check_metadata:
    results.extend(_CheckLicense(path))
    results.extend(_CheckNameMatching(path))

  if FLAGS.check_font:
    results.extend(_CheckFontInternalValues(path))

  return results 
Example #21
Source File: gftools-sanity-check.py    From gftools with Apache License 2.0 6 votes vote down vote up
def main(argv):
  result_code = 0
  all_results = []
  paths = [_DropEmptyPathSegments(os.path.expanduser(p)) for p in argv[1:]]
  for path in paths:
    if not os.path.isdir(path):
      raise ValueError('Not a directory: %s' % path)

  for path in paths:
    for font_dir in fonts.FontDirs(path):
      results = _SanityCheck(font_dir)
      all_results.extend(results)
      for result in results:
        result_msg = 'pass'
        if not result.happy:
          result_code = 1
          result_msg = 'FAIL'
        if not result.happy or not FLAGS.suppress_pass:
          print('%s: %s (%s)' % (result_msg, result.message, font_dir))

  if FLAGS.repair_script:
    _WriteRepairScript(FLAGS.repair_script, all_results)

  sys.exit(result_code) 
Example #22
Source File: gftools-ttf2cp.py    From gftools with Apache License 2.0 6 votes vote down vote up
def main(argv):
  if len(argv) < 2:
    sys.exit('Must specify one or more font files.')

  cps = set()
  for filename in argv[1:]:
    if not os.path.isfile(filename):
      sys.exit('%s is not a file' % filename)
    cps |= fonts.CodepointsInFont(filename)

  for cp in sorted(cps):
    show_char = ''
    if FLAGS.show_char:
      show_char = (' ' + unichr(cp).strip() + ' ' +
                   unicodedata.name(unichr(cp), ''))
    show_subset = ''
    if FLAGS.show_subsets:
      show_subset = ' subset:%s' % ','.join(fonts.SubsetsForCodepoint(cp))

    print(u'0x%04X%s%s' % (cp, show_char, show_subset)) 
Example #23
Source File: download.py    From glazier with Apache License 2.0 6 votes vote down vote up
def _SetUrl(self, url: Text):
    """Simple helper function to determine signed URL.

    Args:
      url: the url we want to download from.

    Returns:
      A string with the applicable URLs

    Raises:
      DownloadError: Failed to obtain SignedURL.
    """
    if not FLAGS.use_signed_url:
      return url
    config_server = '%s%s' % (FLAGS.config_server, '/')
    try:
      return self._beyondcorp.GetSignedUrl(
          url[url.startswith(config_server) and len(config_server):])
    except beyondcorp.BCError as e:
      raise DownloadError(e) 
Example #24
Source File: app_constants.py    From loaner with Apache License 2.0 5 votes vote down vote up
def get_constants_from_flags(module=__name__):
  """Returns a dictionary of all constants from flags.

  This should only be used when skipping user validation (e.g. scripting) since
  it does not validate the provided values with the custom parsers until the
  value is requested. If the flag provided does not meet the `Parser`
  requirements an error will be raised when attempting to retrieve the value.

  Args:
    module: str, the name of the module to get the constants from.

  Returns:
    A dictionary of all constants with the flag value as the constant value.
        The key for each constant is the name of the constant.

  Raises:
    ValueError: when any of the flag values does not meet the parsing
        requirements.
  """
  def _from_flag(name):
    """Gets the value of a flag given the name.

    If flags have not been parsed, the default value will be used.

    Args:
      name: str, the name of the flag.

    Returns:
      The value of the flag.
    """
    if FLAGS.is_parsed():
      return getattr(FLAGS, name)
    return FLAGS[name].default
  return _get_all_constants(module=module, func=_from_flag) 
Example #25
Source File: auth.py    From loaner with Apache License 2.0 5 votes vote down vote up
def _run_local_web_server_for_auth():
  """Whether or not to run the local web server for OAuth2.

  Returns:
    A bool, True when the web server should run, otherwise False.
  """
  if FLAGS.is_parsed():
    return FLAGS.automatic_oauth
  return FLAGS['automatic_oauth'].default 
Example #26
Source File: auth_test.py    From loaner with Apache License 2.0 5 votes vote down vote up
def test_run_local_web_server_for_auth(self):
    """Test whether or not to run the local web server for authentication."""
    FLAGS.unparse_flags()
    self.assertFalse(auth._run_local_web_server_for_auth())
    flags.FLAGS(sys.argv[:1] + ['--automatic_oauth'])
    FLAGS.mark_as_parsed()
    self.assertTrue(auth._run_local_web_server_for_auth()) 
Example #27
Source File: auth_test.py    From loaner with Apache License 2.0 5 votes vote down vote up
def test_cloud_credentials_constructor_no_local_file(
      self, expected_redirect_uri, run_web_server, mock_run_flow,
      mock_server_flow):
    """Test the creation of the CloudCredentials object with no local creds."""
    FLAGS.automatic_oauth = run_web_server
    mock_run_flow.return_value = oauth2_client.OAuth2Credentials(
        access_token='test_access_token',
        client_id=self._test_config.client_id,
        client_secret=self._test_config.client_secret,
        refresh_token='test_refresh_token',
        token_expiry=datetime.datetime(year=2018, month=1, day=1),
        token_uri='test_token_uri',
        user_agent=None,
        id_token='test_id_token',
        scopes=['test_scope1'])
    test_creds = auth.CloudCredentials(self._test_config, ['test_scope1'])
    self.assertEqual(self._test_config, test_creds._config)
    self.assertEqual('test_access_token', test_creds._credentials.token)
    self.assertEqual(
        'test_refresh_token', test_creds._credentials.refresh_token)
    self.assertEqual('test_id_token', test_creds._credentials.id_token)
    self.assertEqual('test_token_uri', test_creds._credentials.token_uri)
    self.assertEqual(
        self._test_config.client_id, test_creds._credentials.client_id)
    self.assertEqual(
        self._test_config.client_secret, test_creds._credentials.client_secret)
    self.assertEqual(['test_scope1'], test_creds._credentials.scopes)
    mock_server_flow.assert_called_once_with(
        client_id=self._test_config.client_id,
        client_secret=self._test_config.client_secret,
        scope=['test_scope1'],
        redirect_uri=expected_redirect_uri) 
Example #28
Source File: train.py    From ffn with Apache License 2.0 5 votes vote down vote up
def train_eval_size(model):
  return (np.array(model.pred_mask_size) +
          np.array(model.deltas) * 2 * FLAGS.fov_moves) 
Example #29
Source File: train.py    From ffn with Apache License 2.0 5 votes vote down vote up
def _get_reflectable_axes():
  return [int(x) + 1 for x in FLAGS.reflectable_axes] 
Example #30
Source File: train.py    From ffn with Apache License 2.0 5 votes vote down vote up
def _get_offset_and_scale_map():
  if not FLAGS.image_offset_scale_map:
    return {}

  ret = {}
  for vol_def in FLAGS.image_offset_scale_map:
    vol_name, offset, scale = vol_def.split(':')
    ret[vol_name] = float(offset), float(scale)

  return ret