# coding=utf-8 # Copyright 2020 The Tensor2Robot Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as python3 """Integration tests for training pose_env models.""" import os from absl.testing import absltest from absl.testing import parameterized import gin from tensor2robot.input_generators import default_input_generator from tensor2robot.meta_learning import meta_policies from tensor2robot.meta_learning import preprocessors from tensor2robot.predictors import checkpoint_predictor from tensor2robot.research.pose_env import pose_env from tensor2robot.research.pose_env import pose_env_maml_models from tensor2robot.research.pose_env import pose_env_models from tensor2robot.utils import train_eval from tensor2robot.utils import train_eval_test_utils import tensorflow.compat.v1 as tf # tf BATCH_SIZE = 1 MAX_TRAIN_STEPS = 1 EVAL_STEPS = 1 NUM_TRAIN_SAMPLES_PER_TASK = 1 NUM_VAL_SAMPLES_PER_TASK = 1 FLAGS = tf.app.flags.FLAGS class PoseEnvModelsTest(parameterized.TestCase): def setUp(self): super(PoseEnvModelsTest, self).setUp() base_dir = 'tensor2robot' test_data = os.path.join(FLAGS.test_srcdir, base_dir, 'test_data/pose_env_test_data.tfrecord') self._train_log_dir = FLAGS.test_tmpdir if tf.io.gfile.exists(self._train_log_dir): tf.io.gfile.rmtree(self._train_log_dir) gin.bind_parameter('train_eval_model.max_train_steps', 3) gin.bind_parameter('train_eval_model.eval_steps', 2) self._record_input_generator = ( default_input_generator.DefaultRecordInputGenerator( batch_size=BATCH_SIZE, file_patterns=test_data)) self._meta_record_input_generator_train = ( default_input_generator.DefaultRandomInputGenerator( batch_size=BATCH_SIZE)) self._meta_record_input_generator_eval = ( default_input_generator.DefaultRandomInputGenerator( batch_size=BATCH_SIZE)) def test_mc(self): train_eval.train_eval_model( t2r_model=pose_env_models.PoseEnvContinuousMCModel(), input_generator_train=self._record_input_generator, input_generator_eval=self._record_input_generator, create_exporters_fn=None) def test_regression(self): train_eval.train_eval_model( t2r_model=pose_env_models.PoseEnvRegressionModel(), input_generator_train=self._record_input_generator, input_generator_eval=self._record_input_generator, create_exporters_fn=None) def test_regression_maml(self): maml_model = pose_env_maml_models.PoseEnvRegressionModelMAML( base_model=pose_env_models.PoseEnvRegressionModel()) train_eval.train_eval_model( t2r_model=maml_model, input_generator_train=self._meta_record_input_generator_train, input_generator_eval=self._meta_record_input_generator_eval, create_exporters_fn=None) def _test_policy_interface(self, policy, restore=True): urdf_root = pose_env.get_pybullet_urdf_root() self.assertTrue(os.path.exists(urdf_root)) env = pose_env.PoseToyEnv( urdf_root=urdf_root, render_mode='DIRECT') env.reset_task() obs = env.reset() if restore: policy.restore() policy.reset_task() action = policy.SelectAction(obs, None, 0) new_obs, rew, done, env_debug = env.step(action) episode_data = [[(obs, action, rew, new_obs, done, env_debug)]] policy.adapt(episode_data) policy.SelectAction(new_obs, None, 1) def test_regression_maml_policy_interface(self): t2r_model = pose_env_maml_models.PoseEnvRegressionModelMAML( base_model=pose_env_models.PoseEnvRegressionModel(), preprocessor_cls=preprocessors.FixedLenMetaExamplePreprocessor) predictor = checkpoint_predictor.CheckpointPredictor(t2r_model=t2r_model) predictor.init_randomly() policy = meta_policies.MAMLRegressionPolicy(t2r_model, predictor=predictor) self._test_policy_interface(policy, restore=False) @parameterized.parameters( ('run_train_reg_maml.gin',), ('run_train_reg.gin',)) def test_train_eval_gin(self, gin_file): base_dir = 'tensor2robot' full_gin_path = os.path.join( FLAGS.test_srcdir, base_dir, 'research/pose_env/configs', gin_file) model_dir = os.path.join(FLAGS.test_tmpdir, 'test_train_eval_gin', gin_file) train_eval_test_utils.test_train_eval_gin( test_case=self, model_dir=model_dir, full_gin_path=full_gin_path, max_train_steps=MAX_TRAIN_STEPS, eval_steps=EVAL_STEPS) if __name__ == '__main__': absltest.main()