import os import tensorflow as tf import numpy as np from .base import BaseWrapper class AutoResetWrapper(BaseWrapper): def __init__(self, env, max_frames=None): super(AutoResetWrapper, self).__init__(env) self.max_frames = max_frames def step(self, action, indices=None, name=None): rew, done = self.env.step(action=action, indices=indices, name=name) if indices is None: indices = np.arange(self.batch_size, dtype=np.int32) done_idxs = tf.boolean_mask(indices, done) with tf.control_dependencies([self.reset(done_idxs, max_frames=self.max_frames)]): return tf.identity(rew), tf.identity(done)