# -*- coding: utf-8 -*-
# SPDX-License-Identifier: Apache-2.0

"""
:Author: FMR LLC
:Email: mabwiser@fmr.com

This module defines the abstract base class for contextual multi-armed bandit algorithms.
"""

import abc
from itertools import chain
from typing import Callable, Dict, List, NoReturn, Optional
import multiprocessing as mp

from joblib import Parallel, delayed
import numpy as np

from mabwiser.utils import Arm, Num, _NumpyRNG


class BaseMAB(metaclass=abc.ABCMeta):
    """Abstract base class for multi-armed bandits.

    This module is not intended to be used directly, instead it declares
    the basic skeleton of multi-armed bandits together with a set of parameters
    that are common to every bandit algorithm.

    It declares abstract methods that sub-classes can override to
    implement specific bandit policies using:

        - ``__init__`` constructor to initialize the bandit
        - ``add_arm`` method to add a new arm
        - ``fit`` method for training
        - ``partial_fit`` method for _online learning
        - ``predict_expectations`` method to retrieve the expectation of each arm
        - ``predict`` method for testing to retrieve the best arm based on the policy

        To ensure this is the case, alpha and l2_lambda are required to be greater than zero.

    Attributes
    ----------
    rng: np.random.RandomState
        The random number generator.
    arms: List
        The list of all arms.
    n_jobs: int
        This is used to specify how many concurrent processes/threads should be used for parallelized routines.
        Default value is set to 1.
        If set to -1, all CPUs are used.
        If set to -2, all CPUs but one are used, and so on.
    backend: str, optional
        Specify a parallelization backend implementation supported in the joblib library. Supported options are:
        - “loky” used by default, can induce some communication and memory overhead when exchanging input and output data with the worker Python processes.
        - “multiprocessing” previous process-based backend based on multiprocessing.Pool. Less robust than loky.
        - “threading” is a very low-overhead backend but it suffers from the Python Global Interpreter Lock if the
          called function relies a lot on Python objects.
        Default value is None. In this case the default backend selected by joblib will be used.
    arm_to_expectation: Dict[Arm, floot]
        The dictionary of arms (keys) to their expected rewards (values).
    """

    @abc.abstractmethod
    def __init__(self, rng: _NumpyRNG, arms: List[Arm], n_jobs: int, backend: str = None):
        """Abstract method.

        Creates a multi-armed bandit policy with the given arms.
        """
        self.rng: _NumpyRNG = rng
        self.arms: List[Arm] = arms
        self.n_jobs: int = n_jobs
        self.backend: str = backend

        self.arm_to_expectation: Dict[Arm, float] = dict.fromkeys(self.arms, 0)

    def add_arm(self, arm: Arm, binarizer: Callable = None, scaler: Callable = None) -> NoReturn:
        """Introduces a new arm to the bandit.

        Adds the new arm with zero expectations and
        calls the ``_uptake_new_arm()`` function of the sub-class.
        """
        self.arm_to_expectation[arm] = 0
        self._uptake_new_arm(arm, binarizer, scaler)

    @abc.abstractmethod
    def fit(self, decisions: np.ndarray, rewards: np.ndarray,
            contexts: Optional[np.ndarray] = None) -> NoReturn:
        """Abstract method.

        Fits the multi-armed bandit to the given
        decision and reward history and corresponding contexts if any.
        """
        pass

    @abc.abstractmethod
    def partial_fit(self, decisions: np.ndarray, rewards: np.ndarray,
                    contexts: Optional[np.ndarray] = None) -> NoReturn:
        """Abstract method.

        Updates the multi-armed bandit with the given
        decision and reward history and corresponding contexts if any.
        """
        pass

    @abc.abstractmethod
    def predict(self, contexts: Optional[np.ndarray] = None) -> Arm:
        """Abstract method.

        Returns the predicted arm.
        """
        pass

    @abc.abstractmethod
    def predict_expectations(self, contexts: Optional[np.ndarray] = None) -> Dict[Arm, Num]:
        """Abstract method.

        Returns a dictionary from arms (keys) to their expected rewards (values).
        """
        pass

    @abc.abstractmethod
    def _uptake_new_arm(self, arm: Arm, binarizer: Callable = None, scaler: Callable = None) -> NoReturn:
        """Abstract method.

        Updates the multi-armed bandit with the new arm.
        """
        pass

    @abc.abstractmethod
    def _fit_arm(self, arm: Arm, decisions: np.ndarray, rewards: np.ndarray,
                 contexts: Optional[np.ndarray] = None) -> NoReturn:
        """Abstract method.

        Fit operation for individual arm.
        """
        pass

    @abc.abstractmethod
    def _predict_contexts(self, contexts: np.ndarray, is_predict: bool,
                          seeds: Optional[np.ndarray] = None, start_index: Optional[int] = None) -> List:
        """Abstract method.

        Predict operation for set of contexts.
        """
        pass

    def _parallel_fit(self, decisions: np.ndarray, rewards: np.ndarray,
                      contexts: Optional[np.ndarray] = None):

        # Compute effective number of jobs
        n_jobs = self._effective_jobs(len(self.arms), self.n_jobs)

        # Perform parallel fit
        Parallel(n_jobs=n_jobs, require='sharedmem')(
                          delayed(self._fit_arm)(
                              arm, decisions, rewards, contexts)
                          for arm in self.arms)

    def _parallel_predict(self, contexts: np.ndarray, is_predict: bool):

        # Total number of contexts to predict
        n_contexts = len(contexts)

        # Partition contexts by job
        n_jobs, n_contexts, starts = self._partition_contexts(n_contexts)
        total_contexts = sum(n_contexts)

        # Get seed value for each context
        seeds = self.rng.randint(np.iinfo(np.int32).max, size=total_contexts)

        # Perform parallel predictions
        predictions = Parallel(n_jobs=n_jobs, backend=self.backend)(
                          delayed(self._predict_contexts)(
                              contexts[starts[i]:starts[i + 1]],
                              is_predict,
                              seeds[starts[i]:starts[i + 1]],
                              starts[i])
                          for i in range(n_jobs))

        # Reduce
        predictions = list(chain.from_iterable(t for t in predictions))

        return predictions if len(predictions) > 1 else predictions[0]

    def _partition_contexts(self, n_contexts: int):

        # Compute effective number of jobs
        n_jobs = self._effective_jobs(n_contexts, self.n_jobs)

        # Partition contexts between jobs
        n_contexts_per_job = np.full(n_jobs, n_contexts // n_jobs, dtype=np.int)
        n_contexts_per_job[:n_contexts % n_jobs] += 1
        starts = np.cumsum(n_contexts_per_job)

        return n_jobs, n_contexts_per_job.tolist(), [0] + starts.tolist()

    @staticmethod
    def _effective_jobs(size: int, n_jobs: int):
        if n_jobs < 0:
            n_jobs = max(mp.cpu_count() + 1 + n_jobs, 1)
        n_jobs = min(n_jobs, size)
        return n_jobs