# -*- coding: utf-8 -*-
from typing import Any
from typing import Optional
from typing import Tuple
from typing import Union

import numpy as np

from .types import AxisType
from .types import GridAxisType
from .types import is_basic_indexing
from .utils.formatting import make_identifiable


def _log_axis(
    min_: Union[float, np.ndarray], max_: Union[float, np.ndarray], points: int
) -> np.ndarray:
    """Generates logarithmically spaced axis/array.

    Returns always an array with the shape (points,) + np.shape(min/max) and
    a floating dtype.
    """
    if np.issubdtype(
        # min_
        type(min_) if not isinstance(min_, np.ndarray) else min_.dtype,
        np.floating,
    ) or np.issubdtype(
        # max_
        type(max_) if not isinstance(max_, np.ndarray) else max_.dtype,
        np.floating,
    ):
        dtype = None
    else:
        dtype = float

    return np.logspace(np.log10(min_), np.log10(max_), points, dtype=dtype)


def _lin_axis(
    min_: Union[float, np.ndarray], max_: Union[float, np.ndarray], points: int
) -> np.ndarray:
    """Generates linearly spaced axis/array.

    Returns always an array with the shape (points,) + np.shape(min/max) and
    a floating dtype.
    """
    if np.issubdtype(
        # min_
        type(min_) if not isinstance(min_, np.ndarray) else min_.dtype,
        np.floating,
    ) or np.issubdtype(
        # max_
        type(max_) if not isinstance(max_, np.ndarray) else max_.dtype,
        np.floating,
    ):
        dtype = None
    else:
        dtype = float

    return np.linspace(min_, max_, points, dtype=dtype)


class Axis:
    def __init__(
        self,
        data: np.ndarray,
        *,
        name: str = "unnamed",
        label: str = "",
        unit: str = "",
    ):
        if not isinstance(data, np.ndarray):
            data = np.asanyarray(data)

        self._data = data if data.ndim > 0 else data[np.newaxis]

        name = make_identifiable(name)
        self._name = name if name else "unnamed"
        self._label = label
        self._unit = unit

    def __repr__(self) -> str:
        repr_ = f"{self.__class__.__name__}("
        repr_ += f"name='{self.name}', "
        repr_ += f"label='{self.label}', "
        repr_ += f"unit='{self.unit}', "
        repr_ += f"axis_dim={self.axis_dim}, "
        repr_ += f"len={len(self)}"
        repr_ += ")"

        return repr_

    def __len__(self) -> int:
        return len(self._data)

    def __iter__(self) -> "Axis":
        for d in self._data:
            yield self.__class__(
                d[np.newaxis], name=self.name, label=self.label, unit=self.unit,
            )

    def __getitem__(
        self, key: Union[int, slice, Tuple[Union[int, slice]]]
    ) -> "Axis":
        if not is_basic_indexing(key):
            raise IndexError("Only basic indexing is supported!")

        key = np.index_exp[key]
        requires_new_axis = False

        # > determine if axis extension is required
        # 1st index (temporal slicing) not hidden if ndim == axis_dim + 1
        # or alternatively -> check len of the axis -> number of temporal slices
        if len(self) != 1:
            # revert dimensionality reduction
            if isinstance(key[0], int):
                requires_new_axis = True
        else:
            requires_new_axis = True

        data = self.data[key]

        if requires_new_axis:
            data = data[np.newaxis]

        return self.__class__(
            data, name=self.name, label=self.label, unit=self.unit,
        )

    def __setitem__(
        self, key: Union[int, slice, Tuple[Union[int, slice]]], value: Any
    ) -> None:
        self.data[key] = value

    def __array__(self, dtype: Optional[np.dtype] = None) -> np.ndarray:
        data = self._data.astype(dtype) if dtype else self._data
        return np.squeeze(data, axis=0) if len(self) == 1 else data

    @property
    def data(self) -> np.ndarray:
        return np.asanyarray(self)

    @data.setter
    def data(self, value: Union[np.ndarray, Any]) -> None:
        new = np.broadcast_to(value, self.shape, subok=True)
        if len(self) == 1:
            self._data = np.array(new, subok=True)[np.newaxis]
        else:
            self._data = np.array(new, subok=True)

    @property
    def axis_dim(self):
        return self._data.ndim - 1

    @property
    def shape(self):
        return self._data.shape[1:] if len(self) == 1 else self._data.shape

    @property
    def dtype(self):
        return self._data.dtype

    @property
    def ndim(self):
        return (self._data.ndim - 1) if len(self) == 1 else self._data.ndim

    @property
    def name(self):
        return self._name

    @name.setter
    def name(self, value):
        parsed_value = make_identifiable(str(value))
        if not parsed_value:
            raise ValueError(
                "Invalid name provided! Has to be able to be valid code"
            )
        self._name = parsed_value

    @property
    def label(self):
        return self._label

    @label.setter
    def label(self, value):
        value = str(value)
        self._label = value

    @property
    def unit(self):
        return self._unit

    @unit.setter
    def unit(self, value):
        value = str(value)
        self._unit = value

    def equivalent(self, other: Union[Any, AxisType]) -> bool:
        if not isinstance(other, self.__class__):
            return False

        if self.axis_dim != other.axis_dim:
            return False

        if self.name != other.name:
            return False

        if self.label != other.label:
            return False

        if self.unit != other.unit:
            return False

        return True

    def append(self, other: "Axis") -> "Axis":
        if not isinstance(other, self.__class__):
            raise TypeError(f"Can not append '{other}' to '{self}'")

        if not self.equivalent(other):
            raise ValueError(
                f"Mismatch in attributes between '{self}' and '{other}'"
            )

        selfdata = (
            self.data[np.newaxis] if self.ndim == self.axis_dim else self.data
        )

        otherdata = (
            other.data[np.newaxis]
            if other.ndim == other.axis_dim
            else other.data
        )

        self._data = np.append(selfdata, otherdata, axis=0)


_ignored_if_data = object()


class GridAxis(Axis):
    _supported_axis_types: Tuple[str, ...] = (
        "lin",
        "linear",
        "log",
        "logarithmic",
        "custom",
    )

    def __init__(
        self,
        data: np.ndarray,
        *,
        axis_type: str = "linear",
        name: str = "unnamed",
        label: str = "",
        unit: str = "",
    ) -> None:
        if axis_type not in self._supported_axis_types:
            raise ValueError(
                f"'{axis_type}' is not supported for axis_type! "
                + f"It has to by one of {self._supported_axis_types}"
            )

        super().__init__(data, name=name, label=label, unit=unit)
        self._axis_type = axis_type

    def __iter__(self) -> "GridAxis":
        for d in self._data:
            yield self.__class__(
                d[np.newaxis],
                name=self.name,
                label=self.label,
                unit=self.unit,
                axis_type=self.axis_type,
            )

    def __getitem__(
        self, key: Union[int, slice, Tuple[Union[int, slice]]]
    ) -> "GridAxis":
        if not is_basic_indexing(key):
            raise IndexError("Only basic indexing is supported!")

        key = np.index_exp[key]
        requires_new_axis = False

        # first index corresponds to temporal slicing if ndim == axis_dim + 1
        # or alternatively -> check len of the axis -> number of temporal slices
        if len(self) != 1:
            # revert dimensionality reduction
            if isinstance(key[0], int):
                requires_new_axis = True
        else:
            requires_new_axis = True

        return self.__class__(
            self.data[key][np.newaxis] if requires_new_axis else self.data[key],
            name=self.name,
            label=self.label,
            unit=self.unit,
            axis_type=self.axis_type,
        )

    def __repr__(self) -> str:
        repr_ = f"{self.__class__.__name__}("
        repr_ += f"name='{self.name}', "
        repr_ += f"label='{self.label}', "
        repr_ += f"unit='{self.unit}', "
        repr_ += f"axis_type={self.axis_type}, "
        repr_ += f"axis_dim={self.axis_dim}, "
        repr_ += f"len={len(self)}"
        repr_ += ")"

        return repr_

    @property
    def axis_type(self) -> str:
        return self._axis_type

    @axis_type.setter
    def axis_type(self, value: str) -> None:
        value = str(value)
        if value not in self._supported_axis_types:
            raise ValueError(
                f"'{value}' is not supported for axis_type! "
                + f"It has to by one of {self._supported_axis_types}"
            )
        self._axis_type = value

    @classmethod
    def from_limits(
        cls,
        min_value: Union[np.ndarray, int, float],
        max_value: Union[np.ndarray, int, float],
        cells: int,
        *,
        axis_type: str = "linear",
        name: str = "unnamed",
        label: str = "",
        unit: str = "",
    ) -> "GridAxis":
        if axis_type in ("lin", "linear"):
            axis: np.ndarray = _lin_axis(min_value, max_value, cells)
        elif axis_type in ("log", "logarithmic"):
            axis: np.ndarray = _log_axis(min_value, max_value, cells)
        else:
            raise ValueError(
                "Invalid axis type provided. "
                + "Only 'lin', 'linear', 'log', and 'logarithmic' "
                + "are supported!"
            )

        if axis.ndim == 1:
            axis = axis[np.newaxis]

        axis = cls(axis, name=name, label=label, unit=unit)
        axis._axis_type = axis_type
        return axis

    def equivalent(self, other: Union[Any, GridAxisType]) -> bool:
        if not super().equivalent(other):
            return False

        if self.axis_type != other.axis_type:
            return False

        return True