# coding=utf-8 # Copyright 2019 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PPO algorithm implementation. Based on: https://arxiv.org/abs/1707.06347 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensor2tensor.layers import common_layers from tensor2tensor.models.research.rl import get_policy from tensor2tensor.utils import learning_rate from tensor2tensor.utils import optimize import tensorflow as tf import tensorflow_probability as tfp def define_ppo_step(data_points, hparams, action_space, lr, epoch=-1, distributional_size=1, distributional_subscale=0.04): """Define ppo step.""" del distributional_subscale (observation, action, discounted_reward, discounted_reward_probs, norm_advantage, old_pdf) = data_points obs_shape = common_layers.shape_list(observation) observation = tf.reshape( observation, [obs_shape[0] * obs_shape[1]] + obs_shape[2:] ) (logits, new_value) = get_policy(observation, hparams, action_space, epoch=epoch, distributional_size=distributional_size) logits = tf.reshape(logits, obs_shape[:2] + [action_space.n]) new_policy_dist = tfp.distributions.Categorical(logits=logits) new_pdf = new_policy_dist.prob(action) ratio = new_pdf / old_pdf clipped_ratio = tf.clip_by_value(ratio, 1 - hparams.clipping_coef, 1 + hparams.clipping_coef) surrogate_objective = tf.minimum(clipped_ratio * norm_advantage, ratio * norm_advantage) policy_loss = -tf.reduce_mean(surrogate_objective) if distributional_size > 1: new_value = tf.reshape(new_value, obs_shape[:2] + [distributional_size]) new_value = tf.nn.log_softmax(new_value, axis=-1) value_shape = common_layers.shape_list(new_value) # The above is the new value distribution. We are also given as discounted # reward the value distribution and the corresponding probabilities. # The given discounted reward is already rounded to integers but in range # increased by 2x for greater fidelity. Increase range of new_values here. new_value_shifted = tf.concat([new_value[1:], new_value[-1:]], axis=0) new_value_mean = (new_value + new_value_shifted) / 2 new_value = tf.concat([tf.expand_dims(new_value, axis=-1), tf.expand_dims(new_value_mean, axis=-1)], -1) new_value = tf.reshape(new_value, value_shape[:-1] + [2 * value_shape[-1]]) # Cast discounted reward to integers and gather the new log-probs for them. discounted_reward = tf.cast(discounted_reward, tf.int32) value_loss = tf.batch_gather(new_value, discounted_reward) # Weight the gathered (new) log-probs by the old probabilities. discounted_reward_probs = tf.expand_dims(discounted_reward_probs, axis=1) value_loss = - tf.reduce_sum(value_loss * discounted_reward_probs, axis=-1) # Take the mean over batch and time as final loss, multiply by coefficient. value_loss = hparams.value_loss_coef * tf.reduce_mean(value_loss) else: new_value = tf.reshape(new_value, obs_shape[:2]) value_error = new_value - discounted_reward value_loss = hparams.value_loss_coef * tf.reduce_mean(value_error ** 2) entropy = new_policy_dist.entropy() entropy_loss = -hparams.entropy_loss_coef * tf.reduce_mean(entropy) losses = [policy_loss, value_loss, entropy_loss] loss = sum(losses) variables = tf.global_variables(hparams.policy_network + "/.*") train_op = optimize.optimize(loss, lr, hparams, variables=variables) with tf.control_dependencies([train_op]): return [tf.identity(x) for x in losses] def _distributional_to_value(value_d, size, subscale, threshold): """Get a scalar value out of a value distribution in distributional RL.""" half = size // 2 value_range = (tf.to_float(tf.range(-half, half)) + 0.5) * subscale probs = tf.nn.softmax(value_d) if threshold == 0.0: return tf.reduce_sum(probs * value_range, axis=-1) # accumulated_probs[..., i] is the sum of probabilities in buckets upto i # so it is the probability that value <= i'th bucket value accumulated_probs = tf.cumsum(probs, axis=-1) # New probs are 0 on all lower buckets, until the threshold probs = tf.where(accumulated_probs < threshold, tf.zeros_like(probs), probs) probs /= tf.reduce_sum(probs, axis=-1, keepdims=True) # Re-normalize. return tf.reduce_sum(probs * value_range, axis=-1) def define_ppo_epoch(memory, hparams, action_space, batch_size, distributional_size=1, distributional_subscale=0.04, distributional_threshold=0.0, epoch=-1): """PPO epoch.""" observation, reward, done, action, old_pdf, value_sm = memory # This is to avoid propagating gradients through simulated environment. observation = tf.stop_gradient(observation) action = tf.stop_gradient(action) reward = tf.stop_gradient(reward) if hasattr(hparams, "rewards_preprocessing_fun"): reward = hparams.rewards_preprocessing_fun(reward) done = tf.stop_gradient(done) value_sm = tf.stop_gradient(value_sm) old_pdf = tf.stop_gradient(old_pdf) value = value_sm if distributional_size > 1: value = _distributional_to_value( value_sm, distributional_size, distributional_subscale, distributional_threshold) advantage = calculate_generalized_advantage_estimator( reward, value, done, hparams.gae_gamma, hparams.gae_lambda) if distributional_size > 1: # Create discounted reward values range. half = distributional_size // 2 value_range = tf.to_float(tf.range(-half, half)) + 0.5 # Mid-bucket value. value_range *= distributional_subscale # Acquire new discounted rewards by using the above range as end-values. end_values = tf.expand_dims(value_range, 0) discounted_reward = discounted_rewards( reward, done, hparams.gae_gamma, end_values) # Re-normalize the discounted rewards to integers, in [0, dist_size] range. discounted_reward /= distributional_subscale discounted_reward += half discounted_reward = tf.maximum(discounted_reward, 0.0) discounted_reward = tf.minimum(discounted_reward, distributional_size) # Multiply the rewards by 2 for greater fidelity and round to integers. discounted_reward = tf.stop_gradient(tf.round(2 * discounted_reward)) # The probabilities corresponding to the end values from old predictions. discounted_reward_prob = tf.stop_gradient(value_sm[-1]) discounted_reward_prob = tf.nn.softmax(discounted_reward_prob, axis=-1) else: discounted_reward = tf.stop_gradient(advantage + value[:-1]) discounted_reward_prob = discounted_reward # Unused in this case. advantage_mean, advantage_variance = tf.nn.moments(advantage, axes=[0, 1], keep_dims=True) advantage_normalized = tf.stop_gradient( (advantage - advantage_mean)/(tf.sqrt(advantage_variance) + 1e-8)) add_lists_elementwise = lambda l1, l2: [x + y for x, y in zip(l1, l2)] number_of_batches = ((hparams.epoch_length-1) * hparams.optimization_epochs // hparams.optimization_batch_size) epoch_length = hparams.epoch_length if hparams.effective_num_agents is not None: number_of_batches *= batch_size number_of_batches //= hparams.effective_num_agents epoch_length //= hparams.effective_num_agents assert number_of_batches > 0, "Set the paremeters so that number_of_batches>0" lr = learning_rate.learning_rate_schedule(hparams) shuffled_indices = [tf.random.shuffle(tf.range(epoch_length - 1)) for _ in range(hparams.optimization_epochs)] shuffled_indices = tf.concat(shuffled_indices, axis=0) shuffled_indices = shuffled_indices[:number_of_batches * hparams.optimization_batch_size] indices_of_batches = tf.reshape(shuffled_indices, shape=(-1, hparams.optimization_batch_size)) input_tensors = [observation, action, discounted_reward, discounted_reward_prob, advantage_normalized, old_pdf] ppo_step_rets = tf.scan( lambda a, i: add_lists_elementwise( # pylint: disable=g-long-lambda a, define_ppo_step( [tf.gather(t, indices_of_batches[i, :]) for t in input_tensors], hparams, action_space, lr, epoch=epoch, distributional_size=distributional_size, distributional_subscale=distributional_subscale )), tf.range(number_of_batches), [0., 0., 0.], parallel_iterations=1) ppo_summaries = [tf.reduce_mean(ret) / number_of_batches for ret in ppo_step_rets] ppo_summaries.append(lr) summaries_names = [ "policy_loss", "value_loss", "entropy_loss", "learning_rate" ] summaries = [tf.summary.scalar(summary_name, summary) for summary_name, summary in zip(summaries_names, ppo_summaries)] losses_summary = tf.summary.merge(summaries) for summary_name, summary in zip(summaries_names, ppo_summaries): losses_summary = tf.Print(losses_summary, [summary], summary_name + ": ") return losses_summary def calculate_generalized_advantage_estimator( reward, value, done, gae_gamma, gae_lambda): # pylint: disable=g-doc-args """Generalized advantage estimator. Returns: GAE estimator. It will be one element shorter than the input; this is because to compute GAE for [0, ..., N-1] one needs V for [1, ..., N]. """ # pylint: enable=g-doc-args next_value = value[1:, :] next_not_done = 1 - tf.cast(done[1:, :], tf.float32) delta = (reward[:-1, :] + gae_gamma * next_value * next_not_done - value[:-1, :]) return_ = tf.reverse(tf.scan( lambda agg, cur: cur[0] + cur[1] * gae_gamma * gae_lambda * agg, [tf.reverse(delta, [0]), tf.reverse(next_not_done, [0])], tf.zeros_like(delta[0, :]), parallel_iterations=1), [0]) return tf.check_numerics(return_, "return") def discounted_rewards(reward, done, gae_gamma, end_values): """Discounted rewards.""" not_done = tf.expand_dims(1 - tf.cast(done, tf.float32), axis=2) end_values = end_values * not_done[-1, :, :] return_ = tf.scan( lambda agg, cur: cur + gae_gamma * agg, tf.expand_dims(reward, axis=2) * not_done, initializer=end_values, reverse=True, back_prop=False, parallel_iterations=2) return tf.check_numerics(return_, "return")