"""Contains custom skorch Dataset and CVSplit."""
import warnings
from functools import partial
from numbers import Number

import numpy as np
from scipy import sparse
from sklearn.model_selection import ShuffleSplit
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import check_cv
import torch
import torch.utils.data

from skorch.utils import flatten
from skorch.utils import is_pandas_ndframe
from skorch.utils import check_indexing
from skorch.utils import multi_indexing
from skorch.utils import to_numpy


ERROR_MSG_1_ITEM = (
    "You are using a non-skorch dataset that returns 1 value. "
    "Remember that for skorch, Dataset.__getitem__ must return exactly "
    "2 values, X and y (more info: "
    "https://skorch.readthedocs.io/en/stable/user/dataset.html).")


ERROR_MSG_MORE_THAN_2_ITEMS = (
    "You are using a non-skorch dataset that returns {} values. "
    "Remember that for skorch, Dataset.__getitem__ must return exactly "
    "2 values, X and y (more info: "
    "https://skorch.readthedocs.io/en/stable/user/dataset.html).")


def _apply_to_data(data, func, unpack_dict=False):
    """Apply a function to data, trying to unpack different data
    types.

    """
    apply_ = partial(_apply_to_data, func=func, unpack_dict=unpack_dict)

    if isinstance(data, dict):
        if unpack_dict:
            return [apply_(v) for v in data.values()]
        return {k: apply_(v) for k, v in data.items()}

    if isinstance(data, (list, tuple)):
        try:
            # e.g.list/tuple of arrays
            return [apply_(x) for x in data]
        except TypeError:
            return func(data)

    return func(data)


def _is_sparse(x):
    try:
        return sparse.issparse(x) or x.is_sparse
    except AttributeError:
        return False


def _len(x):
    if _is_sparse(x):
        return x.shape[0]
    return len(x)


def get_len(data):
    lens = [_apply_to_data(data, _len, unpack_dict=True)]
    lens = list(flatten(lens))
    len_set = set(lens)
    if len(len_set) != 1:
        raise ValueError("Dataset does not have consistent lengths.")
    return list(len_set)[0]


def uses_placeholder_y(ds):
    """If ``ds`` is a ``skorch.dataset.Dataset`` or a
    ``skorch.dataset.Dataset`` nested inside a
    ``torch.utils.data.Subset`` and uses
    y as a placeholder, return ``True``."""

    if isinstance(ds, torch.utils.data.Subset):
        return uses_placeholder_y(ds.dataset)
    return isinstance(ds, Dataset) and hasattr(ds, "y") and ds.y is None


def unpack_data(data):
    """Unpack data returned by the net's iterator into a 2-tuple.

    If the wrong number of items is returned, raise a helpful error
    message.

    """
    # Note: This function cannot detect it when a user only returns 1
    # item that is exactly of length 2 (e.g. because the batch size is
    # 2). In that case, the item will be erroneously split into X and
    # y.
    try:
        X, y = data
        return X, y
    except ValueError:
        # if a 1-tuple/list or something else like a torch tensor
        if not isinstance(data, (tuple, list)) or len(data) < 2:
            raise ValueError(ERROR_MSG_1_ITEM)
        raise ValueError(ERROR_MSG_MORE_THAN_2_ITEMS.format(len(data)))


class Dataset(torch.utils.data.Dataset):
    # pylint: disable=anomalous-backslash-in-string
    """General dataset wrapper that can be used in conjunction with
    PyTorch :class:`~torch.utils.data.DataLoader`.

    The dataset will always yield a tuple of two values, the first
    from the data (``X``) and the second from the target (``y``).
    However, the target is allowed to be ``None``. In that case,
    :class:`.Dataset` will currently return a dummy tensor, since
    :class:`~torch.utils.data.DataLoader` does not work with
    ``None``\s.

    :class:`.Dataset` currently works with the following data types:

    * numpy ``array``\s
    * PyTorch :class:`~torch.Tensor`\s
    * scipy sparse CSR matrices
    * pandas NDFrame
    * a dictionary of the former three
    * a list/tuple of the former three

    Parameters
    ----------
    X : see above
      Everything pertaining to the input data.

    y : see above or None (default=None)
      Everything pertaining to the target, if there is anything.

    length : int or None (default=None)
      If not ``None``, determines the length (``len``) of the data.
      Should usually be left at ``None``, in which case the length is
      determined by the data itself.

    """
    def __init__(
            self,
            X,
            y=None,
            length=None,
    ):
        self.X = X
        self.y = y

        self.X_indexing = check_indexing(X)
        self.y_indexing = check_indexing(y)
        self.X_is_ndframe = is_pandas_ndframe(X)

        if length is not None:
            self._len = length
            return

        # pylint: disable=invalid-name
        len_X = get_len(X)
        if y is not None:
            len_y = get_len(y)
            if len_y != len_X:
                raise ValueError("X and y have inconsistent lengths.")
        self._len = len_X

    def __len__(self):
        return self._len

    def transform(self, X, y):
        # pylint: disable=anomalous-backslash-in-string
        """Additional transformations on ``X`` and ``y``.

        By default, they are cast to PyTorch :class:`~torch.Tensor`\s.
        Override this if you want a different behavior.

        Note: If you use this in conjuction with PyTorch
        :class:`~torch.utils.data.DataLoader`, the latter will call
        the dataset for each row separately, which means that the
        incoming ``X`` and ``y`` each are single rows.

        """
        # pytorch DataLoader cannot deal with None so we use 0 as a
        # placeholder value. We only return a Tensor with one value
        # (as opposed to ``batchsz`` values) since the pytorch
        # DataLoader calls __getitem__ for each row in the batch
        # anyway, which results in a dummy ``y`` value for each row in
        # the batch.
        y = torch.Tensor([0]) if y is None else y

        # pytorch cannot convert sparse matrices, for now just make it
        # dense; squeeze because X[i].shape is (1, n) for csr matrices
        if sparse.issparse(X):
            X = X.toarray().squeeze(0)
        return X, y

    def __getitem__(self, i):
        X, y = self.X, self.y
        if self.X_is_ndframe:
            X = {k: X[k].values.reshape(-1, 1) for k in X}

        Xi = multi_indexing(X, i, self.X_indexing)
        yi = multi_indexing(y, i, self.y_indexing)
        return self.transform(Xi, yi)


class CVSplit:
    """Class that performs the internal train/valid split on a dataset.

    The ``cv`` argument here works similarly to the regular sklearn ``cv``
    parameter in, e.g., ``GridSearchCV``. However, instead of cycling
    through all splits, only one fixed split (the first one) is
    used. To get a full cycle through the splits, don't use
    ``NeuralNet``'s internal validation but instead the corresponding
    sklearn functions (e.g. ``cross_val_score``).

    We additionally support a float, similar to sklearn's
    ``train_test_split``.

    Parameters
    ----------
    cv : int, float, cross-validation generator or an iterable, optional
      (Refer sklearn's User Guide for cross_validation for the various
      cross-validation strategies that can be used here.)

      Determines the cross-validation splitting strategy.
      Possible inputs for cv are:

        - None, to use the default 3-fold cross validation,
        - integer, to specify the number of folds in a ``(Stratified)KFold``,
        - float, to represent the proportion of the dataset to include
          in the validation split.
        - An object to be used as a cross-validation generator.
        - An iterable yielding train, validation splits.

    stratified : bool (default=False)
      Whether the split should be stratified. Only works if ``y`` is
      either binary or multiclass classification.

    random_state : int, RandomState instance, or None (default=None)
      Control the random state in case that ``(Stratified)ShuffleSplit``
      is used (which is when a float is passed to ``cv``). For more
      information, look at the sklearn documentation of
      ``(Stratified)ShuffleSplit``.

    """
    def __init__(
            self,
            cv=5,
            stratified=False,
            random_state=None,
    ):
        self.stratified = stratified
        self.random_state = random_state

        if isinstance(cv, Number) and (cv <= 0):
            raise ValueError("Numbers less than 0 are not allowed for cv "
                             "but CVSplit got {}".format(cv))

        if not self._is_float(cv) and random_state is not None:
            # TODO: raise a ValueError instead of a warning
            warnings.warn(
                "Setting a random_state has no effect since cv is not a float. "
                "This will raise an error in a future. You should leave "
                "random_state to its default (None), or set cv to a float value.",
                FutureWarning
            )

        self.cv = cv

    def _is_stratified(self, cv):
        return isinstance(cv, (StratifiedKFold, StratifiedShuffleSplit))

    def _is_float(self, x):
        if not isinstance(x, Number):
            return False
        return not float(x).is_integer()

    def _check_cv_float(self):
        cv_cls = StratifiedShuffleSplit if self.stratified else ShuffleSplit
        return cv_cls(test_size=self.cv, random_state=self.random_state)

    def _check_cv_non_float(self, y):
        return check_cv(
            self.cv,
            y=y,
            classifier=self.stratified,
        )

    def check_cv(self, y):
        """Resolve which cross validation strategy is used."""
        y_arr = None
        if self.stratified:
            # Try to convert y to numpy for sklearn's check_cv; if conversion
            # doesn't work, still try.
            try:
                y_arr = to_numpy(y)
            except (AttributeError, TypeError):
                y_arr = y

        if self._is_float(self.cv):
            return self._check_cv_float()
        return self._check_cv_non_float(y_arr)

    def _is_regular(self, x):
        return (x is None) or isinstance(x, np.ndarray) or is_pandas_ndframe(x)

    def __call__(self, dataset, y=None, groups=None):
        bad_y_error = ValueError(
            "Stratified CV requires explicitly passing a suitable y.")
        if (y is None) and self.stratified:
            raise bad_y_error

        cv = self.check_cv(y)
        if self.stratified and not self._is_stratified(cv):
            raise bad_y_error

        # pylint: disable=invalid-name
        len_dataset = get_len(dataset)
        if y is not None:
            len_y = get_len(y)
            if len_dataset != len_y:
                raise ValueError("Cannot perform a CV split if dataset and y "
                                 "have different lengths.")

        args = (np.arange(len_dataset),)
        if self._is_stratified(cv):
            args = args + (to_numpy(y),)

        idx_train, idx_valid = next(iter(cv.split(*args, groups=groups)))
        dataset_train = torch.utils.data.Subset(dataset, idx_train)
        dataset_valid = torch.utils.data.Subset(dataset, idx_valid)
        return dataset_train, dataset_valid

    def __repr__(self):
        # pylint: disable=useless-super-delegation
        return super(CVSplit, self).__repr__()