from __future__ import absolute_import from __future__ import division from __future__ import print_function import gin import numpy as np import tensorflow as tf import tensorflow_probability as tfp from tf_agents.agents import tf_agent from tf_agents.networks import network from tf_agents.policies import tf_policy from tf_agents.specs import tensor_spec from tf_agents.trajectories import policy_step from tf_agents.trajectories import time_step as ts from tf_agents.trajectories import trajectory from tf_agents.utils import common from tf_agents.utils import eager_utils from tf_agents.utils import nest_utils from slac.utils import common as slac_common from slac.utils import gif_utils from slac.utils import nest_utils as slac_nest_utils tfd = tfp.distributions def _gif_summary(name, images, fps, saturate=False, step=None): images = tf.image.convert_image_dtype(images, tf.uint8, saturate=saturate) output = tf.concat(tf.unstack(images), axis=2)[None] gif_utils.gif_summary_v2(name, output, 1, fps, step=step) def _gif_and_image_summary(name, images, fps, saturate=False, step=None): images = tf.image.convert_image_dtype(images, tf.uint8, saturate=saturate) output = tf.concat(tf.unstack(images), axis=2)[None] gif_utils.gif_summary_v2(name, output, 1, fps, step=step) output = tf.concat(tf.unstack(images), axis=2) output = tf.concat(tf.unstack(output), axis=0)[None] tf.contrib.summary.image(name, output, step=step) def filter_before_first_step(time_steps, actions=None): flat_time_steps = tf.nest.flatten(time_steps) flat_time_steps = [tf.unstack(time_step, axis=1) for time_step in flat_time_steps] time_steps = [tf.nest.pack_sequence_as(time_steps, time_step) for time_step in zip(*flat_time_steps)] if actions is None: actions = [None] * len(time_steps) else: actions = tf.unstack(actions, axis=1) assert len(time_steps) == len(actions) time_steps = list(reversed(time_steps)) actions = list(reversed(actions)) filtered_time_steps = [] filtered_actions = [] for t, (time_step, action) in enumerate(zip(time_steps, actions)): if t == 0: reset_mask = tf.equal(time_step.step_type, ts.StepType.FIRST) else: time_step = tf.nest.map_structure(lambda x, y: tf.where(reset_mask, x, y), last_time_step, time_step) action = tf.where(reset_mask, tf.zeros_like(action), action) if action is not None else None filtered_time_steps.append(time_step) filtered_actions.append(action) reset_mask = tf.logical_or( reset_mask, tf.equal(time_step.step_type, ts.StepType.FIRST)) last_time_step = time_step filtered_time_steps = list(reversed(filtered_time_steps)) filtered_actions = list(reversed(filtered_actions)) filtered_flat_time_steps = [tf.nest.flatten(time_step) for time_step in filtered_time_steps] filtered_flat_time_steps = [tf.stack(time_step, axis=1) for time_step in zip(*filtered_flat_time_steps)] filtered_time_steps = tf.nest.pack_sequence_as(filtered_time_steps[0], filtered_flat_time_steps) if action is None: return filtered_time_steps else: actions = tf.stack(filtered_actions, axis=1) return filtered_time_steps, actions class ActorSequencePolicy(tf_policy.Base): def __init__(self, time_step_spec=None, action_spec=None, info_spec=(), actor_network=None, model_network=None, compressor_network=None, sequence_length=2, actor_input='state', control_timestep=None, num_images_per_summary=1, debug_summaries=False, name=None): if not isinstance(actor_network, network.Network): raise ValueError('actor_network must be a network.Network. Found ' '{}.'.format(type(actor_network))) self._actor_network = actor_network self._model_network = model_network self._compressor_network = compressor_network self._sequence_length = sequence_length self._actor_input = actor_input self._control_timestep = control_timestep self._num_images_per_summary = num_images_per_summary self._debug_summaries = debug_summaries def _add_time_dimension(spec): return tensor_spec.TensorSpec( (sequence_length,) + tuple(spec.shape), spec.dtype, spec.name) time_steps_spec = tf.nest.map_structure( _add_time_dimension, time_step_spec) actions_spec = tf.nest.map_structure( _add_time_dimension, action_spec) policy_state_spec = ( actor_network.state_spec, time_steps_spec, actions_spec) super(ActorSequencePolicy, self).__init__( time_step_spec=time_step_spec, action_spec=action_spec, policy_state_spec=policy_state_spec, info_spec=info_spec, name=name) def _apply_actor_network(self, time_steps, actions, network_state): states = [] for actor_input in self._actor_input.split('__'): if actor_input == 'state': state = time_steps.observation['state'][:, -1] elif actor_input == 'latent': images = tf.image.convert_image_dtype( time_steps.observation['pixels'], tf.float32) latents, _ = self._model_network.sample_posterior( images, slac_common.flatten(actions, axis=2), time_steps.step_type) if isinstance(latents, (tuple, list)): latents = tf.concat(latents, axis=-1) state = latents[:, -1] elif actor_input == 'feature': image = tf.image.convert_image_dtype( time_steps.observation['pixels'][:, -1], tf.float32) state = self._compressor_network(image) elif actor_input in ('sequence_feature', 'sequence_action_feature'): filtered_time_steps, filtered_actions = filter_before_first_step( time_steps, actions) images = tf.image.convert_image_dtype( filtered_time_steps.observation['pixels'], tf.float32) features = self._compressor_network(images) sequence_feature = slac_common.flatten(features) if actor_input == 'sequence_action_feature': sequence_action = slac_common.flatten(filtered_actions[:, :-1]) state = tf.concat([sequence_feature, sequence_action], axis=-1) else: state = sequence_feature else: raise NotImplementedError states.append(state) state = tf.concat(states, axis=-1) if self._debug_summaries: filtered_time_steps, filtered_actions = filter_before_first_step( time_steps, actions) images = tf.image.convert_image_dtype( filtered_time_steps.observation['pixels'], tf.float32) fps = 10 if self._control_timestep is None else int( np.round(1.0 / self._control_timestep)) _gif_and_image_summary('ActorSequencePolicy/images', images[:self._num_images_per_summary], fps, step=self.train_step_counter) step_type = time_steps.step_type[:, -1] return self._actor_network(state, step_type, network_state) def _variables(self): variables = list(self._actor_network.variables) actor_inputs = set(self._actor_input.split('__')) if 'latent' in actor_inputs: variables += self._model_network.variables if {'feature', 'sequence_feature', 'sequence_action_feature'} & actor_inputs: variables += self._compressor_network.variables return variables def _action(self, time_step, policy_state, seed): distribution_step = self.distribution(time_step, policy_state) action = distribution_step.action.sample(seed=seed) policy_state = distribution_step.state network_state, time_steps, actions = policy_state actions = tf.concat([actions[:, :-1], action[:, None]], axis=1) policy_state = network_state, time_steps, actions return distribution_step._replace(action=action, state=policy_state) def _distribution(self, time_step, policy_state): network_state, time_steps, actions = policy_state def _apply_sequence_update(tensors, tensor): return tf.concat([tensors, tensor[:, None]], axis=1)[:, 1:] time_steps = tf.nest.map_structure( _apply_sequence_update, time_steps, time_step) actions = tf.nest.map_structure( _apply_sequence_update, actions, tf.zeros_like(actions[:, 0])) # Actor network outputs nested structure of distributions or actions. action_or_distribution, network_state = self._apply_actor_network( time_steps, actions, network_state) policy_state = (network_state, time_steps, actions) def _to_distribution(action_or_distribution): if isinstance(action_or_distribution, tf.Tensor): # This is an action tensor, so wrap it in a deterministic distribution. return tfp.distributions.Deterministic(loc=action_or_distribution) return action_or_distribution distribution = tf.nest.map_structure(_to_distribution, action_or_distribution) return policy_step.PolicyStep(distribution, policy_state) @gin.configurable class SlacAgent(tf_agent.TFAgent): """A SLAC Agent.""" def __init__(self, time_step_spec, action_spec, critic_network, actor_network, model_network, compressor_network, actor_optimizer, critic_optimizer, alpha_optimizer, model_optimizer, sequence_length, target_update_tau=1.0, target_update_period=1, td_errors_loss_fn=tf.math.squared_difference, gamma=1.0, reward_scale_factor=1.0, initial_log_alpha=0.0, target_entropy=None, gradient_clipping=None, trainable_model=True, critic_input='state', actor_input='state', critic_input_stop_gradient=True, actor_input_stop_gradient=False, model_batch_size=None, control_timestep=None, num_images_per_summary=1, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, name=None): tf.Module.__init__(self, name=name) self._critic_network1 = critic_network self._critic_network2 = critic_network.copy(name='CriticNetwork2') self._target_critic_network1 = critic_network.copy( name='TargetCriticNetwork1') self._target_critic_network2 = critic_network.copy( name='TargetCriticNetwork2') self._actor_network = actor_network self._model_network = model_network self._compressor_network = compressor_network policy = ActorSequencePolicy( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, model_network=self._model_network, compressor_network=self._compressor_network, sequence_length=sequence_length, actor_input=actor_input, control_timestep=control_timestep, num_images_per_summary=num_images_per_summary, debug_summaries=debug_summaries) self._log_alpha = common.create_variable( 'initial_log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) # If target_entropy was not passed, set it to negative of the total number # of action dimensions. if target_entropy is None: flat_action_spec = tf.nest.flatten(action_spec) target_entropy = -np.sum([ np.product(single_spec.shape.as_list()) for single_spec in flat_action_spec ]) self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._alpha_optimizer = alpha_optimizer self._model_optimizer = model_optimizer self._sequence_length = sequence_length self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_entropy = target_entropy self._gradient_clipping = gradient_clipping self._trainable_model = trainable_model self._critic_input = critic_input self._actor_input = actor_input self._critic_input_stop_gradient = critic_input_stop_gradient self._actor_input_stop_gradient = actor_input_stop_gradient self._model_batch_size = model_batch_size self._control_timestep = control_timestep self._num_images_per_summary = num_images_per_summary self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._update_target = self._get_target_updater( tau=self._target_update_tau, period=self._target_update_period) self._actor_time_step_spec = time_step_spec._replace( observation=actor_network.input_tensor_spec) super(SlacAgent, self).__init__( time_step_spec, action_spec, policy=policy, collect_policy=policy, train_sequence_length=sequence_length + 1, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter) self._train_model_fn = common.function_in_tf1()(self._train_model) def _initialize(self): """Returns an op to initialize the agent. Copies weights from the Q networks to the target Q network. """ common.soft_variables_update( self._critic_network1.variables, self._target_critic_network1.variables, tau=1.0) common.soft_variables_update( self._critic_network2.variables, self._target_critic_network2.variables, tau=1.0) def _experience_to_transitions(self, experience): transitions = trajectory.to_transition(experience) time_steps, policy_steps, next_time_steps = transitions actions = policy_steps.action if (self.train_sequence_length is not None and self.train_sequence_length == 2): # Sequence empty time dimension if critic network is stateless. time_steps, actions, next_time_steps = tf.nest.map_structure( lambda t: tf.squeeze(t, axis=1), (time_steps, actions, next_time_steps)) return time_steps, actions, next_time_steps def _train(self, experience, weights=None): """Returns a train op to update the agent's networks. This method trains with the provided batched experience. Args: experience: A time-stacked trajectory object. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: A train_op. Raises: ValueError: If optimizers are None and no default value was provided to the constructor. """ time_steps, actions, next_time_steps = self._experience_to_transitions( experience) time_step, action, next_time_step = self._experience_to_transitions( tf.nest.map_structure(lambda x: x[:, -2:], experience)) time_step, action, next_time_step = tf.nest.map_structure( lambda x: tf.squeeze(x, axis=1), (time_step, action, next_time_step)) with tf.GradientTape(persistent=True) as tape: state_only = (self._actor_input == 'state' and self._critic_input == 'state' and not self._trainable_model) critic_and_actor_inputs = set( self._critic_input.split('__') + self._actor_input.split('__')) if not state_only: images = tf.image.convert_image_dtype( experience.observation['pixels'], tf.float32) features = self._compressor_network(images) if 'latent' in critic_and_actor_inputs: if self._compressor_network == self._model_network.compressor: latent_samples_and_dists = self._model_network.sample_posterior( images, actions, experience.step_type, features) else: latent_samples_and_dists = self._model_network.sample_posterior( images, actions, experience.step_type) latents, _ = latent_samples_and_dists if isinstance(latents, (tuple, list)): latents = tf.concat(latents, axis=-1) latent, next_latent = tf.unstack(latents[:, -2:], axis=1) else: latent_samples_and_dists = None if 'feature' in critic_and_actor_inputs: feature, next_feature = tf.unstack(features[:, -2:], axis=1) if {'sequence_feature', 'sequence_action_feature'} & critic_and_actor_inputs: feature_time_steps = time_steps._replace(observation=features[:, :-1]) next_feature_time_steps = next_time_steps._replace( observation=features[:, 1:]) filtered_feature_time_steps, filtered_actions = ( filter_before_first_step(feature_time_steps, actions)) filtered_next_feature_time_steps = filter_before_first_step( next_feature_time_steps) sequence_feature = slac_common.flatten( filtered_feature_time_steps.observation) next_sequence_feature = slac_common.flatten( filtered_next_feature_time_steps.observation) if 'sequence_action_feature' in critic_and_actor_inputs: sequence_action = slac_common.flatten(filtered_actions[:, :-1]) sequence_action_feature = tf.concat( [sequence_feature, sequence_action], axis=-1) next_sequence_action = slac_common.flatten(filtered_actions[:, 1:]) next_sequence_action_feature = tf.concat( [next_sequence_feature, next_sequence_action], axis=-1) if self._debug_summaries: if not state_only: image_time_steps = time_steps._replace(observation=images[:, :-1]) next_image_time_steps = next_time_steps._replace( observation=images[:, 1:]) filtered_image_time_steps, _ = filter_before_first_step( image_time_steps, actions) filtered_next_image_time_steps = filter_before_first_step( next_image_time_steps) fps = 10 if self._control_timestep is None else int( np.round(1.0 / self._control_timestep)) _gif_and_image_summary('images', filtered_image_time_steps.observation[ :self._num_images_per_summary], fps, step=self.train_step_counter) _gif_and_image_summary('next_images', filtered_next_image_time_steps.observation[ :self._num_images_per_summary], fps, step=self.train_step_counter) critic_states = [] critic_next_states = [] for critic_input in self._critic_input.split('__'): if critic_input == 'latent': critic_state = latent critic_next_state = next_latent elif critic_input == 'state': critic_state, critic_next_state = tf.unstack( experience.observation['state'][:, -2:], axis=1) elif critic_input == 'feature': critic_state = feature critic_next_state = next_feature elif critic_input == 'sequence_feature': critic_state = sequence_feature critic_next_state = next_sequence_feature elif critic_input == 'sequence_action_feature': critic_state = sequence_action_feature critic_next_state = next_sequence_action_feature else: raise NotImplementedError critic_states.append(critic_state) critic_next_states.append(critic_next_state) critic_state = tf.concat(critic_states, axis=-1) critic_next_state = tf.concat(critic_next_states, axis=-1) critic_time_step = time_step._replace(observation=critic_state) critic_next_time_step = next_time_step._replace( observation=critic_next_state) actor_states = [] actor_next_states = [] for actor_input in self._actor_input.split('__'): if actor_input == 'latent': actor_state = latent actor_next_state = next_latent elif actor_input == 'state': actor_state, actor_next_state = tf.unstack( experience.observation['state'][:, -2:], axis=1) elif actor_input == 'feature': actor_state = feature actor_next_state = next_feature elif actor_input == 'sequence_feature': actor_state = sequence_feature actor_next_state = next_sequence_feature elif actor_input == 'sequence_action_feature': actor_state = sequence_action_feature actor_next_state = next_sequence_action_feature else: raise NotImplementedError actor_states.append(actor_state) actor_next_states.append(actor_next_state) actor_state = tf.concat(actor_states, axis=-1) actor_next_state = tf.concat(actor_next_states, axis=-1) actor_time_step = time_step._replace(observation=actor_state) actor_next_time_step = next_time_step._replace(observation=actor_next_state) critic_loss = self.critic_loss( critic_time_step, action, critic_next_time_step, actor_next_time_step, td_errors_loss_fn=self._td_errors_loss_fn, gamma=self._gamma, reward_scale_factor=self._reward_scale_factor, weights=weights) actor_loss = self.actor_loss( critic_time_step, actor_time_step, weights=weights) alpha_loss = self.alpha_loss(actor_time_step, weights=weights) if self._trainable_model: model_loss = self.model_loss( images, experience.action, experience.step_type, experience.reward, experience.discount, latent_posterior_samples_and_dists=latent_samples_and_dists, weights=weights) tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.') critic_variables = ( list(self._critic_network1.variables) + list(self._critic_network2.variables) + list(self._compressor_network.variables) + list(self._model_network.variables)) assert critic_variables, 'No critic variables to optimize.' critic_grads = tape.gradient(critic_loss, critic_variables) self._apply_gradients( critic_grads, critic_variables, self._critic_optimizer) tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.') actor_variables = ( list(self._actor_network.variables) + list(self._compressor_network.variables) + list(self._model_network.variables)) assert actor_variables, 'No actor variables to optimize.' actor_grads = tape.gradient(actor_loss, actor_variables) self._apply_gradients(actor_grads, actor_variables, self._actor_optimizer) tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.') alpha_variables = [self._log_alpha] assert alpha_variables, 'No alpha variable to optimize.' alpha_grads = tape.gradient(alpha_loss, alpha_variables) self._apply_gradients(alpha_grads, alpha_variables, self._alpha_optimizer) if self._trainable_model: tf.debugging.check_numerics(model_loss, 'Model loss is inf or nan.') model_variables = list(self._model_network.variables) assert model_variables, 'No model variables to optimize.' model_grads = tape.gradient(model_loss, model_variables) self._apply_gradients(model_grads, model_variables, self._model_optimizer) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar( name='critic_loss', data=critic_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='actor_loss', data=actor_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='alpha_loss', data=alpha_loss, step=self.train_step_counter) if self._trainable_model: tf.compat.v2.summary.scalar( name='model_loss', data=model_loss, step=self.train_step_counter) self.train_step_counter.assign_add(1) self._update_target() total_loss = critic_loss + actor_loss + alpha_loss if self._trainable_model: total_loss += model_loss return tf_agent.LossInfo(loss=total_loss, extra=()) def _apply_gradients(self, gradients, variables, optimizer): grads_and_vars = zip(gradients, variables) if self._gradient_clipping is not None: grads_and_vars = eager_utils.clip_gradient_norms(grads_and_vars, self._gradient_clipping) if self._summarize_grads_and_vars: eager_utils.add_variables_summaries(grads_and_vars, self.train_step_counter) eager_utils.add_gradients_summaries(grads_and_vars, self.train_step_counter) optimizer.apply_gradients(grads_and_vars) def _get_target_updater(self, tau=1.0, period=1): """Performs a soft update of the target network parameters. For each weight w_s in the original network, and its corresponding weight w_t in the target network, a soft update is: w_t = (1- tau) x w_t + tau x ws Args: tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update. period: Step interval at which the target network is updated. Returns: A callable that performs a soft update of the target network parameters. """ with tf.name_scope('update_target'): def update(): """Update target network.""" critic_update_1 = common.soft_variables_update( self._critic_network1.variables, self._target_critic_network1.variables, tau) critic_update_2 = common.soft_variables_update( self._critic_network2.variables, self._target_critic_network2.variables, tau) return tf.group(critic_update_1, critic_update_2) return common.Periodically(update, period, 'update_targets') def critic_loss(self, time_steps, actions, next_time_steps, actor_next_time_steps, td_errors_loss_fn, gamma=1.0, reward_scale_factor=1.0, weights=None): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps for the critic. actions: A batch of actions. next_time_steps: A batch of next timesteps for the critic. actor_next_time_steps: A batch of next timesteps for the actor. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): if self._critic_input_stop_gradient: time_steps = tf.nest.map_structure(tf.stop_gradient, time_steps) next_time_steps = tf.nest.map_structure(tf.stop_gradient, next_time_steps) # not really necessary since there is a stop_gradient for the td_targets actor_next_time_steps = tf.nest.map_structure(tf.stop_gradient, actor_next_time_steps) next_actions_distribution, _ = self._actor_network( actor_next_time_steps.observation, actor_next_time_steps.step_type) next_actions = next_actions_distribution.sample() next_log_pis = next_actions_distribution.log_prob(next_actions) target_input_1 = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network1( target_input_1, next_time_steps.step_type) target_input_2 = (next_time_steps.observation, next_actions) target_q_values2, unused_network_state2 = self._target_critic_network2( target_input_2, next_time_steps.step_type) target_q_values = ( tf.minimum(target_q_values1, target_q_values2) - tf.exp(self._log_alpha) * next_log_pis) td_targets = tf.stop_gradient( reward_scale_factor * next_time_steps.reward + gamma * next_time_steps.discount * target_q_values) pred_input_1 = (time_steps.observation, actions) pred_td_targets1, unused_network_state1 = self._critic_network1( pred_input_1, time_steps.step_type) pred_input_2 = (time_steps.observation, actions) pred_td_targets2, unused_network_state2 = self._critic_network2( pred_input_2, time_steps.step_type) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if weights is not None: critic_loss *= weights # Take the mean across the batch. critic_loss = tf.reduce_mean(input_tensor=critic_loss) if self._debug_summaries: td_errors1 = td_targets - pred_td_targets1 td_errors2 = td_targets - pred_td_targets2 td_errors = tf.concat([td_errors1, td_errors2], axis=0) common.generate_tensor_summaries('td_errors', td_errors, self.train_step_counter) common.generate_tensor_summaries('td_targets', td_targets, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2, self.train_step_counter) return critic_loss def actor_loss(self, time_steps, actor_time_steps, weights=None): """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps for the critic. actor_time_steps: A batch of timesteps for the actor. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): time_steps = tf.nest.map_structure(tf.stop_gradient, time_steps) if self._actor_input_stop_gradient: actor_time_steps = tf.nest.map_structure(tf.stop_gradient, actor_time_steps) actions_distribution, _ = self._actor_network( actor_time_steps.observation, actor_time_steps.step_type) actions = actions_distribution.sample() log_pis = actions_distribution.log_prob(actions) target_input_1 = (time_steps.observation, actions) target_q_values1, unused_network_state1 = self._critic_network1( target_input_1, time_steps.step_type) target_input_2 = (time_steps.observation, actions) target_q_values2, unused_network_state2 = self._critic_network2( target_input_2, time_steps.step_type) target_q_values = tf.minimum(target_q_values1, target_q_values2) actor_loss = tf.exp(self._log_alpha) * log_pis - target_q_values if weights is not None: actor_loss *= weights actor_loss = tf.reduce_mean(input_tensor=actor_loss) if self._debug_summaries: common.generate_tensor_summaries('actor_loss', actor_loss, self.train_step_counter) common.generate_tensor_summaries('actions', actions, self.train_step_counter) common.generate_tensor_summaries('log_pis', log_pis, self.train_step_counter) tf.compat.v2.summary.scalar( name='entropy_avg', data=-tf.reduce_mean(input_tensor=log_pis), step=self.train_step_counter) common.generate_tensor_summaries('target_q_values', target_q_values, self.train_step_counter) batch_size = nest_utils.get_outer_shape( time_steps, self._time_step_spec)[0] policy_state = self.policy.get_initial_state(batch_size) action_distribution = self.policy.distribution( time_steps, policy_state).action if isinstance(action_distribution, tfp.distributions.Normal): common.generate_tensor_summaries('act_mean', action_distribution.loc, self.train_step_counter) common.generate_tensor_summaries( 'act_stddev', action_distribution.scale, self.train_step_counter) elif isinstance(action_distribution, tfp.distributions.Categorical): common.generate_tensor_summaries( 'act_mode', action_distribution.mode(), self.train_step_counter) try: common.generate_tensor_summaries('entropy_action', action_distribution.entropy(), self.train_step_counter) except NotImplementedError: pass # Some distributions do not have an analytic entropy. return actor_loss def alpha_loss(self, actor_time_steps, weights=None): """Computes the alpha_loss for EC-SAC training. Args: actor_time_steps: A batch of timesteps for the actor. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: alpha_loss: A scalar alpha loss. """ with tf.name_scope('alpha_loss'): actions_distribution, _ = self._actor_network( actor_time_steps.observation, actor_time_steps.step_type) actions = actions_distribution.sample() log_pis = actions_distribution.log_prob(actions) alpha_loss = ( self._log_alpha * tf.stop_gradient(-log_pis - self._target_entropy)) if weights is not None: alpha_loss *= weights alpha_loss = tf.reduce_mean(input_tensor=alpha_loss) if self._debug_summaries: common.generate_tensor_summaries('alpha_loss', alpha_loss, self.train_step_counter) return alpha_loss def train_model(self, experience, weights=None): if self._enable_functions and getattr( self, "_train_model_fn", None) is None: raise RuntimeError( "Cannot find _train_model_fn. Did %s.__init__ call super?" % type(self).__name__) if not isinstance(experience, trajectory.Trajectory): raise ValueError( "experience must be type Trajectory, saw type: %s" % type(experience)) if self._enable_functions: loss_info = self._train_model_fn(experience=experience, weights=weights) else: loss_info = self._train_model(experience=experience, weights=weights) if not isinstance(loss_info, tf_agent.LossInfo): raise TypeError( "loss_info is not a subclass of LossInfo: {}".format(loss_info)) return loss_info def _train_model(self, experience, weights=None): with tf.GradientTape() as tape: images = tf.image.convert_image_dtype( experience.observation['pixels'], tf.float32) model_loss = self.model_loss( images, experience.action, experience.step_type, rewards=experience.reward, discounts=experience.discount, weights=weights) tf.debugging.check_numerics(model_loss, 'Model loss is inf or nan.') model_variables = list(self._model_network.variables) assert model_variables, 'No model variables to optimize.' model_grads = tape.gradient(model_loss, model_variables) self._apply_gradients(model_grads, model_variables, self._model_optimizer) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar( name='model_loss', data=model_loss, step=self.train_step_counter) self.train_step_counter.assign_add(1) total_loss = model_loss return tf_agent.LossInfo(loss=total_loss, extra=()) def model_loss(self, images, actions, step_types, rewards, discounts, latent_posterior_samples_and_dists=None, weights=None): with tf.name_scope('model_loss'): if self._model_batch_size is not None: # Allow model batch size to be smaller than the batch size of the # other losses. This is because the model loss already gets a lot of # supervision from having a loss over all time steps. images, actions, step_types, rewards, discounts = tf.nest.map_structure( lambda x: x[:self._model_batch_size], (images, actions, step_types, rewards, discounts)) if latent_posterior_samples_and_dists is not None: latent_posterior_samples, latent_posterior_dists = latent_posterior_samples_and_dists latent_posterior_samples = tf.nest.map_structure( lambda x: x[:self._model_batch_size], latent_posterior_samples) latent_posterior_dists = slac_nest_utils.map_distribution_structure( lambda x: x[:self._model_batch_size], latent_posterior_dists) latent_posterior_samples_and_dists = ( latent_posterior_samples, latent_posterior_dists) model_loss, outputs = self._model_network.compute_loss( images, actions, step_types, rewards=rewards, discounts=discounts, latent_posterior_samples_and_dists=latent_posterior_samples_and_dists) for name, output in outputs.items(): if output.shape.ndims == 0: tf.contrib.summary.scalar(name, output) elif output.shape.ndims == 5: fps = 10 if self._control_timestep is None else int( np.round(1.0 / self._control_timestep)) if self._debug_summaries: _gif_summary(name + '/original', output[:self._num_images_per_summary], fps, step=self.train_step_counter) _gif_summary(name, output[:self._num_images_per_summary], fps, saturate=True, step=self.train_step_counter) else: raise NotImplementedError if weights is not None: model_loss *= weights model_loss = tf.reduce_mean(input_tensor=model_loss) if self._debug_summaries: common.generate_tensor_summaries('model_loss', model_loss, self.train_step_counter) return model_loss