"""`ElectrodeArray`, `ElectrodeGrid`"""
import numpy as np
from string import ascii_uppercase
from itertools import product
from collections import OrderedDict
import matplotlib.pyplot as plt

from .electrodes import Electrode, PointSource, DiskElectrode
from ..utils import PrettyPrint


class ElectrodeArray(PrettyPrint):
    """Electrode array

    A collection of :py:class:`~pulse2percept.implants.Electrode` objects.

    Parameters
    ----------
    electrodes : array-like
        Either a single :py:class:`~pulse2percept.implants.Electrode` object
        or a dict, list, or NumPy array thereof. The keys of the dict will
        serve as electrode names. Otherwise electrodes will be indexed 0..N.

        .. note::

            If you pass multiple electrodes in a dictionary, the keys of the
            dictionary will automatically be sorted. Thus the original order
            of electrodes might not be preserved.

    Examples
    --------
    Electrode array made from a single DiskElectrode:

    >>> from pulse2percept.implants import ElectrodeArray, DiskElectrode
    >>> earray = ElectrodeArray(DiskElectrode(0, 0, 0, 100))
    >>> earray.electrodes  # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
    OrderedDict([(0, DiskElectrode(r=100..., x=0..., y=0..., z=0...))])

    Electrode array made from a single DiskElectrode with name 'A1':

    >>> from pulse2percept.implants import ElectrodeArray, DiskElectrode
    >>> earray = ElectrodeArray({'A1': DiskElectrode(0, 0, 0, 100)})
    >>> earray.electrodes  # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
    OrderedDict([('A1', DiskElectrode(r=100..., x=0..., y=0..., z=0...))])

    """
    # Frozen class: User cannot add more class attributes
    __slots__ = ('_electrodes',)

    def __init__(self, electrodes):
        self.electrodes = OrderedDict()
        if isinstance(electrodes, dict):
            for name, electrode in electrodes.items():
                self.add_electrode(name, electrode)
        elif isinstance(electrodes, list):
            for electrode in electrodes:
                self.add_electrode(self.n_electrodes, electrode)
        elif isinstance(electrodes, Electrode):
            self.add_electrode(self.n_electrodes, electrodes)
        else:
            raise TypeError(("electrodes must be a list or dict, not "
                             "%s") % type(electrodes))

    @property
    def electrodes(self):
        return self._electrodes

    @electrodes.setter
    def electrodes(self, electrodes):
        self._electrodes = electrodes

    @property
    def n_electrodes(self):
        return len(self.electrodes)

    def _pprint_params(self):
        """Return dict of class attributes to pretty-print"""
        return {'electrodes': self.electrodes,
                'n_electrodes': self.n_electrodes}

    def add_electrode(self, name, electrode):
        """Add an electrode to the array

        Parameters
        ----------
        name : int|str|...
            Electrode name or index
        electrode : implants.Electrode
            An Electrode object, such as a PointSource or a DiskElectrode.
        """
        if not isinstance(electrode, Electrode):
            raise TypeError(("Electrode %s must be an Electrode object, not "
                             "%s.") % (name, type(electrode)))
        if name in self.electrodes.keys():
            raise ValueError(("Cannot add electrode: key '%s' already "
                              "exists.") % name)
        self._electrodes.update({name: electrode})

    def remove_electrode(self, name):
        """Remove an electrode from the array

        Parameter
        ----------
        name: int|str|...
            Electrode name or index
        """
        if name not in self.electrodes.keys():
            raise ValueError(("Cannot remove electrode: key '%s' does not "
                              "exist") % name)
        del self.electrodes[name]

    def plot(self, annotate=False, autoscale=True, ax=None):
        """Plot the electrode array

        Parameters
        ----------
        annotate : bool, optional
            Flag whether to label electrodes in the implant.
        autoscale : bool, optional
            Whether to adjust the x,y limits of the plot to fit the implant
        ax : matplotlib.axes._subplots.AxesSubplot, optional
            A Matplotlib axes object. If None, will either use the current axes
            (if exists) or create a new Axes object.

        Returns
        -------
        ax : ``matplotlib.axes.Axes``
            Returns the axis object of the plot
        """
        if ax is None:
            ax = plt.gca()
        if autoscale:
            ax.autoscale(True)
        ax.set_aspect('equal')
        for name, electrode in self.items():
            electrode.plot(autoscale=False, ax=ax)
            if annotate:
                ax.text(electrode.x, electrode.y, name, ha='center',
                        va='center',  color='black', size='large',
                        bbox={'boxstyle': 'square,pad=-0.2', 'ec': 'none',
                              'fc': (1, 1, 1, 0.7)},
                        zorder=11)
        ax.set_xlabel('x (microns)')
        ax.set_ylabel('y (microns)')
        return ax

    def __getitem__(self, item):
        """Return an electrode from the array

        An electrode in the array can be accessed either by its name (the
        key value in the dictionary) or by index (in the list).

        Parameters
        ----------
        item : int|string
            If `item` is an integer, returns the `item`-th electrode in the
            array. If `item` is a string, returns the electrode with string
            identifier `item`.
        """
        if isinstance(item, (list, np.ndarray)):
            # Recursive call for list items:
            return [self.__getitem__(i) for i in item]
        if isinstance(item, str):
            # A string is probably a dict key:
            try:
                return self.electrodes[item]
            except KeyError:
                return None
        try:
            # Else, try indexing in various ways:
            return self.electrodes[item]
        except (KeyError, TypeError):
            # If not a dict key, `item` might be an int index into the list:
            try:
                key = list(self.electrodes.keys())[item]
                return self.electrodes[key]
            except IndexError:
                raise StopIteration
            return None

    def __iter__(self):
        return iter(self.electrodes)

    def keys(self):
        return self.electrodes.keys()

    def values(self):
        return self.electrodes.values()

    def items(self):
        return self.electrodes.items()


def _get_alphabetic_names(n_electrodes):
    """Create alphabetic electrode names: A-Z, AA-AZ, BA-BZ, etc. """
    n_times = int(np.ceil(n_electrodes / 26))
    names = [''.join(chars) for r in range(1, n_times + 1)
             for chars in product(ascii_uppercase, repeat=r)]
    return names[:n_electrodes]


def _get_numeric_names(n_electrodes):
    """Create numeric electrode names: 1-n"""
    return [str(i) for i in range(1, n_electrodes + 1)]


class ElectrodeGrid(ElectrodeArray):
    """2D grid of electrodes

    Parameters
    ----------
    shape : (rows, cols)
        A tuple containing the number of rows x columns in the grid
    spacing : double
        Electrode-to-electrode spacing in microns.
    type : 'rect' or 'hex', optional, default: 'rect'
        Grid type ('rect': rectangular, 'hex': hexagonal).
    x, y, z : double, optional, default: (0,0,0)
        3D coordinates of the center of the grid
    rot : double, optional, default: 0rad
        Rotation of the grid in radians (positive angle: counter-clockwise
        rotation on the retinal surface)
    names: (name_rows, name_cols), each of which either 'A' or '1'
        Naming convention for rows and columns, respectively.
        If 'A', rows or columns will be labeled alphabetically: A-Z, AA-AZ,
        BA-BZ, CA-CZ, etc.
        If '1', rows or columns will be labeled numerically.
        Letters will always precede numbers in electrode names.
        For example ('1', 'A') will number rows numerically and columns
        alphabetically; first row: 'A1', 'B1', 'C1', NOT '1A', '1B', '1C'.
    etype : :py:class:`~pulse2percept.implants.Electrode`, optional
        A valid Electrode class. By default,
        :py:class:`~pulse2percept.implants.PointSource` is used.
    **kwargs :
        Any additional arguments that should be passed to the
        :py:class:`~pulse2percept.implants.Electrode` constructor, such as
        radius ``r`` for :py:class:`~pulse2percept.implants.DiskElectrode`.
        See examples below.

    Examples
    --------
    A hexagonal electrode grid with 3 rows and 4 columns, made of disk
    electrodes with 10um radius spaced 20um apart, centered at (10, 20)um, and
    located 500um away from the retinal surface, with names like this:

    .. raw:: html

        A1    A2    A3    A4
           B1    B2    B3    B4
        C1    C2    C3    C4

    >>> from pulse2percept.implants import ElectrodeGrid, DiskElectrode
    >>> ElectrodeGrid((3, 4), 20, x=10, y=20, z=500, names=('A', '1'), r=10,
    ...               type='hex', etype=DiskElectrode) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
    ElectrodeGrid(shape=(3, 4), spacing=20, type='hex')

    A rectangulr electrode grid with 2 rows and 4 columns, made of disk
    electrodes with 10um radius spaced 20um apart, centered at (10, 20)um, and
    located 500um away from the retinal surface, with names like this:

    .. raw:: html

        A1 A2 A3 A4
        B1 B2 B3 B4

    >>> from pulse2percept.implants import ElectrodeGrid, DiskElectrode
    >>> ElectrodeGrid((2, 4), 20, x=10, y=20, z=500, names=('A', '1'), r=10,
    ...               type='rect', etype=DiskElectrode) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
    ElectrodeGrid(shape=(2, 4), spacing=20, type='rect')

    There are three ways to access (e.g.) the last electrode in the grid,
    either by name (``grid['C3']``), by row/column index (``grid[2, 2]``), or
    by index into the flattened array (``grid[8]``):

    >>> from pulse2percept.implants import ElectrodeGrid
    >>> grid = ElectrodeGrid((3, 3), 20, names=('A', '1'))
    >>> grid['C3']  # doctest: +ELLIPSIS
    PointSource(x=20..., y=20..., z=0...)
    >>> grid['C3'] == grid[8] == grid[2, 2]
    True

    You can also access multiple electrodes at the same time by passing a
    list of indices/names (it's ok to mix-and-match):

    >>> from pulse2percept.implants import ElectrodeGrid, DiskElectrode
    >>> grid = ElectrodeGrid((3, 3), 20, etype=DiskElectrode, r=10)
    >>> grid[['A1', 1, (0, 2)]]  # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
    [DiskElectrode(r=10..., x=-20.0, y=-20.0, z=0...),
     DiskElectrode(r=10..., x=0.0, y=-20.0, z=0...),
     DiskElectrode(r=10..., x=20.0, y=-20.0, z=0...)]

    """
    # Frozen class: User cannot add more class attributes
    __slots__ = ('shape', 'type', 'spacing')

    def __init__(self, shape, spacing, x=0, y=0, z=0, rot=0, names=('A', '1'),
                 type='rect', orientation='horizontal', etype=PointSource,
                 **kwargs):
        if not isinstance(names, (tuple, list, np.ndarray)):
            raise TypeError("'names' must be a tuple/list of (rows, cols)")
        if not isinstance(shape, (tuple, list, np.ndarray)):
            raise TypeError("'shape' must be a tuple/list of (rows, cols)")
        if len(shape) != 2:
            raise ValueError("'shape' must have two elements: (rows, cols)")
        if np.prod(shape) <= 0:
            raise ValueError("Grid must have all non-zero rows and columns.")
        if not isinstance(type, str):
            raise TypeError("'type' must be a string, either 'rect' or 'hex'.")
        if not isinstance(orientation, str):
            raise TypeError("'orientation' must be a string, either "
                            "'horizontal' or 'veritical'.")
        if type not in ['rect', 'hex']:
            raise ValueError("'type' must be either 'rect' or 'hex'.")
        if orientation not in ['horizontal', 'vertical']:
            raise ValueError(
                "'orientation' must be either 'horizontal' or 'vertical'.")
        if not issubclass(etype, Electrode):
            raise TypeError("'etype' must be a valid Electrode object.")
        if issubclass(etype, DiskElectrode):
            if 'r' not in kwargs.keys():
                raise ValueError("A DiskElectrode needs a radius ``r``.")
        if not isinstance(names, (tuple, list, np.ndarray)):
            raise TypeError("'names' must be a tuple or list, not "
                            "%s." % type(names))
        else:
            if len(names) != 2 and len(names) != np.prod(shape):
                raise ValueError("'names' must either have two entries for "
                                 "rows/columns or %d entries, not "
                                 "%d" % (np.prod(shape), len(names)))

        self.shape = shape
        self.type = type
        self.spacing = spacing
        # Instantiate empty collection of electrodes. This dictionary will be
        # populated in a private method ``_set_egrid``:
        self.electrodes = OrderedDict()
        self._make_grid(x, y, z, rot, names, orientation, etype, **kwargs)

    def _pprint_params(self):
        """Return dict of class attributes to pretty-print"""
        params = {'shape': self.shape, 'spacing': self.spacing,
                  'type': self.type}
        return params

    def __getitem__(self, item):
        """Access electrode(s) in the grid

        Parameters
        ----------
        item : index, string, tuple, or list thereof
            An electrode in the grid can be accessed in three ways:

            *  by name, e.g. grid['A1']
            *  by index into the flattened array, e.g. grid[0]
            *  by (row, column) index into the 2D grid, e.g. grid[0, 0]

            You can also pass a list or NumPy array of the above.

        Returns
        -------
        electrode : `~pulse2percept.implants.Electrode`, list thereof, or None
            Returns the corresponding `~pulse2percept.implants.Electrode`
            object or ``None`` if index is not valid.
        """
        if isinstance(item, (list, np.ndarray)):
            # Recursive call for list items:
            return [self.__getitem__(i) for i in item]
        try:
            # Access by key into OrderedDict, e.g. grid['A1']:
            return self.electrodes[item]
        except (KeyError, TypeError):
            # Access by index into flattened array, e.g. grid[0]:
            try:
                return list(self.electrodes.values())[item]
            except (IndexError, KeyError, TypeError):
                # Access by [r, c] into 2D grid, e.g. grid[0, 3]:
                try:
                    idx = np.ravel_multi_index(item, self.shape)
                    return list(self.electrodes.values())[idx]
                except (KeyError, TypeError, ValueError):
                    # Index not found:
                    return None

    def _make_grid(self, x, y, z, rot, names, orientation, etype, **kwargs):
        """Private method to build the electrode grid"""
        n_elecs = np.prod(self.shape)
        rows, cols = self.shape

        # The user did not specify a unique naming scheme:
        if len(names) == 2 and np.prod(self.shape) != 2:
            name_rows, name_cols = names
            if not isinstance(name_rows, str):
                raise TypeError("Row name must be a string, not "
                                "%s." % type(name_rows))
            if not isinstance(name_cols, str):
                raise TypeError("Column name must be a string, not "
                                "%s." % type(name_cols))
            # Row names:
            if name_rows.isalpha():
                rws = _get_alphabetic_names(rows)
            elif name_rows.isdigit():
                rws = _get_numeric_names(rows)
            else:
                raise ValueError("Row name must be alphabetic or numeric.")
            # Column names:
            if name_cols.isalpha():
                clms = _get_alphabetic_names(cols)
            elif name_cols.isdigit():
                clms = _get_numeric_names(cols)
            else:
                raise ValueError("Column name must be alphabetic or numeric.")
            # Letters before digits:
            if name_cols.isalpha() and not name_rows.isalpha():
                names = [clms[j] + rws[i] for i in range(len(rws))
                         for j in range(len(clms))]
            else:
                names = [rws[i] + clms[j] for i in range(len(rws))
                         for j in range(len(clms))]

        if isinstance(z, (list, np.ndarray)):
            # Specify different height for every electrode in a list:
            z_arr = np.asarray(z).flatten()
            if z_arr.size != n_elecs:
                raise ValueError("If `h` is a list, it must have %d entries, "
                                 "not %d." % (n_elecs, len(z)))
        else:
            # If `z` is a scalar, choose same height for all electrodes:
            z_arr = np.ones(n_elecs, dtype=float) * z

        spc = self.spacing
        if self.type.lower() == 'rect':
            # Rectangular grid from x, y coordinates:
            # For example, cols=3 with spacing=100 should give: [-100, 0, 100]
            x_arr = (np.arange(cols) * spc - (cols / 2.0 - 0.5) * spc)
            y_arr = (np.arange(rows) * spc - (rows / 2.0 - 0.5) * spc)
            x_arr, y_arr = np.meshgrid(x_arr, y_arr, sparse=False)
        elif self.type.lower() == 'hex':
            # When orientation is horizontal, x_arr is shifting to create the
            # hex grid
            if orientation == 'horizontal':
                # Hexagonal grid from x,y coordinates:
                x_arr_lshift = ((np.arange(cols) * spc -
                                 (cols / 2.0 - 0.5) * spc - spc * 0.25))
                x_arr_rshift = ((np.arange(cols) * spc -
                                 (cols / 2.0 - 0.5) * spc + spc * 0.25))
                y_arr = (np.arange(rows) * np.sqrt(3) * spc / 2.0 -
                         (rows / 2.0 - 0.5) * spc)
                x_arr_lshift, y_arr_lshift = np.meshgrid(x_arr_lshift, y_arr,
                                                         sparse=False)
                x_arr_rshift, y_arr_rshift = np.meshgrid(x_arr_rshift, y_arr,
                                                         sparse=False)
                # Shift every other row to get an interleaved pattern:
                x_arr = []
                for row in range(rows):
                    if row % 2:
                        x_arr.append(x_arr_rshift[row])
                    else:
                        x_arr.append(x_arr_lshift[row])
                x_arr = np.array(x_arr)
                y_arr = y_arr_rshift
            # When orientation is vertical, y_arr is shifting to create the hex
            # grid
            elif orientation == 'vertical':
                # Hexagonal grid from x,y coordinates:
                x_arr = (np.arange(cols) * np.sqrt(3) * spc / 2.0 -
                         (cols / 2.0 - 0.5) * spc)
                y_arr_downshift = ((np.arange(rows) * spc -
                                    (rows / 2.0 - 0.5) * spc - spc * 0.25))
                y_arr_upshift = ((np.arange(rows) * spc -
                                  (rows / 2.0 - 0.5) * spc + spc * 0.25))
                x_arr_downshift, y_arr_downshift = np.meshgrid(x_arr,
                                                               y_arr_downshift,
                                                               sparse=False)
                x_arr_upshift, y_arr_upshift = np.meshgrid(x_arr,
                                                           y_arr_upshift,
                                                           sparse=False)
                # Shift every other column to get an interleaved pattern:
                y_arr = np.zeros(shape=(rows, cols))
                for row in range(rows):
                    for col in range(cols):
                        if col % 2:
                            y_arr[row][col] = y_arr_downshift[row][col]
                        else:
                            y_arr[row][col] = y_arr_upshift[row][col]
                x_arr = x_arr_upshift
                y_arr = np.array(y_arr)
        else:
            raise NotImplementedError

        # Rotate the grid:
        rotmat = np.array([np.cos(rot), -np.sin(rot),
                           np.sin(rot), np.cos(rot)]).reshape((2, 2))
        xy = np.matmul(rotmat, np.vstack((x_arr.flatten(), y_arr.flatten())))
        x_arr = xy[0, :]
        y_arr = xy[1, :]

        # Apply offset to make the grid centered at (x, y):
        x_arr += x
        y_arr += y

        if issubclass(etype, DiskElectrode):
            if isinstance(kwargs['r'], (list, np.ndarray)):
                # Specify different radius for every electrode in a list:
                if len(kwargs['r']) != n_elecs:
                    err_s = ("If `r` is a list, it must have %d entries, not "
                             "%d)." % (n_elecs, len(kwargs['r'])))
                    raise ValueError(err_s)
                r_arr = kwargs['r']
            else:
                # If `r` is a scalar, choose same radius for all electrodes:
                r_arr = np.ones(n_elecs, dtype=float) * kwargs['r']
            # Create a grid of DiskElectrode objects:
            for x, y, z, r, name in zip(x_arr, y_arr, z_arr, r_arr, names):
                self.add_electrode(name, DiskElectrode(x, y, z, r))
        else:
            # Pass keyword arguments to the electrode constructor:
            for x, y, z, name in zip(x_arr, y_arr, z_arr, names):
                self.add_electrode(name, etype(x, y, z, **kwargs))