import numpy as np import matplotlib.cm import matplotlib.colors from ..gui import QT import pyqtgraph as pg from pyqtgraph.util.mutex import Mutex import pyacq from pyacq import WidgetNode,ThreadPollInput, StreamConverter, InputStream #~ _dtype_spike = [('index', 'int64'), ('label', 'int64'), ('jitter', 'float64'),] from ..peeler import _dtype_spike from ..tools import make_color_dict from ..labelcodes import LABEL_UNCLASSIFIED import time class MyViewBox(pg.ViewBox): doubleclicked = QT.pyqtSignal() gain_zoom = QT.pyqtSignal(float) def mouseDoubleClickEvent(self, ev): self.doubleclicked.emit() ev.accept() def wheelEvent(self, ev, axis=None): if ev.modifiers() == QT.Qt.ControlModifier: z = 10 if ev.delta()>0 else 1/10. else: z = 1.3 if ev.delta()>0 else 1/1.3 self.gain_zoom.emit(z) ev.accept() def raiseContextMenu(self, ev): #for some reasons enableMenu=False is not taken (bug ????) pass class OnlineWaveformHistViewer(WidgetNode): _input_specs = {'signals': dict(streamtype='signals'), 'spikes': dict(streamtype='events', shape = (-1, ), dtype=_dtype_spike), } _params = [ {'name': 'colormap', 'type': 'list', 'values' : ['hot', 'viridis', 'jet', 'gray', ] }, {'name': 'bin_min', 'type': 'float', 'value' : -20. }, {'name': 'bin_max', 'type': 'float', 'value' : 8. }, {'name': 'bin_size', 'type': 'float', 'value' : .1 }, {'name': 'refresh_interval', 'type': 'int', 'value': 100, 'limits':[5, 1000]}, ] def __init__(self, **kargs): WidgetNode.__init__(self, **kargs) self.layout = QT.QVBoxLayout() self.setLayout(self.layout) h = QT.QHBoxLayout() self.layout.addLayout(h) self.combobox = QT.QComboBox() h.addWidget(self.combobox) but = QT.QPushButton('clear') h.addWidget(but) but.clicked.connect(self.on_clear) self.label = QT.QLabel('') h.addWidget(self.label) self.graphicsview = pg.GraphicsView() self.layout.addWidget(self.graphicsview) self.params = pg.parametertree.Parameter.create( name='settings', type='group', children=self._params) self.tree_params = pg.parametertree.ParameterTree(parent=self) self.tree_params.header().hide() self.tree_params.setParameters(self.params, showTop=True) self.tree_params.setWindowTitle('Options for waveforms hist viewer') self.tree_params.setWindowFlags(QT.Qt.Window) self.params.sigTreeStateChanged.connect(self.on_params_changed) self.initialize_plot() self.mutex = Mutex() def _configure(self, peak_buffer_size=100000, catalogue=None, **kargs): self.peak_buffer_size = peak_buffer_size self.catalogue = catalogue def _initialize(self, **kargs): self.sample_rate = self.inputs['signals'].params['sample_rate'] self.wf_dtype = self.inputs['signals'].params['dtype'] self.inputs['spikes'].set_buffer(size=self.peak_buffer_size, double=False) buffer_sigs_size = int(self.sample_rate*3.) self.inputs['signals'].set_buffer(size=buffer_sigs_size, double=False) # poller self.poller_sigs = ThreadPollInput(input_stream=self.inputs['signals'], return_data=True) self.poller_spikes = ThreadPollInput(input_stream=self.inputs['spikes'], return_data=True) self.histogram_2d = {} self.last_waveform = {} self.change_catalogue(self.catalogue) self.timer = QT.QTimer(interval=100) self.timer.timeout.connect(self.refresh) def _start(self, **kargs): self.last_head_sigs = None self.last_head_spikes = None self.timer.start() self.inputs['signals'].empty_queue() self.inputs['spikes'].empty_queue() self.poller_sigs.start() self.poller_spikes.start() def _stop(self, **kargs): self.timer.stop() self.poller_sigs.stop() self.poller_sigs.wait() self.poller_spikes.stop() self.poller_spikes.wait() def _close(self, **kargs): pass def open_settings(self): if not self.tree_params.isVisible(): self.tree_params.show() else: self.tree_params.hide() def on_params_changed(self, params, changes): self.change_lut() self.change_catalogue(self.catalogue) self.timer.setInterval(self.params['refresh_interval']) def initialize_plot(self): self.viewBox = MyViewBox() self.viewBox.doubleclicked.connect(self.open_settings) self.viewBox.gain_zoom.connect(self.gain_zoom) self.viewBox.disableAutoRange() self.plot = pg.PlotItem(viewBox=self.viewBox) self.graphicsview.setCentralItem(self.plot) self.plot.hideButtons() self.image = pg.ImageItem() self.plot.addItem(self.image) self.curve_spike = pg.PlotCurveItem() self.plot.addItem(self.curve_spike) self.curve_limit = pg.PlotCurveItem() self.plot.addItem(self.curve_limit) self.change_lut() def change_lut(self): N = 512 cmap_name = self.params['colormap'] cmap = matplotlib.cm.get_cmap(cmap_name , N) lut = [] for i in range(N): r,g,b,_ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) lut.append([r*255,g*255,b*255]) self.lut = np.array(lut, dtype='uint8') def change_catalogue(self, catalogue): self.params.blockSignals(True) with self.mutex: self.catalogue = catalogue colors = make_color_dict(self.catalogue['clusters']) self.qcolors = {} for k, color in colors.items(): r, g, b = color self.qcolors[k] = QT.QColor(r*255, g*255, b*255) self.all_plotted_labels = self.catalogue['cluster_labels'].tolist() + [LABEL_UNCLASSIFIED] centers0 = self.catalogue['centers0'] if centers0.shape[0]>0: self.params['bin_min'] = np.min(centers0)*1.5 self.params['bin_max'] = np.max(centers0)*1.5 bin_min, bin_max = self.params['bin_min'], self.params['bin_max'] bin_size = self.params['bin_size'] self.bins = np.arange(bin_min, bin_max, self.params['bin_size']) self.combobox.clear() self.combobox.addItems([str(k) for k in self.all_plotted_labels]) self.on_clear() self._max = 10 _, peak_width, nb_chan = self.catalogue['centers0'].shape x, y = [], [] for c in range(1, nb_chan): x.extend([c*peak_width, c*peak_width, np.nan]) y.extend([-1000, 1000, np.nan]) self.curve_limit.setData(x=x, y=y, connect='finite') self.params.blockSignals(False) def on_clear(self): with self.mutex: shape = self.catalogue['centers0'].shape self.indexes0 = np.arange(shape[1]*shape[2], dtype='int64') self.histogram_2d = {} self.last_waveform = {} self.nb_spikes = {} for k in self.all_plotted_labels: self.histogram_2d[k] = np.zeros((shape[1]*shape[2], self.bins.size), dtype='int64') self.last_waveform[k] = np.zeros((shape[1]*shape[2],), dtype=self.wf_dtype) self.nb_spikes[k] = 0 self.plot.setXRange(0, self.indexes0[-1]+1) self.plot.setYRange(self.params['bin_min'], self.params['bin_max']) def auto_scale(self): pass def gain_zoom(self, v): self._max *= v self.image.setLevels([0, self._max], update=True) def refresh(self): #~ print('refresh') #~ t0 = time.perf_counter() head_sigs = self.poller_sigs.pos() head_spikes = self.poller_spikes.pos() if self.last_head_sigs is None: self.last_head_sigs = head_sigs if self.last_head_spikes is None: self.last_head_spikes = head_spikes if self.last_head_spikes is None or self.last_head_sigs is None: return # update image n_right, n_left = self.catalogue['n_right'],self.catalogue['n_left'] bin_min, bin_max, bin_size = self.params['bin_min'], self.params['bin_max'],self.params['bin_size'] # check peak_buffer_size here if (head_spikes-self.last_head_spikes)>=(0.9*self.peak_buffer_size): self.last_head_spikes = head_spikes - int(0.9*self.peak_buffer_size) new_spikes = self.inputs['spikes'].get_data(self.last_head_spikes, head_spikes) right_indexes = new_spikes['index'] + n_right if np.any(right_indexes > head_sigs): # the buffer of signals is available for some spikes yet # so remove then for this loop and get back on head_spikes first_out = np.nonzero(right_indexes)[0][0] head_spikes = head_spikes - (new_spikes.size - first_out) new_spikes = new_spikes[:first_out] for k in self.all_plotted_labels: mask = new_spikes['cluster_label'] == k indexes = new_spikes[mask]['index'] for ind in indexes: wf = self.inputs['signals'].get_data(ind+n_left, ind+n_right) wf = wf.T.reshape(-1) wf_bined = np.floor((wf-bin_min)/bin_size).astype('int32') wf_bined = wf_bined.clip(0, self.bins.size-1) with self.mutex: self.histogram_2d[k][self.indexes0, wf_bined] += 1 self.last_waveform[k] = wf self.nb_spikes[k] += 1 self.last_head_sigs = head_sigs self.last_head_spikes = head_spikes if self.combobox.currentIndex() == -1: return if self.visibleRegion().isEmpty(): # when several tabs not need to refresh return # refresh plot , update image k = self.all_plotted_labels[self.combobox.currentIndex()] hist2d = self.histogram_2d[k] self.image.setImage(hist2d, lut=self.lut, levels=[0, self._max]) self.image.setRect(QT.QRectF(-0.5, bin_min, hist2d.shape[0], bin_max-bin_min)) self.image.show() self.curve_spike.setData(x=self.indexes0, y=self.last_waveform[k], pen=pg.mkPen(self.qcolors[k], width=1.5)) txt = 'nbs_pike = {}'.format(self.nb_spikes[k]) self.label.setText(txt) #~ t1 = time.perf_counter() #~ print('refresh time', t1-t0)