from abc import ABC, abstractmethod
from functools import partial

import autograd.numpy as np
from autograd.extend import defvjp, primitive

from .frame import Frame
from .parameter import Parameter
from . import fft
from . import interpolation
from .bbox import Box, overlapped_slices


class Component(ABC):
    """A single component in a blend.

    This class acts as base for building a complex :class:`scarlet.blend.Blend`.

    Parameters
    ----------
    model_frame: `~scarlet.Frame`
        The spectral and spatial characteristics of the model.
    bbox: `~scarlet.Box`
        Hyper-spectral bounding box of this component
    parameters: list of `~scarlet.Parameter`
    kwargs: dict
        Auxiliary information attached to this component.
    """

    def __init__(self, model_frame, bbox, *parameters, **kwargs):
        self.bbox = bbox
        self.set_model_frame(model_frame)

        if hasattr(parameters, "__iter__"):
            for p in parameters:
                assert isinstance(p, Parameter)
            self._parameters = parameters
        else:
            assert isinstance(parameters, Parameter)
            self._parameters = tuple(parameters)
        self.check_parameters()

        # additional non-optimization parameters of component
        self.kwargs = kwargs

        # Properties used for indexing in the ComponentTree
        self._index = None
        self._parent = None

    @property
    def shape(self):
        """Shape of the image (Channel, Height, Width)
        """
        return self.bbox.shape

    @property
    def coord(self):
        """The coordinate in a `~scarlet.component.ComponentTree`.
        """
        if self._index is not None:
            if self._parent._index is not None:
                return tuple(self._parent.coord) + (self._index,)
            else:
                return (self._index,)

    @property
    def parameters(self):
        """The list of non-fixed parameters

        Returns
        -------
        list of parameters available for optimization
        If `parameter.fixed == True`, the parameter will not returned here.
        """
        return [p for p in self._parameters if not p.fixed]

    @abstractmethod
    def get_model(self, *parameters, frame=None):
        """Get the model for this component

        Parameters
        ----------
        parameters: tuple of optimimzation parameters

        frame: `~scarlet.frame.Frame`
            Frame to project the model into. If `frame` is `None`
            then the model contained in `bbox` is returned.

        Returns
        -------
        model: array
            (Channels, Height, Width) image of the model
        """
        pass

    def set_model_frame(self, model_frame):
        """Sets the frame for this component.

        Each component needs to know the properties of the Frame and,
        potentially, the subvolume it covers.

        Parameters
        ----------
        model_frame: `~scarlet.Frame`
            Frame of the model
        """
        self.model_frame = model_frame
        self.model_frame_slices, self.model_slices = overlapped_slices(model_frame, self.bbox)

    def check_parameters(self):
        """Check that all parameters have finite elements

        Raises
        ------
        `ArithmeticError` when non-finite elements are present
        """
        for k, p in enumerate(self._parameters):
            if not np.isfinite(p).all():
                msg = "Component {} Parameter {} is not finite:\n{}".format(self, k, p)
                raise ArithmeticError(msg)

    def model_to_frame(self, frame=None, model=None):
        """Project a model into a frame


        Parameters
        ----------
        model: array
            Image of the model to project.
            This must be the same shape as `self.bbox`.
            If `model` is `None` then `self.get_model()` is used.
        frame: `~scarlet.frame.Frame`
            The frame to project the model into.
            If `frame` is `None` then the model is projected
            into `self.model_frame`.

        Returns
        -------
        projected_model: array
            (Channels, Height, Width) image of the model
        """
        # Use the current model by default
        if model is None:
            model = self.get_model()
        # Use the full model frame by default
        if frame is None or frame == self.model_frame:
            frame = self.model_frame
            frame_slices = self.model_frame_slices
            model_slices = self.model_slices
        else:
            frame_slices, model_slices = overlapped_slices(frame, self.bbox)

        if hasattr(frame, "dtype"):
            dtype = frame.dtype
        else:
            dtype = model.dtype
        result = np.zeros(frame.shape, dtype=dtype)
        result[frame_slices] = model[model_slices]
        return result


class FactorizedComponent(Component):
    """A single component in a blend.

    Uses the non-parametric factorization sed x morphology.

    Parameters
    ----------
    model_frame: `~scarlet.Frame`
        The spectral and spatial characteristics of the full model.
    bbox: `~scarlet.Box`
        Hyper-spectral bounding box of this component.
    sed: `~scarlet.Parameter`
        1D array (channels) of the initial SED.
    morph: `~scarlet.Parameter`
        Image (Height, Width) of the initial morphology.
    shift: `~scarlet.Parameter`
        2D position for the shift of the center
    """

    def __init__(self, model_frame, bbox, sed, morph, shift=None, **kwargs):
        if shift is None:
            parameters = (sed, morph)
        else:
            parameters = (sed, morph, shift)
        super().__init__(model_frame, bbox, *parameters, **kwargs)

        # store shifting structures
        if shift is not None:
            padding = 10
            self.fft_shape = fft._get_fft_shape(morph, morph, padding=padding)
            self.shifter_y, self.shifter_x = interpolation.mk_shifter(self.fft_shape)

    @property
    def sed(self):
        """Numpy view of the component SED
        """
        return self._parameters[0]._data

    @property
    def morph(self):
        """Numpy view of the component morphology
        """
        return self._shift_morph(self.shift, self._parameters[1]._data)

    @property
    def shift(self):
        """Numpy view of the component shift
        """
        if len(self._parameters) == 3:
            return self._parameters[2]._data
        return None

    def get_model(self, *parameters, frame=None):
        """Get the model for this component.

        Parameters
        ----------
        parameters: tuple of optimimzation parameters

        frame: `~scarlet.frame.Frame`
            Frame to project the model into. If `frame` is `None`
            then the model contained in `bbox` is returned.

        Returns
        -------
        model: array
            (Channels, Height, Width) image of the model
        """
        sed, morph, shift = None, None, None

        # if params are set they are not Parameters, but autograd ArrayBoxes
        # need to access the wrapped class with _value
        for p in parameters:
            if p._value is self._parameters[0]:
                sed = p
            if p._value is self._parameters[1]:
                morph = p
            if len(self._parameters) == 3 and p._value is self._parameters[2]:
                shift = p

        if sed is None:
            sed = self.sed

        if morph is None:
            morph = self._parameters[1]._data

        if shift is None:
            shift = self.shift

        morph = self._shift_morph(shift, morph)
        model = sed[:, None, None] * morph[None, :, :]
        # project the model into frame (if necessary)
        if frame is not None:
            model = self.model_to_frame(frame, model)
        return model

    def _shift_morph(self, shift, morph):
        if shift is not None:
            X = fft.Fourier(morph)
            X_fft = X.fft(self.fft_shape, (0, 1))

            # Apply shift in Fourier
            result_fft = (
                X_fft
                * np.exp(self.shifter_y[:, None] * shift[0])
                * np.exp(self.shifter_x[None, :] * shift[1])
            )

            X = fft.Fourier.from_fft(result_fft, self.fft_shape, X.shape, [0, 1])
            return np.real(X.image)
        return morph


class FunctionComponent(FactorizedComponent):
    """A single component in a blend.

    Uses the non-parametric sed x morphology, with the morphology specified
    by a functional expression.

    Parameters
    ----------
    model_frame: `~scarlet.Frame`
        The spectral and spatial characteristics of the full model.
    bbox: `~scarlet.Box`
        Hyper-spectral bounding box of this component.
    sed: `~scarlet.Parameter`
        1D array (channels) of the initial SED.
    fparams: `~scarlet.Parameter`
        Parameters of the initial morphology.
    func: `autograd` function
        Signature: func(*fparams, y=None, x=None) -> Image (Height, Width)
    """

    def __init__(self, model_frame, bbox, sed, fparams, func):
        parameters = (sed, fparams)
        super().__init__(model_frame, bbox, *parameters, func=func)

    @property
    def morph(self):
        """Numpy view of the component morphology
        """
        try:
            return self._morph
        except AttributeError:
            # Cache morph. This is updated in get_model if fparams changes
            self._morph = self._func(*self._parameters[1])
        return self._morph

    def _func(self, *parameters):
        return self.kwargs["func"](*parameters)

    def get_model(self, *parameters, frame=None):
        """Get the model for this component.

        Parameters
        ----------
        parameters: tuple of optimimzation parameters

        frame: `~scarlet.frame.Frame`
            Frame to project the model into. If `frame` is `None`
            then the model contained in `bbox` is returned.

        Returns
        -------
        model: array
            (Channels, Height, Width) image of the model
        """
        sed, fparams = None, None

        # if params are set they are not Parameters, but autograd ArrayBoxes
        # need to access the wrapped class with _value
        for p in parameters:
            if p._value is self._parameters[0]:
                sed = p
            if p._value is self._parameters[1]:
                fparams = p

        if sed is None:
            sed = self.sed
        if fparams is None:
            morph = self.morph
        else:
            morph = self._func(*fparams)
            self._morph = morph._value

        model = sed[:, None, None] * morph[None, :, :]
        if frame is not None:
            model = self.model_to_frame(frame, model)
        return model


class CubeComponent(Component):
    """A single component in a blend.

    Uses full cube parameterization.

    Parameters
    ----------
    frame: `~scarlet.Frame`
        The spectral and spatial characteristics of this component.
    cube: `~scarlet.Parameter`
        3D array (C, Height, Width) of the initial data cube.
    bbox: `~scarlet.Box`
        Hyper-spectral bounding box of this component.
    """

    def __init__(self, frame, bbox, cube):
        parameters = (cube,)
        super().__init__(frame, bbox, *parameters)

    @property
    def cube(self):
        return self._parameters[0]._data

    def get_model(self, *parameters, frame=None):
        cube = None
        for p in parameters:
            if p._value is self._parameters[0]:
                cube = p

        if cube is None:
            cube = self.cube
        if frame is not None:
            cube = self.model_to_frame(frame, cube)
        return cube


@primitive
def _add_models(*models, full_model, slices):
    """Insert the models into the full model

    `slices` is a tuple `(full_model_slice, model_slices)` used
    to insert a model into the full_model in the region where the
    two models overlap.
    """
    for i in range(len(models)):
        if hasattr(models[i], "_value"):
            full_model[slices[i][0]] += models[i][slices[i][1]]._value
        else:
            full_model[slices[i][0]] += models[i][slices[i][1]]
    return full_model


def _grad_add_models(upstream_grad, *models, full_model, slices, index):
    """Gradient for a single model

    The full model is just the sum of the models,
    so the gradient is 1 for each model,
    we just have to slice it appropriately.
    """
    model = models[index]
    full_model_slices = slices[index][0]
    model_slices = slices[index][1]

    def result(upstream_grad):
        _result = np.zeros(model.shape, dtype=model.dtype)
        _result[model_slices] = upstream_grad[full_model_slices]
        return _result
    return result


class ComponentTree:
    """Base class for hierarchical collections of Components.
    """

    def __init__(self, components):
        """Constructor

        Group a list of `~scarlet.component.Component`s in a hierarchy.

        Parameters
        ----------
        components: list of `~scarlet.component.Component` or `~scarlet.component.ComponentTree`
        """
        if not hasattr(components, "__iter__"):
            components = (components,)

        # check type and set coords of subordinate nodes in tree
        self._tree = tuple(components)
        self._index = None
        self._parent = None
        for i, c in enumerate(self._tree):
            if not isinstance(c, ComponentTree) and not isinstance(c, Component):
                raise NotImplementedError(
                    "argument needs to be list of Components or ComponentTrees"
                )
            assert c.model_frame is self.model_frame, "All components need to share the same model Frame"
            c._index = i
            c._parent = self

    @property
    def bbox(self):
        """Union of all the component `~scarlet.bbox.Box`es
        """
        try:
            return self._bbox
        except AttributeError:
            # Make the bbox of the tree the union of the component bboxes
            box = self.components[0].bbox
            self._bbox = Box(box.shape, box.origin)
            for component in self.components:
                self._bbox |= component.bbox
        return self._bbox

    @property
    def shape(self):
        """Shape of this model
        """
        return self._bbox.shape

    @property
    def components(self):
        """Flattened tuple of all components in the tree.

        CAUTION: Each component in a tree can only be a leaf of a single node.
        While one can construct trees that hold the same component multiple
        times, this method will only return that component at its first
        encountered location
        """
        try:
            return self._components
        except AttributeError:
            self._components = self._tree_to_components()
            return self._components

    def _tree_to_components(self):
        components = []
        for c in self._tree:
            if isinstance(c, ComponentTree):
                _c = c.components
            else:
                _c = [c]
            # check uniqueness
            for __c in _c:
                if __c not in components:
                    components.append(__c)
        return tuple(components)

    @property
    def n_components(self):
        """Number of components.
        """
        return len(self.components)

    @property
    def K(self):
        """Number of components.
        """
        return self.n_components

    @property
    def model_frame(self):
        """Frame of the components.
        """
        return self._tree[0].model_frame

    @property
    def sources(self):
        """Initial list of components or sources that generate the tree.

        This can be different than `self.components` because sources can
        have multiple components.

        Returns
        -------
        The arguments of `__init__`
        """
        return self._tree

    @property
    def n_sources(self):
        """Number of initial sources or components.

        This can be different than `self.n_components` because sources can
        have multiple components.

        Returns
        -------
        int: number of initial sources
        """
        return len(self._tree)

    @property
    def coord(self):
        """The coordinate in tree.

        The coordinate can be used to traverse the tree and for `__getitem__`.
        """
        if self._index is not None:
            if self._parent._index is not None:
                return tuple(self._parent.coord) + (self._index,)
            else:
                return (self._index,)

    @property
    def parameters(self):
        """The list of non-fixed parameters

        Returns
        -------
        list of parameters available for optimization
        If `parameter.fixed == True`, the parameter will not returned here.
        """
        pars = []
        for c in self.components:
            pars += c.parameters
        return pars

    def check_parameters(self):
        """Check that all parameters have finite elements

        Raises
        ------
        `ArithmeticError` when non-finite elements are present
        """
        for c in self.components:
            c.check_parameters()

    def get_model(self, *params, frame=None):
        """Get the model of this component tree

        Parameters
        ----------
        params: tuple of optimization parameters

        Returns
        -------
        model: array
            (Bands, Height, Width) data cube
        """
        if frame is None:
            frame = Frame(
                self.bbox, dtype=self.model_frame.dtype,
                psfs=self.model_frame.psf, channels=self.model_frame.channels
            )
        # If this is the model frame then the slices are already cached
        if frame == self.model_frame:
            use_cached = True
        else:
            use_cached = False

        full_model = np.zeros(frame.shape, dtype=frame.dtype)

        models = []
        slices = []
        i = 0

        for k, c in enumerate(self.components):
            if len(params):
                j = len(c.parameters)
                p = params[i: i + j]
                i += j
                model = c.get_model(*p)
            else:
                model = c.get_model()

            models.append(model)

            if use_cached:
                slices.append((c.model_frame_slices, c.model_slices))
            else:
                # Get the slices needed to insert the model
                slices.append(overlapped_slices(frame, c.bbox))

        # We have to declare the function that inserts sources
        # into the blend with autograd.
        # This has to be done each time we fit a blend,
        # since the number of components => the number of arguments,
        # which must be linked to the autograd primitive function.
        defvjp(_add_models, *([partial(_grad_add_models, index=k) for k in range(len(self.components))]))

        full_model = _add_models(*models, full_model=full_model, slices=slices)

        return full_model

    def set_model_frame(self, model_frame):
        """Set the frame for all components in the tree

        see `~scarlet.Component.set_model_frame` for details.

        Parameters
        ----------
        frame: `~scarlet.Frame`
            Frame to adopt for this component
        """
        for c in self.components:
            c.set_model_frame(model_frame)

    def __iadd__(self, c):
        """Add another component or tree.

        Parameters
        ----------
        c: `~scarlet.component.Component` or `~scarlet.component.ComponentTree`
        """
        c_index = self.n_sources
        if isinstance(c, ComponentTree):
            self._tree = self._tree + c._tree
        elif isinstance(c, Component):
            self._tree = self._tree + (c,)
        else:
            raise NotImplementedError("argument needs to be Component or ComponentTree")
        c._index = c_index
        c._parent = self
        self._components = self._tree_to_components()
        return self

    def __getitem__(self, coord):
        """Access node in the tree.

        Parameters
        ----------
        coords: int or tuple of ints

        Returns
        -------
        `~scarlet.component.Component` or `~scarlet.component.ComponentTree`
        """
        if isinstance(coord, (tuple, list)):
            if len(coord) > 1:
                return self._tree[coord[0]].__getitem__(coord[1:])
            else:
                return self._tree[coord[0]]
        elif isinstance(coord, int):
            return self._tree[coord]
        else:
            raise NotImplementedError("coord needs to be index or list of indices")

    def __getstate__(self):
        # needed for pickling to understand what to save
        return (self._tree,)

    def __setstate__(self, state):
        self._tree = state[0]
        self._tree_to_components()

    def model_to_frame(self, frame=None, model=None):
        """Project a model into a frame


        Parameters
        ----------
        model: array
            Image of the model to project.
            This must be the same shape as `self.bbox`.
            If `model` is `None` then `self.get_model()` is used.
        frame: `~scarlet.frame.Frame`
            The frame to project the model into.
            If `frame` is `None` then the model is projected
            into `self.model_frame`.

        Returns
        -------
        projected_model: array
            (Channels, Height, Width) image of the model
        """
        # Use the current model by default
        if model is None:
            model = self.get_model()
        # Use the full model frame by default
        if frame is None:
            frame = self.model_frame

        if hasattr(frame, "dtype"):
            dtype = frame.dtype
        else:
            dtype = self.model_frame.dtype

        frame_slices, model_slices = overlapped_slices(frame, self.bbox)

        result = np.zeros(frame.shape, dtype=dtype)
        result[frame_slices] = model[model_slices]
        return result