import matplotlib.cm import numpy as np import pyqtgraph as pg from PyQt5 import QtGui, QtCore from matplotlib.colors import hsv_to_rgb import suite2p.gui.merge from . import io def make_buttons(parent, b0): """ color buttons at row b0 """ # color buttons parent.color_names = [ "A: random", "S: skew", "D: compact", "F: footprint", "G: aspect_ratio", "H: chan2_prob", "J: classifier, cell prob=", "K: correlations, bin=", "L: corr with 1D var, bin=^^^", "M: rastermap / custom" ] parent.colorbtns = QtGui.QButtonGroup(parent) clabel = QtGui.QLabel(parent) clabel.setText("<font color='white'>Colors</font>") clabel.setFont(parent.boldfont) parent.l0.addWidget(clabel, b0, 0, 1, 1) iwid = 65 # add colormaps parent.CmapChooser = QtGui.QComboBox() cmaps = ['hsv', 'viridis', 'plasma', 'inferno', 'magma', 'cividis', 'viridis_r', 'plasma_r', 'inferno_r', 'magma_r', 'cividis_r'] parent.CmapChooser.addItems(cmaps) parent.CmapChooser.setCurrentIndex(0) parent.CmapChooser.activated.connect(lambda: cmap_change(parent)) parent.CmapChooser.setFont(QtGui.QFont("Arial", 8)) parent.CmapChooser.setFixedWidth(iwid) parent.l0.addWidget(parent.CmapChooser, b0, 1, 1, 1) nv = b0 b = 0 # colorbars for different statistics colorsAll = parent.color_names.copy() for names in colorsAll: btn = ColorButton(b, "&" + names, parent) parent.colorbtns.addButton(btn, b) if b>4 and b<8: parent.l0.addWidget(btn, nv + b + 1, 0, 1, 1) else: parent.l0.addWidget(btn, nv + b + 1, 0, 1, 2) btn.setEnabled(False) parent.color_names[b] = parent.color_names[b][3:] b += 1 parent.chan2edit = QtGui.QLineEdit(parent) parent.chan2edit.setText("0.6") parent.chan2edit.setFixedWidth(iwid) parent.chan2edit.setAlignment(QtCore.Qt.AlignRight) parent.chan2edit.returnPressed.connect(lambda: chan2_prob(parent)) parent.l0.addWidget(parent.chan2edit, nv + b - 4, 1, 1, 1) parent.probedit = QtGui.QLineEdit(parent) parent.probedit.setText("0.5") parent.probedit.setFixedWidth(iwid) parent.probedit.setAlignment(QtCore.Qt.AlignRight) parent.probedit.returnPressed.connect( lambda: suite2p.gui.merge.apply(parent) ) parent.l0.addWidget(parent.probedit, nv + b - 3, 1, 1, 1) parent.binedit = QtGui.QLineEdit(parent) parent.binedit.setValidator(QtGui.QIntValidator(0, 500)) parent.binedit.setText("1") parent.binedit.setFixedWidth(iwid) parent.binedit.setAlignment(QtCore.Qt.AlignRight) parent.binedit.returnPressed.connect( lambda: parent.mode_change(parent.activityMode) ) parent.l0.addWidget(parent.binedit, nv + b - 2, 1, 1, 1) b0 = nv+b+2 return b0 def cmap_change(parent): index = parent.CmapChooser.currentIndex() parent.ops_plot['colormap'] = parent.CmapChooser.itemText(index) if parent.loaded: print('colormap changed to %s, loading...'%parent.ops_plot['colormap']) istat = parent.colors['istat'] for c in range(1, istat.shape[0]): parent.colors['cols'][c] = istat_transform(istat[c], parent.ops_plot['colormap']) rgb_masks(parent, parent.colors['cols'][c], c) parent.colormat = draw_colorbar(parent.ops_plot['colormap']) parent.update_plot() def hsv2rgb(cols): cols = cols[:,np.newaxis] cols = np.concatenate((cols, np.ones_like(cols), np.ones_like(cols)), axis=-1) cols = (255 * hsv_to_rgb(cols)).astype(np.uint8) return cols def make_colors(parent): parent.colors['colorbar'] = [] ncells = len(parent.stat) parent.colors['cols'] = np.zeros((len(parent.color_names), ncells, 3), np.uint8) parent.colors['istat'] = np.zeros((len(parent.color_names), ncells), np.float32) np.random.seed(seed=0) allcols = np.random.random((ncells,)) if 'meanImg_chan2' in parent.ops: allcols = allcols / 1.4 allcols = allcols + 0.1 print(parent.redcell.sum()) parent.randcols = allcols allcols[parent.redcell] = 0 else: parent.randcols = allcols parent.colors['istat'][0] = parent.randcols parent.colors['cols'][0] = hsv2rgb(parent.randcols) b=0 for names in parent.color_names[:-3]: if b > 0: istat = np.zeros((ncells,1)) if b<len(parent.color_names)-2: if names in parent.stat[0]: for n in range(0,ncells): istat[n] = parent.stat[n][names] istat1 = np.percentile(istat,2) istat99 = np.percentile(istat,98) parent.colors['colorbar'].append([istat1, (istat99-istat1)/2 + istat1, istat99]) istat = istat - istat1 istat = istat / (istat99-istat1) istat = np.maximum(0, np.minimum(1, istat)) else: istat = np.expand_dims(parent.probcell, axis=1) parent.parent.colors['colorbar'].append([0.0, .5, 1.0]) col = istat_transform(istat, parent.ops_plot['colormap']) parent.colors['cols'][b] = col parent.colors['istat'][b] = istat.flatten() else: parent.colors['colorbar'].append([0,0.5,1]) b+=1 parent.colors['colorbar'].append([0,0.5,1]) parent.colors['colorbar'].append([0,0.5,1]) parent.colors['colorbar'].append([0,0.5,1]) #parent.ops_plot[4] = corrcols #parent.cc = cc def flip_plot(parent): parent.iflip = parent.ichosen for n in parent.imerge: iscell = int(parent.iscell[n]) parent.iscell[n] = ~parent.iscell[n] parent.ichosen = n flip_roi(parent) if 'imerge' in parent.stat[n]: for k in parent.stat[n]['imerge']: parent.iscell[k] = ~parent.iscell[k] parent.update_plot() io.save_iscell(parent) def chan2_prob(parent): chan2prob = float(parent.chan2edit.text()) if abs(parent.chan2prob - chan2prob) > 1e-3: parent.chan2prob = chan2prob parent.redcell = parent.probredcell > parent.chan2prob chan2_masks(parent) parent.update_plot() io.save_redcell(parent) def make_colorbar(parent, b0): colorbarW = pg.GraphicsLayoutWidget(parent) colorbarW.setMaximumHeight(60) colorbarW.setMaximumWidth(150) colorbarW.ci.layout.setRowStretchFactor(0, 2) colorbarW.ci.layout.setContentsMargins(0, 0, 0, 0) parent.l0.addWidget(colorbarW, b0, 0, 1, 2) parent.colorbar = pg.ImageItem() cbar = colorbarW.addViewBox(row=0, col=0, colspan=3) cbar.setMenuEnabled(False) cbar.addItem(parent.colorbar) parent.clabel = [ colorbarW.addLabel("0.0", color=[255, 255, 255], row=1, col=0), colorbarW.addLabel("0.5", color=[255, 255, 255], row=1, col=1), colorbarW.addLabel("1.0", color=[255, 255, 255], row=1, col=2), ] def init_masks(parent): """ creates RGB masks using stat and puts them in M0 or M1 depending on whether or not iscell is True for a given ROI args: ops: mean_image, Vcorr stat: xpix,ypix,xext,yext iscell: vector with True if ROI is cell ops_plot: plotROI, view, color, randcols outputs: M0: ROIs that are True in iscell M1: ROIs that are False in iscell """ stat = parent.stat iscell = parent.iscell cols = parent.colors['cols'] ncells = len(stat) Ly = parent.Ly Lx = parent.Lx parent.rois['Sroi'] = np.zeros((2,Ly,Lx), np.bool) LamAll = np.zeros((Ly,Lx), np.float32) # these have 3 layers parent.rois['Lam'] = np.zeros((2,3,Ly,Lx), np.float32) parent.rois['iROI'] = -1 * np.ones((2,3,Ly,Lx), np.int32) # ignore merged cells iignore = np.zeros(ncells, np.bool) for n in np.arange(ncells-1,-1,-1,int): ypix = stat[n]['ypix'] if ypix is not None and not iignore[n]: if 'imerge' in stat[n]: for k in stat[n]['imerge']: iignore[k] = True print(k) xpix = stat[n]['xpix'] lam = stat[n]['lam'] lam = lam / lam.sum() i = int(1-iscell[n]) # add cell on top parent.rois['iROI'][i,2,ypix,xpix] = parent.rois['iROI'][i,1,ypix,xpix] parent.rois['iROI'][i,1,ypix,xpix] = parent.rois['iROI'][i,0,ypix,xpix] parent.rois['iROI'][i,0,ypix,xpix] = n # add weighting to all layers parent.rois['Lam'][i,2,ypix,xpix] = parent.rois['Lam'][i,1,ypix,xpix] parent.rois['Lam'][i,1,ypix,xpix] = parent.rois['Lam'][i,0,ypix,xpix] parent.rois['Lam'][i,0,ypix,xpix] = lam parent.rois['Sroi'][i,ypix,xpix] = 1 LamAll[ypix,xpix] = lam parent.rois['LamMean'] = LamAll[LamAll>1e-10].mean() parent.rois['LamNorm'] = np.maximum(0, np.minimum(1, 0.75*parent.rois['Lam'][:,0]/parent.rois['LamMean'])) parent.colors['RGB'] = np.zeros((2,cols.shape[0],Ly,Lx,4), np.uint8) for c in range(0, cols.shape[0]): rgb_masks(parent, cols[c], c) def rgb_masks(parent, col, c): for i in range(2): #S = np.expand_dims(parent.rois['Sroi'][i],axis=2) H = col[parent.rois['iROI'][i,0], :] #H = np.expand_dims(H,axis=2) #hsv = np.concatenate((H,S,S),axis=2) #rgb = (hsv_to_rgb(hsv)*255).astype(np.uint8) parent.colors['RGB'][i,c,:,:,:3] = H def draw_masks(parent): #ops, stat, ops_plot, iscell, ichosen): ''' creates RGB masks using stat and puts them in M0 or M1 depending on whether or not iscell is True for a given ROI args: ops: mean_image, Vcorr stat: xpix,ypix iscell: vector with True if ROI is cell ops_plot: plotROI, view, color, randcols outputs: M0: ROIs that are True in iscell M1: ROIs that are False in iscell ''' ncells = parent.iscell.shape[0] plotROI = parent.ops_plot['ROIs_on'] view = parent.ops_plot['view'] color = parent.ops_plot['color'] opacity = parent.ops_plot['opacity'] wplot = int(1-parent.iscell[parent.ichosen]) # reset transparency for i in range(2): parent.colors['RGB'][i,color,:,:,3] = (opacity[view==0] * parent.rois['Sroi'][i] * parent.rois['LamNorm'][i]).astype(np.uint8) M = [np.array(parent.colors['RGB'][0,color]), np.array(parent.colors['RGB'][1,color])] if view==0: for n in parent.imerge: ypix = parent.stat[n]['ypix'].flatten() xpix = parent.stat[n]['xpix'].flatten() v = (parent.rois['iROI'][wplot][:,ypix,xpix]>-1).sum(axis=0) - 1 v = 1 - v/3 M[wplot] = make_chosen_ROI(M[wplot], ypix, xpix, v) else: for n in parent.imerge: ycirc = parent.stat[n]['ycirc'] xcirc = parent.stat[n]['xcirc'] ypix = parent.stat[n]['ypix'].flatten() xpix = parent.stat[n]['xpix'].flatten() M[wplot][ypix,xpix,3] = 0 col = parent.colors['cols'][color,n] sat = 1 M[wplot] = make_chosen_circle(M[wplot], ycirc, xcirc, col, sat) return M[0],M[1] def make_chosen_ROI(M0, ypix, xpix, v): M0[ypix,xpix,:] = np.tile((255*v[:,np.newaxis]).astype(np.uint8), (1,4)) return M0 def make_chosen_circle(M0, ycirc, xcirc, col, sat): ncirc = ycirc.size M0[ycirc,xcirc,:3] = col#[np.newaxis,:] M0[ycirc,xcirc,3] = 255 return M0 def chan2_masks(parent): c = 0 col = parent.randcols col[parent.redcell] = 0 col = col.flatten() parent.colors['cols'][c] = hsv2rgb(col) rgb_masks(parent, col, c) def custom_masks(parent): c = 9 n = np.array(parent.imerge) istat = parent.custom_mask istat1 = istat.min() istat99 = istat.max() cl = [istat1, (istat99-istat1)/2 + istat1, istat99] istat = istat - istat1 istat = istat / (istat99-istat1) istat = np.maximum(0, np.minimum(1, istat)) parent.colors['colorbar'][c] = cl istat = istat / istat.max() col = istat_transform(istat, parent.ops_plot['colormap']) parent.colors['cols'][c] = col parent.colors['istat'][c] = istat.flatten() rgb_masks(parent, col, c) def rastermap_masks(parent): c = 9 n = np.array(parent.imerge) istat = parent.isort # no 1D variable loaded -- leave blank parent.colors['colorbar'][c] = ([0, istat.max()/2, istat.max()]) istat = istat / istat.max() col = istat_transform(istat, parent.ops_plot['colormap']) col[parent.isort==-1] = 0 parent.colors['cols'][c] = col parent.colors['istat'][c] = istat.flatten() rgb_masks(parent, col, c) def beh_masks(parent): c = 8 n = np.array(parent.imerge) nb = int(np.floor(parent.beh_resampled.size/parent.bin)) sn = np.reshape(parent.beh_resampled[:nb*parent.bin], (nb,parent.bin)).mean(axis=1) sn -= sn.mean() snstd = (sn**2).mean()**0.5 cc = np.dot(parent.Fbin, sn.T) / parent.Fbin.shape[-1] / (parent.Fstd * snstd) cc[n] = cc.mean() istat = cc inactive=False istat_min = istat.min() istat_max = istat.max() istat = istat - istat.min() istat = istat / istat.max() col = istat_transform(istat, parent.ops_plot['colormap']) parent.colors['cols'][c] = col parent.colors['istat'][c] = istat.flatten() parent.colors['colorbar'][c] = [istat_min, (istat_max-istat_min)/2 + istat_min, istat_max] rgb_masks(parent, col, c) def corr_masks(parent): c = 7 n = np.array(parent.imerge) sn = parent.Fbin[n].mean(axis=-2).squeeze() snstd = (sn**2).mean()**0.5 cc = np.dot(parent.Fbin, sn.T) / parent.Fbin.shape[-1] / (parent.Fstd * snstd) cc[n] = cc.mean() istat = cc parent.colors['colorbar'][c] = [istat.min(), (istat.max()-istat.min())/2 + istat.min(), istat.max()] istat = istat - istat.min() istat = istat / istat.max() col = istat_transform(istat, parent.ops_plot['colormap']) parent.colors['cols'][c] = col parent.colors['istat'][c] = istat.flatten() rgb_masks(parent, col, c) def flip_for_class(parent, iscell): ncells = iscell.size if (iscell==parent.iscell).sum() < 100: for n in range(ncells): if iscell[n] != parent.iscell[n]: parent.iscell[n] = iscell[n] parent.ichosen = n flip_roi(parent) else: parent.iscell = iscell init_masks(parent) def plot_colorbar(parent): bid = parent.ops_plot['color'] if bid==0: parent.colorbar.setImage(np.zeros((20,100,3))) else: parent.colorbar.setImage(parent.colormat) for k in range(3): parent.clabel[k].setText('%1.2f'%parent.colors['colorbar'][bid][k]) def plot_masks(parent, M): #M = parent.RGB[:,:,np.newaxis], parent.Alpha[] parent.color1.setImage(M[0], levels=(0., 255.)) parent.color2.setImage(M[1], levels=(0., 255.)) parent.color1.show() parent.color2.show() def remove_roi(parent, n, i0): """ removes roi n from view i0 """ ypix = parent.stat[n]['ypix'] xpix = parent.stat[n]['xpix'] # cell indices ipix = np.array((parent.rois['iROI'][i0,0,:,:]==n).nonzero()).astype(np.int32) ipix1 = np.array((parent.rois['iROI'][i0,1,:,:]==n).nonzero()).astype(np.int32) ipix2 = np.array((parent.rois['iROI'][i0,2,:,:]==n).nonzero()).astype(np.int32) # get rid of cell and push up overlaps on main views parent.rois['Lam'][i0,0,ipix[0,:],ipix[1,:]] = parent.rois['Lam'][i0,1,ipix[0,:],ipix[1,:]] parent.rois['Lam'][i0,1,ipix[0,:],ipix[1,:]] = 0 parent.rois['Lam'][i0,1,ipix1[0,:],ipix1[1,:]] = parent.rois['Lam'][i0,2,ipix1[0,:],ipix1[1,:]] parent.rois['Lam'][i0,2,ipix1[0,:],ipix1[1,:]] = 0 parent.rois['Lam'][i0,2,ipix2[0,:],ipix2[1,:]] = 0 parent.rois['iROI'][i0,0,ipix[0,:],ipix[1,:]] = parent.rois['iROI'][i0,1,ipix[0,:],ipix[1,:]] parent.rois['iROI'][i0,1,ipix[0,:],ipix[1,:]] = -1 parent.rois['iROI'][i0,1,ipix1[0,:],ipix1[1,:]] = parent.rois['iROI'][i0,2,ipix1[0,:],ipix1[1,:]] parent.rois['iROI'][i0,2,ipix1[0,:],ipix1[1,:]] = -1 parent.rois['iROI'][i0,2,ipix2[0,:],ipix2[1,:]] = -1 # remove +/- 1 ROI exists parent.rois['Sroi'][i0,ypix,xpix] = parent.rois['iROI'][i0,0,ypix,xpix] > 0 parent.rois['LamNorm'][i0,ypix,xpix] = np.maximum(0, np.minimum(1, 0.75*parent.rois['Lam'][i0,0,ypix,xpix]/parent.rois['LamMean'])) def add_roi(parent, n, i): """ add roi n to view i """ ypix = parent.stat[n]['ypix'] xpix = parent.stat[n]['xpix'] lam = parent.stat[n]['lam'] parent.rois['iROI'][i,2,ypix,xpix] = parent.rois['iROI'][i,1,ypix,xpix] parent.rois['iROI'][i,1,ypix,xpix] = parent.rois['iROI'][i,0,ypix,xpix] parent.rois['iROI'][i,0,ypix,xpix] = n parent.rois['Lam'][i,2,ypix,xpix] = parent.rois['Lam'][i,1,ypix,xpix] parent.rois['Lam'][i,1,ypix,xpix] = parent.rois['Lam'][i,0,ypix,xpix] parent.rois['Lam'][i,0,ypix,xpix] = lam #/ lam.sum() # set whether or not an ROI + weighting of pixels parent.rois['Sroi'][i,ypix,xpix] = 1 parent.rois['LamNorm'][:,ypix,xpix] = np.maximum(0, np.minimum(1, 0.75*parent.rois['Lam'][:,0,ypix,xpix]/parent.rois['LamMean'])) def redraw_masks(parent, ypix, xpix): """ redraw masks after roi added/removed """ for c in range(parent.colors['cols'].shape[0]): for i in range(2): col = parent.colors['cols'][c] rgb = col[parent.rois['iROI'][i,0,ypix,xpix],:] parent.colors['RGB'][i,c,ypix,xpix,:3] = rgb def flip_roi(parent): """ flips roi to other plot there are 3 levels of overlap so this may be buggy if more than 3 cells are on top of each other """ cols = parent.ops_plot['color'] n = parent.ichosen i = int(1-parent.iscell[n]) i0 = 1-i # remove ROI remove_roi(parent, n, i0) # add cell to other side (on top) and push down overlaps add_roi(parent, n, i) # redraw colors ypix = parent.stat[n]['ypix'] xpix = parent.stat[n]['xpix'] redraw_masks(parent, ypix, xpix) def draw_colorbar(colormap='hsv'): H = np.linspace(0,1,101).astype(np.float32) rgb = istat_transform(H, colormap) colormat = np.expand_dims(rgb, axis=0) colormat = np.tile(colormat,(20,1,1)) return colormat def istat_hsv(istat): istat = istat / 1.4 istat = istat + (0.4/1.4) icols = 1 - istat icols = hsv2rgb(icols.flatten()) return icols def istat_transform(istat, colormap='hsv'): if colormap=='hsv': icols = istat_hsv(istat) else: try: cmap = matplotlib.cm.get_cmap(colormap) icols = istat icols = cmap(icols)[:,:3] icols *= 255 icols = icols.astype(np.uint8) except: print('bad colormap, using hsv') icols = istat_hsv(istat) return icols ### Changes colors of ROIs # button group is exclusive (at least one color is always chosen) class ColorButton(QtGui.QPushButton): def __init__(self, bid, Text, parent=None): super(ColorButton,self).__init__(parent) self.setText(Text) self.setCheckable(True) self.setStyleSheet(parent.styleInactive) self.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) self.resize(self.minimumSizeHint()) self.clicked.connect(lambda: self.press(parent, bid)) self.show() def press(self, parent, bid): for b in range(len(parent.color_names)): if parent.colorbtns.button(b).isEnabled(): parent.colorbtns.button(b).setStyleSheet(parent.styleUnpressed) self.setStyleSheet(parent.stylePressed) parent.ops_plot['color'] = bid if not parent.sizebtns.button(1).isChecked(): if bid==0: for b in [1,2]: parent.topbtns.button(b).setEnabled(False) parent.topbtns.button(b).setStyleSheet(parent.styleInactive) else: for b in [1,2]: parent.topbtns.button(b).setEnabled(True) parent.topbtns.button(b).setStyleSheet(parent.styleUnpressed) else: for b in range(3): parent.topbtns.button(b).setEnabled(False) parent.topbtns.button(b).setStyleSheet(parent.styleInactive) parent.update_plot() parent.show()