#  Copyright 2016 Intel Corporation
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Full Correlation Matrix Analysis (FCMA)

Activity-based voxel selection
"""

# Authors: Yida Wang
# (Intel Labs), 2017

import numpy as np
from sklearn import model_selection
import logging
from mpi4py import MPI

logger = logging.getLogger(__name__)

__all__ = [
    "MVPAVoxelSelector",
]


def _sfn(data, mask, myrad, bcast_var):
    """Score classifier on searchlight data using cross-validation.

    The classifier is in `bcast_var[2]`. The labels are in `bast_var[0]`. The
    number of cross-validation folds is in `bast_var[1].
    """
    clf = bcast_var[2]
    masked_data = data[0][mask, :].T
    # print(l[0].shape, mask.shape, data.shape)
    skf = model_selection.StratifiedKFold(n_splits=bcast_var[1],
                                          shuffle=False)
    accuracy = np.mean(model_selection.cross_val_score(clf, masked_data,
                                                       y=bcast_var[0],
                                                       cv=skf,
                                                       n_jobs=1))
    return accuracy


class MVPAVoxelSelector:
    """Activity-based voxel selection component of FCMA

    Parameters
    ----------

    data: 4D array in shape [brain 3D + epoch]
        contains the averaged and normalized brain data epoch by epoch.
        It is generated by .io.prepare_searchlight_mvpa_data

    mask: 3D array

    labels: 1D array
        contains the labels of the epochs.
        It is generated by .io.prepare_searchlight_mvpa_data

    num_folds: int
        the number of folds to be conducted in the cross validation

    sl: Searchlight
        the distributed Searchlight object
    """
    def __init__(self,
                 data,
                 mask,
                 labels,
                 num_folds,
                 sl
                 ):
        self.data = data
        self.mask = mask.astype(np.bool)
        self.labels = labels
        self.num_folds = num_folds
        self.sl = sl
        num_voxels = np.sum(self.mask)
        if num_voxels == 0:
            raise ValueError('Zero processed voxels')

    def run(self, clf):
        """ run activity-based voxel selection

        Sort the voxels based on the cross-validation accuracy
        of their activity vectors within the searchlight

        Parameters
        ----------
        clf: classification function
            the classifier to be used in cross validation

        Returns
        -------
        result_volume: 3D array of accuracy numbers
            contains the voxelwise accuracy numbers obtained via Searchlight
        results: list of tuple (voxel_id, accuracy)
            the accuracy numbers of all voxels, in accuracy descending order
            the length of array equals the number of voxels
        """
        rank = MPI.COMM_WORLD.Get_rank()
        if rank == 0:
            logger.info(
                'running activity-based voxel selection via Searchlight'
            )
        self.sl.distribute([self.data], self.mask)
        self.sl.broadcast((self.labels, self.num_folds, clf))
        if rank == 0:
            logger.info(
                'data preparation done'
            )

        # obtain a 3D array with accuracy numbers
        result_volume = self.sl.run_searchlight(_sfn)
        # get result tuple list from the volume
        result_list = result_volume[self.mask]
        results = []
        if rank == 0:
            for idx, value in enumerate(result_list):
                if value is None:
                    value = 0
                results.append((idx, value))
            # Sort the voxels
            results.sort(key=lambda tup: tup[1], reverse=True)
            logger.info(
                'activity-based voxel selection via Searchlight is done'
            )
        return result_volume, results