# coding=utf-8 # Copyright 2020 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for RL training.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import math import random from gym.spaces import Box import numpy as np import six from tensor2tensor.data_generators.gym_env import T2TGymEnv from tensor2tensor.layers import common_layers from tensor2tensor.layers import common_video from tensor2tensor.models.research import rl from tensor2tensor.rl.dopamine_connector import DQNLearner from tensor2tensor.rl.envs.simulated_batch_env import PIL_Image from tensor2tensor.rl.envs.simulated_batch_env import PIL_ImageDraw from tensor2tensor.rl.envs.simulated_batch_gym_env import SimulatedBatchGymEnv from tensor2tensor.rl.ppo_learner import PPOLearner from tensor2tensor.utils import misc_utils from tensor2tensor.utils import trainer_lib import tensorflow.compat.v1 as tf def compute_mean_reward(rollouts, clipped): """Calculate mean rewards from given epoch.""" reward_name = "reward" if clipped else "unclipped_reward" rewards = [] for rollout in rollouts: if rollout[-1].done: rollout_reward = sum(getattr(frame, reward_name) for frame in rollout) rewards.append(rollout_reward) if rewards: mean_rewards = np.mean(rewards) else: mean_rewards = 0 return mean_rewards def get_metric_name(sampling_temp, max_num_noops, clipped): return "mean_reward/eval/sampling_temp_{}_max_noops_{}_{}".format( sampling_temp, max_num_noops, "clipped" if clipped else "unclipped" ) def _eval_fn_with_learner( env, hparams, policy_hparams, policy_dir, sampling_temp ): env_fn = rl.make_real_env_fn(env) learner = LEARNERS[hparams.base_algo]( hparams.frame_stack_size, base_event_dir=None, agent_model_dir=policy_dir, total_num_epochs=1 ) learner.evaluate(env_fn, policy_hparams, sampling_temp) def evaluate_single_config( hparams, sampling_temp, max_num_noops, agent_model_dir, eval_fn=_eval_fn_with_learner ): """Evaluate the PPO agent in the real environment.""" tf.logging.info("Evaluating metric %s", get_metric_name( sampling_temp, max_num_noops, clipped=False )) eval_hparams = trainer_lib.create_hparams(hparams.base_algo_params) env = setup_env( hparams, batch_size=hparams.eval_batch_size, max_num_noops=max_num_noops, rl_env_max_episode_steps=hparams.eval_rl_env_max_episode_steps, env_name=hparams.rl_env_name) env.start_new_epoch(0) eval_fn(env, hparams, eval_hparams, agent_model_dir, sampling_temp) rollouts = env.current_epoch_rollouts() env.close() return tuple( compute_mean_reward(rollouts, clipped) for clipped in (True, False) ) def evaluate_all_configs( hparams, agent_model_dir, eval_fn=_eval_fn_with_learner ): """Evaluate the agent with multiple eval configurations.""" metrics = {} # Iterate over all combinations of sampling temperatures and whether to do # initial no-ops. for sampling_temp in hparams.eval_sampling_temps: # Iterate over a set so if eval_max_num_noops == 0 then it's 1 iteration. for max_num_noops in set([hparams.eval_max_num_noops, 0]): scores = evaluate_single_config( hparams, sampling_temp, max_num_noops, agent_model_dir, eval_fn ) for (score, clipped) in zip(scores, (True, False)): metric_name = get_metric_name(sampling_temp, max_num_noops, clipped) metrics[metric_name] = score return metrics def evaluate_world_model( real_env, hparams, world_model_dir, debug_video_path, split=tf.estimator.ModeKeys.EVAL, ): """Evaluate the world model (reward accuracy).""" frame_stack_size = hparams.frame_stack_size rollout_subsequences = [] def initial_frame_chooser(batch_size): assert batch_size == len(rollout_subsequences) return np.stack([ [frame.observation.decode() for frame in subsequence[:frame_stack_size]] # pylint: disable=g-complex-comprehension for subsequence in rollout_subsequences ]) env_fn = rl.make_simulated_env_fn_from_hparams( real_env, hparams, batch_size=hparams.wm_eval_batch_size, initial_frame_chooser=initial_frame_chooser, model_dir=world_model_dir ) sim_env = env_fn(in_graph=False) subsequence_length = int( max(hparams.wm_eval_rollout_ratios) * hparams.simulated_rollout_length ) rollouts = real_env.current_epoch_rollouts( split=split, minimal_rollout_frames=(subsequence_length + frame_stack_size) ) video_writer = common_video.WholeVideoWriter( fps=10, output_path=debug_video_path, file_format="avi" ) reward_accuracies_by_length = { int(ratio * hparams.simulated_rollout_length): [] for ratio in hparams.wm_eval_rollout_ratios } for _ in range(hparams.wm_eval_num_batches): rollout_subsequences[:] = random_rollout_subsequences( rollouts, hparams.wm_eval_batch_size, subsequence_length + frame_stack_size ) eval_subsequences = [ subsequence[(frame_stack_size - 1):] for subsequence in rollout_subsequences ] # Check that the initial observation is the same in the real and simulated # rollout. sim_init_obs = sim_env.reset() def decode_real_obs(index): return np.stack([ subsequence[index].observation.decode() for subsequence in eval_subsequences # pylint: disable=cell-var-from-loop ]) real_init_obs = decode_real_obs(0) assert np.all(sim_init_obs == real_init_obs) debug_frame_batches = [] def append_debug_frame_batch(sim_obs, real_obs, sim_cum_rews, real_cum_rews, sim_rews, real_rews): """Add a debug frame.""" rews = [[sim_cum_rews, sim_rews], [real_cum_rews, real_rews]] headers = [] for j in range(len(sim_obs)): local_nps = [] for i in range(2): img = PIL_Image().new("RGB", (sim_obs.shape[-2], 11),) draw = PIL_ImageDraw().Draw(img) draw.text((0, 0), "c:{:3}, r:{:3}".format(int(rews[i][0][j]), int(rews[i][1][j])), fill=(255, 0, 0)) local_nps.append(np.asarray(img)) local_nps.append(np.zeros_like(local_nps[0])) headers.append(np.concatenate(local_nps, axis=1)) errs = absolute_hinge_difference(sim_obs, real_obs) headers = np.stack(headers) debug_frame_batches.append( # pylint: disable=cell-var-from-loop np.concatenate([headers, np.concatenate([sim_obs, real_obs, errs], axis=2)], axis=1) ) append_debug_frame_batch(sim_init_obs, real_init_obs, np.zeros(hparams.wm_eval_batch_size), np.zeros(hparams.wm_eval_batch_size), np.zeros(hparams.wm_eval_batch_size), np.zeros(hparams.wm_eval_batch_size)) (sim_cum_rewards, real_cum_rewards) = ( np.zeros(hparams.wm_eval_batch_size) for _ in range(2) ) for i in range(subsequence_length): actions = [subsequence[i].action for subsequence in eval_subsequences] (sim_obs, sim_rewards, _) = sim_env.step(actions) sim_cum_rewards += sim_rewards real_rewards = np.array([ subsequence[i + 1].reward for subsequence in eval_subsequences ]) real_cum_rewards += real_rewards for (length, reward_accuracies) in six.iteritems( reward_accuracies_by_length ): if i + 1 == length: reward_accuracies.append( np.sum(sim_cum_rewards == real_cum_rewards) / len(real_cum_rewards) ) real_obs = decode_real_obs(i + 1) append_debug_frame_batch(sim_obs, real_obs, sim_cum_rewards, real_cum_rewards, sim_rewards, real_rewards) for debug_frames in np.stack(debug_frame_batches, axis=1): debug_frame = None for debug_frame in debug_frames: video_writer.write(debug_frame) if debug_frame is not None: # Append two black frames for aesthetics. for _ in range(2): video_writer.write(np.zeros_like(debug_frame)) video_writer.finish_to_disk() return { "reward_accuracy/at_{}".format(length): np.mean(reward_accuracies) for (length, reward_accuracies) in six.iteritems( reward_accuracies_by_length ) } def summarize_metrics(eval_metrics_writer, metrics, epoch): """Write metrics to summary.""" for (name, value) in six.iteritems(metrics): summary = tf.Summary() summary.value.add(tag=name, simple_value=value) eval_metrics_writer.add_summary(summary, epoch) eval_metrics_writer.flush() LEARNERS = { "ppo": PPOLearner, "dqn": DQNLearner, } ATARI_GAME_MODE = "NoFrameskip-v4" def full_game_name(short_name): """CamelCase game name with mode suffix. Args: short_name: snake_case name without mode e.g "crazy_climber" Returns: full game name e.g. "CrazyClimberNoFrameskip-v4" """ camel_game_name = misc_utils.snakecase_to_camelcase(short_name) full_name = camel_game_name + ATARI_GAME_MODE return full_name def should_apply_max_and_skip_env(hparams): """MaxAndSkipEnv doesn't make sense for some games, so omit it if needed.""" return hparams.game != "tictactoe" def setup_env(hparams, batch_size, max_num_noops, rl_env_max_episode_steps=-1, env_name=None): """Setup.""" if not env_name: env_name = full_game_name(hparams.game) maxskip_envs = should_apply_max_and_skip_env(hparams) env = T2TGymEnv( base_env_name=env_name, batch_size=batch_size, grayscale=hparams.grayscale, should_derive_observation_space=hparams .rl_should_derive_observation_space, resize_width_factor=hparams.resize_width_factor, resize_height_factor=hparams.resize_height_factor, rl_env_max_episode_steps=rl_env_max_episode_steps, max_num_noops=max_num_noops, maxskip_envs=maxskip_envs, sticky_actions=hparams.sticky_actions ) return env def update_hparams_from_hparams(target_hparams, source_hparams, prefix): """Copy a subset of hparams to target_hparams.""" for (param_name, param_value) in six.iteritems(source_hparams.values()): if param_name.startswith(prefix): target_hparams.set_hparam(param_name[len(prefix):], param_value) def random_rollout_subsequences(rollouts, num_subsequences, subsequence_length): """Chooses a random frame sequence of given length from a set of rollouts.""" def choose_subsequence(): # TODO(koz4k): Weigh rollouts by their lengths so sampling is uniform over # frames and not rollouts. rollout = random.choice(rollouts) try: from_index = random.randrange(len(rollout) - subsequence_length + 1) except ValueError: # Rollout too short; repeat. return choose_subsequence() return rollout[from_index:(from_index + subsequence_length)] return [choose_subsequence() for _ in range(num_subsequences)] def make_initial_frame_chooser( real_env, frame_stack_size, simulation_random_starts, simulation_flip_first_random_for_beginning, split=tf.estimator.ModeKeys.TRAIN, ): """Make frame chooser. Args: real_env: T2TEnv to take initial frames from. frame_stack_size (int): Number of consecutive frames to extract. simulation_random_starts (bool): Whether to choose frames at random. simulation_flip_first_random_for_beginning (bool): Whether to flip the first frame stack in every batch for the frames at the beginning. split (tf.estimator.ModeKeys or None): Data split to take the frames from, None means use all frames. Returns: Function batch_size -> initial_frames. """ initial_frame_rollouts = real_env.current_epoch_rollouts( split=split, minimal_rollout_frames=frame_stack_size, ) def initial_frame_chooser(batch_size): """Frame chooser.""" deterministic_initial_frames =\ initial_frame_rollouts[0][:frame_stack_size] if not simulation_random_starts: # Deterministic starts: repeat first frames from the first rollout. initial_frames = [deterministic_initial_frames] * batch_size else: # Random starts: choose random initial frames from random rollouts. initial_frames = random_rollout_subsequences( initial_frame_rollouts, batch_size, frame_stack_size ) if simulation_flip_first_random_for_beginning: # Flip first entry in the batch for deterministic initial frames. initial_frames[0] = deterministic_initial_frames return np.stack([ [frame.observation.decode() for frame in initial_frame_stack] # pylint: disable=g-complex-comprehension for initial_frame_stack in initial_frames ]) return initial_frame_chooser def absolute_hinge_difference(arr1, arr2, min_diff=10, dtype=np.uint8): """Point-wise, hinge loss-like, difference between arrays. Args: arr1: integer array to compare. arr2: integer array to compare. min_diff: minimal difference taken into consideration. dtype: dtype of returned array. Returns: array """ diff = np.abs(arr1.astype(np.int) - arr2, dtype=np.int) return np.maximum(diff - min_diff, 0).astype(dtype) # TODO(koz4k): Use this function in player and all debug videos. def augment_observation( observation, reward, cum_reward, frame_index, bar_color=None, header_height=27 ): """Augments an observation with debug info.""" img = PIL_Image().new( "RGB", (observation.shape[1], header_height,) ) draw = PIL_ImageDraw().Draw(img) draw.text( (1, 0), "c:{:3}, r:{:3}".format(int(cum_reward), int(reward)), fill=(255, 0, 0) ) draw.text( (1, 15), "f:{:3}".format(int(frame_index)), fill=(255, 0, 0) ) header = np.copy(np.asarray(img)) del img if bar_color is not None: header[0, :, :] = bar_color return np.concatenate([header, observation], axis=0) def run_rollouts( env, agent, initial_observations, step_limit=None, discount_factor=1.0, log_every_steps=None, video_writers=(), color_bar=False, many_rollouts_from_each_env=False ): """Runs a batch of rollouts from given initial observations.""" assert step_limit is not None or not many_rollouts_from_each_env, ( "When collecting many rollouts from each environment, time limit must " "be set." ) num_dones = 0 first_dones = np.array([False] * env.batch_size) observations = initial_observations step_index = 0 cum_rewards = np.zeros(env.batch_size) for (video_writer, obs_stack) in zip(video_writers, initial_observations): for (i, ob) in enumerate(obs_stack): debug_frame = augment_observation( ob, reward=0, cum_reward=0, frame_index=(-len(obs_stack) + i + 1), bar_color=((0, 255, 0) if color_bar else None) ) video_writer.write(debug_frame) def proceed(): if step_index < step_limit: return num_dones < env.batch_size or many_rollouts_from_each_env else: return False while proceed(): act_kwargs = {} if agent.needs_env_state: act_kwargs["env_state"] = env.state actions = agent.act(observations, **act_kwargs) (observations, rewards, dones) = env.step(actions) observations = list(observations) now_done_indices = [] for (i, done) in enumerate(dones): if done and (not first_dones[i] or many_rollouts_from_each_env): now_done_indices.append(i) first_dones[i] = True num_dones += 1 if now_done_indices: # Unless many_rollouts_from_each_env, reset only envs done the first time # in this timestep to ensure that we collect exactly 1 rollout from each # env. reset_observations = env.reset(now_done_indices) for (i, observation) in zip(now_done_indices, reset_observations): observations[i] = observation observations = np.array(observations) cum_rewards[~first_dones] = ( cum_rewards[~first_dones] * discount_factor + rewards[~first_dones] ) step_index += 1 for (video_writer, obs_stack, reward, cum_reward, done) in zip( video_writers, observations, rewards, cum_rewards, first_dones ): if done: continue ob = obs_stack[-1] debug_frame = augment_observation( ob, reward=reward, cum_reward=cum_reward, frame_index=step_index, bar_color=((255, 0, 0) if color_bar else None) ) video_writer.write(debug_frame) # TODO(afrozm): Clean this up with tf.logging.log_every_n if log_every_steps is not None and step_index % log_every_steps == 0: tf.logging.info("Step %d, mean_score: %f", step_index, cum_rewards.mean()) return (observations, cum_rewards) class BatchAgent(object): """Python API for agents. Runs a batch of parallel agents. Operates on Numpy arrays. """ needs_env_state = False records_own_videos = False def __init__(self, batch_size, observation_space, action_space): self.batch_size = batch_size self.observation_space = observation_space self.action_space = action_space def act(self, observations, env_state=None): """Picks actions based on observations. Args: observations: A batch of observations. env_state: State. Returns: A batch of actions. """ raise NotImplementedError def estimate_value(self, observations): """Estimates values of states based on observations. Used for temporal-difference planning. Args: observations: A batch of observations. Returns: A batch of values. """ raise NotImplementedError def action_distribution(self, observations): """Calculates action distribution based on observations. Used for temporal-difference planning. Args: observations: A batch of observations. Returns: A batch of action probabilities. """ raise NotImplementedError class RandomAgent(BatchAgent): """Random agent, sampling actions from the uniform distribution.""" def act(self, observations, env_state=None): del env_state return np.array([ self.action_space.sample() for _ in range(observations.shape[0]) ]) def estimate_value(self, observations): return np.zeros(observations.shape[0]) def action_distribution(self, observations): return np.full( (observations.shape[0], self.action_space.n), 1.0 / self.action_space.n ) class PolicyAgent(BatchAgent): """Agent based on a policy network.""" def __init__( self, batch_size, observation_space, action_space, policy_hparams, policy_dir, sampling_temp ): super(PolicyAgent, self).__init__( batch_size, observation_space, action_space ) self._sampling_temp = sampling_temp with tf.Graph().as_default(): self._observations_t = tf.placeholder( shape=((batch_size,) + self.observation_space.shape), dtype=self.observation_space.dtype ) (logits, self._values_t) = rl.get_policy( self._observations_t, policy_hparams, self.action_space ) actions = common_layers.sample_with_temperature(logits, sampling_temp) self._probs_t = tf.nn.softmax(logits / sampling_temp) self._actions_t = tf.cast(actions, tf.int32) model_saver = tf.train.Saver( tf.global_variables(policy_hparams.policy_network + "/.*") # pylint: disable=unexpected-keyword-arg ) self._sess = tf.Session() self._sess.run(tf.global_variables_initializer()) trainer_lib.restore_checkpoint(policy_dir, model_saver, self._sess) def _run(self, observations): return self._sess.run( [self._actions_t, self._values_t, self._probs_t], feed_dict={self._observations_t: observations} ) def act(self, observations, env_state=None): del env_state (actions, _, _) = self._run(observations) return actions def estimate_value(self, observations): (_, values, _) = self._run(observations) return values def action_distribution(self, observations): (_, _, probs) = self._run(observations) return probs class PlannerAgent(BatchAgent): """Agent based on temporal difference planning.""" needs_env_state = True records_own_videos = True def __init__( self, batch_size, rollout_agent, sim_env, wrapper_fn, num_rollouts, planning_horizon, discount_factor=1.0, uct_const=0, uniform_first_action=True, normalizer_window_size=30, normalizer_epsilon=0.001, video_writers=(), ): super(PlannerAgent, self).__init__( batch_size, rollout_agent.observation_space, rollout_agent.action_space ) self._rollout_agent = rollout_agent self._sim_env = sim_env self._wrapped_env = wrapper_fn(sim_env) self._num_rollouts = num_rollouts self._num_batches = num_rollouts // rollout_agent.batch_size self._discount_factor = discount_factor self._planning_horizon = planning_horizon self._uct_const = uct_const self._uniform_first_action = uniform_first_action self._normalizer_window_size = normalizer_window_size self._normalizer_epsilon = normalizer_epsilon self._video_writers = video_writers self._best_mc_values = [[] for _ in range(self.batch_size)] def act(self, observations, env_state=None): def run_batch_from(observation, planner_index, batch_index): """Run a batch of actions.""" repeated_observation = np.array( [observation] * self._wrapped_env.batch_size ) actions = self._get_first_actions(repeated_observation) self._wrapped_env.set_initial_state( initial_state=[ copy.deepcopy(env_state[planner_index]) for _ in range(self._sim_env.batch_size) ], initial_frames=repeated_observation ) self._wrapped_env.reset() (initial_observations, initial_rewards, _) = self._wrapped_env.step( actions ) video_writers = () if planner_index < len(self._video_writers) and batch_index == 0: video_writers = (self._video_writers[planner_index],) (final_observations, cum_rewards) = run_rollouts( self._wrapped_env, self._rollout_agent, initial_observations, discount_factor=self._discount_factor, step_limit=self._planning_horizon, video_writers=video_writers, color_bar=True) values = self._rollout_agent.estimate_value(final_observations) total_values = ( initial_rewards + self._discount_factor * cum_rewards + self._discount_factor ** (self._planning_horizon + 1) * values ) return list(zip(actions, total_values)) def run_batches_from(observation, planner_index): sums = {a: 0 for a in range(self.action_space.n)} counts = copy.copy(sums) for i in range(self._num_batches): for (action, total_value) in run_batch_from( observation, planner_index, i ): sums[action] += total_value counts[action] += 1 return {a: (sums[a], counts[a]) for a in sums} def choose_best_action(observation, planner_index): """Choose the best action, update best Monte Carlo values.""" best_mc_values = self._best_mc_values[planner_index] action_probs = self._rollout_agent.action_distribution( np.array([observation] * self._rollout_agent.batch_size) )[0, :] sums_and_counts = run_batches_from(observation, planner_index) def monte_carlo_value(action): (value_sum, count) = sums_and_counts[action] if count > 0: mean_value = value_sum / count else: mean_value = -np.inf return mean_value mc_values = np.array( [monte_carlo_value(action) for action in range(self.action_space.n)] ) best_mc_values.append(mc_values.max()) normalizer = max( np.std(best_mc_values[-self._normalizer_window_size:]), self._normalizer_epsilon ) normalized_mc_values = mc_values / normalizer uct_bonuses = np.array( [self._uct_bonus(sums_and_counts[action][1], action_probs[action]) for action in range(self.action_space.n)] ) values = normalized_mc_values + uct_bonuses return np.argmax(values) return np.array([ choose_best_action(observation, i) for (i, observation) in enumerate(observations) ]) def _uct_bonus(self, count, prob): return self._uct_const * prob * math.sqrt( math.log(self._num_rollouts) / (1 + count) ) def _get_first_actions(self, observations): if self._uniform_first_action: return np.array([ int(x) for x in np.linspace( 0, self.action_space.n, self._rollout_agent.batch_size + 1 ) ])[:self._rollout_agent.batch_size] else: return list(sorted(self._rollout_agent.act(observations))) # TODO(koz4k): Unify interfaces of batch envs. class BatchWrapper(object): """Base class for batch env wrappers.""" def __init__(self, env): self.env = env self.batch_size = env.batch_size self.observation_space = env.observation_space self.action_space = env.action_space self.reward_range = env.reward_range def reset(self, indices=None): return self.env.reset(indices) def step(self, actions): return self.env.step(actions) def close(self): self.env.close() class BatchStackWrapper(BatchWrapper): """Out-of-graph batch stack wrapper. Its behavior is consistent with tf_atari_wrappers.StackWrapper. """ def __init__(self, env, stack_size): super(BatchStackWrapper, self).__init__(env) self.stack_size = stack_size inner_space = env.observation_space self.observation_space = Box( low=np.array([inner_space.low] * self.stack_size), high=np.array([inner_space.high] * self.stack_size), dtype=inner_space.dtype, ) self._history_buffer = np.zeros( (self.batch_size,) + self.observation_space.shape, dtype=inner_space.dtype ) self._initial_frames = None @property def state(self): """Gets the current state.""" return self.env.state def set_initial_state(self, initial_state, initial_frames): """Sets the state that will be used on next reset.""" self.env.set_initial_state(initial_state, initial_frames) self._initial_frames = initial_frames def reset(self, indices=None): if indices is None: indices = range(self.batch_size) observations = self.env.reset(indices) try: # If we wrap the simulated env, take the initial frames from there. assert self.env.initial_frames.shape[1] == self.stack_size self._history_buffer[...] = self.env.initial_frames except AttributeError: # Otherwise, check if set_initial_state was called and we can take the # frames from there. if self._initial_frames is not None: for (index, observation) in zip(indices, observations): assert (self._initial_frames[index, -1, ...] == observation).all() self._history_buffer[index, ...] = self._initial_frames[index, ...] else: # Otherwise, repeat the first observation stack_size times. for (index, observation) in zip(indices, observations): self._history_buffer[index, ...] = [observation] * self.stack_size return self._history_buffer def step(self, actions): (observations, rewards, dones) = self.env.step(actions) self._history_buffer = np.roll(self._history_buffer, shift=-1, axis=1) self._history_buffer[:, -1, ...] = observations return (self._history_buffer, rewards, dones) class SimulatedBatchGymEnvWithFixedInitialFrames(BatchWrapper): """Wrapper for SimulatedBatchGymEnv that allows to fix initial frames.""" def __init__(self, *args, **kwargs): self.initial_frames = None def initial_frame_chooser(batch_size): assert batch_size == self.initial_frames.shape[0] return self.initial_frames env = SimulatedBatchGymEnv( *args, initial_frame_chooser=initial_frame_chooser, **kwargs ) super(SimulatedBatchGymEnvWithFixedInitialFrames, self).__init__(env) @property def state(self): """Gets the current state.""" return [None] * self.batch_size def set_initial_state(self, initial_state, initial_frames): """Sets the state that will be used on next reset.""" del initial_state self.initial_frames = initial_frames