# Copyright 2017 Fabian Isensee # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pyqtgraph.Qt import QtGui, QtCore import pyqtgraph as pg import sys import numpy as np class ImageViewer2DWidget(QtGui.QWidget): def __init__(self, width=300, height=300): QtGui.QWidget.__init__(self) self.image = None self.lut=None self.init(width, height) def init(self, width, height): self.rect = QtCore.QRect(0, 0, width, height) self.imageItem = pg.ImageItem() self.imageItem.setImage(None) self.graphicsScene = pg.GraphicsScene() self.graphicsScene.addItem(self.imageItem) self.graphicsView = pg.GraphicsView() self.graphicsView.setRenderHint(QtGui.QPainter.Antialiasing) self.graphicsView.setScene(self.graphicsScene) layout = QtGui.QHBoxLayout() layout.addWidget(self.graphicsView) self.setLayout(layout) self.setMaximumSize(width, height) self.setMinimumSize(width-10, height-10) def setImage(self, image): assert len(image.shape) == 2 or len(image.shape) == 3 if len(image.shape) == 4: assert image.shape[-1] == 4 self.image = image self.imageItem.setImage(self.image) self.imageItem.setRect(self.rect) if self.lut is not None: self.imageItem.setLookupTable(self.lut) def setLevels(self, levels, update=True): self.imageItem.setLevels(levels, update) def setLUT(self, lut): self.lut = lut class ImageSlicingWidget(ImageViewer2DWidget): def __init__(self, width=300, height=300): self.slice = 0 ImageViewer2DWidget.__init__(self, width, height) def setImage(self, image): assert len(image.shape) == 3 or len(image.shape) == 4 if len(image.shape) == 4: assert image.shape[-1] == 4 self.image3D = np.array(image) self._updateImageSlice() def _updateImageSlice(self): self.imageItem.setImage(self.image3D[self.slice]) self.imageItem.setRect(self.rect) if self.lut is not None: self.imageItem.setLookupTable(self.lut) def setSlice(self, slice): slice = np.max((0, slice)) slice = np.min((slice, self.image3D.shape[0] - 1)) self.slice = slice self._updateImageSlice() def getSlice(self): return self.slice class BatchViewer(QtGui.QWidget): def __init__(self, parent=None, width=300, height=300): QtGui.QWidget.__init__(self, parent) self.batch = None self.width = width self.height = height self.slicingWidgets = {} self._init_gui() def setBatch(self, batch, lut={}): assert len(batch.shape) == 4 batch = np.copy(batch) for v in self.slicingWidgets.values(): self._my_layout.removeWidget(v) self.slicingWidgets = {} if not isinstance(lut, dict): lut = {i: lut for i in range(self.batch.shape[0])} for b in range(batch.shape[0]): mn = batch[b].min() mx = batch[b].max() batch[b, :, 0, 0] = mn batch[b, :, 0, 1] = mx self.batch = batch num_col = int(np.ceil(np.sqrt(batch.shape[0]))) col = 0 row = 0 for i in range(self.batch.shape[0]): w = ImageSlicingWidget(self.width, self.height) if lut is not None and i in lut.keys(): w.setLUT(lut[i]) w.setImage(self.batch[i]) w.setLevels([self.batch[i].min(), self.batch[i].max()]) w.setSlice(0) self._my_layout.addWidget(w, row, col) col += 1 if col >= num_col: col = 0 row += 1 self.slicingWidgets[i] = w def _init_gui(self): self._my_layout = QtGui.QGridLayout() self.slicingWidgets = {} self.setLayout(self._my_layout) self.setWindowTitle("Batch Viewer") def wheelEvent(self, QWheelEvent): for v in self.slicingWidgets.values(): offset=np.sign(QWheelEvent.angleDelta().y()) v.setSlice(v.getSlice() + np.sign(offset)) def view_batch(*args, width=300, height=300, lut={}): use_these = args if not isinstance(use_these, (np.ndarray, np.memmap)): use_these = list(use_these) for i in range(len(use_these)): item = use_these[i] try: import torch if isinstance(item, torch.Tensor): item = item.detach().cpu().numpy() except ImportError: pass while len(item.shape) < 4: item = item[None] use_these[i] = item use_these = np.concatenate(use_these, 0) else: while len(use_these.shape) < 4: use_these = use_these[None] global app app = QtGui.QApplication.instance() if app is None: app = QtGui.QApplication(sys.argv) sv = BatchViewer(width=width, height=height) sv.setBatch(use_these, lut) sv.show() app.exit(app.exec_()) if __name__ == '__main__': import matplotlib.pyplot as plt global app app = QtGui.QApplication.instance() if app is None: app = QtGui.QApplication(sys.argv) sv = BatchViewer() batch = np.random.uniform(0, 3, (6, 100, 100, 100)).astype(int) lut = {2: np.array([[0, 0.5, 0, 1], [0, 0, 0.5, 1], [0.5, 0, 0, 1], [0.5, 0.5, 0, 1]])*255, 1: np.array([[1, 0.5, 0, 1], [0, 0, 0.5, 1], [0.5, 0, 0, 1], [0.5, 0.5, 0, 1]]) * 255} sv.setBatch(batch, lut) sv.show() app.exec_() app.deleteLater() sys.exit() # IPython.embed() '''class SliceViewer(QtGui.QWidget): def __init__(self): super(SliceViewer, self).__init__() self.initUI() def wheelEvent(self, event): for v in [self.imageViewer, self.imageViewer2]: v.setSlice(v.getSlice() + event.delta()/120) def initUI(self): image1 = np.random.uniform(-0.5, 255., (100, 100, 100)).astype(np.float32) image2 = np.random.uniform(-100., 255., (100, 100, 100)).astype(np.float32) self.imageViewer = ImageSlicingWidget() self.imageViewer2 = ImageSlicingWidget() self.imageViewer.setImage(image1) self.imageViewer2.setImage(image2) hLayout = QtGui.QHBoxLayout() hLayout.addWidget(self.imageViewer) hLayout.addWidget(self.imageViewer2) hLayout.addStretch() self.setLayout(hLayout) #self.setGeometry(0, 0, 1200, 1200) self.setWindowTitle('QtGui.QCheckBox') self.show()'''