# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PPO algorithm implementation.

Based on: https://arxiv.org/abs/1707.06347
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.rl.envs.utils import get_policy

import tensorflow as tf


def get_optimiser(config):
  if config.optimizer == "Adam":
    return tf.train.AdamOptimizer(learning_rate=config.learning_rate)
  return config.optimizer(learning_rate=config.learning_rate)


def define_ppo_step(data_points, optimizer, hparams):
  """Define ppo step."""
  observation, action, discounted_reward, norm_advantage, old_pdf = data_points
  new_policy_dist, new_value, _ = get_policy(observation, hparams)
  new_pdf = new_policy_dist.prob(action)

  ratio = new_pdf / old_pdf
  clipped_ratio = tf.clip_by_value(ratio, 1 - hparams.clipping_coef,
                                   1 + hparams.clipping_coef)

  surrogate_objective = tf.minimum(clipped_ratio * norm_advantage,
                                   ratio * norm_advantage)
  policy_loss = -tf.reduce_mean(surrogate_objective)

  value_error = new_value - discounted_reward
  value_loss = hparams.value_loss_coef * tf.reduce_mean(value_error ** 2)

  entropy = new_policy_dist.entropy()
  entropy_loss = -hparams.entropy_loss_coef * tf.reduce_mean(entropy)

  losses = [policy_loss, value_loss, entropy_loss]

  gradients = [list(zip(*optimizer.compute_gradients(loss)))
               for loss in losses]

  gradients_norms = [tf.global_norm(gradient[0]) for gradient in gradients]

  gradients_flat = sum([gradient[0] for gradient in gradients], ())
  gradients_variables_flat = sum([gradient[1] for gradient in gradients], ())

  if hparams.max_gradients_norm:
    gradients_flat, _ = tf.clip_by_global_norm(gradients_flat,
                                               hparams.max_gradients_norm)

  optimize_op = optimizer.apply_gradients(zip(gradients_flat,
                                              gradients_variables_flat))

  with tf.control_dependencies([optimize_op]):
    return [tf.identity(x) for x in losses + gradients_norms]


def define_ppo_epoch(memory, hparams):
  """PPO epoch."""
  observation, reward, done, action, old_pdf, value = memory

  # This is to avoid propagating gradients through simulated environment.
  observation = tf.stop_gradient(observation)
  action = tf.stop_gradient(action)
  reward = tf.stop_gradient(reward)
  if hasattr(hparams, "rewards_preprocessing_fun"):
    reward = hparams.rewards_preprocessing_fun(reward)
  done = tf.stop_gradient(done)
  value = tf.stop_gradient(value)
  old_pdf = tf.stop_gradient(old_pdf)

  advantage = calculate_generalized_advantage_estimator(
      reward, value, done, hparams.gae_gamma, hparams.gae_lambda)

  discounted_reward = tf.stop_gradient(advantage + value)

  advantage_mean, advantage_variance = tf.nn.moments(advantage, axes=[0, 1],
                                                     keep_dims=True)
  advantage_normalized = tf.stop_gradient(
      (advantage - advantage_mean)/(tf.sqrt(advantage_variance) + 1e-8))

  add_lists_elementwise = lambda l1, l2: [x + y for x, y in zip(l1, l2)]

  number_of_batches = (hparams.epoch_length * hparams.optimization_epochs
                       / hparams.optimization_batch_size)

  dataset = tf.data.Dataset.from_tensor_slices(
      (observation, action, discounted_reward, advantage_normalized, old_pdf))
  dataset = dataset.shuffle(buffer_size=hparams.epoch_length,
                            reshuffle_each_iteration=True)
  dataset = dataset.repeat(hparams.optimization_epochs)
  dataset = dataset.batch(hparams.optimization_batch_size)
  iterator = dataset.make_initializable_iterator()
  optimizer = get_optimiser(hparams)

  with tf.control_dependencies([iterator.initializer]):
    ppo_step_rets = tf.scan(
        lambda a, i: add_lists_elementwise(  # pylint: disable=g-long-lambda
            a, define_ppo_step(iterator.get_next(), optimizer, hparams)),
        tf.range(number_of_batches),
        [0., 0., 0., 0., 0., 0.],
        parallel_iterations=1)

  ppo_summaries = [tf.reduce_mean(ret) / number_of_batches
                   for ret in ppo_step_rets]
  summaries_names = ["policy_loss", "value_loss", "entropy_loss",
                     "policy_gradient", "value_gradient", "entropy_gradient"]

  summaries = [tf.summary.scalar(summary_name, summary)
               for summary_name, summary in zip(summaries_names, ppo_summaries)]
  losses_summary = tf.summary.merge(summaries)

  for summary_name, summary in zip(summaries_names, ppo_summaries):
    losses_summary = tf.Print(losses_summary, [summary], summary_name + ": ")

  return losses_summary


def calculate_generalized_advantage_estimator(
    reward, value, done, gae_gamma, gae_lambda):
  """Generalized advantage estimator."""

  # Below is slight weirdness, we set the last reward to 0.
  # This makes the advantage to be 0 in the last timestep
  reward = tf.concat([reward[:-1, :], value[-1:, :]], axis=0)
  next_value = tf.concat([value[1:, :], tf.zeros_like(value[-1:, :])], axis=0)
  next_not_done = 1 - tf.cast(tf.concat([done[1:, :],
                                         tf.zeros_like(done[-1:, :])], axis=0),
                              tf.float32)
  delta = reward + gae_gamma * next_value * next_not_done - value

  return_ = tf.reverse(tf.scan(
      lambda agg, cur: cur[0] + cur[1] * gae_gamma * gae_lambda * agg,
      [tf.reverse(delta, [0]), tf.reverse(next_not_done, [0])],
      tf.zeros_like(delta[0, :]),
      parallel_iterations=1), [0])
  return tf.check_numerics(return_, "return")