from collections import deque from gym import Wrapper, ObservationWrapper, ActionWrapper from gym.spaces.box import Box import numpy as np import cv2 def _process_frame42(frame): reshaped_screen = np.reshape(frame, [210, 160, 3]).astype(np.float32).mean(2) resized_screen = cv2.resize(reshaped_screen, (84, 110)) x_t = resized_screen[18:102, :] x_t = cv2.resize(x_t, (42, 42)) x_t *= (1.0 / 255.0) x_t = np.reshape(x_t, [42, 42, 1]) return x_t class AtariRescale42x42Env(ObservationWrapper): def __init__(self, env=None): super(AtariRescale42x42Env, self).__init__(env) self.observation_space = Box(0, 255, [42, 42, 1]) def _observation(self, observation): return _process_frame42(observation) def _process_frame84(frame): img = np.reshape(frame, [210, 160, 3]).astype(np.float32) img = img[:, :, 0] * 0.2126 + img[:, :, 1] * 0.0722 + img[:, :, 2] * 0.7152 resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_LINEAR) x_t = resized_screen[18:102, :] x_t /= 255.0 # x_t -= 0.5 x_t = np.reshape(x_t, [84, 84, 1]) return x_t class AtariRescale84x84Env(ObservationWrapper): def __init__(self, env=None): super(AtariRescale84x84Env, self).__init__(env) self.observation_space = Box(0, 255, [84, 84, 1]) def _observation(self, observation): return _process_frame84(observation) class RandomizedResetEnv(Wrapper): def __init__(self, env, no_op_max=7): super(RandomizedResetEnv, self).__init__(env) self._no_op_max = no_op_max def _reset(self): ob = self.env.reset() action = 0 # randomize initial state if self._no_op_max > 0: no_op = np.random.randint(0, self._no_op_max + 1) for _ in range(no_op): ob, _, _, _ = self.env.step(action) return ob class OneLiveResetEnv(Wrapper): def _step(self, action): lives = self.env.unwrapped.ale.lives() observation, reward, done, info = self.env.step(action) if lives != self.env.unwrapped.ale.lives(): done = True return observation, reward, done, info class UnstuckPolicyEnv(ActionWrapper): actions = deque(maxlen=30) def _action(self, action): if self.actions.count(action) == 30: action = 1 self.actions.append(action) return action def _reverse_action(self, action): return action class ObservationBuffer(Wrapper): def __init__(self, env, buffer_size=4): super(ObservationBuffer, self).__init__(env) self.buffer_size = buffer_size self.buffer = deque(maxlen=self.buffer_size) assert len(self.env.observation_space.shape) == 3 self._shape = list(self.env.observation_space.shape) self._num_channels = self._shape[2] self._shape[2] *= self.buffer_size self.observation_space = Box(-0.5, 0.5, self._shape) def _step(self, action): observation, reward, done, info = self.env.step(action) self.buffer.append(observation) return np.concatenate(self.buffer, axis=2), reward, done, info def _reset(self): obs = self.env.reset() for _ in range(self.buffer_size): self.buffer.append(obs) return np.concatenate(self.buffer, axis=2) # def _render(self, mode='human', close=False): # if mode == "rgb_array": # return self.buffer[-1] # return self.env.render(mode, close)