# coding=utf-8 # Copyright 2018 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for openai gym.""" from collections import deque import gym import numpy as np # pylint: disable=method-hidden class WarmupWrapper(gym.Wrapper): """Warmup wrapper.""" def __init__(self, env, warm_up_examples=0, warmup_action=0): gym.Wrapper.__init__(self, env) self.warm_up_examples = warm_up_examples self.warm_up_action = warmup_action self.observation_space = gym.spaces.Box( low=0, high=255, shape=(210, 160, 3), dtype=np.uint8) def get_starting_data(self, num_frames): self.reset() starting_observations, starting_actions, starting_rewards = [], [], [] for _ in range(num_frames): observation, rew, _, _ = self.env.step(self.warm_up_action) starting_observations.append(observation) starting_rewards.append(rew) starting_actions.append(self.warm_up_action) return starting_observations, starting_actions, starting_rewards def step(self, action): return self.env.step(action) def reset(self, **kwargs): del kwargs self.env.reset() observation = None for _ in range(self.warm_up_examples): observation, _, _, _ = self.env.step(self.warm_up_action) return observation class PongWrapper(WarmupWrapper): """Pong Wrapper.""" def __init__(self, env, warm_up_examples=0, action_space_reduction=False, reward_skip_steps=0, big_ball=False): super(PongWrapper, self).__init__(env, warm_up_examples=warm_up_examples) self.action_space_reduction = action_space_reduction if self.action_space_reduction: self.action_space = gym.spaces.Discrete(2) self.warm_up_examples = warm_up_examples self.observation_space = gym.spaces.Box( low=0, high=255, shape=(210, 160, 3), dtype=np.uint8) self.reward_skip_steps = reward_skip_steps self.big_ball = big_ball def step(self, action): if self.action_space_reduction: action = 2 if int(action) == 0 else 5 ob, rew, done, info = self.env.step(action) ob = self.process_observation(ob) if rew != 0 and self.reward_skip_steps != 0: for _ in range(self.reward_skip_steps): self.env.step(0) return ob, rew, done, info def reset(self, **kwargs): observation = super(PongWrapper, self).reset(**kwargs) observation = self.process_observation(observation) return observation def process_observation(self, obs): if self.big_ball: pos = PongWrapper.find_ball(obs) if pos is not None: x, y = pos obs[x-5:x+5, y-5:y+5, :] = 255 return obs @staticmethod def find_ball(obs, default=None): ball_area = obs[37:193, :, 0] res = np.argwhere(ball_area == 236) if not res: return default else: x, y = res[0] x += 37 return x, y def wrapped_pong_factory(warm_up_examples=0, action_space_reduction=False, reward_skip_steps=0, big_ball=False): """Wrapped pong games.""" env = gym.make("PongDeterministic-v4") env = env.env # Remove time_limit wrapper. env = PongWrapper(env, warm_up_examples=warm_up_examples, action_space_reduction=action_space_reduction, reward_skip_steps=reward_skip_steps, big_ball=big_ball) return env gym.envs.register(id="T2TPongWarmUp20RewSkip200Steps-v1", entry_point=lambda: wrapped_pong_factory( # pylint: disable=g-long-lambda warm_up_examples=20, reward_skip_steps=15), max_episode_steps=200) gym.envs.register(id="T2TPongWarmUp20RewSkip2000Steps-v1", entry_point=lambda: wrapped_pong_factory( # pylint: disable=g-long-lambda warm_up_examples=20, reward_skip_steps=15), max_episode_steps=2000) class BreakoutWrapper(WarmupWrapper): """Breakout Wrapper.""" FIRE_ACTION = 1 def __init__(self, env, warm_up_examples=0, ball_down_skip=0, big_ball=False, include_direction_info=False, reward_clipping=True): super(BreakoutWrapper, self).__init__( env, warm_up_examples=warm_up_examples, warmup_action=BreakoutWrapper.FIRE_ACTION) self.warm_up_examples = warm_up_examples self.observation_space = gym.spaces.Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8) self.ball_down_skip = ball_down_skip self.big_ball = big_ball self.reward_clipping = reward_clipping self.include_direction_info = include_direction_info self.direction_info = deque([], maxlen=2) self.points_gained = False msg = ("ball_down_skip should be bigger equal 9 for " "include_direction_info to work correctly") assert not self.include_direction_info or ball_down_skip >= 9, msg def step(self, action): ob, rew, done, info = self.env.step(action) if BreakoutWrapper.find_ball(ob) is None and self.ball_down_skip != 0: for _ in range(self.ball_down_skip): # We assume that nothing interesting happens during ball_down_skip # and discard all information. # We fire all the time to start new game ob, _, _, _ = self.env.step(BreakoutWrapper.FIRE_ACTION) self.direction_info.append(BreakoutWrapper.find_ball(ob)) ob = self.process_observation(ob) self.points_gained = self.points_gained or rew > 0 if self.reward_clipping: rew = np.sign(rew) return ob, rew, done, info def reset(self, **kwargs): observation = super(BreakoutWrapper, self).reset(**kwargs) self.env.step(BreakoutWrapper.FIRE_ACTION) self.direction_info = deque([], maxlen=2) observation = self.process_observation(observation) return observation @staticmethod def find_ball(ob, default=None): off_x = 63 clipped_ob = ob[off_x:-21, :, 0] pos = np.argwhere(clipped_ob == 200) if not pos.size: return default x = off_x + pos[0][0] y = 0 + pos[0][1] return x, y def process_observation(self, obs): if self.big_ball: pos = BreakoutWrapper.find_ball(obs) if pos is not None: x, y = pos obs[x-5:x+5, y-5:y+5, :] = 255 if self.include_direction_info: for point in list(self.direction_info): if point is not None: x, y = point obs[x-2:x+2, y-2:y+2, 1] = 255 return obs def wrapped_breakout_factory(warm_up_examples=0, ball_down_skip=0, big_ball=False, include_direction_info=False, reward_clipping=True): """Wrapped breakout games.""" env = gym.make("BreakoutDeterministic-v4") env = env.env # Remove time_limit wrapper. env = BreakoutWrapper(env, warm_up_examples=warm_up_examples, ball_down_skip=ball_down_skip, big_ball=big_ball, include_direction_info=include_direction_info, reward_clipping=reward_clipping) return env gym.envs.register(id="T2TBreakoutWarmUp20RewSkip500Steps-v1", entry_point=lambda: wrapped_breakout_factory( # pylint: disable=g-long-lambda warm_up_examples=1, ball_down_skip=9, big_ball=False, include_direction_info=True, reward_clipping=True ), max_episode_steps=500) class FreewayWrapper(WarmupWrapper): """Wrapper for Freeway.""" def __init__(self, env, warm_up_examples=0, reward_clipping=True, easy_freeway=False): super(FreewayWrapper, self).__init__(env, warm_up_examples) self.easy_freeway = easy_freeway self.half_way_reward = 1.0 # this is probably not needed, just in case self.reward_clipping = reward_clipping def chicken_height(self, image): raise NotImplementedError() def step(self, action): ob, rew, done, info = self.env.step(action) if self.easy_freeway: if rew > 0: self.half_way_reward = 1 chicken_height = self.chicken_height(ob) if chicken_height < 105: rew += self.half_way_reward self.half_way_reward = 0 if self.reward_clipping: rew = np.sign(rew) return ob, rew, done, info def reset(self, **kwargs): self.half_way_reward = 1.0 observation = super(FreewayWrapper, self).reset(**kwargs) return observation def wrapped_freeway_factory(warm_up_examples=0, reward_clipping=True, easy_freeway=False): """Wrapped freeway games.""" env = gym.make("FreewayDeterministic-v4") env = env.env # Remove time_limit wrapper. env = FreewayWrapper(env, warm_up_examples=warm_up_examples, reward_clipping=reward_clipping, easy_freeway=easy_freeway) return env gym.envs.register(id="T2TFreewayWarmUp20RewSkip500Steps-v1", entry_point=lambda: wrapped_freeway_factory( # pylint: disable=g-long-lambda warm_up_examples=1, reward_clipping=True, easy_freeway=False ), max_episode_steps=500)