""" Some heuristic method to clean after clustering: * auto-split * auto-merge """ import numpy as np import os import time import sklearn import sklearn.cluster import sklearn.mixture import sklearn.metrics import sklearn.decomposition from joblib import Parallel, delayed import matplotlib.pyplot as plt from .dip import diptest from .waveformtools import equal_template import hdbscan debug_plot = False #~ debug_plot = True def _get_sparse_waveforms_flatten(cc, dense_mode, label, channel_adjacency, n_spike_for_centroid=None): peak_index, = np.nonzero(cc.all_peaks['cluster_label'] == label) if n_spike_for_centroid is not None and peak_index.size>n_spike_for_centroid: keep = np.random.choice(peak_index.size, n_spike_for_centroid, replace=False) peak_index = peak_index[keep] if dense_mode: waveforms = cc.get_some_waveforms(peak_index, channel_indexes=None) extremum_channel = 0 centroid = np.median(waveforms, axis=0) else: waveforms = cc.get_some_waveforms(peak_index, channel_indexes=None) centroid = np.median(waveforms, axis=0) peak_sign = cc.info['peak_detector_params']['peak_sign'] n_left = cc.info['waveform_extractor_params']['n_left'] if peak_sign == '-': extremum_channel = np.argmin(centroid[-n_left,:], axis=0) elif peak_sign == '+': extremum_channel = np.argmax(centroid[-n_left,:], axis=0) # TODO by sparsity level threhold and not radius adjacency = channel_adjacency[extremum_channel] waveforms = waveforms.take(adjacency, axis=2) wf_flat = waveforms.swapaxes(1,2).reshape(waveforms.shape[0], -1) return waveforms, wf_flat, peak_index def _compute_one_dip_test(cc, dirname, chan_grp, label, n_components_local_pca, adjacency_radius_um): # compute dip test to try to over split from .dataio import DataIO from .catalogueconstructor import CatalogueConstructor if cc is None: dataio = DataIO(dirname) cc = CatalogueConstructor(dataio=dataio, chan_grp=chan_grp) peak_sign = cc.info['peak_detector_params']['peak_sign'] dense_mode = cc.info['mode'] == 'dense' n_left = cc.info['waveform_extractor_params']['n_left'] n_right = cc.info['waveform_extractor_params']['n_right'] peak_width = n_right - n_left nb_channel = cc.nb_channel if dense_mode: channel_adjacency = {c: np.arange(nb_channel) for c in range(nb_channel)} else: channel_adjacency = {} for c in range(nb_channel): nearest, = np.nonzero(cc.channel_distances[c, :] < adjacency_radius_um) channel_adjacency[c] = nearest waveforms, wf_flat, peak_index = _get_sparse_waveforms_flatten(cc, dense_mode, label, channel_adjacency, n_spike_for_centroid=cc.n_spike_for_centroid) #~ pca = sklearn.decomposition.IncrementalPCA(n_components=n_components_local_pca, whiten=True) n_components = min(wf_flat.shape[1]-1, n_components_local_pca) pca = sklearn.decomposition.TruncatedSVD(n_components=n_components) feats = pca.fit_transform(wf_flat) pval = diptest(np.sort(feats[:, 0]), numt=200) return pval def auto_split(catalogueconstructor, n_spike_for_centroid=None, adjacency_radius_um = 30, n_components_local_pca=3, pval_thresh=0.1, min_cluster_size=20, maximum_shift=2, n_jobs=-1, #~ n_jobs=1, joblib_backend='loky', ): cc = catalogueconstructor peak_sign = cc.info['peak_detector_params']['peak_sign'] dense_mode = cc.info['mode'] == 'dense' n_left = cc.info['waveform_extractor_params']['n_left'] n_right = cc.info['waveform_extractor_params']['n_right'] peak_width = n_right - n_left nb_channel = cc.nb_channel if dense_mode: channel_adjacency = {c: np.arange(nb_channel) for c in range(nb_channel)} else: channel_adjacency = {} for c in range(nb_channel): nearest, = np.nonzero(cc.channel_distances[c, :] < adjacency_radius_um) channel_adjacency[c] = nearest if len(cc.positive_cluster_labels) ==0: return m = np.max(cc.positive_cluster_labels) + 1 # pvals = [] # for label in cc.positive_cluster_labels: # pval = _compute_one_dip_test(cc.dataio.dirname, cc.chan_grp, label, n_components_local_pca, adjacency_radius_um) # print('label', label,'pval', pval, pval<pval_thresh) # pvals.append(pval) if cc.memory_mode == 'ram': # prevent paralell because not persistent n_jobs = 1 cc2 = cc else: cc2 = None pvals = Parallel(n_jobs=n_jobs, backend=joblib_backend)( delayed(_compute_one_dip_test)(cc2, cc.dataio.dirname, cc.chan_grp, label, n_components_local_pca, adjacency_radius_um) for label in cc.positive_cluster_labels) pvals = np.array(pvals) inds, = np.nonzero(pvals<pval_thresh) splitable_labels = cc.positive_cluster_labels[inds] #~ print('splitable_labels', splitable_labels) for label in splitable_labels: waveforms, wf_flat, peak_index = _get_sparse_waveforms_flatten(cc, dense_mode, label, channel_adjacency, n_spike_for_centroid=None) #~ pca = sklearn.decomposition.IncrementalPCA(n_components=n_components_local_pca, whiten=True) n_components = min(wf_flat.shape[1]-1, n_components_local_pca) pca = sklearn.decomposition.TruncatedSVD(n_components=n_components) feats = pca.fit_transform(wf_flat) clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, allow_single_cluster=False, metric='l2') sub_labels = clusterer.fit_predict(feats[:, :2]) unique_sub_labels = np.unique(sub_labels) #~ print(unique_sub_labels) if unique_sub_labels.size == 1 and unique_sub_labels[0] == -1: sub_labels[:] = 0 unique_sub_labels = np.unique(sub_labels) if not dense_mode: peak_is_aligned = check_peak_all_aligned(sub_labels, waveforms, peak_sign, n_left, maximum_shift) if debug_plot: fig, ax= plt.subplots() ax.plot(np.median(wf_flat, axis=0)) ax.set_title('label '+str(label)) for i in range(waveforms.shape[2]): ax.axvline(i*peak_width-n_left, color='k') ax.set_title('pval ' + str(pval)) plt.show() if unique_sub_labels.size >1: for i, sub_label in enumerate(unique_sub_labels): sub_mask = sub_labels == sub_label if dense_mode: valid=True else: valid = peak_is_aligned[i] #~ print('sub_label', 'valid', valid) if sub_label == -1 or not valid: #~ cluster_labels[ind_keep[sub_mask]] = -1 cc.all_peaks['cluster_label'][peak_index[sub_mask]] = -1 else: #~ cluster_labels[ind_keep[sub_mask]] = sub_label + m new_label = label + m #~ print(label, m, new_label) cc.all_peaks['cluster_label'][peak_index[sub_mask]] = new_label cc.add_one_cluster(new_label) m += 1 cc.pop_labels_from_cluster([label]) #~ m += np.max(unique_sub_labels) + 1 #~ if True: #~ if False: if debug_plot: print('label', label,'pval', pval, pval<pval_thresh) fig, axs = plt.subplots(ncols=4) colors = plt.cm.get_cmap('Set3', len(unique_sub_labels)) colors = {unique_sub_labels[l]:colors(l) for l in range(len(unique_sub_labels))} colors[-1] = 'k' for sub_label in unique_sub_labels: #~ valid = sub_label in possible_labels[peak_is_aligned] if dense_mode: valid=True else: valid = peak_is_aligned[i] sub_mask = sub_labels == sub_label if valid: ls = '-' color = colors[sub_label] else: ls = '--' color = 'k' ax = axs[0] ax.plot(wf_flat[sub_mask].T, color=color, alpha=0.1) ax = axs[3] if sub_label>=0: ax.plot(np.median(wf_flat[sub_mask], axis=0), color=color, lw=2, ls=ls) for sub_label in unique_sub_labels: if dense_mode: valid=True else: valid = peak_is_aligned[i] sub_mask = sub_labels == sub_label color = colors[sub_label] if valid: color = colors[sub_label] else: color = 'k' ax = axs[1] ax.plot(feats[sub_mask].T, color=color, alpha=0.1) ax = axs[2] ax.scatter(feats[sub_mask][:, 0], feats[sub_mask][:, 1], color=color) plt.show() def check_peak_all_aligned(local_labels, waveforms, peak_sign, n_left, maximum_shift): peak_is_aligned = [] for k in np.unique(local_labels): wfs = waveforms[local_labels == k] centroid = np.median(wfs, axis=0) if peak_sign == '-': chan_peak_local = np.argmin(np.min(centroid, axis=0)) pos_peak = np.argmin(centroid[:, chan_peak_local]) elif peak_sign == '+': chan_peak_local = np.argmax(np.max(centroid, axis=0)) pos_peak = np.argmax(centroid[:, chan_peak_local]) al = np.abs(-n_left - pos_peak) <= maximum_shift peak_is_aligned.append(al) return np.array(peak_is_aligned) def trash_not_aligned(cc, maximum_shift=2): n_left = cc.info['waveform_extractor_params']['n_left'] peak_sign = cc.info['peak_detector_params']['peak_sign'] to_remove = [] for k in list(cc.positive_cluster_labels): #~ print(k) centroid = cc.get_one_centroid(k) if peak_sign == '-': chan_peak = np.argmin(np.min(centroid, axis=0)) extremum_index = np.argmin(centroid[:, chan_peak]) peak_val = centroid[-n_left, chan_peak] elif peak_sign == '+': chan_peak = np.argmax(np.max(centroid, axis=0)) extremum_index = np.argmax(centroid[:, chan_peak]) peak_val = centroid[-n_left, chan_peak] if np.abs(-n_left - extremum_index)>maximum_shift: if debug_plot: n_left = cc.info['waveform_extractor_params']['n_left'] n_right = cc.info['waveform_extractor_params']['n_right'] peak_width = n_right - n_left print('remove not aligned peak', 'k', k) fig, ax = plt.subplots() #~ centroid = centroids[k] ax.plot(centroid.T.flatten()) ax.set_title('not aligned peak') for i in range(centroid.shape[1]): ax.axvline(i*peak_width-n_left, color='k') plt.show() mask = cc.all_peaks['cluster_label'] == k cc.all_peaks['cluster_label'][mask] = -1 to_remove.append(k) cc.pop_labels_from_cluster(to_remove) def auto_merge(catalogueconstructor, auto_merge_threshold=2.3, maximum_shift=2, amplitude_factor_thresh = 0.2, ): cc = catalogueconstructor peak_sign = cc.info['peak_detector_params']['peak_sign'] #~ dense_mode = cc.info['mode'] == 'dense' n_left = cc.info['waveform_extractor_params']['n_left'] n_right = cc.info['waveform_extractor_params']['n_right'] peak_width = n_right - n_left threshold = cc.info['peak_detector_params']['relative_threshold'] while True: labels = cc.positive_cluster_labels.copy() nb_merge = 0 n = labels.size #~ pop_from_centroids = [] new_centroids = [] pop_from_cluster = [] for i in range(n): k1 = labels[i] if k1 == -1: # this can have been removed yet continue for j in range(i+1, n): k2 = labels[j] if k2 == -1: # this can have been removed yet continue #~ print(k1, k2) #~ print(' k2', k2) ind1 = cc.index_of_label(k1) extremum_amplitude1 = np.abs(cc.clusters[ind1]['extremum_amplitude']) centroid1 = cc.get_one_centroid(k1) ind2 = cc.index_of_label(k2) extremum_amplitude2 = np.abs(cc.clusters[ind2]['extremum_amplitude']) centroid2 = cc.get_one_centroid(k2) thresh = max(extremum_amplitude1, extremum_amplitude2) * amplitude_factor_thresh thresh = max(thresh, auto_merge_threshold) #~ print('thresh', thresh) #~ t1 = time.perf_counter() do_merge = equal_template(centroid1, centroid2, thresh=thresh, n_shift=maximum_shift) #~ t2 = time.perf_counter() #~ print('equal_template', t2-t1) #~ print('do_merge', do_merge) #~ if debug_plot: #~ print(k1, k2) #~ if k1==4 and k2==5: #~ print(k1, k2, do_merge, thresh) #~ fig, ax = plt.subplots() #~ ax.plot(centroid1.T.flatten()) #~ ax.plot(centroid2.T.flatten()) #~ ax.set_title('merge ' + str(do_merge)) #~ plt.show() if do_merge: #~ print('merge', k1, k2) #~ cluster_labels2[cluster_labels2==k2] = k1 mask = cc.all_peaks['cluster_label'] == k2 cc.all_peaks['cluster_label'][mask] = k1 #~ t1 = time.perf_counter() #~ cc.compute_one_centroid(k1) #~ t2 = time.perf_counter() #~ print('cc.compute_one_centroid', t2-t1) new_centroids.append(k1) pop_from_cluster.append(k2) labels[j] = -1 nb_merge += 1 if debug_plot: fig, ax = plt.subplots() ax.plot(centroid1.T.flatten()) ax.plot(centroid2.T.flatten()) ax.set_title('merge '+str(k1)+' '+str(k2)) plt.show() #~ for k in np.unique(pop_from_cluster): #~ cc.pop_labels_from_cluster([k]) pop_from_cluster = np.unique(pop_from_cluster) cc.pop_labels_from_cluster(pop_from_cluster) new_centroids = np.unique(new_centroids) new_centroids = [k for k in new_centroids if k not in pop_from_cluster] cc.compute_several_centroids(new_centroids) #~ cc.compute_one_centroid(k) #~ for k in np.unique(pop_from_centroids): #~ if k in centroids: #~ centroids.pop(k) #~ print('nb_merge', nb_merge) if nb_merge == 0: break def trash_low_extremum(cc, min_extremum_amplitude=None): if min_extremum_amplitude is None: threshold = cc.info['peak_detector_params']['relative_threshold'] min_extremum_amplitude = threshold + 0.5 to_remove = [] for k in list(cc.positive_cluster_labels): #~ print(k) ind = cc.index_of_label(k) assert k == cc.clusters[ind]['cluster_label'], 'this is a bug in trash_low_extremum' extremum_amplitude = np.abs(cc.clusters[ind]['extremum_amplitude']) #~ print('k', k , extremum_amplitude) if extremum_amplitude < min_extremum_amplitude: if debug_plot: print('k', k , extremum_amplitude, 'too small') mask = cc.all_peaks['cluster_label']==k cc.all_peaks['cluster_label'][mask] = -1 to_remove.append(k) cc.pop_labels_from_cluster(to_remove) def trash_small_cluster(cc, minimum_size=10): to_remove = [] for k in list(cc.positive_cluster_labels): mask = cc.all_peaks['cluster_label']==k cluster_size = np.sum(mask) #~ print(k, cluster_size) if cluster_size <= minimum_size : cc.all_peaks['cluster_label'][mask] = -1 to_remove.append(k) cc.pop_labels_from_cluster(to_remove)