import copy
import logging
import traceback
from typing import Any, Callable, Iterable, List

import numpy

from fragile.core.base_classes import (
    BaseCritic,
    BaseEnvironment,
    BaseModel,
    BaseSwarm,
)
from fragile.core.states import OneWalker, StatesEnv, StatesModel, StatesWalkers
from fragile.core.tree import HistoryTree
from fragile.core.utils import running_in_ipython, Scalar
from fragile.core.walkers import Walkers


class Swarm(BaseSwarm):
    """
    The Swarm is in charge of performing a fractal evolution process.

    It contains the necessary logic to use an Environment, a Model, and a \
    Walkers instance to run the Swarm evolution algorithm.
    """

    _log = logging.getLogger("Swarm")

    def __init__(
        self,
        n_walkers: int,
        env: Callable[[], BaseEnvironment],
        model: Callable[[BaseEnvironment], BaseModel],
        walkers: Callable[..., Walkers] = Walkers,
        reward_scale: float = 1.0,
        distance_scale: float = 1.0,
        tree: Callable[[], HistoryTree] = None,
        report_interval: int = numpy.inf,
        show_pbar: bool = True,
        use_notebook_widget: bool = True,
        force_logging: bool = False,
        *args,
        **kwargs
    ):
        """
        Initialize a :class:`Swarm`.

        Args:
            n_walkers: Number of walkers of the swarm.
            env: A callable that returns an instance of an :class:`Environment`.
            model: A callable that returns an instance of a :class:`Model`.
            walkers: A callable that returns an instance of :class:`BaseWalkers`.
            reward_scale: Virtual reward exponent for the reward score.
            distance_scale: Virtual reward exponent for the distance score.
            tree: class:`StatesTree` that keeps track of the visited states.
            report_interval: Display the algorithm progress every ``report_interval`` epochs.
            show_pbar: If ``True`` A progress bar will display the progress of \
                       the algorithm run.
            use_notebook_widget: If ``True`` and the class is running in an IPython \
                                kernel it will display the evolution of the swarm \
                                in a widget.
            force_logging: If ``True``, disable al ``ipython`` related behaviour.
            *args: Additional args passed to init_swarm.
            **kwargs: Additional kwargs passed to init_swarm.

        """
        self._prune_tree = False
        self._epoch = 0
        self.show_pbar = show_pbar
        self.report_interval = report_interval
        super(Swarm, self).__init__(
            walkers=walkers,
            env=env,
            model=model,
            n_walkers=n_walkers,
            reward_scale=reward_scale,
            distance_scale=distance_scale,
            tree=tree,
            *args,
            **kwargs
        )
        self._notebook_container = None
        self._use_notebook_widget = use_notebook_widget
        self._ipython_mode = running_in_ipython() and not force_logging
        self.setup_notebook_container()

    def __len__(self) -> int:
        return self.walkers.n

    def __repr__(self) -> str:
        walkers_data = self.walkers.__repr__()
        tree_data = self.tree.__repr__() if self.tree is not None else ""
        return walkers_data + tree_data

    @property
    def env(self) -> BaseEnvironment:
        """All the simulation code (problem specific) will be handled here."""
        return self._env

    @property
    def model(self) -> BaseModel:
        """
        All the policy and random perturbation code (problem specific) will \
        be handled here.
        """
        return self._model

    @property
    def walkers(self) -> Walkers:
        """
        Access the :class:`Walkers` in charge of implementing the FAI \
        evolution process.
        """
        return self._walkers

    @property
    def best_time(self) -> numpy.ndarray:
        """Return the state of the best walker found in the current algorithm run."""
        return self.walkers.best_time

    @property
    def best_state(self) -> numpy.ndarray:
        """Return the state of the best walker found in the current algorithm run."""
        return self.walkers.best_state

    @property
    def best_reward(self) -> Scalar:
        """Return the reward of the best walker found in the current algorithm run."""
        return self.walkers.best_reward

    @property
    def best_id(self) -> int:
        """
        Return the id (hash value of the state) of the best walker found in the \
        current algorithm run.
        """
        return self.walkers.best_id

    @property
    def best_obs(self) -> numpy.ndarray:
        """
        Return the observation corresponding to the best walker found in the \
        current algorithm run.
        """
        return self.walkers.best_obs

    @property
    def critic(self) -> BaseCritic:
        """Return the :class:`Critic` of the walkers."""
        return self._walkers.critic

    def get(self, name: str, default: Any = None) -> Any:
        """Access attributes of the :class:`Swarm` and its children."""
        if hasattr(self.walkers.states, name):
            return getattr(self.walkers.states, name)
        elif hasattr(self.walkers.env_states, name):
            return getattr(self.walkers.env_states, name)
        elif hasattr(self.walkers.model_states, name):
            return getattr(self.walkers.model_states, name)
        elif hasattr(self.walkers, name):
            return getattr(self.walkers, name)
        elif hasattr(self, name):
            return getattr(self, name)
        return default

    def init_swarm(
        self,
        env_callable: Callable[[], BaseEnvironment],
        model_callable: Callable[[BaseEnvironment], BaseModel],
        walkers_callable: Callable[..., Walkers],
        n_walkers: int,
        reward_scale: float = 1.0,
        distance_scale: float = 1.0,
        tree: Callable[[], HistoryTree] = None,
        prune_tree: bool = True,
        *args,
        **kwargs
    ):
        """
        Initialize and set up all the necessary internal variables to run the swarm.

        This process involves instantiating the Swarm, the Environment and the \
        model.

        Args:
            env_callable: A callable that returns an instance of an
                :class:`fragile.Environment`.
            model_callable: A callable that returns an instance of a
                :class:`fragile.Model`.
            walkers_callable: A callable that returns an instance of
                :class:`fragile.Walkers`.
            n_walkers: Number of walkers of the swarm.
            reward_scale: Virtual reward exponent for the reward score.
            distance_scale: Virtual reward exponent for the distance score.
            tree: class:`StatesTree` that keeps track of the visited states.
            prune_tree: If `tree` is `None` it has no effect. If true, \
                       store in the :class:`Tree` only the past history of alive \
                        walkers, and discard the branches with leaves that have \
                        no walkers.
            args: Passed to ``walkers_callable``.
            kwargs: Passed to ``walkers_callable``.

        Returns:
            None.

        """
        self._env: BaseEnvironment = env_callable()
        self._model: BaseModel = model_callable(self._env)

        model_params = self._model.get_params_dict()
        env_params = self._env.get_params_dict()
        self._walkers: Walkers = walkers_callable(
            env_state_params=env_params,
            model_state_params=model_params,
            n_walkers=n_walkers,
            reward_scale=reward_scale,
            distance_scale=distance_scale,
            *args,
            **kwargs
        )
        self.tree: HistoryTree = tree() if tree is not None else None
        self._prune_tree = prune_tree
        self._epoch = 0

    def reset(
        self,
        root_walker: OneWalker = None,
        walkers_states: StatesWalkers = None,
        model_states: StatesModel = None,
        env_states: StatesEnv = None,
    ):
        """
        Reset the :class:`fragile.Walkers`, the :class:`Environment`, the \
        :class:`Model` and clear the internal data to start a new search process.

        Args:
            root_walker: Walker representing the initial state of the search. \
                         The walkers will be reset to this walker, and it will \
                         be added to the root of the :class:`StateTree` if any.
            model_states: :class:`StatesModel` that define the initial state of \
                          the :class:`Model`.
            env_states: :class:`StatesEnv` that define the initial state of \
                        the :class:`Environment`.
            walkers_states: :class:`StatesWalkers` that define the internal \
                            states of the :class:`Walkers`.

        """
        self._epoch = 0
        env_states = (
            self.env.reset(batch_size=self.walkers.n) if env_states is None else env_states
        )
        # Add corresponding root_walkers data to env_states
        if root_walker is not None:
            if not isinstance(root_walker, OneWalker):
                raise ValueError(
                    "Root walker needs to be an "
                    "instance of OneWalker, got %s instead." % type(root_walker)
                )
            env_states = self._update_env_with_root(root_walker=root_walker, env_states=env_states)

        model_states = (
            self.model.reset(batch_size=len(self.walkers), env_states=env_states)
            if model_states is None
            else model_states
        )
        model_states.update(init_actions=model_states.actions)
        self.walkers.reset(env_states=env_states, model_states=model_states)
        if self.tree is not None:
            root_id = (
                self.walkers.get("id_walkers")[0]
                if root_walker is None
                else copy.copy(root_walker.id_walkers)
            )
            self.tree.reset(
                root_id=root_id,
                env_states=self.walkers.env_states,
                model_states=self.walkers.model_states,
                walkers_states=self.walkers.states,
            )

    def run(
        self,
        root_walker: OneWalker = None,
        model_states: StatesModel = None,
        env_states: StatesEnv = None,
        walkers_states: StatesWalkers = None,
        report_interval: int = None,
        show_pbar: bool = None,
    ):
        """
        Run a new search process.

        Args:
            root_walker: Walker representing the initial state of the search. \
                         The walkers will be reset to this walker, and it will \
                         be added to the root of the :class:`StateTree` if any.
            model_states: :class:`StatesModel` that define the initial state of \
                          the :class:`Model`.
            env_states: :class:`StatesEnv` that define the initial state of \
                        the :class:`Function`.
            walkers_states: :class:`StatesWalkers` that define the internal \
                            states of the :class:`Walkers`.
            report_interval: Display the algorithm progress every ``log_interval`` epochs.
            show_pbar: A progress bar will display the progress of the algorithm run.

        Returns:
            None.

        """
        report_interval = self.report_interval if report_interval is None else report_interval
        self.reset(
            root_walker=root_walker,
            model_states=model_states,
            env_states=env_states,
            walkers_states=walkers_states,
        )

        for _ in self.get_run_loop(show_pbar=show_pbar):
            if self.calculate_end_condition():
                break
            try:
                self.run_step()
                if self.epoch % report_interval == 0 and self.epoch > 0:
                    self.report_progress()
                self.increment_epoch()
            except (KeyboardInterrupt, Exception) as e:
                if not isinstance(e, KeyboardInterrupt):
                    tb = traceback.format_exc()
                    name = e.__class__.__name__
                    self._log.warning(
                        "Stopped due to unhandled exception: %s\n %s\n %s" % (name, e, tb)
                    )
                break

    def get_run_loop(self, show_pbar: bool = None) -> Iterable[int]:
        """
        Return a tqdm progress bar or a regular range iterator.

        If the code is running in an IPython kernel it will also display the \
        internal ``_notebook_container``.

        Args:
            show_pbar: If ``False`` the progress bar will not be displayed.

        Returns:
            A Progressbar if ``show_pbar`` is ``True`` and the code is running \
            in an IPython kernel. If the code is running in a terminal the logging \
            level must be set at least to "INFO". Otherwise return a range iterator \
            for ``self.max_range`` iteration.

        """
        show_pbar = show_pbar if show_pbar is not None else self.show_pbar
        no_tqdm = not (
            show_pbar if self._ipython_mode else self._log.level < logging.WARNING and show_pbar
        )
        if self._ipython_mode:
            from tqdm.notebook import trange
        else:
            from tqdm import trange

        loop_iterable = trange(
            self.max_epochs, desc="%s" % self.__class__.__name__, disable=no_tqdm
        )

        if self._ipython_mode and self._use_notebook_widget:
            from IPython.core.display import display

            display(self._notebook_container)
        return loop_iterable

    def setup_notebook_container(self):
        """Display the display widgets if the Swarm is running in an IPython kernel."""
        if self._ipython_mode and self._use_notebook_widget:
            from ipywidgets import HTML
            from IPython.core.display import display, HTML as cell_html

            # Set font weight of tqdm progressbar
            display(cell_html("<style> .widget-label {font-weight: bold !important;} </style>"))
            self._notebook_container = HTML()

    def report_progress(self):
        """Report information of the current run."""
        if self._ipython_mode and self._use_notebook_widget:
            line_break = '<br style="line-height:1px; content: "  ";>'
            html = str(self).replace("\n\n", "\n").replace("\n", line_break)
            # Add strong formatting for headers
            html = html.replace("Walkers States", "<strong>Walkers States</strong>")
            html = html.replace("Model States", "<strong>Model States</strong>")
            html = html.replace("Environment States", "<strong>Environment Model</strong>")
            if self.tree is not None:
                tree_name = self.tree.__class__.__name__
                html = html.replace(tree_name, "<strong>%s</strong>" % tree_name)
            self._notebook_container.value = "%s" % html
        elif not self._ipython_mode:
            self._log.info(repr(self))

    def calculate_end_condition(self) -> bool:
        """Implement the logic for deciding if the algorithm has finished. \
        The algorithm will stop if it returns True."""
        return self.walkers.calculate_end_condition()

    def step_and_update_best(self) -> None:
        """
        Make the positions of the walkers evolve and keep track of the new states found.

        It also keeps track of the best state visited.
        """
        self.walkers.update_best()
        self.walkers.fix_best()
        self.step_walkers()

    def balance_and_prune(self) -> None:
        """
        Calculate the virtual reward and perform the cloning process.

        It also updates the :class:`Tree` data structure that takes care of \
        storing the visited states.
        """
        self.walkers.balance()
        self.prune_tree()

    def run_step(self) -> None:
        """
        Compute one iteration of the :class:`Swarm` evolution process and \
        update all the data structures.
        """
        self.step_and_update_best()
        self.balance_and_prune()
        self.walkers.fix_best()

    def step_walkers(self) -> None:
        """
        Make the walkers evolve to their next state sampling an action from the \
        :class:`Model` and applying it to the :class:`Environment`.
        """
        model_states = self.walkers.model_states
        env_states = self.walkers.env_states

        parent_ids = (
            copy.deepcopy(self.walkers.states.id_walkers) if self.tree is not None else None
        )

        model_states = self.model.predict(
            env_states=env_states, model_states=model_states, walkers_states=self.walkers.states
        )
        env_states = self.env.step(model_states=model_states, env_states=env_states)
        self.walkers.update_states(
            env_states=env_states, model_states=model_states,
        )
        self.update_tree(parent_ids)

    def update_tree(self, parent_ids: List[int]) -> None:
        """
        Add a list of walker states represented by `states_ids` to the :class:`Tree`.

        Args:
            parent_ids: list containing the ids of the parents of the new states added.
        """
        if self.tree is not None:
            self.tree.add_states(
                parent_ids=parent_ids,
                env_states=self.walkers.env_states,
                model_states=self.walkers.model_states,
                walkers_states=self.walkers.states,
                n_iter=int(self.walkers.epoch),
            )

    def prune_tree(self) -> None:
        """
        Remove all the branches that are do not have alive walkers at their leaf nodes.
        """
        if self.tree is not None:
            leaf_nodes = set(self.get("id_walkers"))
            self.tree.prune_tree(alive_leafs=leaf_nodes)

    def _update_env_with_root(self, root_walker, env_states) -> StatesEnv:
        env_states.rewards[:] = copy.deepcopy(root_walker.rewards[0])
        env_states.observs[:] = copy.deepcopy(root_walker.observs[0])
        env_states.states[:] = copy.deepcopy(root_walker.states[0])
        return env_states


class NoBalance(Swarm):
    """Swarm that does not perform the cloning process."""

    def balance_and_prune(self):
        """Do noting."""
        pass

    def calculate_end_condition(self):
        """Finish after reaching the maximum number of epochs."""
        return self.epoch > self.walkers.max_epochs