Python gym.Env() Examples

The following are 30 code examples of gym.Env(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gym , or try the search function .
Example #1
Source File: reward_function.py    From irl-benchmark with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self,
                 env: gym.Env,
                 parameters: Union[None, str, np.ndarray] = None,
                 action_in_domain: bool = False,
                 next_state_in_domain: bool = False):
        """

        Parameters
        ----------
        env: gym.Env
            A gym environment. The environment has to be wrapped in a FeatureWrapper.
        parameters: Union[None, str, np.ndarray]
            The parameters of the reward function. One parameter for each feature.
            If value is 'random', initializes with random parameters (mean 0, standard deviation 1).
        """

        assert utils.wrapper.is_unwrappable_to(env, FeatureWrapper)
        super(FeatureBasedRewardFunction, self).__init__(
            env, parameters, action_in_domain, next_state_in_domain)

        if parameters == 'random':
            parameters_shape = utils.wrapper.unwrap_env(
                self.env, FeatureWrapper).feature_dimensionality()
            self.parameters = np.random.standard_normal(parameters_shape) 
Example #2
Source File: cmd_util.py    From HardRLWithYoutube with MIT License 6 votes vote down vote up
def make_mujoco_env(env_id, seed, reward_scale=1.0):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.
    """
    rank = MPI.COMM_WORLD.Get_rank()
    myseed = seed  + 1000 * rank if seed is not None else None
    set_global_seeds(myseed)
    env = gym.make(env_id)
    logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank))
    env = Monitor(env, logger_path, allow_early_resets=True)
    env.seed(seed)

    if reward_scale != 1.0:
        from baselines.common.retro_wrappers import RewardScaler
        env = RewardScaler(env, reward_scale)

    return env 
Example #3
Source File: rl2.py    From garage with MIT License 6 votes vote down vote up
def step(self, action):
        """gym.Env step function.

        Args:
            action (int): action taken.

        Returns:
            np.ndarray: augmented observation.
            float: reward.
            bool: terminal signal.
            dict: environment info.

        """
        next_obs, reward, done, info = self.env.step(action)
        next_obs = np.concatenate([next_obs, action, [reward], [done]])
        return next_obs, reward, done, info 
Example #4
Source File: misc_util.py    From HardRLWithYoutube with MIT License 6 votes vote down vote up
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname) 
Example #5
Source File: misc_util.py    From Reinforcement_Learning_for_Traffic_Light_Control with Apache License 2.0 6 votes vote down vote up
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname) 
Example #6
Source File: misc_util.py    From Reinforcement_Learning_for_Traffic_Light_Control with Apache License 2.0 6 votes vote down vote up
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname) 
Example #7
Source File: misc_util.py    From Reinforcement_Learning_for_Traffic_Light_Control with Apache License 2.0 6 votes vote down vote up
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname) 
Example #8
Source File: misc_util.py    From lirpg with MIT License 6 votes vote down vote up
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname) 
Example #9
Source File: runners.py    From stable-baselines with MIT License 6 votes vote down vote up
def __init__(self, *, env: Union[gym.Env, VecEnv], model: 'BaseRLModel', n_steps: int):
        """
        Collect experience by running `n_steps` in the environment.
        Note: if this is a `VecEnv`, the total number of steps will
        be `n_steps * n_envs`.

        :param env: (Union[gym.Env, VecEnv]) The environment to learn from
        :param model: (BaseRLModel) The model to learn
        :param n_steps: (int) The number of steps to run for each environment
        """
        self.env = env
        self.model = model
        n_envs = env.num_envs
        self.batch_ob_shape = (n_envs * n_steps,) + env.observation_space.shape
        self.obs = np.zeros((n_envs,) + env.observation_space.shape, dtype=env.observation_space.dtype.name)
        self.obs[:] = env.reset()
        self.n_steps = n_steps
        self.states = model.initial_state
        self.dones = [False for _ in range(n_envs)]
        self.callback = None  # type: Optional[BaseCallback]
        self.continue_training = True
        self.n_envs = n_envs 
Example #10
Source File: misc_util.py    From rl_graph_generation with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname) 
Example #11
Source File: cmd_util.py    From stable-baselines with MIT License 6 votes vote down vote up
def make_robotics_env(env_id, seed, rank=0, allow_early_resets=True):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.

    :param env_id: (str) the environment ID
    :param seed: (int) the initial seed for RNG
    :param rank: (int) the rank of the environment (for logging)
    :param allow_early_resets: (bool) allows early reset of the environment
    :return: (Gym Environment) The robotic environment
    """
    set_global_seeds(seed)
    env = gym.make(env_id)
    keys = ['observation', 'desired_goal']
    # TODO: remove try-except once most users are running modern Gym
    try:  # for modern Gym (>=0.15.4)
        from gym.wrappers import FilterObservation, FlattenObservation
        env = FlattenObservation(FilterObservation(env, keys))
    except ImportError:  # for older gym (<=0.15.3)
        from gym.wrappers import FlattenDictWrapper  # pytype:disable=import-error
        env = FlattenDictWrapper(env, keys)
    env = Monitor(
        env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
        info_keywords=('is_success',), allow_early_resets=allow_early_resets)
    env.seed(seed)
    return env 
Example #12
Source File: base.py    From nevergrad with MIT License 6 votes vote down vote up
def run(self, *agent: Agent, **agents: Agent) -> Union[float, Dict[str, float]]:
        """Run one agent or multiple named agents

        Parameters
        ----------
        *agent: Agent (optional)
            the agent to play a single-agent environment
        **agents: Agent
            the named agents to play a multi-agent environment

        Returns
        -------
        float:
            the mean reward (possibly for each agent)
        """
        san = "single_agent_name"
        sum_rewards: Dict[str, float] = {name: 0.0 for name in agents} if agents else {san: 0.0}
        for _ in range(self.num_repetitions):
            rewards = self._run_once(*agent, **agents)
            for name, value in rewards.items():
                sum_rewards[name] += value
        mean_rewards = {name: float(value) / self.num_repetitions for name, value in sum_rewards.items()}
        if isinstance(self.env, gym.Env):
            return mean_rewards[san]
        return mean_rewards 
Example #13
Source File: __init__.py    From stable-baselines with MIT License 6 votes vote down vote up
def sync_envs_normalization(env: Union[gym.Env, VecEnv], eval_env: Union[gym.Env, VecEnv]) -> None:
    """
    Sync eval and train environments when using VecNormalize

    :param env: (Union[gym.Env, VecEnv]))
    :param eval_env: (Union[gym.Env, VecEnv]))
    """
    env_tmp, eval_env_tmp = env, eval_env
    # Special case for the _UnvecWrapper
    # Avoid circular import
    from stable_baselines.common.base_class import _UnvecWrapper
    if isinstance(env_tmp, _UnvecWrapper):
        return
    while isinstance(env_tmp, VecEnvWrapper):
        if isinstance(env_tmp, VecNormalize):
            # sync reward and observation scaling
            eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
            eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
        env_tmp = env_tmp.venv
        # Make pytype happy, in theory env and eval_env have the same type
        assert isinstance(eval_env_tmp, VecEnvWrapper), "the second env differs from the first env"
        eval_env_tmp = eval_env_tmp.venv 
Example #14
Source File: callbacks.py    From stable-baselines with MIT License 6 votes vote down vote up
def __init__(self, verbose: int = 0):
        super(BaseCallback, self).__init__()
        # The RL model
        self.model = None  # type: Optional[BaseRLModel]
        # An alias for self.model.get_env(), the environment used for training
        self.training_env = None  # type: Union[gym.Env, VecEnv, None]
        # Number of time the callback was called
        self.n_calls = 0  # type: int
        # n_envs * n times env.step() was called
        self.num_timesteps = 0  # type: int
        self.verbose = verbose
        self.locals = None  # type: Optional[Dict[str, Any]]
        self.globals = None  # type: Optional[Dict[str, Any]]
        self.logger = None  # type: Optional[logger.Logger]
        # Sometimes, for event callback, it is useful
        # to have access to the parent object
        self.parent = None  # type: Optional[BaseCallback]

    # Type hint as string to avoid circular import 
Example #15
Source File: envs.py    From imitation with MIT License 6 votes vote down vote up
def test_model_based(env: gym.Env) -> None:
    """Smoke test for each of the ModelBasedEnv methods with type checks.

    Raises:
        AssertionError if test fails.
    """
    state = env.initial_state()
    assert env.state_space.contains(state)

    action = env.action_space.sample()
    new_state = env.transition(state, action)
    assert env.state_space.contains(new_state)

    reward = env.reward(state, action, new_state)
    assert isinstance(reward, float)

    done = env.terminal(state, 0)
    assert isinstance(done, bool)

    obs = env.obs_from_state(state)
    assert env.observation_space.contains(obs)
    next_obs = env.obs_from_state(new_state)
    assert env.observation_space.contains(next_obs) 
Example #16
Source File: util.py    From imitation with MIT License 6 votes vote down vote up
def init_rl(
    env: Union[gym.Env, VecEnv],
    model_class: Type[BaseRLModel] = stable_baselines.PPO2,
    policy_class: Type[BasePolicy] = MlpPolicy,
    **model_kwargs,
):
    """Instantiates a policy for the provided environment.

    Args:
        env: The (vector) environment.
        model_class: A Stable Baselines RL algorithm.
        policy_class: A Stable Baselines compatible policy network class.
        model_kwargs (dict): kwargs passed through to the algorithm.
          Note: anything specified in `policy_kwargs` is passed through by the
          algorithm to the policy network.

    Returns:
      An RL algorithm.
    """
    return model_class(
        policy_class, env, **model_kwargs
    )  # pytype: disable=not-instantiable 
Example #17
Source File: __init__.py    From irl-benchmark with GNU General Public License v3.0 6 votes vote down vote up
def make_env(env_id: str):
    """Make a basic gym environment, without any special wrappers.

    Parameters
    ----------
    env_id: str
        The environment's id, e.g. 'FrozenLake-v0'.
    Returns
    -------
    gym.Env
        A gym environment.
    """
    assert env_id in ENV_IDS
    if not env_id in ENV_IDS_NON_GYM:
        env = gym.make(env_id)
    else:
        if env_id == 'MazeWorld0-v0':
            env = TimeLimit(MazeWorld(map_id=0), max_episode_steps=200)
        elif env_id == 'MazeWorld1-v0':
            env = TimeLimit(MazeWorld(map_id=1), max_episode_steps=200)
        else:
            raise NotImplementedError()
    return env 
Example #18
Source File: utils_wrapper_test.py    From irl-benchmark with GNU General Public License v3.0 6 votes vote down vote up
def test_is_unwrappable_to():
    assert is_unwrappable_to(make_env('FrozenLake-v0'), TimeLimit)
    assert is_unwrappable_to(make_env('FrozenLake-v0'), DiscreteEnv)
    assert is_unwrappable_to(
        feature_wrapper.make('FrozenLake-v0'), FrozenLakeFeatureWrapper)
    assert is_unwrappable_to(
        feature_wrapper.make('FrozenLake8x8-v0'), FrozenLakeFeatureWrapper)
    assert is_unwrappable_to(
        feature_wrapper.make('FrozenLake-v0'), feature_wrapper.FeatureWrapper)
    env = feature_wrapper.make('FrozenLake-v0')
    reward_function = FeatureBasedRewardFunction(env, 'random')
    env = RewardWrapper(env, reward_function)
    assert is_unwrappable_to(env, RewardWrapper)
    assert is_unwrappable_to(env, feature_wrapper.FeatureWrapper)
    assert is_unwrappable_to(env, DiscreteEnv)
    assert is_unwrappable_to(env, gym.Env) 
Example #19
Source File: wrapper.py    From irl-benchmark with GNU General Public License v3.0 6 votes vote down vote up
def unwrap_env(env: gym.Env,
               until_class: Union[None, gym.Env] = None) -> gym.Env:
    """Unwrap wrapped env until we get an instance that is a until_class.

    If until_class is None, env will be unwrapped until the lowest layer.
    """
    if until_class is None:
        while hasattr(env, 'env'):
            env = env.env
        return env

    while hasattr(env, 'env') and not isinstance(env, until_class):
        env = env.env

    if not isinstance(env, until_class):
        raise ValueError(
            "Unwrapping env did not yield an instance of class {}".format(
                until_class))
    return env 
Example #20
Source File: reward_function.py    From irl-benchmark with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self,
                 env: gym.Env,
                 parameters: Union[None, str, np.ndarray] = None,
                 action_in_domain: bool = False,
                 next_state_in_domain: bool = False):
        """ The abstract base class for reward functions

        Parameters
        ----------
        env: gym.Env
            A gym environment for which the reward function is defined.
        parameters: Union[None, str, np.ndarray]
            A numpy ndarray containing the parameters. If value is 'random',
            initializes with random parameters (mean 0, standard deviation 1).
        action_in_domain: bool
            Indicates whether actions are in the domain, i.e. R(s, a) or R(s, a, s')
        next_state_in_domain: bool
            Indicates whether next states are in the domain, i.e. R(s, a, s')
        """
        self.env = env
        self.action_in_domain = action_in_domain
        if next_state_in_domain:
            assert action_in_domain
        self.next_state_in_domain = next_state_in_domain
        self.parameters = parameters 
Example #21
Source File: rl2.py    From garage with MIT License 5 votes vote down vote up
def reset(self):
        """gym.Env reset function.""" 
Example #22
Source File: gym_utils_test.py    From BERT with Apache License 2.0 5 votes vote down vote up
def test_gym_registration(self):
    reg_id, env = gym_utils.register_gym_env(
        "tensor2tensor.rl.gym_utils_test:SimpleEnv")

    self.assertEqual("T2TEnv-SimpleEnv-v0", reg_id)

    # Most basic check.
    self.assertTrue(isinstance(env, gym.Env))

    # Just make sure we got the same environment.
    self.assertTrue(
        np.allclose(env.reset(), np.zeros(shape=(3, 3), dtype=np.uint8)))

    _, _, done, _ = env.step(1)
    self.assertTrue(done) 
Example #23
Source File: rl2.py    From garage with MIT License 5 votes vote down vote up
def reset(self):
        """gym.Env reset function."""
        self._policy._prev_hiddens = self._initial_hiddens 
Example #24
Source File: appr_irl.py    From irl-benchmark with GNU General Public License v3.0 5 votes vote down vote up
def __init__(self, env: gym.Env, expert_trajs: List[Dict[str, list]],
                 rl_alg_factory: Callable[[gym.Env], BaseRLAlgorithm],
                 metrics: List[BaseMetric], config: dict):
        """

        Parameters
        ----------
        env: gym.Env
            The gym environment to be trained on.
            Needs to be wrapped in a RewardWrapper to prevent leaking the true reward function.
        expert_trajs: List[dict]
            A list of trajectories.
            Each trajectory is a dictionary with keys
            ['states', 'actions', 'rewards', 'true_rewards', 'features'].
            The values of each dictionary are lists.
            See :func:`irl_benchmark.irl.collect.collect_trajs`.
        rl_alg_factory: Callable[[gym.Env], BaseRLAlgorithm]
            A function which returns a new RL algorithm when called.
        config: dict
            A dictionary containing hyper-parameters for the algorithm.
            The fields are:
            * 'gamma': discount factor between 0. and 1.
            * 'epsilon': small positive value, stopping criterion.
            * 'mode': which variant of the algorithm to use, either 'svm' or 'projection'.
        """
        super(ApprIRL, self).__init__(env, expert_trajs, rl_alg_factory,
                                      metrics, config)

        # calculate the feature counts of expert trajectories:
        self.expert_feature_count = self.feature_count(self.expert_trajs,
                                                       self.config['gamma'])

        print('EXPERT FEATURE COUNT:')
        print(self.expert_feature_count)

        # create list of feature counts:
        self.feature_counts = [self.expert_feature_count]
        # for SVM mode: create list of labels:
        self.labels = [1.]

        self.distances = [] 
Example #25
Source File: base_algorithm.py    From irl-benchmark with GNU General Public License v3.0 5 votes vote down vote up
def __init__(self,
                 env: gym.Env,
                 expert_trajs: List[Dict[str, list]],
                 rl_alg_factory: Callable[[gym.Env], BaseRLAlgorithm],
                 metrics: List[BaseMetric] = None,
                 config: Union[dict, None] = None):
        """

        Parameters
        ----------
        env: gym.Env
            The gym environment to be trained on.
            Needs to be wrapped in a RewardWrapper to not leak the true reward function.
        expert_trajs: List[dict]
            A list of trajectories.
            Each trajectory is a dictionary with keys
            ['states', 'actions', 'rewards', 'true_rewards', 'features'].
            The values of each dictionary are lists.
            See :func:`irl_benchmark.irl.collect.collect_trajs`.
        rl_alg_factory: Callable[[gym.Env], BaseRLAlgorithm]
            A function which returns a new RL algorithm when called.
        config: dict
            A dictionary containing algorithm-specific parameters.
        """

        assert is_unwrappable_to(env, RewardWrapper)

        if IRL_ALG_REQUIREMENTS[type(self)]['requires_features']:
            assert is_unwrappable_to(env, FeatureWrapper)
        if IRL_ALG_REQUIREMENTS[type(self)]['requires_transitions']:
            assert is_unwrappable_to(env, BaseWorldModelWrapper)

        self.env = env
        self.expert_trajs = expert_trajs
        self.rl_alg_factory = rl_alg_factory
        if metrics is None:
            metrics = []
        self.metrics = metrics
        self.metric_results = [[]] * len(metrics)
        self.config = preprocess_config(self, IRL_CONFIG_DOMAINS, config) 
Example #26
Source File: me_irl.py    From irl-benchmark with GNU General Public License v3.0 5 votes vote down vote up
def __init__(self, env: gym.Env, expert_trajs: List[Dict[str, list]],
                 rl_alg_factory: Callable[[gym.Env], BaseRLAlgorithm],
                 metrics: List[BaseMetric], config: dict):
        """See :class:`irl_benchmark.irl.algorithms.base_algorithm.BaseIRLAlgorithm`."""

        super(MaxEntIRL, self).__init__(env, expert_trajs, rl_alg_factory,
                                        metrics, config)
        # get transition matrix (with absorbing state)
        self.transition_matrix = unwrap_env(
            env, BaseWorldModelWrapper).get_transition_array()
        self.n_states, self.n_actions, _ = self.transition_matrix.shape

        # get map of features for all states:
        feature_wrapper = unwrap_env(env, FeatureWrapper)
        self.feat_map = feature_wrapper.feature_array() 
Example #27
Source File: dagger.py    From imitation with MIT License 5 votes vote down vote up
def __init__(
        self,
        env: gym.Env,
        get_robot_act: Callable[[np.ndarray], np.ndarray],
        beta: float,
        save_dir: str,
    ):
        """Trajectory collector constructor.

        Args:
          env: environment to sample trajectories from.
          get_robot_act: get a single robot action that can be substituted for
              human action. Takes a single observation as input & returns a
              single action.
          beta: fraction of the time to use action given to .step() instead of
              robot action.
          save_dir: directory to save collected trajectories in.
        """
        super().__init__(env)
        self.get_robot_act = get_robot_act
        assert 0 <= beta <= 1
        self.beta = beta
        self.traj_accum = None
        self.save_dir = save_dir
        self._last_obs = None
        self._done_before = True
        self._is_reset = False 
Example #28
Source File: dagger.py    From imitation with MIT License 5 votes vote down vote up
def __init__(
        self,
        env: gym.Env,
        scratch_dir: str,
        beta_schedule: Callable[[int], float] = None,
        **bc_kwargs,
    ):
        """Trainer constructor.

        Args:
          env: environment to train in.
          scratch_dir: directory to use to store intermediate training
              information (e.g. for resuming training).
          beta_schedule: provides a value of `beta` (the probability of taking
              expert action in any given state) at each round of training. If
              `None`, then `linear_beta_schedule` will be used instead.
          **bc_kwargs: additional arguments for constructing the `BC` that
              will be used to train the underlying policy.
        """
        # for pickling
        self._init_args = locals()
        self._init_args.update(bc_kwargs)
        del self._init_args["self"]
        del self._init_args["bc_kwargs"]

        if beta_schedule is None:
            beta_schedule = linear_beta_schedule(15)
        self.beta_schedule = beta_schedule
        self.scratch_dir = scratch_dir
        self.env = env
        self.round_num = 0
        self.bc_kwargs = bc_kwargs
        self._last_loaded_round = -1
        self._all_demos = []

        self._build_graph() 
Example #29
Source File: rl2.py    From garage with MIT License 5 votes vote down vote up
def reset(self, **kwargs):
        """gym.Env reset function.

        Args:
            kwargs: Keyword arguments.

        Returns:
            np.ndarray: augmented observation.

        """
        del kwargs
        obs = self.env.reset()
        return np.concatenate(
            [obs, np.zeros(self.env.action_space.shape), [0], [0]]) 
Example #30
Source File: utils_wrapper_test.py    From irl-benchmark with GNU General Public License v3.0 5 votes vote down vote up
def test_unwrap():
    env = make_env('FrozenLake-v0')
    assert env.env is unwrap_env(env, DiscreteEnv)

    # No unwrapping needed:
    assert env is unwrap_env(env, gym.Env)

    # Unwrap all the way:
    assert env.env is unwrap_env(env)

    env = FrozenLakeFeatureWrapper(env)
    assert env.env.env is unwrap_env(env, DiscreteEnv)

    # No unwrapping needed:
    assert env is unwrap_env(env, FrozenLakeFeatureWrapper)

    # Unwrap all the way:
    assert env.env.env is unwrap_env(env)

    # check types:
    assert isinstance(unwrap_env(env, DiscreteEnv), DiscreteEnv)
    assert isinstance(
        unwrap_env(env, feature_wrapper.FeatureWrapper),
        feature_wrapper.FeatureWrapper)
    assert isinstance(
        unwrap_env(env, FrozenLakeFeatureWrapper), FrozenLakeFeatureWrapper)
    assert isinstance(
        unwrap_env(env, FrozenLakeFeatureWrapper),
        feature_wrapper.FeatureWrapper)
    assert isinstance(unwrap_env(env), gym.Env)