from spikeextractors import RecordingExtractor
from spikeextractors.extraction_tools import check_get_traces_args
import numpy as np

try:
    from scipy import special, signal

    HAVE_RR = True
except ImportError:
    HAVE_RR = False


class ResampleRecording(RecordingExtractor):
    preprocessor_name = 'Resample'
    installed = HAVE_RR  # check at class level if installed or not
    installation_mesg = "To use the ResampleRecording, install scipy: \n\n pip install scipy\n\n"  # err

    def __init__(self, recording, resample_rate):
        assert HAVE_RR, "To use the ResampleRecording, install scipy: \n\n pip install scipy\n\n"
        self._recording = recording
        self._resample_rate = resample_rate
        RecordingExtractor.__init__(self)
        self._dtype = recording.get_dtype()
        self.copy_channel_properties(recording)
        self.is_filtered = self._recording.is_filtered

        self._kwargs = {'recording': recording.make_serialized_dict(), 'resample_rate': resample_rate}

    def get_sampling_frequency(self):
        return self._resample_rate

    def get_num_frames(self):
        return int(self._recording.get_num_frames() / self._recording.get_sampling_frequency() * self._resample_rate)

    # avoid filtering one sample
    def get_dtype(self):
        return self._dtype

    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        start_frame_not_sampled = int(start_frame / self.get_sampling_frequency() *
                                      self._recording.get_sampling_frequency())
        start_frame_sampled = start_frame
        end_frame_not_sampled = int(end_frame / self.get_sampling_frequency() *
                                    self._recording.get_sampling_frequency())
        end_frame_sampled = end_frame
        traces = self._recording.get_traces(start_frame=start_frame_not_sampled,
                                            end_frame=end_frame_not_sampled,
                                            channel_ids=channel_ids)
        if np.mod(self._recording.get_sampling_frequency(), self._resample_rate) == 0:
            traces_resampled = signal.decimate(traces,
                                               q=int(self._recording.get_sampling_frequency() / self._resample_rate),
                                               axis=1)
        else:
            traces_resampled = signal.resample(traces, int(end_frame_sampled - start_frame_sampled), axis=1)
        return traces_resampled.astype(self._dtype)

    def get_channel_ids(self):
        return self._recording.get_channel_ids()


def resample(recording, resample_rate):
    '''
    Resamples the recording extractor traces. If the resampling rate is multiple of the sampling rate, the faster
    scipy decimate function is used.

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor to be resampled
    resample_rate: int or float
        The resampling frequency

    Returns
    -------
    resampled_recording: ResampleRecording
        The resample recording extractor

    '''
    return ResampleRecording(
        recording=recording,
        resample_rate=resample_rate
    )