import os import gym import numpy as np from gym.spaces.box import Box from gym import wrappers, logger from baselines.common.atari_wrappers import make_atari, wrap_deepmind from .wrappers import VisdomMonitor, ProcessObservationWrapper, SkipWrapper, SensorEnvWrapper from .distributed_factory import DistributedEnv # from evkit.env.gibson.gibsonenv import GibsonEnv, DummyGibsonEnv from evkit.env.habitat.habitatenv import make_habitat_vector_env try: from evkit.env.vizdoom import * except ImportError: pass try: import dm_control2gym except ImportError: pass try: import roboschool except ImportError: pass try: import pybullet_envs except ImportError: pass import torch DEFAULT_SENSOR_NAME = "DEFAULT" class EnvFactory(object): @staticmethod def vectorized(env_id, seed, num_processes, log_dir, add_timestep, sensors={DEFAULT_SENSOR_NAME: None}, addl_repeat_count=0, preprocessing_fn=None, env_specific_kwargs={}, vis_interval=20, visdom_name='main', visdom_log_file=None, visdom_server='localhost', visdom_port='8097', num_val_processes=0, gae_gamma=None): '''Returns vectorized environment. Either the simulator implements this (habitat) or 'vectorized' uses the call_to_run helper ''' simulator, scenario = env_id.split('_') if simulator.lower() in ['habitat']: # These simulators internally handle vectorization/distribution env = make_habitat_vector_env( num_processes=num_processes, preprocessing_fn=preprocessing_fn, log_dir=log_dir, num_val_processes=num_val_processes, vis_interval=vis_interval, visdom_name=visdom_name, visdom_log_file=visdom_log_file, visdom_server=visdom_server, visdom_port=visdom_port, seed=seed, **env_specific_kwargs) else: # These simulators must be manually vectorized envs = [ EnvFactory.call_to_run(env_id, seed, rank, log_dir, add_timestep, sensors=sensors, addl_repeat_count=addl_repeat_count, preprocessing_fn=preprocessing_fn, env_specific_kwargs=env_specific_kwargs, vis_interval=vis_interval, visdom_name=visdom_name, visdom_log_file=visdom_log_file, visdom_server=visdom_server, visdom_port=visdom_port, num_val_processes=num_val_processes, num_processes=num_processes) for rank in range(num_processes) ] if num_processes == 1: env = DummyVecEnv(envs) else: env = DistributedEnv.new(envs, gae_gamma=gae_gamma, distribution_method=DistributedEnv.distribution_schemes.vectorize) return env @staticmethod def call_to_run(env_id, seed, rank, log_dir, add_timestep, sensors={DEFAULT_SENSOR_NAME: None}, addl_repeat_count=0, preprocessing_fn=None, gibson_config=None, blank_sensor=False, start_locations_file=None, target_dim=16, blind=False, env_specific_kwargs=None, vis_interval=20, visdom_name='main', visdom_log_file=None, visdom_server='localhost', visdom_port='8097', num_val_processes=0, num_processes=1): '''Returns a function which can be called to instantiate a new environment. Args: env_id: Name of the ID to make seed: random seed for environment rank: environment number (i of k) log_dir: directory to log to add_timestep: ??? sensors: A configuration of sensor names -> specs (for now, just none) preprocessing_fn(env): function which returns (transform, obs_shape) transform(obs): a function that is run on every obs obs_shape: the final shape of transform(obs) gibson_config: If using gibson, which config to use visdon_name: If using visdom, what to name the visdom environment visdom_log_file: Where to store visdom logging entries. This allows replaying training back to visdom. If this is set to none, then disable visdom logging. visdom_server: visdom server ip (http:// is automatically appended) visdom_port: Which port the visdom server is listening on Returns: A callable function (no parameters) which instantiates an enviroment. ''' simulator, scenario = env_id.split('_') if env_specific_kwargs is None: env_specific_kwargs = {} def _thunk(): preprocessing_fn_implemented_inside_env = False logging_implemented_inside_env = False already_distributed = False if env_id.startswith("dm"): _, domain, task = env_id.split('.') env = dm_control2gym.make(domain_name=domain, task_name=task) elif env_id.startswith("Gibson"): env = GibsonEnv(env_id=env_id, gibson_config=gibson_config, blind=blind, blank_sensor=blank_sensor, start_locations_file=start_locations_file, target_dim=target_dim, **env_specific_kwargs) elif env_id.startswith("DummyGibson"): env = DummyGibsonEnv(env_id=env_id, gibson_config=gibson_config, blind=blind, blank_sensor=blank_sensor, start_locations_file=start_locations_file, target_dim=target_dim, **env_specific_kwargs) elif env_id.startswith("Doom"): env_specific_kwargs['repeat_count'] = addl_repeat_count + 1 num_train_processes = num_processes - num_val_processes # 1 (train only), 2 test only env_specific_kwargs['randomize_textures'] = 1 if rank < num_train_processes else 2 vizdoom_class = eval(scenario.split('.')[0]) env = vizdoom_class(**env_specific_kwargs) elif env_id.startswith("Habitat"): env = make_habitat_vector_env( num_processes=rank, target_dim=target_dim, preprocessing_fn=preprocessing_fn, log_dir=log_dir, num_val_processes=num_val_processes, visdom_name=visdom_name, visdom_log_file=visdom_log_file, visdom_server=visdom_server, visdom_port=visdom_port, seed=seed, **env_specific_kwargs) already_distributed = True preprocessing_fn_implemented_inside_env = True logging_implemented_inside_env = True else: env = gym.make(env_id) if already_distributed: # Env is now responsible for logging, preprocessing, repeat_count return env is_atari = hasattr(gym.envs, 'atari') and isinstance( env.unwrapped, gym.envs.atari.atari_env.AtariEnv) if is_atari: env = make_atari(env_id) if add_timestep: raise NotImplementedError("AddTimestep not implemented for SensorDict") obs_shape = env.observation_space.shape if add_timestep and len(obs_shape) == 1 \ and str(env).find('TimeLimit') > -1: env = AddTimestep(env) if not (logging_implemented_inside_env or log_dir is None): os.makedirs(os.path.join(log_dir, visdom_name), exist_ok=True) print("Visdom log file", visdom_log_file) first_val_process = num_processes - num_val_processes if (rank == 0 or rank == first_val_process) and visdom_log_file is not None: env = VisdomMonitor(env, directory=os.path.join(log_dir, visdom_name), video_callable=lambda x: x % vis_interval == 0, uid=str(rank), server=visdom_server, port=visdom_port, visdom_log_file=visdom_log_file, visdom_env=visdom_name) else: print("Not using visdom") env = wrappers.Monitor(env, directory=os.path.join(log_dir, visdom_name), uid=str(rank)) if is_atari: env = wrap_deepmind(env) if addl_repeat_count > 0: if not hasattr(env, 'repeat_count') and not hasattr(env.unwrapped, 'repeat_count'): env = SkipWrapper(repeat_count)(env) if sensors is not None: if hasattr(env, 'is_embodied') or hasattr(env.unwrapped, 'is_embodied'): pass else: assert len(sensors) == 1, 'Can only handle one sensor' sensor_name = list(sensors.keys())[0] env = SensorEnvWrapper(env, name=sensor_name) if not (preprocessing_fn_implemented_inside_env or preprocessing_fn is None): transform, space = preprocessing_fn(env.observation_space) env = ProcessObservationWrapper(env, transform, space) env.seed(seed + rank) return env return _thunk class AddTimestep(gym.ObservationWrapper): def __init__(self, env=None): super(AddTimestep, self).__init__(env) self.observation_space = Box( self.observation_space.low[0], self.observation_space.high[0], [self.observation_space.shape[0] + 1], dtype=self.observation_space.dtype) def observation(self, observation): return np.concatenate((observation, [self.env._elapsed_steps])) class WrapPyTorch(gym.ObservationWrapper): def __init__(self, env=None): super(WrapPyTorch, self).__init__(env) obs_shape = self.observation_space.shape self.observation_space = Box( self.observation_space.low[0, 0, 0], self.observation_space.high[0, 0, 0], [obs_shape[2], obs_shape[0], obs_shape[1]], dtype=self.observation_space.dtype) def observation(self, observation): return observation.transpose(2, 0, 1)