import csv import collections import contextlib import tqdm import numpy as np import six from abc import ABCMeta, abstractmethod from contextlib import contextmanager import tensorflow as tf gfile = tf.gfile @six.add_metaclass(ABCMeta) class ParasolEnvironment(object): def __init__(self, sliding_window=0): self.recording = False self.episode_number = 1 self.currently_logging = False self.sliding_window = sliding_window self.prev_obs = None @abstractmethod def reset(self): pass @abstractmethod def step(self, action): pass @abstractmethod def render(self, mode='human'): pass @contextmanager def video(self, video_path): self.recording = True self.start_recording(video_path) yield self.recording = False self.stop_recording() @abstractmethod def start_recording(self, video_path): pass @abstractmethod def grab_frame(self): pass @abstractmethod def stop_recording(self): pass @abstractmethod def config(self): pass @abstractmethod def state_dim(self): pass @abstractmethod def action_dim(self): pass def cost_fn(self, s, a): return np.zeros(s.shape[0]) def get_state_dim(self): return self.state_dim() * (1 + self.sliding_window) def get_action_dim(self): return self.action_dim() def is_recording(self): return self.recording def is_image(self): return False def image_size(self): return None @abstractmethod def make_summary(self, observations, name): pass @abstractmethod def _observe(self): pass def observe(self): if self.sliding_window == 0: return self._observe() curr_obs = self._observe() if self.prev_obs is None: self.prev_obs = [curr_obs] * self.sliding_window obs = [curr_obs] + self.prev_obs self.prev_obs = obs[:-1] return np.concatenate(obs, 0) def rollout(self, num_steps, policy=None, render=False, show_progress=False, init_std=1, noise=None): if policy is None: def policy(_, t, noise=None): return np.random.normal(size=self.get_action_dim(), scale=init_std) states, actions, costs = ( np.zeros([num_steps] + [self.get_state_dim()]), np.zeros([num_steps] + [self.get_action_dim()]), np.zeros([num_steps]) ) infos = collections.defaultdict(list) current_state = self.reset() times = tqdm.trange(num_steps, desc='Rollout') if show_progress else range(num_steps) for t in times: states[t] = current_state if render: self.render(mode='human') if self.is_recording(): self.render(mode='rgb_array') self.grab_frame() n = None if noise is not None: n = noise[t] actions[t] = policy(states, actions, t, noise=n) current_state, costs[t], done, info = self.step(actions[t]) for k, v in info.items(): infos[k].append(v) if self.currently_logging: log_entry = collections.OrderedDict() log_entry['episode_number'] = self.episode_number log_entry['mean_cost'] = costs.mean() log_entry['total_cost'] = costs.sum() log_entry['final_cost'] = costs[-1] for k, v in infos.items(): v = np.array(v) log_entry['mean_%s' % k] = v.mean() log_entry['total_%s' % k] = v.sum() log_entry['final_%s' % k] = v[-1] self.log_entry(log_entry) self.episode_number += 1 return states, actions, costs, infos def rollouts(self, num_rollouts, num_steps, show_progress=False, noise=None, callback=lambda x: None, **kwargs): states, actions, costs = ( np.empty([num_rollouts, num_steps] + [self.get_state_dim()]), np.empty([num_rollouts, num_steps] + [self.get_action_dim()]), np.empty([num_rollouts, num_steps]) ) infos = [None] * num_rollouts rollouts = tqdm.trange(num_rollouts, desc='Rollouts') if show_progress else range(num_rollouts) for i in rollouts: with contextlib.ExitStack() as stack: context = callback(i) if context is not None: stack.enter_context(callback(i)) n = None if noise is not None: n = noise() states[i], actions[i], costs[i], infos[i] = \ self.rollout(num_steps, noise=n,**kwargs) return states, actions, costs, infos def get_config(self): config = self.config().copy() config['environment_name'] = self.environment_name return config @contextmanager def logging(self, log_file, **kwargs): self.start_logging(log_file, **kwargs) yield self.stop_logging() def log_entry(self, entry): self.log_entries.append(entry) def start_logging(self, log_file, verbose=False): self.log_file = log_file self.log_entries = [] self.currently_logging = True self.verbose_logging = verbose def stop_logging(self): if len(self.log_entries) > 0: with gfile.GFile(self.log_file, 'a+') as fp: log_writer = csv.writer(fp) if (self.episode_number - len(self.log_entries) - 1) == 0: log_writer.writerow(self.log_entries[0].keys()) for entry in self.log_entries: log_writer.writerow(entry.values()) if self.verbose_logging: for k in self.log_entries[0].keys(): if k == 'episode_number': continue metric = np.array([l[k] for l in self.log_entries]) print("Average %s: %.3f +/- %.3f" % (k, metric.mean(), metric.std())) self.log_entries = None self.log_file = None self.currently_logging = False