import os from pathlib import Path import numpy as np from spikeextractors import SortingExtractor from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor from spikeextractors.extraction_tools import save_to_probe_file, load_probe_file, check_valid_unit_id try: import hybridizer.io as sbio import hybridizer.probes as sbprb HAVE_SBEX = True except ImportError: HAVE_SBEX = False class SHYBRIDRecordingExtractor(BinDatRecordingExtractor): extractor_name = 'SHYBRIDRecording' installed = HAVE_SBEX is_writable = True mode = 'file' installation_mesg = "To use the SHYBRID extractors, install SHYBRID: \n\n pip install shybrid\n\n" def __init__(self, file_path): # load params file related to the given shybrid recording assert HAVE_SBEX, self.installation_mesg params = sbio.get_params(file_path)['data'] # create a shybrid probe object probe = sbprb.Probe(params['probe']) nb_channels = probe.total_nb_channels # translate the byte ordering # TODO still ambiguous, shybrid should assume time_axis=1, since spike interface makes an assumption on the byte ordering byte_order = params['order'] if byte_order == 'C': time_axis = 1 elif byte_order == 'F': time_axis = 0 # piggyback on binary data recording extractor BinDatRecordingExtractor.__init__(self, file_path, params['fs'], nb_channels, params['dtype'], time_axis=time_axis) self._kwargs = {'file_path': str(Path(file_path).absolute())} self = load_probe_file(self, params['probe']) @staticmethod def write_recording(recording, save_path, initial_sorting_fn, dtype='float32'): """ Convert and save the recording extractor to SHYBRID format parameters ---------- recording: RecordingExtractor The recording extractor to be converted and saved save_path: str Full path to desired target folder initial_sorting_fn: str Full path to the initial sorting csv file (can also be generated using write_sorting static method from the SHYBRIDSortingExtractor) dtype: dtype Type of the saved data. Default float32. """ assert HAVE_SBEX, SHYBRIDRecordingExtractor.installation_mesg RECORDING_NAME = 'recording.bin' PROBE_NAME = 'probe.prb' PARAMETERS_NAME = 'recording.yml' # location information has to be present in order for shybrid to # be able to operate on the recording if 'location' not in recording.get_shared_channel_property_names(): raise GeometryNotLoadedError("Channel locations were not found") # write recording recording_fn = os.path.join(save_path, RECORDING_NAME) BinDatRecordingExtractor.write_recording(recording, recording_fn, time_axis=0, dtype=dtype) # write probe file probe_fn = os.path.join(save_path, PROBE_NAME) save_to_probe_file(recording, probe_fn) # create parameters file parameters = params_template.format(initial_sorting_fn=initial_sorting_fn, data_type=dtype, sampling_frequency=str(recording.get_sampling_frequency()), byte_ordering='F', probe_fn=probe_fn) # write parameters file parameters_fn = os.path.join(save_path, PARAMETERS_NAME) with open(parameters_fn, 'w') as fp: fp.write(parameters) class SHYBRIDSortingExtractor(SortingExtractor): extractor_name = 'SHYBRIDSortingExtractor' installed = HAVE_SBEX is_writable = True installation_mesg = "To use the SHYBRID extractors, install SHYBRID: \n\n pip install shybrid\n\n" def __init__(self, file_path, delimiter=','): assert HAVE_SBEX, self.installation_mesg SortingExtractor.__init__(self) if os.path.isfile(file_path): self._spike_clusters = sbio.SpikeClusters() self._spike_clusters.fromCSV(file_path, None, delimiter=delimiter) else: raise FileNotFoundError('the ground truth file "{}" could not be found'.format(file_path)) self._kwargs = {'file_path': str(Path(file_path).absolute()), 'delimiter': delimiter} def get_unit_ids(self): return self._spike_clusters.keys() @check_valid_unit_id def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) train = self._spike_clusters[unit_id].get_actual_spike_train().spikes if start_frame is None: start_frame = 0 if end_frame is None: end_frame = np.Inf idxs = np.where((start_frame <= train) & (train < end_frame)) return train[idxs] @staticmethod def write_sorting(sorting, save_path): """ Convert and save the sorting extractor to SHYBRID CSV format parameters ---------- sorting : SortingExtractor The sorting extractor to be converted and saved save_path : str Full path to the desired target folder """ assert HAVE_SBEX, SHYBRIDSortingExtractor.installation_mesg dump = np.empty((0, 2)) for unit_id in sorting.get_unit_ids(): spikes = sorting.get_unit_spike_train(unit_id)[:,np.newaxis] expanded_id = (np.ones(spikes.size) * unit_id)[:,np.newaxis] tmp_concat = np.concatenate((expanded_id, spikes), axis=1) dump = np.concatenate((dump, tmp_concat), axis=0) sorting_fn = os.path.join(save_path, 'initial_sorting.csv') np.savetxt(sorting_fn, dump, delimiter=',', fmt='%i') class GeometryNotLoadedError(Exception): """ Raised when the recording extractor has no associated channel locations """ pass params_template = \ """clusters: csv: {initial_sorting_fn} data: dtype: {data_type} fs: {sampling_frequency} order: {byte_ordering} probe: {probe_fn} """