import os import time from .wrapper import Wrapper import imageio from multiprocessing import Process, Queue import numpy as np # File taken from 5d4757b9d2789c34a1120f4c6e4452e4cb7a11fc by Zihua def save_video(frame_queue, filename, fps): ''' Target function for video process. Opens up file with path and uses library imageio Args: frame_queue: a queue of frames to be store. If the frame is None, the capture is over filename: filename to which the capture is to be stored fps: framerate. Note that Queue.get() is a blocking function and thus will hang until new frames are added to frame_queue ''' writer = imageio.get_writer(filename, fps=fps) while True: frame = frame_queue.get() if frame is None: break writer.append_data(frame) writer.close() class VideoWrapper(Wrapper): ''' Environment Wrappers for automatically rendering and saving test runs. Attributes: public attributes env (Env): environment to be wrapped capture_interval (int): number of episodes between captures default to 10 frame_interval (int): number of frames between each recorded frame default to 10 fps (int): frame rate default to 30 ext (str): file extention. Either .gif or .mp4 depending on input default to .mp4 save_folder (str): directory to which the files are to be saved default to 'main-experiment-folder/videos' max_videos (int): maximum number of videos allowed in a directory default to 10 helper attributes num_eps (int): number of episodes executed num_steps (int): number of steps executed is_recording (bool): whether the on going episode is being recorded path_queue (Queue): to keep track of saved videos to maintain max number of captures num_paths (int): number of paths stored in path_queue. necessary because Queue.qsize() is not implemented on Mac. video_process (Process): separate process that writes images to file video_queue (Queue): queue of frames to be writen to file ''' def __init__(self, env, env_config, session_config, frame_interval = 10, fps =30, use_gif = False): ''' Constructor for VideoWrapper. also creates the save directory if not present Args: env (Env): environment to be wrapped capture_interval (int): number of episodes between captures frame_interval (int): number of frames between each recorded frame fps (int): frame rate save_folder (str): directory to which the files are to be saved max_videos (int): maximum number of videos allowed in a directory use_gif (bool): boolean flag to use either gif or mp4 ''' super().__init__(env) self.env = env self.env_category = env_config.env_name.split(':')[0] self.max_videos = env_config.video.max_videos self.capture_interval = env_config.video.record_every self.frame_interval = frame_interval self.fps = fps self.ext = '.gif' if use_gif else '.mp4' self.save_folder = env_config.video.save_folder if not self.save_folder: self.save_folder = os.path.join(session_config.folder, 'videos') self.save_folder = os.path.expanduser(self.save_folder) if not os.path.exists(self.save_folder): os.makedirs(self.save_folder) self.num_eps = 0 self.num_steps = 0 self.is_recording = False self.path_queue = Queue() self.num_paths = 0 #work around, qsize() not implemented on mac def _reset(self, **kwargs): ''' Overwrites reset method. in addition to reseting the environment, this method also manages the video process and when to start writing video or gif to file ''' self.num_steps = 0 if self.is_recording: self.stop_record() if self.num_eps % self.capture_interval == 0: self.video_queue = Queue() self.is_recording = True path = os.path.join(self.save_folder, 'video_eps_{}{}'.format(self.num_eps, self.ext)) if self.num_paths >= self.max_videos: dep_path = self.path_queue.get() os.remove(dep_path) self.num_paths -= 1 self.path_queue.put(path) self.video_process = Process(target=save_video, args=(self.video_queue, path, self.fps)) self.video_process.start() self.num_paths += 1 state = self.env.reset(**kwargs) self.num_eps += 1 return state def _step(self, action): ''' Overwrites _step function. In addition to taking an action, if the video is recording and its time to capture a frame, the frame is rendered using 'rgb_array' mode and is put into the video_queue. If video capture is over, stop_recording is called. ''' state, step_reward, terminal, info = self.env.step(action) self.num_steps += 1 if self.is_recording: # For dm_control videos, the video will be transposed when we save to disk # We correct for that here ob = self.env.render() if self.env_category == 'dm_control': ob = ob.transpose(1, 0, 2) else: # mujocomanip ob = np.rot90(ob, 2) self.video_queue.put(ob) return state, step_reward, terminal, info def stop_record(self): ''' stops recording and wait for the video_process to finish writing to file and join First puts a None (End of Video) frame into the frame queue and wait for writing process to terminate. ''' self.video_queue.put(None) self.video_process.join() self.is_recording = False print('finished recording video {}'.format(self.num_eps))