from .filterrecording import FilterRecording import spikeextractors as se import numpy as np try: import scipy.signal as ss HAVE_NFR = True except ImportError: HAVE_NFR = False class NotchFilterRecording(FilterRecording): preprocessor_name = 'NotchFilter' installed = HAVE_NFR # check at class level if installed or not installation_mesg = "To use the NotchFilterRecording, install scipy: \n\n pip install scipy\n\n" # error message when not installed def __init__(self, recording, freq=3000, q=30, chunk_size=30000, cache_chunks=False): assert HAVE_NFR, "To use the NotchFilterRecording, install scipy: \n\n pip install scipy\n\n" self._freq = freq self._q = q fn = 0.5 * float(recording.get_sampling_frequency()) self._b, self._a = ss.iirnotch(self._freq / fn, self._q) if not np.all(np.abs(np.roots(self._a)) < 1): raise ValueError('Filter is not stable') FilterRecording.__init__(self, recording=recording, chunk_size=chunk_size, cache_chunks=cache_chunks) self.copy_channel_properties(recording) self.is_filtered = self._recording.is_filtered self._kwargs = {'recording': recording.make_serialized_dict(), 'freq': freq, 'q': q, 'chunk_size': chunk_size, 'cache_chunks': cache_chunks} def filter_chunk(self, *, start_frame, end_frame): padding = 3000 i1 = start_frame - padding i2 = end_frame + padding padded_chunk = self._read_chunk(i1, i2) filtered_padded_chunk = self._do_filter(padded_chunk) return filtered_padded_chunk[:, start_frame - i1:end_frame - i1] def _do_filter(self, chunk): chunk_filtered = ss.filtfilt(self._b, self._a, chunk, axis=1) return chunk_filtered def _read_chunk(self, i1, i2): M = len(self._recording.get_channel_ids()) N = self._recording.get_num_frames() if i1 < 0: i1b = 0 else: i1b = i1 if i2 > N: i2b = N else: i2b = i2 ret = np.zeros((M, i2 - i1)) ret[:, i1b - i1:i2b - i1] = self._recording.get_traces(start_frame=i1b, end_frame=i2b) return ret def notch_filter(recording, freq=3000, q=30, chunk_size=30000, cache_to_file=False, cache_chunks=False): ''' Performs a notch filter on the recording extractor traces using scipy iirnotch function. Parameters ---------- recording: RecordingExtractor The recording extractor to be notch-filtered. freq: int or float The target frequency of the notch filter. q: int The quality factor of the notch filter. chunk_size: int The chunk size to be used for the filtering. cache_to_file: bool (default False). If True, filtered traces are computed and cached all at once on disk in temp file cache_chunks: bool (default False). If True then each chunk is cached in memory (in a dict) Returns ------- filter_recording: NotchFilterRecording The notch-filtered recording extractor object ''' if cache_to_file: assert not cache_chunks, 'if cache_to_file cache_chunks should be False' notch_recording = NotchFilterRecording( recording=recording, freq=freq, q=q, chunk_size=chunk_size, cache_chunks=cache_chunks, ) if cache_to_file: return se.CacheRecordingExtractor(notch_recording, chunk_size=chunk_size) else: return notch_recording