r"""Running a pre-trained ppo agent on rex environments""" import logging import os import site import time import tensorflow.compat.v1 as tf from rex_gym.agents.scripts import utility from rex_gym.agents.ppo import simple_ppo_agent from rex_gym.util import action_mapper class PolicyPlayer: def __init__(self, env_id: str, args: dict): self.gym_dir_path = str(site.getsitepackages()[0]) self.env_id = env_id self.args = args def play(self): policy_dir = os.path.join(self.gym_dir_path, action_mapper.ENV_ID_TO_POLICY[self.env_id][0]) config = utility.load_config(policy_dir) policy_layers = config.policy_layers value_layers = config.value_layers env = config.env(render=True, **self.args) network = config.network checkpoint = os.path.join(policy_dir, action_mapper.ENV_ID_TO_POLICY[self.env_id][1]) with tf.Session() as sess: agent = simple_ppo_agent.SimplePPOPolicy(sess, env, network, policy_layers=policy_layers, value_layers=value_layers, checkpoint=checkpoint) sum_reward = 0 observation = env.reset() while True: action = agent.get_action([observation]) observation, reward, done, _ = env.step(action[0]) time.sleep(0.002) sum_reward += reward logging.info(f"Reward={sum_reward}") if done: break