Python tensorflow.flags() Examples

The following are 30 code examples of tensorflow.flags(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: sample.py    From IMPLEMENTATION_Variational-Auto-Encoder with MIT License 7 votes vote down vote up
def main():
    flags = tf.flags
    flags.DEFINE_integer("latent_dim", 64, "Dimension of latent space.")
    flags.DEFINE_integer("obs_dim", 12288, "Dimension of observation space.")
    flags.DEFINE_integer("batch_size", 60, "Batch size.")
    flags.DEFINE_integer("epochs", 500, "As it said")
    flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.")
    FLAGS = flags.FLAGS

    kwargs = {
        'latent_dim': FLAGS.latent_dim,
        'observation_dim': FLAGS.obs_dim,
        'generator': conv_anime_decoder,
        'obs_distrib': 'Gaussian'
    }
    g = GENERATOR(**kwargs)
    g.load_pretrained("weights/vae_anime/generator")

    z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim])
    samples = g.e2x(z)
    print samples.shape
    show_samples(samples, 4, 15, [64, 64, 3], name='small_samples', shift=True) 
Example #2
Source File: config.py    From Question_Answering_Models with MIT License 6 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode, you must choose mode from [train/prepro/debug/test]")
        exit(0) 
Example #3
Source File: train.py    From text-gan-tensorflow with MIT License 6 votes vote down vote up
def get_supervisor(model):
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.model_dir)

    supervisor = tf.train.Supervisor(
        logdir=FLAGS.model_dir,
        is_chief=True,
        saver=saver,
        init_op=set_initial_ops(),
        summary_op=tf.summary.merge_all(),
        summary_writer=summary_writer,
        save_summaries_secs=100,  # TODO: add as flags
        save_model_secs=1000,
        global_step=model.global_step,
    )

    return supervisor 
Example #4
Source File: config.py    From QANet with MIT License 6 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    elif config.mode == "demo":
        demo(config)
    else:
        print("Unknown mode")
        exit(0) 
Example #5
Source File: config.py    From AIchallenger2018_MachineReadingComprehension with MIT License 6 votes vote down vote up
def main(_):
    config = flags.FLAGS
    os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu  # 选择一块gpu
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        data_process.prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    elif config.mode == "examine":
        examine_dev(config)
    elif config.mode == "save_dev":
        save_dev(config)
    elif config.mode == "save_test":
        save_test(config)
    else:
        print("Unknown mode")
        exit(0) 
Example #6
Source File: config.py    From AIchallenger2018_MachineReadingComprehension with MIT License 6 votes vote down vote up
def main(_):
    config = flags.FLAGS
    os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu  # 选择一块gpu
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        data_process_addAnswer.prepro(config)
    elif config.mode == "test":
        test(config)
    elif config.mode == "examine":
        examine_dev(config)
    elif config.mode == "save_dev":
        save_dev(config)
    elif config.mode == "save_test":
        save_test(config)
    else:
        print("Unknown mode")
        exit(0) 
Example #7
Source File: config.py    From Question_Answering_Models with MIT License 6 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode, you must choose mode from [train/prepro/debug/test]")
        exit(0) 
Example #8
Source File: config.py    From Question_Answering_Models with MIT License 6 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode, you must choose mode from [train/prepro/debug/test]")
        exit(0) 
Example #9
Source File: config.py    From AmusingPythonCodes with MIT License 6 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        if config.use_cudnn:
            print("Warning: Due to a known bug in Tensorlfow, the parameters of CudnnGRU may not be properly restored.")
        test(config)
    else:
        print("Unknown mode")
        exit(0) 
Example #10
Source File: config.py    From R-Net with MIT License 6 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode")
        exit(0) 
Example #11
Source File: config.py    From QGforQA with MIT License 6 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "get_vocab":
        get_vocab(config)
    elif config.mode == "prepare":
        prepare(config)
    elif config.mode == "train":
        train(config)
    elif config.mode == "train_rl":
        train_rl(config)
    elif config.mode == "train_qpp":
        train_qpp(config)
    elif config.mode == "train_qap":
        train_qap(config)
    elif config.mode == "train_qqp_qap":
        train_qqp_qap(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode")
        exit(0) 
Example #12
Source File: vae_train_anime.py    From IMPLEMENTATION_Variational-Auto-Encoder with MIT License 5 votes vote down vote up
def main():
    flags = tf.flags
    flags.DEFINE_integer("latent_dim", 64, "Dimension of latent space.")
    flags.DEFINE_integer("obs_dim", 12288, "Dimension of observation space.")
    flags.DEFINE_integer("batch_size", 64, "Batch size.")
    flags.DEFINE_integer("epochs", 500, "As it said")
    flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.")
    FLAGS = flags.FLAGS

    kwargs = {
        'latent_dim': FLAGS.latent_dim,
        'batch_size': FLAGS.batch_size,
        'observation_dim': FLAGS.obs_dim,
        'encoder': conv_anime_encoder,
        'decoder': conv_anime_decoder,
        'observation_distribution': 'Gaussian'
    }
    vae = VAE(**kwargs)
    provider = Anime()
    tbar = tqdm(range(FLAGS.epochs))
    for epoch in tbar:
        training_loss = 0.

        for _ in range(FLAGS.updates_per_epoch):
            x = provider.next_batch(FLAGS.batch_size)
            loss = vae.update(x)
            training_loss += loss

        training_loss /= FLAGS.updates_per_epoch
        s = "Loss: {:.4f}".format(training_loss)
        tbar.set_description(s)

    z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim])
    samples = vae.z2x(z)[0]
    show_samples(samples, 8, 8, [64, 64, 3], name='samples')

    vae.save_generator('weights/vae_anime/generator') 
Example #13
Source File: pythonLanguageModel.py    From pycodesuggest with MIT License 5 votes vote down vote up
def print_flags(flags):
    for flag in flags.__flags:
        val = getattr(flags, flag)
        if not isinstance(val, bool) or val:
            print("%s=%s" % (flag, val))
    print()
    print() 
Example #14
Source File: vae_train.py    From IMPLEMENTATION_Variational-Auto-Encoder with MIT License 5 votes vote down vote up
def main():
    flags = tf.flags
    flags.DEFINE_integer("latent_dim", 2, "Dimension of latent space.")
    flags.DEFINE_integer("batch_size", 128, "Batch size.")
    flags.DEFINE_integer("epochs", 500, "As it said")
    flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.")
    flags.DEFINE_string("data_dir", 'mnist', "Tensorflow demo data download position.")
    FLAGS = flags.FLAGS

    kwargs = {
        'latent_dim': FLAGS.latent_dim,
        'batch_size': FLAGS.batch_size,
        'encoder': fc_mnist_encoder,
        'decoder': fc_mnist_decoder
    }
    vae = VAE(**kwargs)
    mnist = input_data.read_data_sets(train_dir=FLAGS.data_dir)
    tbar = tqdm(range(FLAGS.epochs))
    for epoch in tbar:
        training_loss = 0.

        for _ in range(FLAGS.updates_per_epoch):
            x, _ = mnist.train.next_batch(FLAGS.batch_size)
            loss = vae.update(x)
            training_loss += loss

        training_loss /= FLAGS.updates_per_epoch
        s = "Loss: {:.4f}".format(training_loss)
        tbar.set_description(s)

    z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim])
    samples = vae.z2x(z)[0]
    show_samples(samples, 10, 10, [28, 28], name='samples')
    show_latent_scatter(vae, mnist, name='latent')

    vae.save_generator('weights/vae_mnist/generator') 
Example #15
Source File: config.py    From QGforQA with MIT License 5 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "train_for_qg":
        train_for_qg(config)
    elif config.mode == "test_qa_for_qg":
        test_qa_for_qg(config)
    else:
        print("Unknown mode")
        exit(0) 
Example #16
Source File: config.py    From QGforQA with MIT License 5 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "prepare":
        prepare(config)
    elif config.mode == "test":
        test(config)
    elif config.mode == "train":
        train(config)
    else:
        print("Unknown mode")
        exit(0) 
Example #17
Source File: train.py    From HMEAE with MIT License 5 votes vote down vote up
def main(_):
    config = flags.FLAGS
    os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu
    extractor = utils.Extractor()
    extractor.Extract()
    loader = utils.Loader()
    t_data = loader.load_trigger()
    a_data = loader.load_argument()
    trigger = DMCNN(t_data,a_data,loader.maxlen,loader.max_argument_len,loader.wordemb)
    a_data_process = trigger.train_trigger()
    argument = DMCNN(t_data,a_data_process,loader.maxlen,loader.max_argument_len,loader.wordemb,stage=config.mode,classify=config.classify)
    argument.train_argument() 
Example #18
Source File: inference_demo.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def _validate_flags():
  flags.register_validator('checkpoint_path', bool,
                           'Must provide `checkpoint_path`.')
  flags.register_validator(
      'generated_x_dir',
      lambda x: False if (FLAGS.image_set_y_glob and not x) else True,
      'Must provide `generated_x_dir`.')
  flags.register_validator(
      'generated_y_dir',
      lambda x: False if (FLAGS.image_set_x_glob and not x) else True,
      'Must provide `generated_y_dir`.') 
Example #19
Source File: train.py    From text-gan-tensorflow with MIT License 5 votes vote down vote up
def get_sess_config():
    # gpu_options = tf.GPUOptions(
    # per_process_gpu_memory_fraction=self.gpu_memory_fraction,
    # allow_growth=True) # seems to be not working

    sess_config = tf.ConfigProto(
        # log_device_placement=True,
        inter_op_parallelism_threads=8,  # TODO: add as flags
        # allow_soft_placement=True,
        # gpu_options=gpu_options)
    )

    return sess_config 
Example #20
Source File: config.py    From QGforQA with MIT License 5 votes vote down vote up
def main(_):
    config = flags.FLAGS
    if config.mode == "prepare":
        prepare(config)
    elif config.mode == "train":
        train(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode")
        exit(0) 
Example #21
Source File: tf_t2t.py    From sgnmt with Apache License 2.0 5 votes vote down vote up
def vocab_size(self):
            return self._vocab_size

    # Define flags from the t2t binaries 
Example #22
Source File: query.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def validate_flags():
  """Validates flags are set to acceptable values."""
  if FLAGS.cloud_mlengine_model_name:
    assert not FLAGS.server
    assert not FLAGS.servable_name
  else:
    assert FLAGS.server
    assert FLAGS.servable_name 
Example #23
Source File: t2t_trainer.py    From fine-lm with MIT License 5 votes vote down vote up
def save_metadata(hparams):
  """Saves FLAGS and hparams to output_dir."""
  output_dir = os.path.expanduser(FLAGS.output_dir)
  if not tf.gfile.Exists(output_dir):
    tf.gfile.MakeDirs(output_dir)

  # Save FLAGS in txt file
  if hasattr(FLAGS, "flags_into_string"):
    flags_str = FLAGS.flags_into_string()
    t2t_flags_str = "\n".join([
        "--%s=%s" % (f.name, f.value)
        for f in FLAGS.flags_by_module_dict()["tensor2tensor.utils.flags"]
    ])
  else:
    flags_dict = FLAGS.__dict__["__flags"]
    flags_str = "\n".join(
        ["--%s=%s" % (name, str(f)) for (name, f) in flags_dict.items()])
    t2t_flags_str = None

  flags_txt = os.path.join(output_dir, "flags.txt")
  with tf.gfile.Open(flags_txt, "w") as f:
    f.write(flags_str)

  if t2t_flags_str:
    t2t_flags_txt = os.path.join(output_dir, "flags_t2t.txt")
    with tf.gfile.Open(t2t_flags_txt, "w") as f:
      f.write(t2t_flags_str)

  # Save hparams as hparams.json
  hparams_fname = os.path.join(output_dir, "hparams.json")
  with tf.gfile.Open(hparams_fname, "w") as f:
    f.write(hparams.to_json(indent=0, sort_keys=True)) 
Example #24
Source File: transformer_model.py    From fine-lm with MIT License 5 votes vote down vote up
def __init__(self, processor_configuration):
    """Creates the Transformer estimator.

    Args:
      processor_configuration: A ProcessorConfiguration protobuffer with the
        transformer fields populated.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    transformer_config = processor_configuration["transformer"]
    FLAGS.output_dir = transformer_config["model_dir"]
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(transformer_config["data_dir"])

    # Create the basic hyper parameters.
    self.hparams = trainer_lib.create_hparams(
        transformer_config["hparams_set"],
        transformer_config["hparams"],
        data_dir=data_dir,
        problem_name=transformer_config["problem"])

    decode_hp = decoding.decode_hparams()
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = trainer_lib.create_estimator(
        transformer_config["model"],
        self.hparams,
        t2t_trainer.create_run_config(self.hparams),
        decode_hparams=decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
    self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir) 
Example #25
Source File: query.py    From fine-lm with MIT License 5 votes vote down vote up
def validate_flags():
  """Validates flags are set to acceptable values."""
  if FLAGS.cloud_mlengine_model_name:
    assert not FLAGS.server
    assert not FLAGS.servable_name
  else:
    assert FLAGS.server
    assert FLAGS.servable_name 
Example #26
Source File: t2t_translate_all.py    From BERT with Apache License 2.0 5 votes vote down vote up
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  # pylint: disable=unused-variable
  model_dir = os.path.expanduser(FLAGS.model_dir)
  translations_dir = os.path.expanduser(FLAGS.translations_dir)
  source = os.path.expanduser(FLAGS.source)
  tf.gfile.MakeDirs(translations_dir)
  translated_base_file = os.path.join(translations_dir, FLAGS.problem)

  # Copy flags.txt with the original time, so t2t-bleu can report correct
  # relative time.
  flags_path = os.path.join(translations_dir, FLAGS.problem + "-flags.txt")
  if not os.path.exists(flags_path):
    shutil.copy2(os.path.join(model_dir, "flags.txt"), flags_path)

  locals_and_flags = {"FLAGS": FLAGS}
  for model in bleu_hook.stepfiles_iterator(model_dir, FLAGS.wait_minutes,
                                            FLAGS.min_steps):
    tf.logging.info("Translating " + model.filename)
    out_file = translated_base_file + "-" + str(model.steps)
    locals_and_flags.update(locals())
    if os.path.exists(out_file):
      tf.logging.info(out_file + " already exists, so skipping it.")
    else:
      tf.logging.info("Translating " + out_file)
      params = (
          "--t2t_usr_dir={FLAGS.t2t_usr_dir} --output_dir={model_dir} "
          "--data_dir={FLAGS.data_dir} --problem={FLAGS.problem} "
          "--decode_hparams=beam_size={FLAGS.beam_size},alpha={FLAGS.alpha} "
          "--model={FLAGS.model} --hparams_set={FLAGS.hparams_set} "
          "--checkpoint_path={model.filename} --decode_from_file={source} "
          "--decode_to_file={out_file} --keep_timestamp"
      ).format(**locals_and_flags)
      command = FLAGS.decoder_command.format(**locals())
      tf.logging.info("Running:\n" + command)
      os.system(command)
  # pylint: enable=unused-variable 
Example #27
Source File: t2t_trainer.py    From BERT with Apache License 2.0 5 votes vote down vote up
def save_metadata(hparams):
  """Saves FLAGS and hparams to output_dir."""
  output_dir = os.path.expanduser(FLAGS.output_dir)
  if not tf.gfile.Exists(output_dir):
    tf.gfile.MakeDirs(output_dir)

  # Save FLAGS in txt file
  if hasattr(FLAGS, "flags_into_string"):
    flags_str = FLAGS.flags_into_string()
    t2t_flags_str = "\n".join([
        "--%s=%s" % (f.name, f.value)
        for f in FLAGS.flags_by_module_dict()["tensor2tensor.utils.flags"]
    ])
  else:
    flags_dict = FLAGS.__dict__["__flags"]
    flags_str = "\n".join(
        ["--%s=%s" % (name, str(f)) for (name, f) in flags_dict.items()])
    t2t_flags_str = None

  flags_txt = os.path.join(output_dir, "flags.txt")
  with tf.gfile.Open(flags_txt, "w") as f:
    f.write(flags_str)

  if t2t_flags_str:
    t2t_flags_txt = os.path.join(output_dir, "flags_t2t.txt")
    with tf.gfile.Open(t2t_flags_txt, "w") as f:
      f.write(t2t_flags_str)

  # Save hparams as hparams.json
  new_hparams = hparams_lib.copy_hparams(hparams)
  # Modality class is not JSON serializable so remove.
  new_hparams.del_hparam("modality")

  hparams_fname = os.path.join(output_dir, "hparams.json")
  with tf.gfile.Open(hparams_fname, "w") as f:
    f.write(new_hparams.to_json(indent=0, sort_keys=True)) 
Example #28
Source File: t2t_eval.py    From BERT with Apache License 2.0 5 votes vote down vote up
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  hparams = trainer_lib.create_hparams(
      FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir,
      problem_name=FLAGS.problem)

  # set appropriate dataset-split, if flags.eval_use_test_set.
  dataset_split = "test" if FLAGS.eval_use_test_set else None
  dataset_kwargs = {"dataset_split": dataset_split}
  eval_input_fn = hparams.problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs)
  config = t2t_trainer.create_run_config(hparams)

  # summary-hook in tf.estimator.EstimatorSpec requires
  # hparams.model_dir to be set.
  hparams.add_hparam("model_dir", config.model_dir)

  estimator = trainer_lib.create_estimator(
      FLAGS.model, hparams, config, use_tpu=FLAGS.use_tpu)
  ckpt_iter = trainer_lib.next_checkpoint(
      hparams.model_dir, FLAGS.eval_timeout_mins)
  for ckpt_path in ckpt_iter:
    predictions = estimator.evaluate(
        eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path)
    tf.logging.info(predictions) 
Example #29
Source File: transformer_model.py    From BERT with Apache License 2.0 5 votes vote down vote up
def __init__(self, processor_configuration):
    """Creates the Transformer estimator.

    Args:
      processor_configuration: A ProcessorConfiguration protobuffer with the
        transformer fields populated.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    transformer_config = processor_configuration["transformer"]
    FLAGS.output_dir = transformer_config["model_dir"]
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(transformer_config["data_dir"])

    # Create the basic hyper parameters.
    self.hparams = trainer_lib.create_hparams(
        transformer_config["hparams_set"],
        transformer_config["hparams"],
        data_dir=data_dir,
        problem_name=transformer_config["problem"])

    decode_hp = decoding.decode_hparams()
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = trainer_lib.create_estimator(
        transformer_config["model"],
        self.hparams,
        t2t_trainer.create_run_config(self.hparams),
        decode_hparams=decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
    self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir) 
Example #30
Source File: query.py    From BERT with Apache License 2.0 5 votes vote down vote up
def validate_flags():
  """Validates flags are set to acceptable values."""
  if FLAGS.cloud_mlengine_model_name:
    assert not FLAGS.server
    assert not FLAGS.servable_name
  else:
    assert FLAGS.server
    assert FLAGS.servable_name