"""Automated rejection and repair of trials in M/EEG."""

# Authors: Mainak Jas <mainak.jas@telecom-paristech.fr>
#          Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#          Denis A. Engemann <denis.engemann@gmail.com>

import os.path as op
from functools import partial

import numpy as np
from scipy.stats.distributions import uniform

from joblib import Parallel, delayed

import mne
from mne import pick_types
from mne.externals.h5io import read_hdf5, write_hdf5
from mne.viz import plot_epochs as plot_mne_epochs

from sklearn.base import BaseEstimator
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import KFold, StratifiedShuffleSplit
from sklearn.model_selection import cross_val_score

from .utils import (_clean_by_interp, interpolate_bads, _get_epochs_type,
                    _pbar, _handle_picks, _check_data, _compute_dots,
                    _get_picks_by_type, _pprint)
from .bayesopt import expected_improvement, bayes_opt

_INIT_PARAMS = ('consensus', 'n_interpolate', 'picks',
                'verbose', 'n_jobs', 'cv', 'random_state',
                'thresh_method')

_FIT_PARAMS = ('threshes_', 'n_interpolate_', 'consensus_',
               'dots', 'picks_', 'loss_')


def _slicemean(obj, this_slice, axis):
    mean = np.nan
    if len(obj[this_slice]) > 0:
        mean = np.mean(obj[this_slice], axis=axis)
    return mean


def validation_curve(epochs, y=None, param_name="thresh", param_range=None,
                     cv=None, return_param_range=False, n_jobs=1):
    """Validation curve on epochs for global autoreject.

    Parameters
    ----------
    epochs : instance of mne.Epochs.
        The epochs.
    y : array | None
        The labels.
    param_name : str
        Name of the parameter that will be varied.
        Defaults to 'thresh'.
    param_range : array | None
        The values of the parameter that will be evaluated.
        If None, 15 values between the min and the max threshold
        will be tested.
    cv : int, cross-validation generator or an iterable, optional
        Determines the cross-validation strategy.
    return_param_range : bool
        If True the used param_range is returned.
        Defaults to False.
    n_jobs : int
        The number of thresholds to compute in parallel.

    Returns
    -------
    train_scores : array
        The scores in the training set
    test_scores : array
        The scores in the test set
    param_range : array
        The thresholds used to build the validation curve.
        Only returned if `return_param_range` is True.
    """
    from sklearn.model_selection import validation_curve
    estimator = _GlobalAutoReject()

    BaseEpochs = _get_epochs_type()
    if not isinstance(epochs, BaseEpochs):
        raise ValueError('Only accepts MNE epochs objects.')

    data_picks = _handle_picks(info=epochs.info, picks=None)
    X = epochs.get_data()[:, data_picks, :]
    n_epochs, n_channels, n_times = X.shape

    if param_range is None:
        ptps = np.ptp(X, axis=2)
        param_range = np.linspace(ptps.min(), ptps.max(), 15)

    estimator.n_channels = n_channels
    estimator.n_times = n_times

    train_scores, test_scores = \
        validation_curve(estimator, X.reshape(n_epochs, -1), y=y,
                         param_name="thresh", param_range=param_range,
                         cv=cv, n_jobs=n_jobs, verbose=0)

    out = (train_scores, test_scores)
    if return_param_range:
        out += (param_range,)

    return out


def read_auto_reject(fname):
    """Read AutoReject object.

    Parameters
    ----------
    fname : str
        The filename where the AutoReject object is saved.

    Returns
    -------
    ar : instance of autoreject.AutoReject
    """
    state = read_hdf5(fname, title='autoreject')
    init_kwargs = {param: state[param] for param in _INIT_PARAMS}
    if isinstance(init_kwargs['verbose'], int):
        init_kwargs['verbose'] = bool(init_kwargs['verbose'])
    ar = AutoReject(**init_kwargs)
    ar.__setstate__(state)
    return ar


class BaseAutoReject(BaseEstimator):
    """Base class for rejection."""

    def score(self, X, y=None):
        """Score it."""
        if hasattr(self, 'n_channels'):
            X = X.reshape(-1, self.n_channels, self.n_times)
        if np.any(np.isnan(self.mean_)):
            return -np.inf
        else:
            return -np.sqrt(np.mean((np.median(X, axis=0) - self.mean_) ** 2))


class _GlobalAutoReject(BaseAutoReject):
    """Class to compute global rejection thresholds.

    Parameters
    ----------
    n_channels : int | None
        The number of channels in the epochs. Defaults to None.
    n_times : int | None
        The number of time points in the epochs. Defaults to None.
    thresh : float
        Boilerplate API. The rejection threshold.
    """

    def __init__(self, n_channels=None, n_times=None, thresh=40e-6):
        """Init it."""
        self.thresh = thresh
        self.n_channels = n_channels
        self.n_times = n_times

    def fit(self, X, y=None):
        """Fit it."""
        if self.n_channels is None or self.n_times is None:
            raise ValueError('Cannot fit without knowing n_channels'
                             ' and n_times')
        X = X.reshape(-1, self.n_channels, self.n_times)
        deltas = np.array([np.ptp(d, axis=1) for d in X])
        epoch_deltas = deltas.max(axis=1)
        keep = epoch_deltas <= self.thresh
        self.mean_ = _slicemean(X, keep, axis=0)
        return self


def get_rejection_threshold(epochs, decim=1, random_state=None,
                            ch_types=None, cv=5, verbose=True):
    """Compute global rejection thresholds.

    Parameters
    ----------
    epochs : mne.Epochs object
        The epochs from which to estimate the epochs dictionary
    decim : int
        The decimation factor: Increment for selecting every nth time slice.
    random_state : int seed, RandomState instance, or None (default)
        The seed of the pseudo random number generator to use.
    ch_types : str | list of str | None
        The channel types for which to find the rejection dictionary.
        e.g., ['mag', 'grad']. If None, the rejection dictionary
        will have keys ['mag', 'grad', 'eeg', 'eog'].
    cv : int
        The number of folds used. Defaults to 5.
    verbose : bool
        If False, suppress all output messages.

    Returns
    -------
    reject : dict
        The rejection dictionary with keys as specified by ch_types.

    Note
    ----
    Sensors marked as bad by user will be excluded when estimating the
    rejection dictionary.
    """
    reject = dict()

    if ch_types is not None and not isinstance(ch_types, (list, str)):
        raise ValueError('ch_types must be of type None, list,'
                         'or str. Got %s' % type(ch_types))

    if ch_types is None:
        ch_types = ['mag', 'grad', 'eeg', 'eog']
    elif isinstance(ch_types, str):
        ch_types = [ch_types]

    if decim > 1:
        epochs = epochs.copy()
        epochs.decimate(decim=decim)

    cv = KFold(n_splits=cv, random_state=random_state)
    for ch_type in ch_types:
        if ch_type not in epochs:
            continue

        if ch_type == 'mag':
            picks = pick_types(epochs.info, meg='mag', eeg=False)
        elif ch_type == 'eeg':
            picks = pick_types(epochs.info, meg=False, eeg=True)
        elif ch_type == 'eog':
            picks = pick_types(epochs.info, meg=False, eog=True)
        elif ch_type == 'grad':
            picks = pick_types(epochs.info, meg='grad', eeg=False)

        X = epochs.get_data()[:, picks, :]
        n_epochs, n_channels, n_times = X.shape
        deltas = np.array([np.ptp(d, axis=1) for d in X])
        all_threshes = np.sort(deltas.max(axis=1))

        if verbose:
            print('Estimating rejection dictionary for %s' % ch_type)
        cache = dict()
        est = _GlobalAutoReject(n_channels=n_channels, n_times=n_times)

        def func(thresh):
            idx = np.where(thresh - all_threshes >= 0)[0][-1]
            thresh = all_threshes[idx]
            if thresh not in cache:
                est.set_params(thresh=thresh)
                obj = -np.mean(cross_val_score(est, X, cv=cv))
                cache.update({thresh: obj})
            return cache[thresh]

        n_epochs = all_threshes.shape[0]
        idx = np.concatenate((
            np.linspace(0, n_epochs, 5, endpoint=False, dtype=int),
            [n_epochs - 1]))  # ensure last point is in init
        idx = np.unique(idx)  # linspace may be non-unique if n_epochs < 5
        initial_x = all_threshes[idx]
        best_thresh, _ = bayes_opt(func, initial_x,
                                   all_threshes,
                                   expected_improvement,
                                   max_iter=10, debug=False,
                                   random_state=random_state)
        reject[ch_type] = best_thresh

    return reject


class _ChannelAutoReject(BaseAutoReject):
    """docstring for AutoReject."""

    def __init__(self, thresh=40e-6):
        self.thresh = thresh

    def fit(self, X, y=None):
        """Fit it.

        Parameters
        ----------
        X : array, shape (n_epochs, n_times)
            The data for one channel.
        y : None
            Redundant. Necessary to be compatible with sklearn
            API.
        """
        deltas = np.ptp(X, axis=1)
        self.deltas_ = deltas
        keep = deltas <= self.thresh
        # XXX: actually go over all the folds before setting the min
        # in skopt. Otherwise, may confuse skopt.
        if self.thresh < np.min(np.ptp(X, axis=1)):
            assert np.sum(keep) == 0
            keep = deltas <= np.min(np.ptp(X, axis=1))
        self.mean_ = _slicemean(X, keep, axis=0)
        return self


def _compute_thresh(this_data, method='bayesian_optimization',
                    cv=10, y=None, random_state=None):
    """Compute the rejection threshold for one channel.

    Parameters
    ----------
    this_data: array (n_epochs, n_times)
        Data for one channel.
    method : str
        'bayesian_optimization' or 'random_search'
    cv : iterator
        Iterator for cross-validation.
    random_state : int seed, RandomState instance, or None (default)
        The seed of the pseudo random number generator to use.

    Returns
    -------
    best_thresh : float
        The best threshold.

    Notes
    -----
    For method='random_search', the random_state parameter gives deterministic
    results only for scipy versions >= 0.16. This is why we recommend using
    autoreject with scipy version 0.16 or greater.
    """
    est = _ChannelAutoReject()
    all_threshes = np.sort(np.ptp(this_data, axis=1))

    if method == 'random_search':
        param_dist = dict(thresh=uniform(all_threshes[0],
                                         all_threshes[-1]))
        rs = RandomizedSearchCV(est,
                                param_distributions=param_dist,
                                n_iter=20, cv=cv,
                                random_state=random_state)
        rs.fit(this_data, y)
        best_thresh = rs.best_estimator_.thresh
    elif method == 'bayesian_optimization':
        cache = dict()

        def func(thresh):
            idx = np.where(thresh - all_threshes >= 0)[0][-1]
            thresh = all_threshes[idx]
            if thresh not in cache:
                est.set_params(thresh=thresh)
                obj = -np.mean(cross_val_score(est, this_data, y=y, cv=cv))
                cache.update({thresh: obj})
            return cache[thresh]

        n_epochs = all_threshes.shape[0]
        idx = np.concatenate((
            np.linspace(0, n_epochs, 40, endpoint=False, dtype=int),
            [n_epochs - 1]))  # ensure last point is in init
        idx = np.unique(idx)  # linspace may be non-unique if n_epochs < 40
        initial_x = all_threshes[idx]
        best_thresh, _ = bayes_opt(func, initial_x,
                                   all_threshes,
                                   expected_improvement,
                                   max_iter=10, debug=False,
                                   random_state=random_state)

    return best_thresh


def compute_thresholds(epochs, method='bayesian_optimization',
                       random_state=None, picks=None, augment=True,
                       verbose='progressbar', n_jobs=1):
    """Compute thresholds for each channel.

    Parameters
    ----------
    epochs : instance of mne.Epochs
        The epochs objects whose thresholds must be computed.
    method : str
        'bayesian_optimization' or 'random_search'
    random_state : int seed, RandomState instance, or None (default)
        The seed of the pseudo random number generator to use
    picks : ndarray, shape(n_channels,) | None
        The channels to be considered for autoreject. If None, defaults
        to data channels {'meg', 'eeg'}.
    augment : boolean
        Whether to augment the data or not. By default it is True, but
        set it to False, if the channel locations are not available.
    verbose : 'tqdm', 'tqdm_notebook', 'progressbar' or False
        The verbosity of progress messages.
        If `'progressbar'`, use `mne.utils.ProgressBar`.
        If `'tqdm'`, use `tqdm.tqdm`.
        If `'tqdm_notebook'`, use `tqdm.tqdm_notebook`.
        If False, suppress all output messages.
    n_jobs : int
        Number of jobs to run in parallel

    Returns
    -------
    threshes : dict
        The channel-level rejection thresholds

    Examples
    --------
    For example, we can compute the channel-level thresholds for all the
    EEG sensors this way:
        >>> compute_thresholds(epochs)
    """
    return _compute_thresholds(epochs, method=method,
                               random_state=random_state, picks=picks,
                               augment=augment, verbose=verbose, n_jobs=n_jobs)


def _compute_thresholds(epochs, method='bayesian_optimization',
                        random_state=None, picks=None, augment=True,
                        dots=None, verbose='progressbar', n_jobs=1):
    if method not in ['bayesian_optimization', 'random_search']:
        raise ValueError('`method` param not recognized')
    picks = _handle_picks(info=epochs.info, picks=picks)
    _check_data(epochs, picks, verbose=verbose,
                ch_constraint='data_channels')
    picks_by_type = _get_picks_by_type(picks=picks, info=epochs.info)
    picks_by_type = None if len(picks_by_type) == 1 else picks_by_type  # XXX
    if picks_by_type is not None:
        threshes = dict()
        for ch_type, this_picks in picks_by_type:
            threshes.update(_compute_thresholds(
                epochs=epochs, method=method, random_state=random_state,
                picks=this_picks, augment=augment, dots=dots,
                verbose=verbose, n_jobs=n_jobs))
    else:
        n_epochs = len(epochs)
        data, y = epochs.get_data(), np.ones((n_epochs, ))
        if augment:
            epochs_interp = _clean_by_interp(epochs, picks=picks,
                                             dots=dots, verbose=verbose)
            # non-data channels will be duplicate
            data = np.concatenate((epochs.get_data(),
                                   epochs_interp.get_data()), axis=0)
            y = np.r_[np.zeros((n_epochs, )), np.ones((n_epochs, ))]
        cv = StratifiedShuffleSplit(n_splits=10, test_size=0.2,
                                    random_state=random_state)

        ch_names = epochs.ch_names

        my_thresh = delayed(_compute_thresh)
        parallel = Parallel(n_jobs=n_jobs, verbose=0)
        desc = 'Computing thresholds ...'
        threshes = parallel(
            my_thresh(data[:, pick], cv=cv, method=method, y=y,
                      random_state=random_state)
            for pick in _pbar(picks, desc=desc, verbose=verbose))
        threshes = {ch_names[p]: thresh for p, thresh in zip(picks, threshes)}
    return threshes


class _AutoReject(BaseAutoReject):
    r"""Automatically reject bad epochs and repair bad trials.

    Parameters
    ----------
    epochs : instance of mne.Epochs
        The epochs object
    consensus : float (0 to 1.0)
        Percentage of channels that must agree as a fraction of
        the total number of channels. This sets :math:`\\kappa/Q`.
    n_interpolate : int (default 0)
        Number of channels for which to interpolate. This is :math:`\\rho`.
    thresh_func : callable | None
        Function which returns the channel-level thresholds. If None,
        defaults to :func:`autoreject.compute_thresholds`.
    picks : ndarray, shape(n_channels,) | None
        The channels to be considered for autoreject. If None, defaults
        to data channels {'meg', 'eeg'}.
    verbose : 'tqdm', 'tqdm_notebook', 'progressbar' or False
        The verbosity of progress messages.
        If `'progressbar'`, use `mne.utils.ProgressBar`.
        If `'tqdm'`, use `tqdm.tqdm`.
        If `'tqdm_notebook'`, use `tqdm.tqdm_notebook`.
        If False, suppress all output messages.

    Attributes
    -----------
    bad_segments : array, shape (n_epochs, n_channels)
        A boolean matrix where 1 denotes a bad data segment
        according to the sensor thresholds.
    labels : array, shape (n_epochs, n_channels)
        Similar to bad_segments, but with entries 0, 1, and 2.
            0 : good data segment
            1 : bad data segment not interpolated
            2 : bad data segment interpolated
    bad_epochs_idx : array
        The indices of bad epochs.
    threshes_ : dict
        The sensor-level thresholds with channel names as keys
        and the peak-to-peak thresholds as the values.
    """

    def __init__(self, consensus=0.1,
                 n_interpolate=0, thresh_func=None,
                 method='bayesian_optimization',
                 picks=None, dots=None,
                 verbose='progressbar'):
        """Init it."""
        if thresh_func is None:
            thresh_func = _compute_thresholds
        if not (0 <= consensus <= 1):
            raise ValueError('"consensus" must be between 0 and 1. '
                             'You gave me %s.' % consensus)
        self.consensus = consensus
        self.n_interpolate = n_interpolate
        self.thresh_func = thresh_func
        self.picks = picks
        self.verbose = verbose
        self.dots = dots

    def __repr__(self):
        """repr."""
        class_name = self.__class__.__name__
        params = dict(n_interpolate=self.n_interpolate,
                      consensus=self.consensus,
                      verbose=self.verbose, picks=self.picks)
        return '%s(%s)' % (class_name, _pprint(params,
                                               offset=len(class_name),),)

    def _vote_bad_epochs(self, epochs, picks):
        """Each channel votes for an epoch as good or bad.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epochs object for which bad epochs must be found.
        picks : array-like
            The indices of the channels to consider.
        """
        labels = np.zeros((len(epochs), len(epochs.ch_names)))
        labels.fill(np.nan)
        bad_sensor_counts = np.zeros((len(epochs),))

        this_ch_names = [epochs.ch_names[p] for p in picks]
        deltas = np.ptp(epochs.get_data()[:, picks], axis=-1).T
        threshes = [self.threshes_[ch_name] for ch_name in this_ch_names]
        for ch_idx, (delta, thresh) in enumerate(zip(deltas, threshes)):
            bad_epochs_idx = np.where(delta > thresh)[0]
            labels[:, picks[ch_idx]] = 0
            labels[bad_epochs_idx, picks[ch_idx]] = 1

        bad_sensor_counts = np.sum(labels == 1, axis=1)
        return labels, bad_sensor_counts

    def _get_epochs_interpolation(self, epochs, labels,
                                  picks, n_interpolate,
                                  verbose='progressbar'):
        """Interpolate the bad epochs."""
        # 1: bad segment, # 2: interpolated
        assert labels.shape[0] == len(epochs)
        assert labels.shape[1] == len(epochs.ch_names)
        labels = labels.copy()
        non_picks = np.setdiff1d(range(epochs.info['nchan']), picks)
        for epoch_idx in range(len(epochs)):
            n_bads = labels[epoch_idx, picks].sum()
            if n_bads == 0:
                continue
            else:
                if n_bads <= n_interpolate:
                    interp_chs_mask = labels[epoch_idx] == 1
                else:
                    # get peak-to-peak for channels in that epoch
                    data = epochs[epoch_idx].get_data()[0]
                    peaks = np.ptp(data, axis=-1)
                    peaks[non_picks] = -np.inf
                    # find channels which are bad by rejection threshold
                    interp_chs_mask = labels[epoch_idx] == 1
                    # ignore good channels
                    peaks[~interp_chs_mask] = -np.inf
                    # find the ordering of channels amongst the bad channels
                    sorted_ch_idx_picks = np.argsort(peaks)[::-1]
                    # then select only the worst n_interpolate channels
                    interp_chs_mask[
                        sorted_ch_idx_picks[n_interpolate:]] = False

            labels[epoch_idx][interp_chs_mask] = 2
        return labels

    def _get_bad_epochs(self, bad_sensor_counts, ch_type, picks):
        """Get the mask of bad epochs."""
        # XXX : avoid sorting twice
        sorted_epoch_idx = np.argsort(bad_sensor_counts)[::-1]
        bad_sensor_counts = np.sort(bad_sensor_counts)[::-1]
        n_channels = len(picks)
        n_consensus = self.consensus_[ch_type] * n_channels
        bad_epochs = np.zeros(len(bad_sensor_counts), dtype=np.bool)
        if np.max(bad_sensor_counts) >= n_consensus:
            n_epochs_drop = np.sum(bad_sensor_counts >=
                                   n_consensus)
            bad_epochs_idx = sorted_epoch_idx[:n_epochs_drop]
            bad_epochs[bad_epochs_idx] = True
        return bad_epochs

    def get_reject_log(self, epochs, threshes=None, picks=None):
        """Get rejection logs from epochs.

        .. note::
           If multiple channel types are present, reject_log.bad_epochs
           reflects the union of bad epochs across channel types.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epochs from which to get the drop logs.
        picks : np.ndarray, shape(n_channels, ) | list | None
            The channel indices to be used. If None, the .picks attribute
            will be used.

        Returns
        -------
        reject_log : instance of autoreject.RejectLog
            The rejection log.
        """
        picks = (self.picks_ if picks is None else picks)
        picks_by_type = _get_picks_by_type(picks=picks, info=epochs.info)
        assert len(picks_by_type) == 1
        ch_type, this_picks = picks_by_type[0]
        del picks

        labels, bad_sensor_counts = self._vote_bad_epochs(
            epochs, picks=this_picks)

        labels = self._get_epochs_interpolation(
            epochs, labels=labels, picks=this_picks,
            n_interpolate=self.n_interpolate_[ch_type])

        assert len(labels) == len(epochs)

        bad_epochs = self._get_bad_epochs(
            bad_sensor_counts, ch_type=ch_type, picks=this_picks)

        reject_log = RejectLog(labels=labels, bad_epochs=bad_epochs,
                               ch_names=epochs.ch_names)
        return reject_log

    def fit(self, epochs):
        """Compute the thresholds.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epochs object from which the channel-level thresholds are
            estimated.

        Returns
        -------
        self : instance of _AutoReject
            The instance.
        """
        self.picks_ = _handle_picks(info=epochs.info, picks=self.picks)
        _check_data(epochs, picks=self.picks_, verbose=self.verbose,
                    ch_constraint='single_channel_type')

        picks_by_type = _get_picks_by_type(picks=self.picks_, info=epochs.info)
        assert len(picks_by_type) == 1
        ch_type, this_picks = picks_by_type[0]

        self.consensus_ = dict()
        self.n_interpolate_ = dict()
        self.n_interpolate_[ch_type] = self.n_interpolate
        self.consensus_[ch_type] = self.consensus

        self.threshes_ = self.thresh_func(
            epochs.copy(), dots=self.dots, picks=self.picks_,
            verbose=self.verbose)

        reject_log = self.get_reject_log(epochs=epochs, picks=self.picks_)

        epochs_copy = epochs.copy()
        interp_channels = _get_interp_chs(
            reject_log.labels, reject_log.ch_names, this_picks)

        # interpolate copy to compute the clean .mean_
        _interpolate_bad_epochs(
            epochs_copy, interp_channels=interp_channels,
            picks=self.picks_, verbose=self.verbose)
        self.mean_ = _slicemean(
            epochs_copy.get_data(),
            np.nonzero(np.invert(reject_log.bad_epochs))[0], axis=0)
        del epochs_copy  # I can't wait for garbage collection.
        return self

    def transform(self, epochs, return_log=False):
        """Fix and find the bad epochs.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epochs object for which bad epochs must be found.

        return_log : bool
            If true the rejection log is also returned.

        Returns
        -------
        epochs_clean : instance of mne.Epochs
            The cleaned epochs.

        reject_log : instance of autoreject.RejectLog
            The rejection log. Returned only of return_log is True.
        """
        _check_data(epochs, picks=self.picks, verbose=self.verbose,
                    ch_constraint='data_channels')

        reject_log = self.get_reject_log(epochs, picks=None)
        if np.all(reject_log.bad_epochs):
            raise ValueError('All epochs are bad. Sorry.')

        epochs_clean = epochs.copy()
        # this one knows how to handle picks.
        _apply_interp(reject_log, self, epochs_clean, self.threshes_,
                      self.picks_, self.dots, self.verbose)

        _apply_drop(reject_log, self, epochs_clean, self.threshes_,
                    self.picks_, self.verbose)

        if return_log:
            return epochs_clean, reject_log
        else:
            return epochs_clean


def _interpolate_bad_epochs(
        epochs, interp_channels, picks, dots=None, verbose='progressbar'):
    """Actually do the interpolation."""
    assert len(epochs) == len(interp_channels)
    pos = 2

    for epoch_idx, interp_chs in _pbar(
            list(enumerate(interp_channels)),
            desc='Repairing epochs',
            position=pos, leave=True, verbose=verbose):
        epoch = epochs[epoch_idx]
        epoch.info['bads'] = interp_chs
        interpolate_bads(epoch, dots=dots, picks=picks, reset_bads=True)
        epochs._data[epoch_idx] = epoch._data


def _run_local_reject_cv(epochs, thresh_func, picks_, n_interpolate, cv,
                         consensus, dots, verbose):
    n_folds = cv.get_n_splits()
    loss = np.zeros((len(consensus), len(n_interpolate),
                     n_folds))

    # The thresholds must be learnt from the entire data
    local_reject = _AutoReject(thresh_func=thresh_func,
                               verbose=verbose, picks=picks_,
                               dots=dots)
    local_reject.fit(epochs)

    assert len(local_reject.consensus_) == 1  # works with one ch_type
    ch_type = next(iter(local_reject.consensus_))

    labels, bad_sensor_counts = \
        local_reject._vote_bad_epochs(epochs, picks=picks_)
    desc = 'n_interp'

    for jdx, n_interp in enumerate(_pbar(n_interpolate, desc=desc,
                                         position=1, verbose=verbose)):
        # we can interpolate before doing cross-valida(tion
        # because interpolation is independent across trials.
        local_reject.n_interpolate_[ch_type] = n_interp
        labels = local_reject._get_epochs_interpolation(
            epochs, labels=labels, picks=picks_, n_interpolate=n_interp)

        interp_channels = _get_interp_chs(labels, epochs.ch_names, picks_)
        epochs_interp = epochs.copy()
        # for learning we need to go by channnel type, even for meg
        _interpolate_bad_epochs(
            epochs_interp, interp_channels=interp_channels,
            picks=picks_, dots=dots, verbose=verbose)

        # Hack to allow len(self.cv_.split(X)) as ProgressBar
        # assumes an iterable whereas self.cv_.split(X) is a
        # generator
        class CVSplits(object):
            def __init__(self, gen, length):
                self.gen = gen
                self.length = length

            def __len__(self):
                return self.length

            def __iter__(self):
                return self.gen

        X = epochs.get_data()[:, picks_]
        cv_splits = CVSplits(cv.split(X), n_folds)
        pbar = _pbar(cv_splits, desc='Fold',
                     position=3, verbose=verbose)

        for fold, (train, test) in enumerate(pbar):
            for idx, this_consensus in enumerate(consensus):
                # \kappa must be greater than \rho
                n_channels = len(picks_)
                if this_consensus * n_channels <= n_interp:
                    loss[idx, jdx, fold] = np.inf
                    continue

                local_reject.consensus_[ch_type] = this_consensus
                bad_epochs = local_reject._get_bad_epochs(
                    bad_sensor_counts[train], picks=picks_, ch_type=ch_type)

                good_epochs_idx = np.nonzero(np.invert(bad_epochs))[0]

                local_reject.mean_ = _slicemean(
                    epochs_interp[train].get_data()[:, picks_],
                    good_epochs_idx, axis=0)
                loss[idx, jdx, fold] = -local_reject.score(X[test])

    return local_reject, loss


class AutoReject(object):
    r"""Efficiently find n_interpolate and consensus.

    .. note::
       AutoReject by design supports multiple channels.
       If no picks are passed, separate solutions will be computed for each
       channel type and internally combined. This then readily supports
       cleaning unseen epochs from the different channel types used during fit.

    Parameters
    ----------
    consensus : array | None
        The values to try for percentage of channels that must agree as a
        fraction of the total number of channels. This sets :math:`\\kappa/Q`.
        If None, defaults to `np.linspace(0, 1.0, 11)`
    n_interpolate : array | None
        The values to try for the number of channels for which to interpolate.
        This is :math:`\\rho`. If None, defaults to np.array([1, 4, 32])
    cv : a scikit-learn cross-validation object
        Defaults to cv=10
    picks : ndarray, shape(n_channels) | None
        The channels to be considered for autoreject. If None, defaults
        to data channels {'meg', 'eeg'}, which will lead fitting and combining
        autoreject solutions across these channel types. Note that, if picks is
        None, autoreject ignores channels marked bad in epochs.info['bads'].
    thresh_method : str
        'bayesian_optimization' or 'random_search'
    n_jobs : int
        The number of jobs.
    random_state : int seed, RandomState instance, or None (default)
        The seed of the pseudo random number generator to use.
    verbose : 'tqdm', 'tqdm_notebook', 'progressbar' or False
        The verbosity of progress messages.
        If `'progressbar'`, use `mne.utils.ProgressBar`.
        If `'tqdm'`, use `tqdm.tqdm`.
        If `'tqdm_notebook'`, use `tqdm.tqdm_notebook`.
        If False, suppress all output messages.

    Attributes
    -----------
    local_reject_ : list
        The instances of _AutoReject for each channel type.
    threshes_ : dict
        The sensor-level thresholds with channel names as keys
        and the peak-to-peak thresholds as the values.
    loss_ : dict of array, shape (len(n_interpolate), len(consensus))
        The cross validation error for different parameter values.
    consensus_ : dict
        The estimated consensus per channel type.
    n_interpolate_ : dict
        The estimated n_interpolate per channel type.
    picks_ : array-like, shape (n_data_channels,)
        The data channels considered by autoreject. By default
        only data channels, not already marked as bads are considered.
    """

    def __init__(self, n_interpolate=None, consensus=None,
                 thresh_func=None, cv=10, picks=None,
                 thresh_method='bayesian_optimization',
                 n_jobs=1, random_state=None, verbose='progressbar'):
        """Init it."""
        self.n_interpolate = n_interpolate
        self.consensus = consensus
        self.thresh_method = thresh_method
        self.cv = cv
        self.verbose = verbose
        self.picks = picks  # XXX : should maybe be ch_types?
        self.n_jobs = n_jobs
        self.random_state = random_state

        if self.consensus is None:
            self.consensus = np.linspace(0, 1.0, 11)

    def __repr__(self):
        """repr."""
        class_name = self.__class__.__name__
        params = dict(n_interpolate=self.n_interpolate,
                      consensus=self.consensus,
                      cv=self.cv, verbose=self.verbose, picks=self.picks,
                      thresh_method=self.thresh_method,
                      random_state=self.random_state, n_jobs=self.n_jobs)
        return '%s(%s)' % (class_name, _pprint(params,
                                               offset=len(class_name),),)

    def __getstate__(self):
        """Get the state of autoreject as a dictionary."""
        state = dict()

        for param in _INIT_PARAMS:
            state[param] = getattr(self, param)
        for param in _FIT_PARAMS:
            if hasattr(self, param):
                state[param] = getattr(self, param)

        if hasattr(self, 'local_reject_'):
            state['local_reject_'] = dict()
            for ch_type in self.local_reject_:
                state['local_reject_'][ch_type] = dict()
                for param in _INIT_PARAMS[:4] + _FIT_PARAMS[:4]:
                    state['local_reject_'][ch_type][param] = \
                        getattr(self.local_reject_[ch_type], param)
        return state

    def __setstate__(self, state):
        """Set the state of autoreject."""
        for param in state.keys():
            if param == 'local_reject_':
                local_reject_ = dict()
                for ch_type in state['local_reject_']:
                    init_kwargs = {
                        key: state['local_reject_'][ch_type][key]
                        for key in _INIT_PARAMS[:4]
                    }
                    if isinstance(init_kwargs['verbose'], int):
                        init_kwargs['verbose'] = bool(init_kwargs['verbose'])
                    local_reject_[ch_type] = _AutoReject(**init_kwargs)
                    for key in _FIT_PARAMS[:4]:
                        setattr(local_reject_[ch_type], key,
                                state['local_reject_'][ch_type][key])
                self.local_reject_ = local_reject_
            elif param not in _INIT_PARAMS:
                setattr(self, param, state[param])

    def fit(self, epochs):
        """Fit the epochs on the AutoReject object.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epochs object to be fit.

        Returns
        -------
        self : instance of AutoReject
            The instance.
        """
        self.picks_ = _handle_picks(picks=self.picks, info=epochs.info)
        _check_data(epochs, picks=self.picks_, verbose=self.verbose)
        self.cv_ = self.cv
        if isinstance(self.cv_, int):
            self.cv_ = KFold(n_splits=self.cv_)

        # XXX : maybe use an mne function in pick.py ?
        picks_by_type = _get_picks_by_type(info=epochs.info, picks=self.picks_)
        ch_types = [ch_type for ch_type, _ in picks_by_type]
        self.dots = None
        if 'mag' in ch_types or 'grad' in ch_types:
            meg_picks = pick_types(epochs.info, meg=True,
                                   eeg=False, exclude=[])
            this_info = mne.pick_info(epochs.info, meg_picks, copy=True)
            self.dots = _compute_dots(this_info)

        thresh_func = partial(_compute_thresholds, n_jobs=self.n_jobs,
                              method=self.thresh_method,
                              random_state=self.random_state,
                              dots=self.dots)

        if self.n_interpolate is None:
            if len(self.picks_) < 4:
                raise ValueError('Too few channels. autoreject is unlikely'
                                 ' to be effective')
            # XXX: dont interpolate all channels
            max_interp = min(len(self.picks_) - 1, 32)
            self.n_interpolate = np.array([1, 4, max_interp])

        self.n_interpolate_ = dict()  # rho
        self.consensus_ = dict()  # kappa
        self.threshes_ = dict()  # update
        self.loss_ = dict()
        self.local_reject_ = dict()

        for ch_type, this_picks in picks_by_type:
            if self.verbose is not False:
                print('Running autoreject on ch_type=%s' % ch_type)
            this_local_reject, this_loss = \
                _run_local_reject_cv(epochs, thresh_func, this_picks,
                                     self.n_interpolate, self.cv_,
                                     self.consensus, self.dots,
                                     self.verbose)
            self.threshes_.update(this_local_reject.threshes_)

            best_idx, best_jdx = \
                np.unravel_index(this_loss.mean(axis=-1).argmin(),
                                 this_loss.shape[:2])

            self.consensus_[ch_type] = self.consensus[best_idx]
            self.n_interpolate_[ch_type] = self.n_interpolate[best_jdx]
            self.loss_[ch_type] = this_loss

            # update local reject with best and store it
            this_local_reject.consensus_[ch_type] = self.consensus_[ch_type]
            this_local_reject.n_interpolate_[ch_type] = \
                self.n_interpolate_[ch_type]

            # needed for generating reject logs by channel
            self.local_reject_[ch_type] = this_local_reject

            if self.verbose is not False:
                print('\n\n\n\nEstimated consensus=%0.2f and n_interpolate=%d'
                      % (self.consensus_[ch_type],
                         self.n_interpolate_[ch_type]))
        return self

    def get_reject_log(self, epochs, picks=None):
        """Get rejection logs of epochs.

        .. note::
           If multiple channel types are present, reject_log['bad_epochs_idx']
           reflects the union of bad trials across channel types.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epoched data for which the reject log is computed.
        picks : np.ndarray, shape(n_channels, ) | list | None
            The channel indices to be used. If None, the .picks attribute
            will be used.

        Returns
        -------
        reject_log : instance of autoreject.RejectLog
            The reject log.
        """
        # XXX gut feeling that there is a bad condition that we miss
        ch_names = [cc for cc in epochs.ch_names]
        labels = np.ones((len(epochs), len(ch_names)))
        labels.fill(np.nan)
        reject_log = RejectLog(
            labels=labels,
            bad_epochs=np.zeros(len(epochs), dtype=np.bool),
            ch_names=ch_names)

        picks_by_type = _get_picks_by_type(info=epochs.info, picks=self.picks_)
        for ch_type, this_picks in picks_by_type:
            this_reject_log = self.local_reject_[ch_type].get_reject_log(
                epochs, threshes=self.threshes_, picks=this_picks)
            reject_log.labels[:, this_picks] = \
                this_reject_log.labels[:, this_picks]
            reject_log.bad_epochs = np.logical_or(
                reject_log.bad_epochs, this_reject_log.bad_epochs)
            reject_log.ch_names = this_reject_log.ch_names
        return reject_log

    def transform(self, epochs, return_log=False):
        """Remove bad epochs, repairs sensors and returns clean epochs.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epochs object which must be cleaned.

        return_log : bool
            If true the rejection log is also returned.

        Returns
        -------
        epochs_clean : instance of mne.Epochs
            The cleaned epochs

        reject_log : instance of autoreject.RejectLog
            The rejection log. Returned only if return_log is True.
        """
        # XXX : should be a check_fitted method
        if not hasattr(self, 'n_interpolate_'):
            raise ValueError('Please run autoreject.fit() method first')

        _check_data(epochs, picks=self.picks_, verbose=self.verbose)

        reject_log = self.get_reject_log(epochs)
        epochs_clean = epochs.copy()
        _apply_interp(reject_log, epochs_clean, self.threshes_,
                      self.picks_, self.dots, self.verbose)

        _apply_drop(reject_log, epochs_clean, self.threshes_, self.picks_,
                    self.verbose)

        if return_log:
            return epochs_clean, reject_log
        else:
            return epochs_clean

    def fit_transform(self, epochs, return_log=False):
        """Estimate the rejection params and finds bad epochs.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epochs object which must be cleaned.

        return_log : bool
            If true the rejection log is also returned.

        Returns
        -------
        epochs_clean : instance of mne.Epochs
            The cleaned epochs.

        reject_log : instance of autoreject.RejectLog
            The rejection log. Returned only of return_log is True.
        """
        return self.fit(epochs).transform(epochs, return_log=return_log)

    def save(self, fname, overwrite=False):
        """Save autoreject object.

        Parameters
        ----------
        fname : str
            The filename to save to. The filename must end
            in '.h5' or '.hdf5'.
        overwrite : bool
            If True, overwrite file if it already exists. Defaults to False.
        """
        fname = op.realpath(fname)
        if not overwrite and op.isfile(fname):
            raise ValueError('%s already exists. Please make overwrite=True'
                             'if you want to overwrite this file' % fname)

        write_hdf5(fname, self.__getstate__(), overwrite=overwrite,
                   title='autoreject')


def _check_fit(epochs, threshes_, picks_):
    msg = ('You are passing channels which were not present '
           'at fit-time. Please fit it again, this time '
           'correctly.')
    if not all(epochs.ch_names[pp] in threshes_
               for pp in picks_):
        raise ValueError(msg)


def _apply_interp(reject_log, epochs, threshes_, picks_, dots,
                  verbose):
    _check_fit(epochs, threshes_, picks_)
    interp_channels = _get_interp_chs(
        reject_log.labels, reject_log.ch_names, picks_)
    _interpolate_bad_epochs(
        epochs, interp_channels=interp_channels,
        picks=picks_, dots=dots, verbose=verbose)


def _apply_drop(reject_log, epochs, threshes_, picks_,
                verbose):
    _check_fit(epochs, threshes_, picks_)
    if np.any(reject_log.bad_epochs):
        epochs.drop(np.nonzero(reject_log.bad_epochs)[0],
                    reason='AUTOREJECT')
    elif verbose:
        print("No bad epochs were found for your data. Returning "
              "a copy of the data you wanted to clean. Interpolation "
              "may have been done.")


def _get_interp_chs(labels, ch_names, picks):
    """Convert labels to channel names.
    It returns a list of length n_epochs. Each entry contains
    the names of the channels to interpolate.

    labels is of shape n_epochs x n_channels
    and picks is the sublist of channels to consider.
    """
    interp_channels = list()
    assert labels.shape[1] == len(ch_names)
    assert labels.shape[1] > np.max(picks)
    idx_nan_in_row = np.where(np.any(~np.isnan(labels), axis=0))[0]
    np.testing.assert_array_equal(picks, idx_nan_in_row)
    for this_labels in labels:
        interp_idx = np.where(this_labels == 2)[0]
        interp_channels.append([ch_names[ii] for ii in interp_idx])
    return interp_channels


class RejectLog(object):
    """The Rejection Log.

    Parameters
    ----------
    bad_epochs : array-like, shape (n_epochs,)
        The boolean array with entries True for epochs that
        are marked as bad.
    labels : array, shape (n_epochs, n_channels)
        It contains integers that encode if a channel in a given
        epoch is good (value 0), bad (1), or bad and interpolated (2).
    ch_names : list of str
        The list of channels corresponding to the rows of the labels.
    """

    def __init__(self, bad_epochs, labels, ch_names):
        self.bad_epochs = bad_epochs
        self.labels = labels
        self.ch_names = ch_names
        assert len(bad_epochs) == labels.shape[0]
        assert len(ch_names) == labels.shape[1]

    def plot(self, orientation='vertical', show=True):
        """Plot.

        Parameters
        ----------
        orientation : 'vertical' or 'horizontal'
            If `'vertical'`, will plot sensors on x-axis and epochs on y-axis.
            If `'horizontal'`, will plot epochs on x-axis and sensors
            on y-axis.
        show : bool
            If True, display the figure immediately.

        Returns
        -------
        figure : Instance of matplotlib.figure.Figure
        """
        import matplotlib.pyplot as plt

        figure, ax = plt.subplots(figsize=(12, 6))
        ax.grid(False)
        ch_names_ = self.ch_names[7::10]

        if orientation == 'horizontal':
            ax.imshow(self.labels.T, cmap='Reds',
                      interpolation='nearest')
            ax.set_xlabel('Epochs')
            ax.set_ylabel('Channels')
            plt.setp(ax, yticks=range(7, self.labels.shape[1], 10),
                     yticklabels=ch_names_)
        elif orientation == 'vertical':
            ax.imshow(self.labels, cmap='Reds',
                      interpolation='nearest')
            ax.set_xlabel('Channels')
            ax.set_ylabel('Epochs')
            plt.setp(ax, xticks=range(7, self.labels.shape[1], 10),
                     xticklabels=ch_names_)
        else:
            msg = """orientation can be only \
                  'horizontal' or 'vertical'. Got %s""" % orientation
            raise ValueError(msg)

        # XXX to be fixed
        plt.setp(ax.get_yticklabels(), rotation=0)
        plt.setp(ax.get_xticklabels(), rotation=90)
        ax.tick_params(axis=u'both', which=u'both', length=0)
        plt.tight_layout(rect=[None, None, None, 1.1])
        if show:
            plt.show()
        return figure

    def plot_epochs(self, epochs, scalings=None, title=''):
        """Plot interpolated and dropped epochs.

        Parameters
        ----------
        epochs : instance of Epochs
            The epochs.
        scalings : dict | None
            Scaling factors for the traces. If None, defaults to::

                dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
                     emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
                     resp=1, chpi=1e-4, whitened=1e2)
        title : str
            The title to display.

        Returns
        -------
        fig : Instance of matplotlib.figure.Figure
            Epochs traces.
        """
        labels = self.labels
        n_epochs, n_channels = labels.shape

        if not labels.shape[0] == len(epochs.events):
            raise ValueError('The number of epochs should match the number of'
                             'epochs *before* autoreject. Please provide'
                             'the epochs object before running autoreject')
        if not labels.shape[1] == len(epochs.ch_names):
            raise ValueError('The number of channels should match the number'
                             ' of channels before running autoreject.')
        bad_epochs_idx = np.where(self.bad_epochs)[0]
        if len(bad_epochs_idx) > 0 and \
                bad_epochs_idx.max() > len(epochs.events):
            raise ValueError('You had a bad_epoch with index'
                             '%d but there are only %d epochs. Make sure'
                             ' to provide the epochs *before* running'
                             'autoreject.'
                             % (bad_epochs_idx.max(),
                                len(epochs.events)))

        color_map = {0: None, 1: 'r', 2: (0.6, 0.6, 0.6, 1.0)}
        epoch_colors = list()
        for epoch_idx, label_epoch in enumerate(labels):
            if self.bad_epochs[epoch_idx]:
                epoch_color = ['r'] * n_channels
                epoch_colors.append(epoch_color)
                continue
            epoch_color = list()
            for this_label in label_epoch:
                if not np.isnan(this_label):
                    epoch_color.append(color_map[this_label])
                else:
                    epoch_color.append(None)
            epoch_colors.append(epoch_color)

        return plot_mne_epochs(
            epochs=epochs,
            epoch_colors=epoch_colors, scalings=scalings,
            title='')