import torch
import numpy as np

from .torch_trainer import TorchTrainer
from rltime.policies.torch.dqn import DQNPolicy


class DQN(TorchTrainer):
    """The basic (non-distributional) DQN training algorithm

    By default this used a replay buffer but this can be changed to use
    priortized replay, or 'online' mode for online nstep q-learning
    """

    def _train(self, double_q=False, loss_mode="huber", huber_kappa=1.0,
               loss_aggregation="mean", loss_timestep_aggregation=None,
               history_mode={"type": "replay"}, **kwargs):
        """ Entry point for DQN training

        Args:
            double_q: Whether to use 'doubleq' method for target value
                calculation, i.e. to use the online policy for selecting the
                greedy action to use to take the value estimations from the
                target policy (Instead of using the target-policy for both)
            loss_mode: How to calculate the loss on the td-errors: Either 'mse'
                or 'huber' is supported. huber loss restricts the gradient of
                the loss to specific value (e.g. '1.0' if huber_kappa=1.0)
                rather then exploding as the td-errors grow, while still
                maintaining squared-error behavior for lower losses
            huber_kappa: Kappa value to use in the huber-loss formula if
                loss_mode=huber.
            loss_aggregation: How to aggregate the losses across the batch.
                Supported values are 'mean' or 'sum'
            loss_timestep_aggregation: Optionally use a different loss
                aggregation across timesteps. For example 'mean' across
                timesteps and 'sum' the mbatch (Relevant only if multi-step
                training, i.e. nstep_train>1)
        """
        self.double_q = double_q
        self.loss_mode = loss_mode
        self.huber_kappa = huber_kappa
        self.loss_aggregation = self._get_aggregator(loss_aggregation)
        self.loss_timestep_aggregation = None \
            if not loss_timestep_aggregation \
            else self._get_aggregator(loss_timestep_aggregation)
        super()._train(history_mode=history_mode, **kwargs)

    @staticmethod
    def create_policy(**kwargs):
        return DQNPolicy.create(**kwargs)

    def _get_bootstrap_target_value(self, target_states, timesteps):
        # Target values are taken from the target policy
        target_values = self.target_policy.predict(
            target_states, timesteps=timesteps)
        if not self.double_q:
            # Without double-Q, actions are chosen using the same target policy
            # values
            action_selection_values = target_values
        else:
            # For double Q, we use the online policy to select the best target
            # action, and take the value of that action from the target policy
            # This is supposed to improve training results but costs an
            # additional forward pass
            action_selection_values = self.policy.predict(
                target_states, timesteps=timesteps)

        target_actions = action_selection_values.argmax(dim=-1, keepdim=True)
        target_values = target_values.gather(
            dim=-1, index=target_actions).squeeze(-1)
        return target_values

    def _report_losses_if_needed(self, losses, extra_train_data):
        """Notifies the history buffer about the losses for each index, if
        requested"""
        if "loss_indices" not in extra_train_data:
            return
        notify_errors = losses.data.cpu().numpy()
        loss_indices = extra_train_data["loss_indices"]
        assert(notify_errors.shape == loss_indices.shape[:1])
        self.history_buffer.update_losses(loss_indices, notify_errors)

    def _apply_importance_weights_if_needed(self, losses, extra_train_data):
        """Applies importance weights to the losses, if supplied"""
        if "importance_weights" not in extra_train_data:
            return losses

        importance_weights = self.policy.make_tensor(
            extra_train_data["importance_weights"])
        assert(importance_weights.shape == losses.shape)
        losses = losses * importance_weights
        # Add the average importance weights to the result log
        self.value_log.log(
            "importance_weights",
            np.mean(extra_train_data["importance_weights"]), group="train")
        return losses

    def _calc_loss(self, errors):
        """Calculates the losses given the batch-wise 'td-errors'

        This is either squared-error or huber loss
        """
        if self.loss_mode == "mse":
            return errors.pow(2)
        elif self.loss_mode == "huber":
            # Huber loss element-wise
            abs_errors = torch.abs(errors)
            return torch.where(
                abs_errors <= self.huber_kappa,
                0.5 * errors.pow(2),
                self.huber_kappa * (abs_errors - (0.5 * self.huber_kappa)))
        else:
            assert(False), \
                f"{self.loss_mode} is not a valid q-learning loss mode"

    def _get_aggregator(self, name):
        assert (name in ['mean', 'sum'])
        return torch.mean if name == 'mean' else torch.sum

    def _aggregate_losses(self, losses, timesteps):
        """Aggregates the losses to a single value using the requested
        aggregation modes
        """
        assert(len(losses.shape) == 1)
        if self.loss_timestep_aggregation:
            # Optionally use a different aggregation across timesteps
            # Timesteps are on the first dimension
            losses = self.loss_timestep_aggregation(
                losses.view(timesteps, -1), dim=0)
        return self.loss_aggregation(losses)

    def _compute_grads(self, states, targets, policy_outputs, extra_data,
                       timesteps):
        # Q-Learning loss function is basically huber/mean-squared error
        # between the outputs (qvalues) of the acted actions
        # and the bootstrapped nstep-discounted reward return starting from
        # that state/action
        # Forward training pass to return all the qvalues
        all_qvalues = self.policy.predict(states, timesteps)

        # Take the actions used when this transition happened at acting time
        actions = self.policy.make_tensor(
            policy_outputs['actions']).long().unsqueeze(-1)
        assert(actions.shape[:-1] == all_qvalues.shape[:-1])
        # Gather only the qvalues for the acted actions, to train only those
        # action-q-values
        chosen_qvalues = torch.gather(
            all_qvalues, dim=-1, index=actions).squeeze(-1)

        assert(targets.shape == chosen_qvalues.shape)
        td_errors = (chosen_qvalues - targets)

        loss = self._calc_loss(td_errors)

        # If we got importance weights (From prioritized replay for example),
        # multiply the losses element-wise by them
        loss = self._apply_importance_weights_if_needed(loss, extra_data)

        # Aggregate the losses and do the backprop pass
        loss = self._aggregate_losses(loss, timesteps)
        loss.backward()

        # If someone wants to get the td-error back, send them (Prioritized
        # replay for example)
        # NOTE: This needs to be after the backward() to avoid stalling the
        # cuda pipe since it copies the losses back to CPU
        self._report_losses_if_needed(td_errors, extra_data)

        # Log some stats
        self.value_log.log("qloss", loss.item(), group="train")
        self.value_log.log(
            "qvalue", chosen_qvalues.mean().item(), group="train")