Python gin.query_parameter() Examples

The following are 3 code examples of gin.query_parameter(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gin , or try the search function .
Example #1
Source File: trainer.py    From BERT with Apache License 2.0 5 votes vote down vote up
def _default_output_dir():
  """Default output directory."""
  try:
    dataset_name = gin.query_parameter("inputs.dataset_name")
  except ValueError:
    dataset_name = "random"
  dir_name = "{model_name}_{dataset_name}_{timestamp}".format(
      model_name=gin.query_parameter("train.model").configurable.name,
      dataset_name=dataset_name,
      timestamp=datetime.datetime.now().strftime("%Y%m%d_%H%M"),
  )
  dir_path = os.path.join("~", "trax", dir_name)
  print()
  trax.log("No --output_dir specified")
  return dir_path 
Example #2
Source File: trainer.py    From trax with Apache License 2.0 5 votes vote down vote up
def _output_dir_or_default():
  """Returns a path to the output directory."""
  if FLAGS.output_dir:
    output_dir = FLAGS.output_dir
    trainer_lib.log('Using --output_dir {}'.format(output_dir))
    return os.path.expanduser(output_dir)

  # Else, generate a default output dir (under the user's home directory).
  try:
    dataset_name = gin.query_parameter('data_streams.dataset_name')
  except ValueError:
    dataset_name = 'random'
  output_name = '{model_name}_{dataset_name}_{timestamp}'.format(
      model_name=gin.query_parameter('train.model').configurable.name,
      dataset_name=dataset_name,
      timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'),
  )
  output_dir = os.path.join('~', 'trax', output_name)
  output_dir = os.path.expanduser(output_dir)
  print()
  trainer_lib.log('No --output_dir specified')
  trainer_lib.log('Using default output_dir: {}'.format(output_dir))
  return output_dir


# TODO(afrozm): Share between trainer.py and rl_trainer.py 
Example #3
Source File: run.py    From reaver with MIT License 4 votes vote down vote up
def main(argv):
    tf.disable_eager_execution()
    tf.disable_v2_behavior()

    args = flags.FLAGS
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    if args.env in rvr.utils.config.SC2_MINIGAMES_ALIASES:
        args.env = rvr.utils.config.SC2_MINIGAMES_ALIASES[args.env]

    if args.test:
        args.n_envs = 1
        args.log_freq = 1
        args.restore = True

    expt = rvr.utils.Experiment(args.results_dir, args.env, args.agent, args.experiment, args.restore)

    gin_files = rvr.utils.find_configs(args.env, os.path.dirname(os.path.abspath(__file__)))
    if args.restore:
        gin_files += [expt.config_path]
    gin_files += args.gin_files

    if not args.gpu:
        args.gin_bindings.append("build_cnn_nature.data_format = 'channels_last'")
        args.gin_bindings.append("build_fully_conv.data_format = 'channels_last'")

    gin.parse_config_files_and_bindings(gin_files, args.gin_bindings)
    args.n_envs = min(args.n_envs, gin.query_parameter('ACAgent.batch_sz'))

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    sess_mgr = rvr.utils.tensorflow.SessionManager(sess, expt.path, args.ckpt_freq, training_enabled=not args.test)

    env_cls = rvr.envs.GymEnv if '-v' in args.env else rvr.envs.SC2Env
    env = env_cls(args.env, args.render, max_ep_len=args.max_ep_len)

    agent = rvr.agents.registry[args.agent](env.obs_spec(), env.act_spec(), sess_mgr=sess_mgr, n_envs=args.n_envs)
    agent.logger = rvr.utils.StreamLogger(args.n_envs, args.log_freq, args.log_eps_avg, sess_mgr, expt.log_path)

    if sess_mgr.training_enabled:
        expt.save_gin_config()
        expt.save_model_summary(agent.model)

    agent.run(env, args.n_updates * agent.traj_len * agent.batch_sz // args.n_envs)