import numpy as np import sklearn import sklearn.decomposition import sklearn.cluster import sklearn.manifold import sklearn.discriminant_analysis from . import tools import joblib import time def project_waveforms(method='pca_by_channel', catalogueconstructor=None, selection=None, **params): """ If slection is None then the fit of projector is done on all some_peaks_index Otherwise the fit is done on the a subset of waveform force by slection bool mask selection is mask bool size all_peaks """ cc = catalogueconstructor if method=='global_pca': projector = GlobalPCA(catalogueconstructor=cc, selection=selection, **params) elif method=='peak_max': projector = PeakMaxOnChannel(catalogueconstructor=cc, selection=selection, **params) elif method=='pca_by_channel': projector = PcaByChannel(catalogueconstructor=cc, selection=selection, **params) #~ elif method=='neighborhood_pca': #~ projector = NeighborhoodPca(waveforms, catalogueconstructor=catalogueconstructor, **params) elif method=='global_lda': projector = GlobalLDA(catalogueconstructor=cc, selection=selection, **params) else: raise NotImplementedError #~ features = projector.transform(waveforms2) features = projector.get_features(catalogueconstructor) channel_to_features = projector.channel_to_features return features, channel_to_features, projector class GlobalPCA: def __init__(self, catalogueconstructor=None, selection=None, n_components=5, **params): cc = catalogueconstructor self.n_components = n_components self.waveforms = cc.get_some_waveforms() if selection is None: waveforms = self.waveforms #~ print('all selection', waveforms.shape[0]) else: peaks_index, = np.nonzero(selection) waveforms = cc.get_some_waveforms(peaks_index=peaks_index) #~ print('subset selection', waveforms.shape[0]) flatten_waveforms = waveforms.reshape(waveforms.shape[0], -1) #~ self.pca = sklearn.decomposition.IncrementalPCA(n_components=n_components, **params) self.pca = sklearn.decomposition.TruncatedSVD(n_components=n_components, **params) self.pca.fit(flatten_waveforms) #In GlobalPCA all feature represent all channels self.channel_to_features = np.ones((cc.nb_channel, self.n_components), dtype='bool') def get_features(self, catalogueconstructor): features = self.transform(self.waveforms) del self.waveforms return features def transform(self, waveforms): flatten_waveforms = waveforms.reshape(waveforms.shape[0], -1) return self.pca.transform(flatten_waveforms) class PeakMaxOnChannel: def __init__(self, catalogueconstructor=None, selection=None, **params): if selection is not None: print('selection with PeakMaxOnChannel is a non sens') cc = catalogueconstructor #~ self.waveforms = waveforms # TODO something faster with only the max!!!!! self.waveforms = cc.get_some_waveforms() self.ind_peak = -catalogueconstructor.info['waveform_extractor_params']['n_left'] #~ print('PeakMaxOnChannel self.ind_peak', self.ind_peak) #In full PeakMaxOnChannel one feature is one channel self.channel_to_features = np.eye(cc.nb_channel, dtype='bool') def get_features(self, catalogueconstructor): features = self.transform(self.waveforms) del self.waveforms return features def transform(self, waveforms): #~ print('ici', waveforms.shape, self.ind_peak) features = waveforms[:, self.ind_peak, : ].copy() return features #~ Parallel(n_jobs=n_jobs)(delayed(count_match_spikes)(sorting1.get_unit_spike_train(u1), #~ s2_spiketrains, delta_frames) for #~ i1, u1 in enumerate(unit1_ids)) #~ def get_pca_one_channel(wf_chan, chan, thresh, n_left, n_components_by_channel, params): #~ print(chan) #~ pca = sklearn.decomposition.IncrementalPCA(n_components=n_components_by_channel, **params) #~ wf_chan = waveforms[:,:,chan] #~ print(wf_chan.shape) #~ print(wf_chan[:, -n_left].shape) #~ keep = np.any((wf_chan>thresh) | (wf_chan<-thresh)) #~ keep = (wf_chan[:, -n_left]>thresh) | (wf_chan[:, -n_left]<-thresh) #~ if keep.sum() >=n_components_by_channel: #~ pca.fit(wf_chan[keep, :]) #~ return pca #~ else: #~ return None class PcaByChannel: def __init__(self, catalogueconstructor=None, selection=None, n_components_by_channel=3, adjacency_radius_um=200, **params): cc = catalogueconstructor thresh = cc.info['peak_detector_params']['relative_threshold'] n_left = cc.info['waveform_extractor_params']['n_left'] self.dtype = cc.info['internal_dtype'] #~ self.waveforms = waveforms self.n_components_by_channel = n_components_by_channel self.adjacency_radius_um = adjacency_radius_um #~ t1 = time.perf_counter() if selection is None: peaks_index = cc.some_peaks_index else: peaks_index, = np.nonzero(selection) some_peaks = cc.all_peaks[peaks_index] self.pcas = [] for chan in range(cc.nb_channel): #~ for chan in range(20): #~ print('fit', chan) sel = some_peaks['channel'] == chan wf_chan = cc.get_some_waveforms(peaks_index=peaks_index[sel], channel_indexes=[chan]) wf_chan = wf_chan[:, :, 0] #~ print(wf_chan.shape) if wf_chan.shape[0] - 1 > n_components_by_channel: #~ pca = sklearn.decomposition.IncrementalPCA(n_components=n_components_by_channel, **params) #~ print('PcaByChannel SVD') pca = sklearn.decomposition.TruncatedSVD(n_components=n_components_by_channel, **params) pca.fit(wf_chan) else: pca = None self.pcas.append(pca) #~ t2 = time.perf_counter() #~ print('pca fit', t2-t1) #~ pca = get_pca_one_channel(waveforms, chan, thresh, n_left, n_components_by_channel, params) #~ n_jobs = -1 #~ self.pcas = joblib.Parallel(n_jobs=n_jobs)(joblib.delayed(get_pca_one_channel)(waveforms, chan, thresh, n_components_by_channel, params) for chan in range(cc.nb_channel)) #In full PcaByChannel n_components_by_channel feature correspond to one channel self.channel_to_features = np.zeros((cc.nb_channel, cc.nb_channel*n_components_by_channel), dtype='bool') for c in range(cc.nb_channel): self.channel_to_features[c, c*n_components_by_channel:(c+1)*n_components_by_channel] = True def get_features(self, catalogueconstructor): cc = catalogueconstructor nb = cc.some_peaks_index.size n = self.n_components_by_channel features = np.zeros((nb, cc.nb_channel*self.n_components_by_channel), dtype=self.dtype) some_peaks = cc.all_peaks[cc.some_peaks_index] if cc.mode == 'sparse': assert cc.info['peak_detector_params']['method'] == 'geometrical' #~ adjacency_radius_um = cc.info['peak_detector_params']['adjacency_radius_um'] channel_adjacency = cc.dataio.get_channel_adjacency(chan_grp=cc.chan_grp, adjacency_radius_um=self.adjacency_radius_um) #~ t1 = time.perf_counter() for chan, pca in enumerate(self.pcas): if pca is None: continue #~ print('transform', chan) #~ sel = some_peaks['channel'] == chan if cc.mode == 'dense': wf_chan = cc.get_some_waveforms(peaks_index=cc.some_peaks_index, channel_indexes=[chan]) wf_chan = wf_chan[:, :, 0] #~ print('dense', wf_chan.shape) features[:, chan*n:(chan+1)*n] = pca.transform(wf_chan) elif cc.mode == 'sparse': sel = np.in1d(some_peaks['channel'], channel_adjacency[chan]) #~ print(chan, np.sum(sel)) wf_chan = cc.get_some_waveforms(peaks_index=cc.some_peaks_index[sel], channel_indexes=[chan]) wf_chan = wf_chan[:, :, 0] #~ print('sparse', wf_chan.shape) features[:, chan*n:(chan+1)*n][sel, :] = pca.transform(wf_chan) #~ t2 = time.perf_counter() #~ print('pca transform', t2-t1) return features def transform(self, waveforms): n = self.n_components_by_channel all = np.zeros((waveforms.shape[0], waveforms.shape[2]*n), dtype=self.dtype) for c, pca in enumerate(self.pcas): if pca is None: continue #~ print(c) all[:, c*n:(c+1)*n] = pca.transform(waveforms[:, :, c]) return all class GlobalLDA: def __init__(self, catalogueconstructor=None, selection=None, **params): cc = catalogueconstructor self.waveforms = cc.get_some_waveforms() if selection is None: #~ waveforms = self.waveforms raise NotImplementedError else: peaks_index, = np.nonzero(selection) waveforms = cc.get_some_waveforms(peaks_index=peaks_index) labels = cc.all_peaks[peaks_index]['cluster_label'] flatten_waveforms = waveforms.reshape(waveforms.shape[0], -1) self.lda = sklearn.discriminant_analysis.LinearDiscriminantAnalysis() self.lda.fit(flatten_waveforms, labels) #In GlobalPCA all feature represent all channels self.channel_to_features = np.ones((cc.nb_channel, self.lda._max_components), dtype='bool') def get_features(self, catalogueconstructor): features = self.transform(self.waveforms) del self.waveforms return features def transform(self, waveforms): flatten_waveforms = waveforms.reshape(waveforms.shape[0], -1) return self.lda.transform(flatten_waveforms) #~ class NeighborhoodPca: #~ def __init__(self, waveforms, catalogueconstructor=None, n_components_by_neighborhood=6, radius_um=300., **params): #~ cc = catalogueconstructor #~ self.n_components_by_neighborhood = n_components_by_neighborhood #~ self.neighborhood = tools.get_neighborhood(cc.geometry, radius_um) #~ self.pcas = [] #~ for c in range(cc.nb_channel): #~ neighbors = self.neighborhood[c, :] #~ pca = sklearn.decomposition.IncrementalPCA(n_components=n_components_by_neighborhood, **params) #~ wfs = waveforms[:,:,neighbors] #~ wfs = wfs.reshape(wfs.shape[0], -1) #~ pca.fit(wfs) #~ self.pcas.append(pca) #~ #In full NeighborhoodPca n_components_by_neighborhood feature correspond to one channel #~ self.channel_to_features = np.zeros((cc.nb_channel, cc.nb_channel*n_components_by_neighborhood), dtype='bool') #~ for c in range(cc.nb_channel): #~ self.channel_to_features[c, c*n_components_by_neighborhood:(c+1)*n_components_by_neighborhood] = True #~ def transform(self, waveforms): #~ n = self.n_components_by_neighborhood #~ all = np.zeros((waveforms.shape[0], waveforms.shape[2]*n), dtype=waveforms.dtype) #~ for c, pca in enumerate(self.pcas): #~ neighbors = self.neighborhood[c, :] #~ wfs = waveforms[:,:,neighbors] #~ wfs = wfs.reshape(wfs.shape[0], -1) #~ all[:, c*n:(c+1)*n] = pca.transform(wfs) #~ return all