from .myqt import QT import pyqtgraph as pg import numpy as np import matplotlib.cm import matplotlib.colors import time from .base import WidgetBase from .tools import ParamDialog from ..tools import median_mad from .. import labelcodes class MyViewBox(pg.ViewBox): doubleclicked = QT.pyqtSignal() gain_zoom = QT.pyqtSignal(float) def mouseDoubleClickEvent(self, ev): self.doubleclicked.emit() ev.accept() def wheelEvent(self, ev, axis=None): if ev.modifiers() == QT.Qt.ControlModifier: z = 10 if ev.delta()>0 else 1/10. else: z = 1.3 if ev.delta()>0 else 1/1.3 self.gain_zoom.emit(z) ev.accept() def raiseContextMenu(self, ev): #for some reasons enableMenu=False is not taken (bug ????) pass class WaveformHistViewer(WidgetBase): """ **Waveform histogram viewer** is also a important thing. It is equivalent to **Waveform veiwer** in **flatten** mode but with a 2d histogram that show the density (probability) of a cluster. So waveforms are flatten from (nb_peak, nb_sample, nb_channel) to (nb_peak, nb_channel*nb_sample) and are binarized on a 2d histogram. Then this is plotted as a map. The color code the density. This is the best friend to see if two cluster are well discrimitated somewhere or if one cluster must be split. Important: * use right click for X/Y zoom * use left clik to move * use **mouse wheel** for color zoom.Really important to play with this to discover low density * intentionnaly not all cluster are displayed other we see nothing. The best is to plot 2 by 2. Furthermore it faster to plot with few cluster. * don't forget to display the **noise snippet** to validate that the mad is 1 for all channel. Settings: * **colormap** hot is good because loaw density are black like background. * **data** choose waveforms or features * **bin_min** y limts of histogram * **bin_max** y limts of histogram * **bin_size** * **display_threshold** * **max_label** maximum number of labels displayed simulteneously (2 by default but you can set more) """ _params = [ {'name': 'colormap', 'type': 'list', 'values' : ['hot', 'viridis', 'jet', 'gray', ] }, {'name': 'data', 'type': 'list', 'values' : ['waveforms', 'features', ] }, {'name': 'bin_min', 'type': 'float', 'value' : -20. }, {'name': 'bin_max', 'type': 'float', 'value' : 8. }, {'name': 'bin_size', 'type': 'float', 'value' : .1 }, {'name': 'display_threshold', 'type': 'bool', 'value' : True }, {'name': 'max_label', 'type': 'int', 'value' : 2 }, {'name': 'n_spike_for_centroid', 'type': 'int', 'value' : 300 }, {'name': 'sparse_display', 'type': 'bool', 'value' : True }, ] def __init__(self, controller=None, parent=None): WidgetBase.__init__(self, parent=parent, controller=controller) self.layout = QT.QVBoxLayout() self.setLayout(self.layout) h = QT.QHBoxLayout() self.layout.addLayout(h) but = QT.QPushButton('Show 1D dist', checkable=True) h.addWidget(but) but.clicked.connect(self.show_hide_1d_dist) self.graphicsview = pg.GraphicsView() self.layout.addWidget(self.graphicsview) self.graphicsview2 = pg.GraphicsView() self.layout.addWidget(self.graphicsview2) self.graphicsview2.hide() self.create_settings() self.initialize_plot() self.similarity = None self.on_params_changed()#this do refresh def on_params_changed(self, ): #params, changes #~ for param, change, data in changes: #~ if change != 'value': continue #~ if param.name()=='data': N = 512 cmap_name = self.params['colormap'] cmap = matplotlib.cm.get_cmap(cmap_name , N) lut = [] for i in range(N): r,g,b,_ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) lut.append([r*255,g*255,b*255]) self.lut = np.array(lut, dtype='uint8') self._x_range = None self._y_range = None self.refresh() def initialize_plot(self): #~ if self.controller.some_peaks_index is None: #~ return self.viewBox = MyViewBox() self.viewBox.doubleclicked.connect(self.open_settings) self.viewBox.gain_zoom.connect(self.gain_zoom) self.viewBox.disableAutoRange() self.plot = pg.PlotItem(viewBox=self.viewBox) self.graphicsview.setCentralItem(self.plot) self.plot.hideButtons() self.image = pg.ImageItem() self.plot.addItem(self.image) #~ self.curve1 = pg.PlotCurveItem() #~ self.plot.addItem(self.curve1) #~ self.curve2 = pg.PlotCurveItem() #~ self.plot.addItem(self.curve2) self.curves = [] thresh = self.controller.get_threshold() self.thresh_line = pg.InfiniteLine(pos=thresh, angle=0, movable=False, pen = pg.mkPen('w')) self.plot.addItem(self.thresh_line) self.params.blockSignals(True) #~ self.params['bin_min'] = np.min(self.controller.some_waveforms) #~ self.params['bin_max'] = np.max(self.controller.some_waveforms) # this is too slow and take too much mem #~ ind, = np.nonzero(self.controller.spike_label[self.controller.some_peaks_index]>=0) #~ if ind.size > 0: #~ wfs = self.controller.some_waveforms.take(ind, axis=0) #~ self.params['bin_min'] = np.percentile(wfs, .001) #~ self.params['bin_max'] = np.percentile(wfs, 99.999) #~ ind, = np.nonzero(self.controller.spike_label[self.controller.some_peaks_index]>=0) #~ n_left, n_right = self.controller.get_waveform_left_right() #~ if ind.size > 0: #~ if ind.size > 1000: #~ ind = ind[np.random.choice(ind.size, 1000, replace=False)] #~ mins, maxs = [], [] #~ for c in range(self.controller.nb_channel): #~ wfs = self.controller.some_waveforms[:, :, c].take(ind, axis=0) #~ mins.append(np.percentile(wfs, .001)) #~ maxs.append(np.percentile(wfs, 99.999)) #~ self.params['bin_min'] = min(mins) #~ self.params['bin_max'] = max(maxs) #~ if (self.params['bin_max'] - self.params['bin_min']) < 60: #~ self.params['bin_size'] = 0.1 #~ else: #~ self.params['bin_size'] = (self.params['bin_max'] - self.params['bin_min']) / 600 n_left, n_right = self.controller.get_waveform_left_right() #~ get_min_max_centroids #~ peak_sign = self.controller.get_peak_sign() #~ if peak_sign == '+': #~ m = np.max(self.controller.spikes['extremum_amplitude']) #~ self.params['bin_min'] = min(-m / 10, -5.) #~ self.params['bin_max'] = m * 1.2 #~ elif peak_sign == '-': #~ m = np.min(self.controller.spikes['extremum_amplitude']) #~ self.params['bin_min'] = m * 1.2 #~ self.params['bin_max'] = max(-m / 10, 5.) self.wf_min, self.wf_max = self.controller.get_min_max_centroids() self.params['bin_min'] = min(self.wf_min * 2, -5.) self.params['bin_max'] = max(self.wf_max * 2, 5) if (self.params['bin_max'] - self.params['bin_min']) < 60: self.params['bin_size'] = 0.1 else: self.params['bin_size'] = (self.params['bin_max'] - self.params['bin_min']) / 600 self.params.blockSignals(False) def gain_zoom(self, v): #~ print('v', v) levels = self.image.getLevels() if levels is not None: self.image.setLevels(levels * v, update=True) def refresh(self): if not hasattr(self, 'viewBox'): self.initialize_plot() if not hasattr(self, 'viewBox'): return #~ if self._x_range is not None: #~ #this may change with pyqtgraph #~ self._x_range = tuple(self.viewBox.state['viewRange'][0]) #~ self._y_range = tuple(self.viewBox.state['viewRange'][1]) cluster_visible = self.controller.cluster_visible visibles = [k for k, v in cluster_visible.items() if v and k>=-1 ] sparse = self.controller.have_sparse_template and self.params['sparse_display'] # get common visible channels if sparse: if len(visibles) > 0: common_channels = self.controller.get_common_sparse_channels(visibles) else: common_channels = np.array([], dtype='int64') else: common_channels = self.controller.channels #remove old curves for curve in self.curves: self.plot.removeItem(curve) self.curves = [] if len(visibles)>self.params['max_label'] or len(visibles)==0: self.image.hide() return if len(common_channels) ==0: self.image.hide() return keep = np.zeros(self.controller.spikes.size, dtype='bool') for label in visibles: ind_keep, = np.nonzero(self.controller.spikes['cluster_label'] == label) if ind_keep.size > self.params['n_spike_for_centroid']: sub_sel = np.random.choice(ind_keep.size, self.params['n_spike_for_centroid'], replace=False) ind_keep = ind_keep[sub_sel] keep[ind_keep] = True ind_keep, = np.nonzero(keep) if self.params['data']=='waveforms': if ind_keep.size == 0: self.plot.clear() return seg_nums = self.controller.spikes['segment'][ind_keep] peak_sample_indexes = self.controller.spikes['index'][ind_keep] data_kept = self.controller.get_some_waveforms(seg_nums, peak_sample_indexes, channel_indexes=common_channels) if data_kept.size == 0: self.plot.clear() return data_kept = data_kept.swapaxes(1,2).reshape(data_kept.shape[0], -1) elif self.params['data']=='features': data = self.controller.some_features if data is None: self.plot.clear() return labels = self.controller.spike_label[self.controller.some_peaks_index] keep = np.in1d(labels, visibles) ind_keep, = np.nonzero(keep) nb_feature_by_channel = data.shape[1] // self.controller.nb_channel mask_feat = np.zeros(data.shape[1], dtype='bool') for i in range(nb_feature_by_channel): mask_feat[common_channels*nb_feature_by_channel+i] = True data_kept = data[ind_keep, :][:, mask_feat] if data_kept.size == 0: self.plot.clear() return #TODO change for PCA if self.params['data']=='waveforms': bin_min, bin_max = self.params['bin_min'], self.params['bin_max'] bin_size = max(self.params['bin_size'], 0.01) bins = np.arange(bin_min, bin_max, self.params['bin_size']) elif self.params['data']=='features': bin_min, bin_max = np.min(data_kept), np.max(data_kept) #~ n = 500 bins = np.linspace(bin_min, bin_max, 500) bin_size = bins[1] - bins[0] #~ med, mad = median_mad(data_kept, axis=0) #~ min, max = np.min(med-10*mad), np.max(med+10*mad) #~ n = self.params['nb_bin'] #~ bin = (max-min)/(n-1) n = bins.size hist2d = np.zeros((data_kept.shape[1], bins.size)) indexes0 = np.arange(data_kept.shape[1]) data_bined = np.floor((data_kept-bin_min)/bin_size).astype('int32') data_bined = data_bined.clip(0, bins.size-1) for d in data_bined: hist2d[indexes0, d] += 1 # for catalogue window only if self.controller.cluster_visible.get(labelcodes.LABEL_NOISE, False) and self.controller.some_noise_snippet is not None: #~ print('labelcodes.LABEL_NOISE in cluster_visible', labelcodes.LABEL_NOISE in cluster_visible, cluster_visible) if self.params['data']=='waveforms': noise = self.controller.some_noise_snippet[:, :, common_channels] noise = noise.swapaxes(1,2).reshape(noise.shape[0], -1) noise_bined = np.floor((noise-bin_min)/bin_size).astype('int32') noise_bined = noise_bined.clip(0, bins.size-1) for d in noise_bined: hist2d[indexes0, d] += 1 self.image.setImage(hist2d, lut=self.lut)#, levels=[0, self._max]) self.image.setRect(QT.QRectF(-0.5, bin_min, data_kept.shape[1], bin_max-bin_min)) self.image.show() #~ for k, curve in zip(visibles, [self.curve1, self.curve2]): for k in visibles: median, chans = self.controller.get_waveform_centroid(k, 'median', channels=common_channels) if median is None: continue if self.params['data']=='waveforms': y = median.T.flatten() else: sel = labels[ind_keep] == k y = np.median(data_kept[sel, :], axis=0) color = self.controller.qcolors.get(k, QT.QColor( 'white')) curve = pg.PlotCurveItem(x=indexes0, y=y, pen=pg.mkPen(color, width=2)) self.plot.addItem(curve) self.curves.append(curve) #~ curve.setData() #~ curve.setPen() #~ curve.show() if self.params['display_threshold'] and self.params['data']=='waveforms' : self.thresh_line.show() else: self.thresh_line.hide() #~ if self._x_range is None: if True: self._x_range = 0, indexes0[-1] #hist2d.shape[1] self._y_range = bin_min, bin_max self.plot.setXRange(*self._x_range, padding = 0.0) self.plot.setYRange(*self._y_range, padding = 0.0) def on_spike_selection_changed(self): pass def on_spike_label_changed(self): self.refresh() def on_colors_changed(self): self.refresh() def on_cluster_visibility_changed(self): self.refresh() def on_cluster_tag_changed(self): pass def show_hide_1d_dist(self, v=None): #~ print(v) if v: self.graphicsview2.show() else: self.graphicsview2.hide()