Python gin.clear_config() Examples

The following are 26 code examples of gin.clear_config(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gin , or try the search function .
Example #1
Source File: data_loading_test.py    From rl-reliability-metrics with Apache License 2.0 6 votes vote down vote up
def setUp(self):
    super(DataLoadingTest, self).setUp()

    gin.clear_config()
    gin_file = os.path.join(
        './',
        'rl_reliability_metrics/evaluation',
        'eval_metrics_test.gin')
    gin.parse_config_file(gin_file)

    # fake set of training curves to test analysis
    test_data_dir = os.path.join(
        './',
        'rl_reliability_metrics/evaluation/test_data')
    self.run_dirs = [
        os.path.join(test_data_dir, 'run%d' % i, 'train') for i in range(3)
    ] 
Example #2
Source File: eval_metrics_test.py    From rl-reliability-metrics with Apache License 2.0 6 votes vote down vote up
def setUp(self):
    super(EvalMetricsTest, self).setUp()

    gin.clear_config()
    gin_file = os.path.join(
        './',
        'rl_reliability_metrics/evaluation',
        'eval_metrics_test.gin')
    gin.parse_config_file(gin_file)

    # fake set of training curves to test analysis
    self.test_data_dir = os.path.join(
        './',
        'rl_reliability_metrics/evaluation/test_data')
    self.run_dirs = [
        os.path.join(self.test_data_dir, 'run%d' % i, 'train') for i in range(3)
    ] 
Example #3
Source File: inputs_test.py    From BERT with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    gin.clear_config() 
Example #4
Source File: test_utils.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(CompareGanTestCase, self).setUp()
    # Use fake datasets instead of reading real files.
    FLAGS.data_fake_dataset = True
    # Clear the gin cofiguration.
    gin.clear_config()
    # Mock the inception graph.
    fake_inception_graph = create_fake_inception_graph()
    self.inception_graph_def_mock = mock.patch.object(
        eval_utils,
        "get_inception_graph_def",
        return_value=fake_inception_graph).start() 
Example #5
Source File: resnet_init_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(ResNetInitTest, self).setUp()
    gin.clear_config() 
Example #6
Source File: resnet_biggan_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(ResNet5BigGanTest, self).setUp()
    gin.clear_config() 
Example #7
Source File: eval_gan_lib_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(EvalGanLibTest, self).setUp()
    gin.clear_config()
    FLAGS.data_fake_dataset = True
    self.mock_get_graph = mock.patch.object(
        eval_utils, "get_inception_graph_def").start()
    self.mock_get_graph.return_value = create_fake_inception_graph() 
Example #8
Source File: heteroscedastic_q_network_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(SingleObservationSingleActionTest, self).setUp()
    gin.clear_config() 
Example #9
Source File: q_network_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(SingleObservationSingleActionTest, self).setUp()
    gin.clear_config() 
Example #10
Source File: categorical_q_network_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def tearDown(self):
    gin.clear_config()
    super(CategoricalQNetworkTest, self).tearDown() 
Example #11
Source File: test_utils.py    From agents with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(TestCase, self).setUp()
    tf.compat.v1.enable_resource_variables()
    # Guard against tests calling gin.parse_config() without calling
    # gin.clear_config(), which can cause nasty bugs that show up in a
    # completely different test. See b/139088071 for example.
    gin.clear_config() 
Example #12
Source File: suite_bsuite_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def tearDown(self):
    gin.clear_config()
    super(SuiteBsuiteTest, self).tearDown() 
Example #13
Source File: suite_mujoco_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def tearDown(self):
    gin.clear_config()
    super(SuiteMujocoTest, self).tearDown() 
Example #14
Source File: suite_gym_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def tearDown(self):
    gin.clear_config()
    super(SuiteGymTest, self).tearDown() 
Example #15
Source File: suite_pybullet_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def tearDown(self):
    gin.clear_config()
    super(SuitePybulletTest, self).tearDown() 
Example #16
Source File: checkpoint_predictor_test.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(CheckpointPredictorTest, self).setUp()
    gin.clear_config()
    gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1') 
Example #17
Source File: exported_savedmodel_predictor_test.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(ExportedSavedmodelPredictorTest, self).setUp()
    gin.clear_config()
    gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1') 
Example #18
Source File: train_eval_test.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def tearDown(self):
    gin.clear_config()
    super(TrainEvalTest, self).tearDown() 
Example #19
Source File: runner_lib_test.py    From ml-fairness-gym with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(RunnerLibTest, self).setUp()
    gin.clear_config() 
Example #20
Source File: reformer_test.py    From trax with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super().setUp()
    gin.clear_config() 
Example #21
Source File: reformer_oom_test.py    From trax with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super().setUp()
    gin.clear_config() 
Example #22
Source File: reformer_e2e_test.py    From trax with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super().setUp()
    gin.clear_config()
    gin.add_config_file_search_path(_CONFIG_DIR) 
Example #23
Source File: tf_inputs_test.py    From trax with Apache License 2.0 5 votes vote down vote up
def test_get_t5_preprocessor_by_name(self):
    gin.clear_config()

    gin.parse_config("""
      get_t5_preprocessor_by_name.name = 'rekey'
      get_t5_preprocessor_by_name.fn_kwargs = {'key_map': {'inputs': 'other', 'targets': 'text'}}
    """)
    prep_rekey = tf_inputs.get_t5_preprocessor_by_name()
    og_dataset = tf.data.Dataset.from_tensors({
        'text': 'That is good.', 'other': 'That is bad.'})
    training = True
    dataset = prep_rekey(og_dataset, training)
    t5_test_utils.assert_dataset(
        dataset,
        {'inputs': 'That is bad.', 'targets': 'That is good.'}) 
Example #24
Source File: tf_inputs_test.py    From trax with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super().setUp()
    gin.clear_config() 
Example #25
Source File: backend_test.py    From BERT with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    gin.clear_config() 
Example #26
Source File: aicrowd_utils.py    From disentanglement-pytorch with GNU General Public License v3.0 4 votes vote down vote up
def evaluate_disentanglement_metric(model, metric_names=['mig'], dataset_name='mpi3d_toy'):
    # These imports are included only inside this function for code base to run on systems without
    # proper installation of tensorflow and libcublas
    from aicrowd import utils_pytorch
    from aicrowd.evaluate import evaluate
    from disentanglement_lib.config.unsupervised_study_v1 import sweep as unsupervised_study_v1

    _study = unsupervised_study_v1.UnsupervisedStudyV1()
    evaluation_configs = sorted(_study.get_eval_config_files())
    evaluation_configs.append(os.path.join(os.getenv("PWD", ""), "extra_metrics_configs/irs.gin"))

    results_dict_all = dict()
    for metric_name in metric_names:
        eval_bindings = [
            "evaluation.random_seed = {}".format(0),
            "evaluation.name = '{}'".format(metric_name)
        ]

        # Get the correct config file and load it
        my_config = get_gin_config(evaluation_configs, metric_name)
        if my_config is None:
            logging.warning('metric {} not among available configs: {}'.format(metric_name, evaluation_configs))
            return 0
        # gin.parse_config_file(my_config)
        gin.parse_config_files_and_bindings([my_config], eval_bindings)

        model_path = os.path.join(model.ckpt_dir, 'pytorch_model.pt')
        utils_pytorch.export_model(utils_pytorch.RepresentationExtractor(model.model.encoder, 'mean'),
                                   input_shape=(1, model.num_channels, model.image_size, model.image_size),
                                   path=model_path)

        output_dir = os.path.join(model.ckpt_dir, 'eval_results', metric_name)
        os.makedirs(os.path.join(model.ckpt_dir, 'results'), exist_ok=True)

        results_dict = evaluate(model.ckpt_dir, output_dir, True)
        gin.clear_config()
        results = 0
        for key, value in results_dict.items():
            if key != 'elapsed_time' and key != 'uuid' and key != 'num_active_dims':
                results = value
        logging.info('Evaluation   {}={}'.format(metric_name, results))
        results_dict_all['eval_{}'.format(metric_name)] = results
    # print(results_dict)
    return results_dict_all