# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Multi Q-Network DQN agent."""

import copy
import os

from batch_rl.multi_head import atari_helpers
from dopamine.agents.dqn import dqn_agent
import gin
import tensorflow.compat.v1 as tf


@gin.configurable
class MultiNetworkDQNAgent(dqn_agent.DQNAgent):
  """DQN agent with multiple heads."""

  def __init__(self,
               sess,
               num_actions,
               num_networks=1,
               transform_strategy='IDENTITY',
               num_convex_combinations=1,
               network=atari_helpers.MulitNetworkQNetwork,
               init_checkpoint_dir=None,
               use_deep_exploration=False,
               **kwargs):
    """Initializes the agent and constructs the components of its graph.

    Args:
      sess: tf.Session, for executing ops.
      num_actions: int, number of actions the agent can take at any state.
      num_networks: int, Number of different Q-functions.
      transform_strategy: str, Possible options include (1) 'STOCHASTIC' for
        multiplication with a left stochastic matrix. (2) 'IDENTITY', in which
        case the heads are not transformed.
      num_convex_combinations: If transform_strategy is 'STOCHASTIC',
        then this argument specifies the number of random
        convex combinations to be created. If None, `num_heads` convex
        combinations are created.
      network: tf.Keras.Model. A call to this object will return an
        instantiation of the network provided. The network returned can be run
        with different inputs to create different outputs. See
        atari_helpers.MultiNetworkQNetwork as an example.
      init_checkpoint_dir: str, directory from which initial checkpoint before
        training is loaded if there doesn't exist any checkpoint in the current
        agent directory. If None, no initial checkpoint is loaded.
      use_deep_exploration: Adaptation of Bootstrapped DQN for REM exploration.
      **kwargs: Arbitrary keyword arguments.
    """
    tf.logging.info('Creating MultiNetworkDQNAgent with following parameters:')
    tf.logging.info('\t num_networks: %d', num_networks)
    tf.logging.info('\t transform_strategy: %s', transform_strategy)
    tf.logging.info('\t num_convex_combinations: %d', num_convex_combinations)
    tf.logging.info('\t init_checkpoint_dir: %s', init_checkpoint_dir)
    tf.logging.info('\t use_deep_exploration %s', use_deep_exploration)
    self.num_networks = num_networks
    if init_checkpoint_dir is not None:
      self._init_checkpoint_dir = os.path.join(init_checkpoint_dir,
                                               'checkpoints')
    else:
      self._init_checkpoint_dir = None
    # The transform matrix should be created on device specified by tf_device
    # if the transform_strategy is UNIFORM_STOCHASTIC or STOCHASTIC
    self._q_networks_transform = None
    self._num_convex_combinations = num_convex_combinations
    self.transform_strategy = transform_strategy
    self.use_deep_exploration = use_deep_exploration
    super(MultiNetworkDQNAgent, self).__init__(
        sess, num_actions, network=network, **kwargs)

  def _create_network(self, name):
    """Builds a multi-network Q-network that outputs Q-values for each network.

    Args:
      name: str, this name is passed to the tf.keras.Model and used to create
        variable scope under the hood by the tf.keras.Model.

    Returns:
      network: tf.keras.Model, the network instantiated by the Keras model.
    """
    # Pass the device_fn to place Q-networks on different devices
    kwargs = {'device_fn': lambda i: '/gpu:{}'.format(i // 4)}
    if self._q_networks_transform is None:
      if self.transform_strategy == 'STOCHASTIC':
        tf.logging.info('Creating q_networks transformation matrix..')
        self._q_networks_transform = atari_helpers.random_stochastic_matrix(
            self.num_networks, num_cols=self._num_convex_combinations)
    if self._q_networks_transform is not None:
      kwargs.update({'transform_matrix': self._q_networks_transform})
    return self.network(
        num_actions=self.num_actions,
        num_networks=self.num_networks,
        transform_strategy=self.transform_strategy,
        name=name,
        **kwargs)

  def _build_target_q_op(self):
    """Build an op used as a target for the Q-value.

    Returns:
      target_q_op: An op calculating the Q-value.
    """
    # Get the maximum Q-value across the actions dimension for each head.
    replay_next_qt_max = tf.reduce_max(
        self._replay_next_target_net_outputs.q_networks, axis=1)
    is_non_terminal = 1. - tf.cast(self._replay.terminals, tf.float32)
    is_non_terminal = tf.expand_dims(is_non_terminal, axis=-1)
    rewards = tf.expand_dims(self._replay.rewards, axis=-1)
    return rewards + (
        self.cumulative_gamma * replay_next_qt_max * is_non_terminal)

  def begin_episode(self, observation):
    """Returns the agent's first action for this episode.

    Args:
      observation: numpy array, the environment's initial observation.

    Returns:
      int, the selected action.
    """
    if self.use_deep_exploration:
      # Randomly pick a Q-function from all possible Q-functions for data
      # collection each episode for online experiments, similar to deep
      # exploration strategy proposed by Bootstrapped DQN
      self._sess.run(self._update_episode_q_function)
    return super(MultiNetworkDQNAgent, self).begin_episode(observation)

  def _build_networks(self):
    super(MultiNetworkDQNAgent, self)._build_networks()
    # q_argmax is only used for picking an action
    self._q_argmax_eval = tf.argmax(self._net_outputs.q_values, axis=1)[0]
    if self.use_deep_exploration:
      if self.transform_strategy.endswith('STOCHASTIC'):
        q_transform = atari_helpers.random_stochastic_matrix(
            self.num_networks, num_cols=1)
        self._q_episode_transform = tf.get_variable(
            trainable=False,
            dtype=tf.float32,
            shape=q_transform.get_shape().as_list(),
            name='q_episode_transform')
        self._update_episode_q_function = self._q_episode_transform.assign(
            q_transform)
        episode_q_function = tf.tensordot(
            self._net_outputs.unordered_q_networks,
            self._q_episode_transform, axes=[[2], [0]])
        self._q_argmax_train = tf.argmax(episode_q_function[:, :, 0], axis=1)[0]
      elif self.transform_strategy == 'IDENTITY':
        self._q_function_index = tf.Variable(
            initial_value=0,
            trainable=False,
            dtype=tf.int32,
            shape=(),
            name='q_head_episode')
        self._update_episode_q_function = self._q_function_index.assign(
            tf.random.uniform(
                shape=(), maxval=self.num_networks, dtype=tf.int32))
        q_function = self._net_outputs.unordered_q_networks[
            :, :, self._q_function_index]
        # This is only used for picking an action
        self._q_argmax_train = tf.argmax(q_function, axis=1)[0]
    else:
      self._q_argmax_train = self._q_argmax_eval

  def _select_action(self):
    if self.eval_mode:
      self._q_argmax = self._q_argmax_eval
    else:
      self._q_argmax = self._q_argmax_train
    return super(MultiNetworkDQNAgent, self)._select_action()

  def _build_train_op(self):
    """Builds a training op.

    Returns:
      train_op: An op performing one step of training from replay data.
    """
    actions = self._replay.actions
    indices = tf.stack([tf.range(actions.shape[0]), actions], axis=-1)
    replay_chosen_q = tf.gather_nd(
        self._replay_net_outputs.q_networks, indices=indices)
    target = tf.stop_gradient(self._build_target_q_op())
    loss = tf.losses.huber_loss(
        target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)
    q_head_losses = tf.reduce_mean(loss, axis=0)
    final_loss = tf.reduce_mean(q_head_losses)
    if self.summary_writer is not None:
      with tf.variable_scope('Losses'):
        tf.summary.scalar('HuberLoss', final_loss)
    self.optimizers = [copy.deepcopy(self.optimizer) for _ in
                       range(self.num_networks)]
    train_ops = []
    for i in range(self.num_networks):
      var_list = tf.trainable_variables(scope='Online/subnet_{}'.format(i))
      train_op = self.optimizers[i].minimize(final_loss, var_list=var_list)
      train_ops.append(train_op)
    return tf.group(*train_ops, name='merged_train_op')