# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PPO learner."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os

from tensor2tensor.layers import common_layers
from tensor2tensor.models.research.rl import get_policy
from tensor2tensor.rl import ppo
from tensor2tensor.rl.envs.tf_atari_wrappers import StackWrapper
from tensor2tensor.rl.envs.tf_atari_wrappers import WrapperBase
from tensor2tensor.rl.policy_learner import PolicyLearner
from tensor2tensor.rl.restarter import Restarter
from tensor2tensor.utils import trainer_lib

import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp


class PPOLearner(PolicyLearner):
  """PPO for policy learning."""

  def __init__(self, frame_stack_size, base_event_dir, agent_model_dir,
               total_num_epochs, **kwargs):
    super(PPOLearner, self).__init__(
        frame_stack_size, base_event_dir, agent_model_dir, total_num_epochs)
    self._num_completed_iterations = 0
    self._lr_decay_start = None
    self._distributional_size = kwargs.get("distributional_size", 1)
    self._distributional_subscale = kwargs.get("distributional_subscale", 0.04)
    self._distributional_threshold = kwargs.get("distributional_threshold", 0.0)

  def train(self,
            env_fn,
            hparams,
            simulated,
            save_continuously,
            epoch,
            sampling_temp=1.0,
            num_env_steps=None,
            env_step_multiplier=1,
            eval_env_fn=None,
            report_fn=None,
            model_save_fn=None):
    assert sampling_temp == 1.0 or hparams.learning_rate == 0.0, \
        "Sampling with non-1 temperature does not make sense during training."

    if not save_continuously:
      # We do not save model, as that resets frames that we need at restarts.
      # But we need to save at the last step, so we set it very high.
      hparams.save_models_every_epochs = 1000000

    if simulated:
      simulated_str = "sim"
    else:
      simulated_str = "real"
    name_scope = "ppo_{}{}".format(simulated_str, epoch + 1)
    event_dir = os.path.join(self.base_event_dir, "ppo_summaries",
                             str(epoch) + simulated_str)

    with tf.Graph().as_default():
      with tf.name_scope(name_scope):
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
          env = env_fn(in_graph=True)
          (train_summary_op, eval_summary_op, initializers) = (
              _define_train(
                  env,
                  hparams,
                  eval_env_fn,
                  sampling_temp,
                  distributional_size=self._distributional_size,
                  distributional_subscale=self._distributional_subscale,
                  distributional_threshold=self._distributional_threshold,
                  epoch=epoch if simulated else -1,
                  frame_stack_size=self.frame_stack_size,
                  force_beginning_resets=simulated))

        if num_env_steps is None:
          iteration_increment = hparams.epochs_num
        else:
          iteration_increment = int(
              math.ceil(
                  num_env_steps / (env.batch_size * hparams.epoch_length)))
        iteration_increment *= env_step_multiplier

        self._num_completed_iterations += iteration_increment

        restarter = Restarter(
            "policy", self.agent_model_dir, self._num_completed_iterations
        )
        if restarter.should_skip:
          return

        if hparams.lr_decay_in_final_epoch:
          if epoch != self.total_num_epochs - 1:
            # Extend the warmup period to the end of this epoch.
            hparams.learning_rate_warmup_steps = restarter.target_global_step
          else:
            if self._lr_decay_start is None:
              # Stop the warmup at the beginning of this epoch.
              self._lr_decay_start = \
                  restarter.target_global_step - iteration_increment
            hparams.learning_rate_warmup_steps = self._lr_decay_start

        _run_train(
            hparams,
            event_dir,
            self.agent_model_dir,
            restarter,
            train_summary_op,
            eval_summary_op,
            initializers,
            epoch,
            report_fn=report_fn,
            model_save_fn=model_save_fn)

  def evaluate(self, env_fn, hparams, sampling_temp):
    with tf.Graph().as_default():
      with tf.name_scope("rl_eval"):
        eval_env = env_fn(in_graph=True)
        (collect_memory, _, collect_init) = _define_collect(
            eval_env,
            hparams,
            "ppo_eval",
            eval_phase=True,
            frame_stack_size=self.frame_stack_size,
            force_beginning_resets=False,
            sampling_temp=sampling_temp,
            distributional_size=self._distributional_size,
        )
        model_saver = tf.train.Saver(
            tf.global_variables(hparams.policy_network + "/.*")
            # tf.global_variables("clean_scope.*")  # Needed for sharing params.
        )

        with tf.Session() as sess:
          sess.run(tf.global_variables_initializer())
          collect_init(sess)
          trainer_lib.restore_checkpoint(self.agent_model_dir, model_saver,
                                         sess)
          sess.run(collect_memory)


def _define_train(
    train_env,
    ppo_hparams,
    eval_env_fn=None,
    sampling_temp=1.0,
    distributional_size=1,
    distributional_subscale=0.04,
    distributional_threshold=0.0,
    epoch=-1,
    **collect_kwargs
):
  """Define the training setup."""
  memory, collect_summary, train_initialization = (
      _define_collect(
          train_env,
          ppo_hparams,
          "ppo_train",
          eval_phase=False,
          sampling_temp=sampling_temp,
          distributional_size=distributional_size,
          **collect_kwargs))
  ppo_summary = ppo.define_ppo_epoch(
      memory, ppo_hparams, train_env.action_space, train_env.batch_size,
      distributional_size=distributional_size,
      distributional_subscale=distributional_subscale,
      distributional_threshold=distributional_threshold,
      epoch=epoch)
  train_summary = tf.summary.merge([collect_summary, ppo_summary])

  if ppo_hparams.eval_every_epochs:
    # TODO(koz4k): Do we need this at all?
    assert eval_env_fn is not None
    eval_env = eval_env_fn(in_graph=True)
    (_, eval_collect_summary, eval_initialization) = (
        _define_collect(
            eval_env,
            ppo_hparams,
            "ppo_eval",
            eval_phase=True,
            sampling_temp=0.0,
            distributional_size=distributional_size,
            **collect_kwargs))
    return (train_summary, eval_collect_summary, (train_initialization,
                                                  eval_initialization))
  else:
    return (train_summary, None, (train_initialization,))


def _run_train(ppo_hparams,
               event_dir,
               model_dir,
               restarter,
               train_summary_op,
               eval_summary_op,
               initializers,
               epoch,
               report_fn=None,
               model_save_fn=None):
  """Train."""
  summary_writer = tf.summary.FileWriter(
      event_dir, graph=tf.get_default_graph(), flush_secs=60)

  model_saver = tf.train.Saver(
      tf.global_variables(ppo_hparams.policy_network + "/.*") +
      tf.global_variables("training/" + ppo_hparams.policy_network + "/.*") +
      # tf.global_variables("clean_scope.*") +  # Needed for sharing params.
      tf.global_variables("global_step") +
      tf.global_variables("losses_avg.*") +
      tf.global_variables("train_stats.*")
  )

  global_step = tf.train.get_or_create_global_step()
  with tf.control_dependencies([tf.assign_add(global_step, 1)]):
    train_summary_op = tf.identity(train_summary_op)

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for initializer in initializers:
      initializer(sess)
    trainer_lib.restore_checkpoint(model_dir, model_saver, sess)

    num_target_iterations = restarter.target_local_step
    num_completed_iterations = num_target_iterations - restarter.steps_to_go
    with restarter.training_loop():
      for epoch_index in range(num_completed_iterations, num_target_iterations):
        summary = sess.run(train_summary_op)
        if summary_writer:
          summary_writer.add_summary(summary, epoch_index)

        if (ppo_hparams.eval_every_epochs and
            epoch_index % ppo_hparams.eval_every_epochs == 0):
          eval_summary = sess.run(eval_summary_op)
          if summary_writer:
            summary_writer.add_summary(eval_summary, epoch_index)
          if report_fn:
            summary_proto = tf.Summary()
            summary_proto.ParseFromString(eval_summary)
            for elem in summary_proto.value:
              if "mean_score" in elem.tag:
                report_fn(elem.simple_value, epoch_index)
                break

        if (model_saver and ppo_hparams.save_models_every_epochs and
            (epoch_index % ppo_hparams.save_models_every_epochs == 0 or
             (epoch_index + 1) == num_target_iterations)):
          ckpt_name = "model.ckpt-{}".format(
              tf.train.global_step(sess, global_step)
          )
          # Keep the last checkpoint from each epoch in a separate directory.
          epoch_dir = os.path.join(model_dir, "epoch_{}".format(epoch))
          tf.gfile.MakeDirs(epoch_dir)
          for ckpt_dir in (model_dir, epoch_dir):
            model_saver.save(sess, os.path.join(ckpt_dir, ckpt_name))
          if model_save_fn:
            model_save_fn(model_dir)


def _rollout_metadata(batch_env, distributional_size=1):
  """Metadata for rollouts."""
  batch_env_shape = batch_env.observ.get_shape().as_list()
  batch_size = [batch_env_shape[0]]
  value_size = batch_size
  if distributional_size > 1:
    value_size = batch_size + [distributional_size]
  shapes_types_names = [
      # TODO(piotrmilos): possibly retrieve the observation type for batch_env
      (batch_size + batch_env_shape[1:], batch_env.observ_dtype, "observation"),
      (batch_size, tf.float32, "reward"),
      (batch_size, tf.bool, "done"),
      (batch_size + list(batch_env.action_shape), batch_env.action_dtype,
       "action"),
      (batch_size, tf.float32, "pdf"),
      (value_size, tf.float32, "value_function"),
  ]
  return shapes_types_names


class _MemoryWrapper(WrapperBase):
  """Memory wrapper."""

  def __init__(self, batch_env):
    super(_MemoryWrapper, self).__init__(batch_env)
    infinity = 10000000
    meta_data = list(zip(*_rollout_metadata(batch_env)))
    # In memory wrapper we do not collect pdfs neither value_function
    # thus we only need the first 4 entries of meta_data
    shapes = meta_data[0][:4]
    dtypes = meta_data[1][:4]
    self.speculum = tf.FIFOQueue(infinity, shapes=shapes, dtypes=dtypes)
    observs_shape = batch_env.observ.shape
    # TODO(piotrmilos): possibly retrieve the observation type for batch_env
    self._observ = tf.Variable(
        tf.zeros(observs_shape, self.observ_dtype), trainable=False)

  def __str__(self):
    return "MemoryWrapper(%s)" % str(self._batch_env)

  def simulate(self, action):

    # There is subtlety here. We need to collect data
    # obs, action = policy(obs), done, reward = env(abs, action)
    # Thus we need to enqueue data before assigning new observation

    reward, done = self._batch_env.simulate(action)

    with tf.control_dependencies([reward, done]):
      enqueue_op = self.speculum.enqueue(
          [self._observ.read_value(), reward, done, action])

    with tf.control_dependencies([enqueue_op]):
      assign = self._observ.assign(self._batch_env.observ)

    with tf.control_dependencies([assign]):
      return tf.identity(reward), tf.identity(done)


def _define_collect(batch_env, ppo_hparams, scope, frame_stack_size, eval_phase,
                    sampling_temp, force_beginning_resets,
                    distributional_size=1):
  """Collect trajectories.

  Args:
    batch_env: Batch environment.
    ppo_hparams: PPO hparams, defined in tensor2tensor.models.research.rl.
    scope: var scope.
    frame_stack_size: Number of last observations to feed into the policy.
    eval_phase: TODO(koz4k): Write docstring.
    sampling_temp: Sampling temperature for the policy.
    force_beginning_resets: Whether to reset at the beginning of each episode.
    distributional_size: optional, number of buckets in distributional RL.

  Returns:
    Returns memory (observations, rewards, dones, actions,
    pdfs, values_functions)
    containing a rollout of environment from nested wrapped structure.
  """
  epoch_length = ppo_hparams.epoch_length

  to_initialize = []
  with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
    num_agents = batch_env.batch_size

    to_initialize.append(batch_env)
    wrappers = [(StackWrapper, {
        "history": frame_stack_size
    }), (_MemoryWrapper, {})]
    rollout_metadata = None
    speculum = None
    for w in wrappers:
      tf.logging.info("Applying wrapper %s(%s) to env %s." % (str(
          w[0]), str(w[1]), str(batch_env)))
      batch_env = w[0](batch_env, **w[1])
      to_initialize.append(batch_env)

    rollout_metadata = _rollout_metadata(batch_env, distributional_size)
    speculum = batch_env.speculum

    def initialization_lambda(sess):
      for batch_env in to_initialize:
        batch_env.initialize(sess)

    memory = [
        tf.get_variable(  # pylint: disable=g-complex-comprehension
            "collect_memory_%d_%s" % (epoch_length, name),
            shape=[epoch_length] + shape,
            dtype=dtype,
            initializer=tf.zeros_initializer(),
            trainable=False) for (shape, dtype, name) in rollout_metadata
    ]

    cumulative_rewards = tf.get_variable(
        "cumulative_rewards", len(batch_env), trainable=False)

    eval_phase_t = tf.convert_to_tensor(eval_phase)
    should_reset_var = tf.Variable(True, trainable=False)
    zeros_tensor = tf.zeros(len(batch_env))

  force_beginning_resets = tf.convert_to_tensor(force_beginning_resets)

  def reset_ops_group():
    return tf.group(
        batch_env.reset(tf.range(len(batch_env))),
        tf.assign(cumulative_rewards, zeros_tensor))

  reset_op = tf.cond(
      tf.logical_or(should_reset_var.read_value(), force_beginning_resets),
      reset_ops_group, tf.no_op)

  with tf.control_dependencies([reset_op]):
    reset_once_op = tf.assign(should_reset_var, False)

  with tf.control_dependencies([reset_once_op]):

    def step(index, scores_sum, scores_num):
      """Single step."""
      index %= epoch_length  # Only needed in eval runs.
      # Note - the only way to ensure making a copy of tensor is to run simple
      # operation. We are waiting for tf.copy:
      # https://github.com/tensorflow/tensorflow/issues/11186
      obs_copy = batch_env.observ + 0
      value_fun_shape = (num_agents,)
      if distributional_size > 1:
        value_fun_shape = (num_agents, distributional_size)

      def env_step(arg1, arg2, arg3):  # pylint: disable=unused-argument
        """Step of the environment."""

        (logits, value_function) = get_policy(
            obs_copy, ppo_hparams, batch_env.action_space, distributional_size
        )
        action = common_layers.sample_with_temperature(logits, sampling_temp)
        action = tf.cast(action, tf.int32)
        action = tf.reshape(action, shape=(num_agents,))

        reward, done = batch_env.simulate(action)

        pdf = tfp.distributions.Categorical(logits=logits).prob(action)
        pdf = tf.reshape(pdf, shape=(num_agents,))
        value_function = tf.reshape(value_function, shape=value_fun_shape)
        done = tf.reshape(done, shape=(num_agents,))

        with tf.control_dependencies([reward, done]):
          return tf.identity(pdf), tf.identity(value_function), \
                 tf.identity(done)

      # TODO(piotrmilos): while_body is executed at most once,
      # thus should be replaced with tf.cond
      pdf, value_function, top_level_done = tf.while_loop(
          lambda _1, _2, _3: tf.equal(speculum.size(), 0),
          env_step,
          [
              tf.constant(0.0, shape=(num_agents,)),
              tf.constant(0.0, shape=value_fun_shape),
              tf.constant(False, shape=(num_agents,))
          ],
          parallel_iterations=1,
          back_prop=False,
      )

      with tf.control_dependencies([pdf, value_function]):
        obs, reward, done, action = speculum.dequeue()
        to_save = [obs, reward, done, action, pdf, value_function]
        save_ops = [
            tf.scatter_update(memory_slot, index, value)
            for memory_slot, value in zip(memory, to_save)
        ]
        cumulate_rewards_op = cumulative_rewards.assign_add(reward)

        agent_indices_to_reset = tf.where(top_level_done)[:, 0]
      with tf.control_dependencies([cumulate_rewards_op]):
        # TODO(piotrmilos): possibly we need cumulative_rewards.read_value()
        scores_sum_delta = tf.reduce_sum(
            tf.gather(cumulative_rewards.read_value(), agent_indices_to_reset))
        scores_num_delta = tf.count_nonzero(done, dtype=tf.int32)
      with tf.control_dependencies(save_ops +
                                   [scores_sum_delta, scores_num_delta]):
        reset_env_op = batch_env.reset(agent_indices_to_reset)
        reset_cumulative_rewards_op = tf.scatter_update(
            cumulative_rewards, agent_indices_to_reset,
            tf.gather(zeros_tensor, agent_indices_to_reset))
      with tf.control_dependencies([reset_env_op, reset_cumulative_rewards_op]):
        return [
            index + 1, scores_sum + scores_sum_delta,
            scores_num + scores_num_delta
        ]

    def stop_condition(i, _, resets):
      return tf.cond(eval_phase_t, lambda: resets < num_agents,
                     lambda: i < epoch_length)

    init = [tf.constant(0), tf.constant(0.0), tf.constant(0)]
    index, scores_sum, scores_num = tf.while_loop(
        stop_condition, step, init, parallel_iterations=1, back_prop=False)

  # We handle force_beginning_resets differently. We assume that all envs are
  # reseted at the end of episod (though it happens at the beginning of the
  # next one
  scores_num = tf.cond(force_beginning_resets,
                       lambda: scores_num + len(batch_env), lambda: scores_num)

  with tf.control_dependencies([scores_sum]):
    scores_sum = tf.cond(
        force_beginning_resets,
        lambda: scores_sum + tf.reduce_sum(cumulative_rewards.read_value()),
        lambda: scores_sum)

  mean_score = tf.cond(
      tf.greater(scores_num, 0),
      lambda: scores_sum / tf.cast(scores_num, tf.float32), lambda: 0.)
  printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ")
  with tf.control_dependencies([index, printing]):
    memory = [mem.read_value() for mem in memory]
    # When generating real data together with PPO training we must use single
    # agent. For PPO to work we reshape the history, as if it was generated
    # by real_ppo_effective_num_agents.
    if ppo_hparams.effective_num_agents is not None and not eval_phase:
      new_memory = []
      effective_num_agents = ppo_hparams.effective_num_agents
      assert epoch_length % ppo_hparams.effective_num_agents == 0, (
          "The rollout of ppo_hparams.epoch_length will be distributed amongst"
          "effective_num_agents of agents")
      new_epoch_length = int(epoch_length / effective_num_agents)
      for mem, info in zip(memory, rollout_metadata):
        shape, _, name = info
        new_shape = [effective_num_agents, new_epoch_length] + shape[1:]
        perm = list(range(len(shape) + 1))
        perm[0] = 1
        perm[1] = 0
        mem = tf.transpose(mem, perm=perm)
        mem = tf.reshape(mem, shape=new_shape)
        mem = tf.transpose(
            mem,
            perm=perm,
            name="collect_memory_%d_%s" % (new_epoch_length, name))
        new_memory.append(mem)
      memory = new_memory

    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
      mean_score_summary = tf.cond(
          tf.greater(scores_num, 0),
          lambda: tf.summary.scalar("mean_score_this_iter", mean_score), str)
      summaries = tf.summary.merge([
          mean_score_summary,
          tf.summary.scalar("episodes_finished_this_iter", scores_num)
      ])
      return memory, summaries, initialization_lambda