from .myqt import QT import pyqtgraph as pg import numpy as np import pandas as pd from .base import WidgetBase class MyViewBox(pg.ViewBox): doubleclicked = QT.pyqtSignal() gain_zoom = QT.pyqtSignal(float) def __init__(self, *args, **kwds): pg.ViewBox.__init__(self, *args, **kwds) #~ self.disableAutoRange() def mouseClickEvent(self, ev): ev.accept() def mouseDoubleClickEvent(self, ev): self.doubleclicked.emit() ev.accept() #~ def mouseDragEvent(self, ev): #~ ev.ignore() def wheelEvent(self, ev, axis=None): if ev.modifiers() == QT.Qt.ControlModifier: z = 10 if ev.delta()>0 else 1/10. else: z = 1.3 if ev.delta()>0 else 1/1.3 self.gain_zoom.emit(z) ev.accept() class WaveformViewerBase(WidgetBase): #base for both WaveformViewer (Catalogue) and PeelerWaveformViewer def __init__(self, controller=None, parent=None): WidgetBase.__init__(self, parent=parent, controller=controller) self.layout = QT.QVBoxLayout() self.setLayout(self.layout) #~ self.create_settings() self.create_toolbar() self.layout.addWidget(self.toolbar) self.graphicsview = pg.GraphicsView() self.layout.addWidget(self.graphicsview) self.initialize_plot() self.alpha = 60 self.refresh() def create_toolbar(self): tb = self.toolbar = QT.QToolBar() #Mode flatten or geometry self.combo_mode = QT.QComboBox() tb.addWidget(self.combo_mode) #~ self.mode = 'flatten' #~ self.combo_mode.addItems([ 'flatten', 'geometry']) self.mode = 'geometry' self.combo_mode.addItems([ 'geometry', 'flatten']) self.combo_mode.currentIndexChanged.connect(self.on_combo_mode_changed) tb.addSeparator() but = QT.QPushButton('settings') but.clicked.connect(self.open_settings) tb.addWidget(but) but = QT.QPushButton('scale') but.clicked.connect(self.zoom_range) tb.addWidget(but) but = QT.QPushButton('refresh') but.clicked.connect(self.refresh) tb.addWidget(but) def on_combo_mode_changed(self): self.mode = str(self.combo_mode.currentText()) self.initialize_plot() self.refresh() def on_params_changed(self, params, changes): for param, change, data in changes: if change != 'value': continue if param.name()=='flip_bottom_up': self.initialize_plot() self.refresh() def initialize_plot(self): #~ print('WaveformViewer.initialize_plot', self.controller.some_waveforms) if self.controller.get_waveform_left_right()[0] is None: return self.viewBox1 = MyViewBox() self.viewBox1.disableAutoRange() grid = pg.GraphicsLayout(border=(100,100,100)) self.graphicsview.setCentralItem(grid) self.plot1 = grid.addPlot(row=0, col=0, rowspan=2, viewBox=self.viewBox1) self.plot1.hideButtons() self.plot1.showAxis('left', True) self.curve_one_waveform = pg.PlotCurveItem([], [], pen=pg.mkPen(QT.QColor( 'white'), width=1), connect='finite') self.plot1.addItem(self.curve_one_waveform) if self.mode=='flatten': grid.nextRow() grid.nextRow() self.viewBox2 = MyViewBox() self.viewBox2.disableAutoRange() self.plot2 = grid.addPlot(row=2, col=0, rowspan=1, viewBox=self.viewBox2) self.plot2.hideButtons() self.plot2.showAxis('left', True) self.viewBox2.setXLink(self.viewBox1) self.factor_y = 1. self._common_channels_flat = None elif self.mode=='geometry': self.plot2 = None chan_grp = self.controller.chan_grp channel_group = self.controller.dataio.channel_groups[chan_grp] #~ print(channel_group['geometry']) if channel_group['geometry'] is None: print('no geometry') self.xvect = None else: n_left, n_right = self.controller.get_waveform_left_right() width = n_right - n_left nb_channel = self.controller.nb_channel #~ self.xvect = np.zeros(shape[0]*shape[1], dtype='float32') #~ self.xvect = np.zeros((shape[1], shape[0]), dtype='float32') self.xvect = np.zeros((nb_channel, width), dtype='float32') self.arr_geometry = [] for i, chan in enumerate(self.controller.channel_indexes): x, y = channel_group['geometry'][chan] self.arr_geometry.append([x, y]) self.arr_geometry = np.array(self.arr_geometry, dtype='float64') if self.params['flip_bottom_up']: self.arr_geometry[:, 1] *= -1. xpos = self.arr_geometry[:,0] ypos = self.arr_geometry[:,1] if np.unique(xpos).size>1: self.delta_x = np.min(np.diff(np.sort(np.unique(xpos)))) else: self.delta_x = np.unique(xpos)[0] if np.unique(ypos).size>1: self.delta_y = np.min(np.diff(np.sort(np.unique(ypos)))) else: self.delta_y = max(np.unique(ypos)[0], 1) self.factor_y = .3 if self.delta_x>0.: #~ espx = self.delta_x/2. *.95 espx = self.delta_x/2.5 else: espx = .5 for i, chan in enumerate(channel_group['channels']): x, y = channel_group['geometry'][chan] self.xvect[i, :] = np.linspace(x-espx, x+espx, num=width) self.wf_min, self.wf_max = self.controller.get_min_max_centroids() self._x_range = None self._y1_range = None self._y2_range = None self.viewBox1.gain_zoom.connect(self.gain_zoom) self.viewBox1.doubleclicked.connect(self.open_settings) #~ self.viewBox.xsize_zoom.connect(self.xsize_zoom) def gain_zoom(self, factor_ratio): self.factor_y *= factor_ratio self.refresh(keep_range=True) def zoom_range(self): self._x_range = None self._y1_range = None self._y2_range = None self.refresh(keep_range=False) def refresh(self, keep_range=False): if not hasattr(self, 'viewBox1'): self.initialize_plot() if not hasattr(self, 'viewBox1'): return n_selected = np.sum(self.controller.spike_selection) if self.params['show_only_selected_cluster'] and n_selected==1: cluster_visible = {k:False for k in self.controller.cluster_visible} ind, = np.nonzero(self.controller.spike_selection) ind = ind[0] k = self.controller.spikes[ind]['cluster_label'] cluster_visible[k] = True else: cluster_visible = self.controller.cluster_visible if self.mode=='flatten': self.refresh_mode_flatten(cluster_visible, keep_range) elif self.mode=='geometry': self.refresh_mode_geometry(cluster_visible, keep_range) self._refresh_one_spike(n_selected) def refresh_mode_flatten(self, cluster_visible, keep_range): if self._x_range is not None and keep_range: #this may change with pyqtgraph self._x_range = tuple(self.viewBox1.state['viewRange'][0]) self._y1_range = tuple(self.viewBox1.state['viewRange'][1]) self._y2_range = tuple(self.viewBox2.state['viewRange'][1]) self.plot1.clear() self.plot2.clear() self.plot1.addItem(self.curve_one_waveform) if self.controller.spike_index ==[]: return nb_channel = self.controller.nb_channel #~ d = self.controller.info['waveform_extractor_params'] #~ n_left, n_right = d['n_left'], d['n_right'] n_left, n_right = self.controller.get_waveform_left_right() width = n_right - n_left sparse = self.controller.have_sparse_template and self.params['sparse_display'] visibles = [k for k, v in cluster_visible.items() if v and k>=-1 ] if sparse: if len(visibles) > 0: common_channels = self.controller.get_common_sparse_channels(visibles) else: #~ common_channels = np.array([], dtype='int64') return else: common_channels = self.controller.channels self._common_channels_flat = common_channels #lines def addSpan(plot): white = pg.mkColor(255, 255, 255, 20) #~ for i in range(nb_channel): for i, c in enumerate(common_channels): if i%2==1: region = pg.LinearRegionItem([width*i, width*(i+1)-1], movable = False, brush = white) plot.addItem(region, ignoreBounds=True) for l in region.lines: l.setPen(white) vline = pg.InfiniteLine(pos = -n_left + width*i, angle=90, movable=False, pen = pg.mkPen('w')) plot.addItem(vline) if self.params['plot_limit_for_flatten']: addSpan(self.plot1) addSpan(self.plot2) if self.params['display_threshold']: thresh = self.controller.get_threshold() thresh_line = pg.InfiniteLine(pos=thresh, angle=0, movable=False, pen = pg.mkPen('w')) self.plot1.addItem(thresh_line) #waveforms if self.params['metrics']=='median/mad': key1, key2 = 'median', 'mad' elif self.params['metrics']=='mean/std': key1, key2 = 'mean', 'std' #~ shape = self.controller.get_waveforms_shape() #~ if shape is None: #~ return n_left, n_right = self.controller.get_waveform_left_right() if n_left is None: return width = n_right - n_left shape = (width, len(common_channels)) xvect = np.arange(shape[0]*shape[1]) #~ for i,k in enumerate(self.controller.centroids): for k in cluster_visible: #~ if not self.controller.cluster_visible[k]: if not cluster_visible[k]: continue #~ wf0 = self.controller.centroids[k][key1].T.flatten() #~ mad = self.controller.centroids[k][key2].T.flatten() wf0, chans = self.controller.get_waveform_centroid(k, key1, channels=common_channels) if wf0 is None: continue wf0 = wf0.T.flatten() mad, chans = self.controller.get_waveform_centroid(k, key2, channels=common_channels) color = self.controller.qcolors.get(k, QT.QColor( 'white')) curve = pg.PlotCurveItem(xvect, wf0, pen=pg.mkPen(color, width=2)) self.plot1.addItem(curve) if self.params['fillbetween'] and mad is not None: mad = mad.T.flatten() color2 = QT.QColor(color) color2.setAlpha(self.alpha) curve1 = pg.PlotCurveItem(xvect, wf0+mad, pen=color2) curve2 = pg.PlotCurveItem(xvect, wf0-mad, pen=color2) self.plot1.addItem(curve1) self.plot1.addItem(curve2) fill = pg.FillBetweenItem(curve1=curve1, curve2=curve2, brush=color2) self.plot1.addItem(fill) if mad is not None: curve = pg.PlotCurveItem(xvect, mad, pen=color) self.plot2.addItem(curve) if self.params['show_channel_num']: cn = self.controller.channel_indexes_and_names for i, c in enumerate(common_channels): # chan i sabsolut chan chan, name = cn[c] #~ for i, (chan, name) in enumerate(self.controller.channel_indexes_and_names): itemtxt = pg.TextItem('{}: {}'.format(i, name), anchor=(.5,.5), color='#FFFF00') itemtxt.setFont(QT.QFont('', pointSize=12)) self.plot1.addItem(itemtxt) itemtxt.setPos(width*i-n_left, 0) if self._x_range is None or not keep_range : if xvect.size>0: self._x_range = xvect[0], xvect[-1] self._y1_range = self.wf_min*1.1, self.wf_max*1.1 self._y2_range = 0., 5. if self._x_range is not None: self.plot1.setXRange(*self._x_range, padding = 0.0) self.plot1.setYRange(*self._y1_range, padding = 0.0) self.plot2.setYRange(*self._y2_range, padding = 0.0) def refresh_mode_geometry(self, cluster_visible, keep_range): if self._x_range is not None and keep_range: #this may change with pyqtgraph self._x_range = tuple(self.viewBox1.state['viewRange'][0]) self._y1_range = tuple(self.viewBox1.state['viewRange'][1]) self.plot1.clear() if self.xvect is None: return sparse = self.controller.have_sparse_template and self.params['sparse_display'] visibles = [k for k, v in cluster_visible.items() if v and k>=-1 ] #~ if sparse: #~ if len(visibles) > 0: #~ common_channels = self.controller.get_common_sparse_channels(visibles) #~ else: #~ common_channels = np.array([], dtype='int64') #~ return #~ else: #~ common_channels = self.controller.channels n_left, n_right = self.controller.get_waveform_left_right() if n_left is None: return width = n_right - n_left #~ shape = self.controller.get_waveforms_shape() #~ if shape is None: #~ return # if n_left/n_right have change need new xvect #~ if self.xvect.size != shape[0] * shape[1]: #~ self.initialize_plot() if width != self.xvect.shape[1]: self.initialize_plot() #~ shape = (shape[0], len(common_channels)) self.plot1.addItem(self.curve_one_waveform) if self.params['metrics']=='median/mad': key1, key2 = 'median', 'mad' elif self.params['metrics']=='mean/std': key1, key2 = 'mean', 'std' #~ ypos = self.arr_geometry[:,1] #~ ypos = self.arr_geometry[common_channels,1] #~ xvect = self.xvect.reshape(self.controller.nb_channel, -1)[common_channels, :].flatten() for k in cluster_visible: if not cluster_visible[k]: continue wf, chans = self.controller.get_waveform_centroid(k, key1, sparse=sparse) if wf is None: continue ypos = self.arr_geometry[chans,1] wf = wf*self.factor_y*self.delta_y + ypos[None, :] #wf[0,:] = np.nan connect = np.ones(wf.shape, dtype='bool') connect[0, :] = 0 connect[-1, :] = 0 xvect = self.xvect[chans, :] color = self.controller.qcolors.get(k, QT.QColor( 'white')) curve = pg.PlotCurveItem(xvect.flatten(), wf.T.flatten(), pen=pg.mkPen(color, width=2), connect=connect.T.flatten()) self.plot1.addItem(curve) if self.params['show_channel_num']: chan_grp = self.controller.chan_grp channel_group = self.controller.dataio.channel_groups[chan_grp] for i, (chan, name) in enumerate(self.controller.channel_indexes_and_names): x, y = self.arr_geometry[i, : ] itemtxt = pg.TextItem('{}: {}'.format(i, name), anchor=(.5,.5), color='#FFFF00') itemtxt.setFont(QT.QFont('', pointSize=12)) self.plot1.addItem(itemtxt) itemtxt.setPos(x, y) #~ if self._x_range is None: if self._x_range is None or not keep_range : self._x_range = np.min(self.xvect), np.max(self.xvect) self._y1_range = np.min(self.arr_geometry[:,1])-self.delta_y*2, np.max(self.arr_geometry[:,1])+self.delta_y*2 self.plot1.setXRange(*self._x_range, padding = 0.0) self.plot1.setYRange(*self._y1_range, padding = 0.0) def _refresh_one_spike(self, n_selected): #TODO peak the selected peak if only one if n_selected!=1 or not self.params['plot_selected_spike']: self.curve_one_waveform.setData([], []) return ind, = np.nonzero(self.controller.spike_selection) ind = ind[0] seg_num = self.controller.spike_segment[ind] peak_ind = self.controller.spike_index[ind] n_left, n_right = self.controller.get_waveform_left_right() wf = self.controller.dataio.get_signals_chunk(seg_num=seg_num, chan_grp=self.controller.chan_grp, i_start=peak_ind+n_left, i_stop=peak_ind+n_right, signal_type='processed') if wf.shape[0]==(n_right-n_left): #this avoid border bugs if self.mode=='flatten': if self._common_channels_flat is None: self.curve_one_waveform.setData([], []) return wf = wf[:, self._common_channels_flat].T.flatten() xvect = np.arange(wf.size) self.curve_one_waveform.setData(xvect, wf) elif self.mode=='geometry': ypos = self.arr_geometry[:,1] wf = wf*self.factor_y*self.delta_y + ypos[None, :] connect = np.ones(wf.shape, dtype='bool') connect[0, :] = 0 connect[-1, :] = 0 self.curve_one_waveform.setData(self.xvect.flatten(), wf.T.flatten(), connect=connect.T.flatten()) def on_spike_selection_changed(self): #~ n_selected = np.sum(self.controller.spike_selection) #~ self._refresh_one_spike(n_selected) self.refresh(keep_range=True) class WaveformViewer(WaveformViewerBase): """ **Waveform viewer** is undoubtedly the view to inspect waveforms. Note that in some aspect **Waveform hist viewer** can be a better firend. All centroid (median or mean) of visible cluster are plotted here. 2 main modes: * **geometry** waveforms are organized with 2d geometry given by PRB file. * **flatten** each chunk of each channel is put side by side in channel order than it can be ploted in 1d. The bottom view is th mad. On good cluster the mad must as close as possible from the value 1 because 1 is the normalized noise. The **geometry** mode is more intuitive and help users about spatial information. But the **flatten** mode is really important because is give information about the variance (mad or std) for each point and about peak alignement. The centoid is dfine by median+mad but you can also check with mean+std. For healthy cluster it should more or less the same. Important for zooming: * **geometry** : zoomXY geometry = right click, move = left click and mouse wheel = zoom waveforms * **flatten**: zoomXY = right click and move = left click Settings: * **plot_selected_spike**: superimposed one slected peak on centroid * **show_only_selected_cluster**: this auto hide all cluster except the one of selected spike * **plot_limit_for_flatten**: for flatten mode this plot line for delimiting channels. Plotting is important but it slow down the zoom. * **metrics**: choose median+mad or mean+std. * *show_channel_num**: what could it be ? * **flip_bottom_up**: in geometry this flip bottom up the channel geometry. * **display_threshold**: what could it be ? """ _params = [{'name': 'plot_selected_spike', 'type': 'bool', 'value': False }, {'name': 'show_only_selected_cluster', 'type': 'bool', 'value': False}, {'name': 'plot_limit_for_flatten', 'type': 'bool', 'value': True }, {'name': 'metrics', 'type': 'list', 'values': ['median/mad', 'mean/std'] }, {'name': 'fillbetween', 'type': 'bool', 'value': True }, {'name': 'show_channel_num', 'type': 'bool', 'value': False}, {'name': 'flip_bottom_up', 'type': 'bool', 'value': False}, {'name': 'display_threshold', 'type': 'bool', 'value' : True }, {'name': 'sparse_display', 'type': 'bool', 'value' : True }, ] class PeelerWaveformViewer(WaveformViewerBase): """ **Waveform viewer** """ _params = [{'name': 'plot_selected_spike', 'type': 'bool', 'value': True }, {'name': 'show_only_selected_cluster', 'type': 'bool', 'value': True}, {'name': 'plot_limit_for_flatten', 'type': 'bool', 'value': True }, {'name': 'metrics', 'type': 'list', 'values': ['median/mad'] }, {'name': 'fillbetween', 'type': 'bool', 'value': True }, {'name': 'show_channel_num', 'type': 'bool', 'value': False}, {'name': 'flip_bottom_up', 'type': 'bool', 'value': False}, {'name': 'display_threshold', 'type': 'bool', 'value' : True }, {'name': 'sparse_display', 'type': 'bool', 'value' : True }, ]