import unittest import logging import torch import numpy as np import gym from gym import Wrapper from gym import ObservationWrapper from gym import RewardWrapper from PIL import Image from termcolor import colored as clr from collections import OrderedDict from utils.torch_types import TorchTypes logger = logging.getLogger(__name__) class SqueezeRewards(RewardWrapper): def __init__(self, env): super(SqueezeRewards, self).__init__(env) print("[Reward Wrapper] for clamping rewards to -+1") def _reward(self, reward): return float(np.sign(reward)) class PreprocessFrames(ObservationWrapper): def __init__(self, env, env_type, hist_len, state_dims, cuda=None): super(PreprocessFrames, self).__init__(env) self.env_type = env_type self.state_dims = state_dims self.hist_len = hist_len self.env_wh = self.env.observation_space.shape[0:2] self.env_ch = self.env.observation_space.shape[2] self.wxh = self.env_wh[0] * self.env_wh[1] # need to find a better way if self.env_type == "atari": self._preprocess = self._atari_preprocess elif self.env_type == "catch": self._preprocess = self._catch_preprocess print("[Preprocess Wrapper] for %s with state history of %d frames." % (self.env_type, hist_len)) self.cuda = False if cuda is None else cuda self.dtype = dtype = TorchTypes(self.cuda) self.rgb = dtype.FT([.2126, .7152, .0722]) # torch.size([1, 4, 24, 24]) """ self.hist_state = torch.FloatTensor(1, hist_len, *state_dims) self.hist_state.fill_(0) """ self.d = OrderedDict({i: torch.FloatTensor(1, 1, *state_dims).fill_(0) for i in range(hist_len)}) def _observation(self, o): return self._preprocess(o) def _reset(self): # self.hist_state.fill_(0) self.d = OrderedDict( {i: torch.FloatTensor(1, 1, *self.state_dims).fill_(0) for i in range(self.hist_len)}) observation = self.env.reset() return self._observation(observation) def _catch_preprocess(self, o): return self._get_concatenated_state(self._rgb2y(o)) def _atari_preprocess(self, o): img = Image.fromarray(self._rgb2y(o).numpy()) img = np.array(img.resize(self.state_dims, resample=Image.NEAREST)) th_img = torch.from_numpy(img) return self._get_concatenated_state(th_img) def _rgb2y(self, o): o = torch.from_numpy(o).type(self.dtype.FT) s = o.view(self.wxh, 3).mv(self.rgb).view(*self.env_wh) / 255 return s.cpu() def _get_concatenated_state(self, o): hist_len = self.hist_len for i in range(hist_len - 1): self.d[i] = self.d[i + 1] self.d[hist_len - 1] = o.unsqueeze(0).unsqueeze(0) return torch.cat(list(self.d.values()), 1) """ def _get_concatenated_state(self, o): hist_len = self.hist_len # eg. 4 # move frames already existent one position below if hist_len > 1: self.hist_state[0][0:hist_len - 1] = self.hist_state[0][1:hist_len] # concatenate the newest frame to the top of the augmented state self.hist_state[0][self.hist_len - 1] = o return self.hist_state """ class DoneAfterLostLife(gym.Wrapper): def __init__(self, env): super(DoneAfterLostLife, self).__init__(env) self.no_more_lives = True self.crt_live = env.unwrapped.ale.lives() self.has_many_lives = self.crt_live != 0 if self.has_many_lives: self._step = self._many_lives_step else: self._step = self._one_live_step not_a = clr("not a", attrs=['bold']) print("[DoneAfterLostLife Wrapper] %s is %s many lives game." % (env.env.spec.id, "a" if self.has_many_lives else not_a)) def _reset(self): if self.no_more_lives: obs = self.env.reset() self.crt_live = self.env.unwrapped.ale.lives() return obs else: return self.__obs def _many_lives_step(self, action): obs, reward, done, info = self.env.step(action) crt_live = self.env.unwrapped.ale.lives() if crt_live < self.crt_live: # just lost a live done = True self.crt_live = crt_live if crt_live == 0: self.no_more_lives = True else: self.no_more_lives = False self.__obs = obs return obs, reward, done, info def _one_live_step(self, action): return self.env.step(action) class EvaluationMonitor(Wrapper): def __init__(self, env, cmdl): super(EvaluationMonitor, self).__init__(env) self.freq = cmdl.eval_frequency # in steps self.eval_steps = cmdl.eval_steps self.cmdl = cmdl if self.cmdl.display_plots: import Visdom self.vis = Visdom() self.plot = self.vis.line( Y=np.array([0]), X=np.array([0]), opts=dict( title=cmdl.label, caption="Episodic reward per %d steps." % self.eval_steps) ) self.eval_cnt = 0 self.crt_training_step = 0 self.step_cnt = 0 self.ep_cnt = 1 self.total_rw = 0 self.max_mean_rw = -1000 no_of_evals = cmdl.training_steps // cmdl.eval_frequency \ - (cmdl.eval_start-1) // cmdl.eval_frequency self.eval_frame_idx = torch.LongTensor(no_of_evals).fill_(0) self.eval_rw_per_episode = torch.FloatTensor(no_of_evals).fill_(0) self.eval_rw_per_frame = torch.FloatTensor(no_of_evals).fill_(0) self.eval_eps_per_eval = torch.LongTensor(no_of_evals).fill_(0) def get_crt_step(self, crt_training_step): self.crt_training_step = crt_training_step def _reset_monitor(self): self.step_cnt, self.ep_cnt, self.total_rw = 0, 0, 0 def _step(self, action): # self._before_step(action) observation, reward, done, info = self.env.step(action) done = self._after_step(observation, reward, done, info) return observation, reward, done, info def _reset(self): observation = self.env.reset() self._after_reset(observation) return observation def _after_step(self, o, r, done, info): self.total_rw += r self.step_cnt += 1 # Evaluation ends here if self.step_cnt == self.eval_steps: self._update() self._reset_monitor() return done def _after_reset(self, observation): if self.step_cnt != self.eval_steps: self.ep_cnt += 1 def _update(self): mean_rw = self.total_rw / (self.ep_cnt - 1) max_mean_rw = self.max_mean_rw self.max_mean_rw = mean_rw if mean_rw > max_mean_rw else max_mean_rw self._update_plot(self.crt_training_step, mean_rw) self._display_logs(mean_rw, max_mean_rw) self._update_reports(mean_rw) self.eval_cnt += 1 def _update_reports(self, mean_rw): idx = self.eval_cnt self.eval_frame_idx[idx] = self.crt_training_step self.eval_rw_per_episode[idx] = mean_rw self.eval_rw_per_frame[idx] = self.total_rw / self.step_cnt self.eval_eps_per_eval[idx] = (self.ep_cnt - 1) torch.save({ 'eval_frame_idx': self.eval_frame_idx, 'eval_rw_per_episode': self.eval_rw_per_episode, 'eval_rw_per_frame': self.eval_rw_per_frame, 'eval_eps_per_eval': self.eval_eps_per_eval }, self.cmdl.results_path + "/eval_stats.torch") def _update_plot(self, crt_training_step, mean_rw): if self.cmdl.display_plots: self.vis.line( X=np.array([crt_training_step]), Y=np.array([mean_rw]), win=self.plot, update='append' ) def _display_logs(self, mean_rw, max_mean_rw): bg_color = 'on_magenta' if mean_rw > max_mean_rw else 'on_blue' print(clr("[Evaluator] done in %5d steps. " % self.step_cnt, attrs=['bold']) + clr(" rw/ep=%3.2f " % mean_rw, 'white', bg_color, attrs=['bold'])) class VisdomMonitor(Wrapper): def __init__(self, env, cmdl): super(VisdomMonitor, self).__init__(env) self.freq = cmdl.report_freq # in steps self.cmdl = cmdl if self.cmdl.display_plots: from visdom import Visdom self.vis = Visdom() self.plot = self.vis.line( Y=np.array([0]), X=np.array([0]), opts=dict( title=cmdl.label, caption="Episodic reward per 1200 steps.") ) self.step_cnt = 0 self.ep_cnt = -1 self.ep_rw = [] self.last_reported_ep = 0 def _step(self, action): # self._before_step(action) observation, reward, done, info = self.env.step(action) done = self._after_step(observation, reward, done, info) return observation, reward, done, info def _reset(self): self._before_reset() observation = self.env.reset() self._after_reset(observation) return observation def _after_step(self, o, r, done, info): self.ep_rw[self.ep_cnt] += r self.step_cnt += 1 if self.step_cnt % self.freq == 0: self._update_plot() return done def _before_reset(self): self.ep_rw.append(0) def _after_reset(self, observation): self.ep_cnt += 1 # print("[%2d][%4d] RESET" % (self.ep_cnt, self.step_cnt)) def _update_plot(self): # print(self.last_reported_ep, self.ep_cnt + 1) completed_eps = self.ep_rw[self.last_reported_ep:self.ep_cnt + 1] ep_mean_reward = sum(completed_eps) / len(completed_eps) if self.cmdl.display_plots: self.vis.line( X=np.array([self.step_cnt]), Y=np.array([ep_mean_reward]), win=self.plot, update='append' ) self.last_reported_ep = self.ep_cnt + 1 class TestAtariWrappers(unittest.TestCase): def _test_env(self, env_name): env = gym.make(env_name) env = DoneAfterLostLife(env) o = env.reset() for i in range(10000): o, r, d, _ = env.step(env.action_space.sample()) if d: o = env.reset() print("%3d, %s, %d" % (i, env_name, env.unwrapped.ale.lives())) def test_pong(self): print("Testing Pong") self._test_env("Pong-v0") def test_frostbite(self): print("Testing Frostbite") self._test_env("Frostbite-v0") if __name__ == "__main__": import unittest unittest.main()