# Copyright 2018/2019 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from rlgraph import get_backend
from rlgraph.components.loss_functions import LossFunction
from rlgraph.spaces import IntBox
from rlgraph.spaces.space_utils import sanity_check_space
from rlgraph.utils import pytorch_one_hot
from rlgraph.utils.decorators import rlgraph_api, graph_fn
from rlgraph.utils.pytorch_util import pytorch_tile
from rlgraph.utils.util import get_rank

if get_backend() == "tf":
    import tensorflow as tf
elif get_backend() == "pytorch":
    import torch


class DQNLossFunction(LossFunction):
    """
    The classic 2015 DQN Loss Function [1] with options for "double" Q-losses [2], Huber loss [3], and container
    actions [4]:

    L = Expectation-over-uniform-batch(r + gamma x max_a'Qt(s',a') - Qn(s,a))^2
    Where Qn is the "normal" Q-network and Qt is the "target" net (which is a little behind Qn for stability purposes).

    [1] Human-level control through deep reinforcement learning. Mnih, Kavukcuoglu, Silver et al. - 2015
    [2] Deep Reinforcement Learning with Double Q-learning. v. Hasselt, Guez, Silver - 2015
    [3] https://en.wikipedia.org/wiki/Huber_loss
    [4] Action Branching Architectures for Deep Reinforcement Learning. Tavakoli, Pardo, and Kormushev - 2017
    """
    def __init__(self, double_q=False, huber_loss=False, importance_weights=False, n_step=1,
                 shared_container_action_target=True, scope="dqn-loss-function", **kwargs):
        """
        Args:
            double_q (bool): Whether to use the double DQN loss function ([2]).
            huber_loss (bool): Whether to apply a huber loss correction ([3]).
            importance_weights (bool): Where to use importance weights from a prioritized replay.
            n_step (int): n-step adjustment to discounting.

            shared_container_action_target (bool): Whether - only in the case of container actions - the target term
                should be shared (average) over all action components' single loss terms. Default: True.
        """
        self.double_q = double_q
        self.huber_loss = huber_loss
        assert n_step >= 1, "Number of steps for n-step learning must be >= 1, is {}".format(n_step)
        # TODO reward must be preprocessed to work correctly for n-step.
        # For Apex, this is done in the worker - do we want to move this as an in-graph option too?
        self.n_step = n_step
        self.shared_container_action_target = shared_container_action_target

        # Clip value, see: https://en.wikipedia.org/wiki/Huber_loss
        self.huber_delta = kwargs.get("huber_delta", 1.0)
        self.importance_weights = importance_weights

        super(DQNLossFunction, self).__init__(scope=scope, **kwargs)

        self.flat_action_space = None
        self.ranks_to_reduce = 0  # How many ranks do we have to reduce to get down to the final loss per batch item?

    def check_input_spaces(self, input_spaces, action_space=None):
        """
        Do some sanity checking on the incoming Spaces:
        """
        assert action_space is not None
        self.action_space = action_space
        self.flat_action_space = action_space.flatten()
        # Check for IntBox and num_categories.
        sanity_check_space(self.action_space, must_have_categories=True, allowed_sub_types=[IntBox])
        self.ranks_to_reduce = len(self.action_space.get_shape(with_batch_rank=True)) - 1

    @rlgraph_api
    def loss(self, q_values_s, actions, rewards, terminals, qt_values_sp, q_values_sp=None, importance_weights=None):
        loss_per_item = self.loss_per_item(
            q_values_s, actions, rewards, terminals, qt_values_sp, q_values_sp, importance_weights
        )
        total_loss = self.loss_average(loss_per_item)
        return total_loss, loss_per_item

    @rlgraph_api
    def loss_per_item(self, q_values_s, actions, rewards, terminals, qt_values_sp, q_values_sp=None,
                      importance_weights=None):
        # Get the targets per action.
        td_targets = self._graph_fn_get_td_targets(rewards, terminals, qt_values_sp, q_values_sp)
        # Average over container sub-actions.
        if self.shared_container_action_target is True:
            td_targets = self._graph_fn_average_over_container_keys(td_targets)

        # Calculate the loss per item.
        loss_per_item = self._graph_fn_loss_per_item(td_targets, q_values_s, actions, importance_weights)
        # Average over container sub-actions.
        loss_per_item = self._graph_fn_average_over_container_keys(loss_per_item)

        # Apply huber loss.
        loss_per_item = self._graph_fn_apply_huber_loss_if_necessary(loss_per_item)

        return loss_per_item

    @graph_fn(flatten_ops=True, split_ops=True, add_auto_key_as_first_param=True)
    def _graph_fn_get_td_targets(self, key, rewards, terminals, qt_values_sp, q_values_sp=None):
        """
        Args:
            rewards (SingleDataOp): The batch of rewards that we received after having taken a in s (from a memory).
            terminals (SingleDataOp): The batch of terminal signals that we received after having taken a in s
                (from a memory).
            qt_values_sp (SingleDataOp): The batch of Q-values representing the expected accumulated discounted
                returns (estimated by the target net) when in s' and taking different actions a'.
            q_values_sp (Optional[SingleDataOp]): If `self.double_q` is True: The batch of Q-values representing the
                expected accumulated discounted returns (estimated by the (main) policy net) when in s' and taking
                different actions a'.

        Returns:
            SingleDataOp: The target values vector.
        """
        qt_sp_ap_values = None

        # Numpy backend primarily for testing purposes.
        if self.backend == "python" or get_backend() == "python":
            from rlgraph.utils.numpy import one_hot
            if self.double_q:
                a_primes = np.argmax(q_values_sp, axis=-1)
                a_primes_one_hot = one_hot(a_primes, depth=self.flat_action_space[key].num_categories)
                qt_sp_ap_values = np.sum(qt_values_sp * a_primes_one_hot, axis=-1)
            else:
                qt_sp_ap_values = np.max(qt_values_sp, axis=-1)

            for _ in range(qt_sp_ap_values.ndim - 1):
                rewards = np.expand_dims(rewards, axis=1)

            qt_sp_ap_values = np.where(terminals, np.zeros_like(qt_sp_ap_values), qt_sp_ap_values)

        elif get_backend() == "tf":
            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = tf.stop_gradient(qt_values_sp)

            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = tf.argmax(input=q_values_sp, axis=-1)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = tf.one_hot(indices=a_primes, depth=self.flat_action_space[key].num_categories)
                qt_sp_ap_values = tf.reduce_sum(input_tensor=(qt_values_sp * one_hot), axis=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = tf.reduce_max(input_tensor=qt_values_sp, axis=-1)

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = tf.expand_dims(rewards, axis=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            qt_sp_ap_values = tf.where(
                condition=terminals, x=tf.zeros_like(qt_sp_ap_values), y=qt_sp_ap_values
            )

        elif get_backend() == "pytorch":
            if not isinstance(terminals, torch.ByteTensor):
                terminals = terminals.byte()
            # Add batch dim in case of single sample.
            if qt_values_sp.dim() == 1:
                rewards = rewards.unsqueeze(-1)
                terminals = terminals.unsqueeze(-1)
                q_values_sp = q_values_sp.unsqueeze(-1)
                qt_values_sp = qt_values_sp.unsqueeze(-1)

            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = qt_values_sp.detach()
            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = torch.argmax(q_values_sp, dim=-1, keepdim=True)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = pytorch_one_hot(a_primes, depth=self.flat_action_space[key].num_categories)
                qt_sp_ap_values = torch.sum(qt_values_sp * one_hot.squeeze(), dim=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = torch.max(qt_values_sp, -1)[0]

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = torch.unsqueeze(rewards, dim=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            # torch.where cannot broadcast here, so tile and reshape to same shape.
            if qt_sp_ap_values.dim() > 1:
                num_tiles = np.prod(qt_sp_ap_values.shape[1:])
                terminals = pytorch_tile(terminals, num_tiles, -1).reshape(qt_sp_ap_values.shape)
            qt_sp_ap_values = torch.where(
                terminals, torch.zeros_like(qt_sp_ap_values), qt_sp_ap_values
            )
        td_targets = (rewards + (self.discount ** self.n_step) * qt_sp_ap_values)
        return td_targets

    @graph_fn(flatten_ops=True, split_ops=True, add_auto_key_as_first_param=True)
    def _graph_fn_loss_per_item(self, key, td_targets, q_values_s, actions, importance_weights=None):
        """
        Args:
            td_targets (SingleDataOp): The already calculated TD-target terms (r + gamma maxa'Qt(s',a')
                OR for double Q: r + gamma Qt(s',argmaxa'(Q(s',a'))))

            q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns
                when in s and taking different actions a.

            actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory).

            importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to
                apply to the losses.

        Returns:
            SingleDataOp: The loss values vector (one single value for each batch item).
        """
        # Numpy backend primarily for testing purposes.
        if self.backend == "python" or get_backend() == "python":
            from rlgraph.utils.numpy import one_hot

            actions_one_hot = one_hot(actions, depth=self.flat_action_space[key].num_categories)
            q_s_a_values = np.sum(q_values_s * actions_one_hot, axis=-1)

            td_delta = td_targets - q_s_a_values

            if td_delta.ndim > 1:
                if self.importance_weights:
                    td_delta = np.mean(
                        td_delta * importance_weights,
                        axis=list(range(1, self.ranks_to_reduce + 1))
                    )

                else:
                    td_delta = np.mean(td_delta, axis=list(range(1, self.ranks_to_reduce + 1)))

        elif get_backend() == "tf":
            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = tf.one_hot(indices=actions, depth=self.flat_action_space[key].num_categories)
            q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot), axis=-1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = td_targets - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = tf.reduce_mean(input_tensor=td_delta, axis=list(range(1, self.ranks_to_reduce + 1)))

        elif get_backend() == "pytorch":
            # Add batch dim in case of single sample.
            if q_values_s.dim() == 1:
                q_values_s = q_values_s.unsqueeze(-1)
                actions = actions.unsqueeze(-1)
                if self.importance_weights:
                    importance_weights = importance_weights.unsqueeze(-1)

            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = pytorch_one_hot(actions, depth=self.flat_action_space[key].num_categories)
            q_s_a_values = torch.sum((q_values_s * one_hot), -1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = td_targets - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = torch.mean(td_delta, tuple(range(1, self.ranks_to_reduce + 1)), keepdim=False)

        # Apply importance-weights from a prioritized replay to the loss.
        if self.importance_weights:
            return importance_weights * td_delta
        else:
            return td_delta

    @graph_fn
    def _graph_fn_apply_huber_loss_if_necessary(self, td_delta):
        if self.backend == "python" or get_backend() == "python":
            if self.huber_loss:
                return np.where(
                    condition=np.abs(td_delta) <= self.huber_delta,
                    x=0.5 * np.square(td_delta),
                    y=self.huber_delta * (np.abs(td_delta) - 0.5 * self.huber_delta)
                )
            else:
                return 0.5 * np.square(x=td_delta)
        elif get_backend() == "tf":
            if self.huber_loss:
                return tf.where(
                    condition=tf.abs(x=td_delta) <= self.huber_delta,
                    x=0.5 * tf.square(x=td_delta),
                    y=self.huber_delta * (tf.abs(x=td_delta) - 0.5 * self.huber_delta)
                )
            else:
                return 0.5 * tf.square(x=td_delta)
        elif get_backend() == "pytorch":
            if self.huber_loss:
                # Not certain if arithmetics need to be expressed via torch operators.
                return torch.where(
                    torch.abs(td_delta) <= self.huber_delta,
                    # PyTorch has no `square`
                    0.5 * torch.pow(td_delta, 2),
                    self.huber_delta * (torch.abs(td_delta) - 0.5 * self.huber_delta)
                )
            else:
                return 0.5 * td_delta * td_delta