"""A GUI for interacting with a trained descriptor network. """ import argparse import os import sys import glob import pickle import torch import torch.nn as nn import numpy as np import matplotlib.cm as cm from form2fit import config from form2fit.code.ml.models import * from form2fit.code.ml.dataloader import get_corr_loader from form2fit.code.utils import ml, misc from PyQt5.QtGui import * from PyQt5.QtWidgets import * from PyQt5.QtCore import * class Debugger(QDialog): """A PyQt5 GUI for debugging a descriptor network. """ USE_CUDA = True WINDOW_WIDTH = 1500 WINDOW_HEIGHT = 1000 WINDOW_TITLE = "Debug Descriptor Network" def __init__(self, args): super().__init__() self._foldername = args.foldername self._dtype = args.dtype self._num_desc = args.num_desc self._background_subtract = args.background_subtract self._augment = args.augment self._num_channels = args.num_channels self._init_loader_and_network() self._reset() self._init_UI() self.show() def _init_loader_and_network(self): """Initializes the data loader and network. """ self._dev = torch.device("cuda" if Debugger.USE_CUDA and torch.cuda.is_available() else "cpu") self._data = get_corr_loader( self._foldername, batch_size=1, sample_ratio=1, dtype=self._dtype, shuffle=False, num_workers=0, augment=self._augment, num_rotations=20, background_subtract=config.BACKGROUND_SUBTRACT[self._foldername], num_channels=self._num_channels, ) self._net = CorrespondenceNet(self._num_channels, self._num_desc, 20).to(self._dev) self._net.eval() stats = self._data.dataset.stats self._color_mean = stats[0][0] self._color_std = stats[0][1] self._resolve_data_dims() def _resolve_data_dims(self): """Reads the image dimensions from the data loader. """ x, _, _ = next(iter(self._data)) self._h, self._w = x.shape[2:] self._c = 3 self._zeros = np.zeros((self._h, self._w, self._c), dtype="uint8") self.xs = None self.xt = None def _reset(self): """Resets the GUI. """ def _he_init(m): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(m.weight, mode="fan_in") self._is_switch = False self._pair_idx = 0 self._dloader = iter(self._data) self._get_network_names() self._net.apply(_he_init) def _get_network_names(self): """Reads all saved model weights. """ self.weights_dir = os.path.join(config.weights_dir, "matching") filenames = glob.glob(os.path.join(self.weights_dir, "*.tar")) self._model_names = [os.path.basename(x).split(".")[0] for x in filenames] def _load_selected_network(self, name): """Loads a trained network. """ if name: self._model_name = name state_dict = torch.load(os.path.join(self.weights_dir, name + ".tar"), map_location=self._dev) self._net.load_state_dict(state_dict['model_state']) self._set_prediction_text("{} was loaded...".format(name)) def _init_UI(self): """Initializes the UI. """ self.setWindowTitle(Debugger.WINDOW_TITLE) # self.setFixedSize(Debugger.WINDOW_WIDTH, Debugger.WINDOW_HEIGHT) self._create_menu() self._create_main() self._create_progress() self._all_layout = QVBoxLayout(self) self._all_layout.addLayout(self._menu_layout) self._all_layout.addLayout(self._main_layout) self._all_layout.addLayout(self._progress_layout) def _create_menu(self): """Creates the top horizontal menu bar. """ # buttons next_button = QPushButton("Next Pair", self) next_button.clicked.connect(self._next_click) reset_button = QPushButton("Reset", self) reset_button.clicked.connect(self._reset_click) sample_button = QPushButton("Sample", self) sample_button.clicked.connect(self._sample_click) colorize_button = QPushButton("Rotation Error", self) colorize_button.clicked.connect(self._colorize_click) self._switch_button = QPushButton("View RGB", self) self._switch_button.clicked.connect(self._switch_click) # boxes self._is_correct_box = QLabel(self) self._networks_box = QComboBox(self) self._networks_box.addItems([""] + self._model_names) self._networks_box.activated[str].connect(self._load_selected_network) self._networks_box_label = QLabel("Network Name", self) self._networks_box_label.setBuddy(self._networks_box) # add to layout self._menu_layout = QHBoxLayout() self._menu_layout.addWidget(self._networks_box_label) self._menu_layout.addWidget(self._networks_box) self._menu_layout.addWidget(next_button) self._menu_layout.addWidget(sample_button) self._menu_layout.addWidget(colorize_button) self._menu_layout.addWidget(self._is_correct_box) self._menu_layout.addStretch(1) self._menu_layout.addWidget(self._switch_button) self._menu_layout.addWidget(reset_button) def _create_main(self): """Creates the main layout. """ vbox_left = QVBoxLayout() grid_right = QGridLayout() self._target_widget = QLabel(self) self._source_widget = QLabel(self) self._grid_widgets = [QLabel(self) for _ in range(20)] self._draw_target(init=True) self._draw_source(init=True) vbox_left.addWidget(self._target_widget) vbox_left.addWidget(self._source_widget) self._target_widget.mousePressEvent = self._get_mouse_pos self._draw_rotations(init=True) for col in range(5): for row in range(4): grid_right.addWidget(self._grid_widgets[col * 4 + row], col, row) self._main_layout = QHBoxLayout() self._main_layout.addLayout(vbox_left) self._main_layout.addLayout(grid_right) def _create_progress(self): """A progress bar for the data loader. """ self._progress_bar = QProgressBar(self) self._progress_bar.setRange(0, len(self._dloader)) self._progress_bar.setValue(0) self._progress_layout = QHBoxLayout() self._progress_layout.addWidget(self._progress_bar) self._advance_progress_bar() def _draw_target(self, uv=None, init=False): img_target = self._zeros.copy() if init else self._xt_np.copy() if uv is not None: img_target[uv[0] - 1 : uv[0] + 1, uv[1] - 1 : uv[1] + 1] = [255, 0, 0] self._target_img = QImage( img_target.data, self._w, self._h, self._c * self._w, QImage.Format_RGB888 ) self._target_pixmap = QPixmap.fromImage(self._target_img) self._target_widget.setPixmap(self._target_pixmap) self._target_widget.setScaledContents(True) def _draw_source(self, uvs=None, init=False): if uvs is None: img_source = self._zeros.copy() if init else self._xs_np.copy() else: img_source = self._xt_np.copy() colors = [[0, 255, 0], [0, 0, 255], [255, 0, 0]] color_names = ["green", "blue", "red"] for i in range(3): mask = np.where(uvs[:, 2] == i)[0] idxs = uvs[mask] img_source[idxs[:, 0], idxs[:, 1]] = colors[i] self._source_img = QImage( img_source.data, self._w, self._h, self._c * self._w, QImage.Format_RGB888 ) self._source_pixmap = QPixmap.fromImage(self._source_img) self._source_widget.setPixmap(self._source_pixmap) self._source_widget.setScaledContents(True) def _draw_rotations(self, init=False, heatmap=True): def _hist_eq(img): from skimage import exposure img_cdf, bin_centers = exposure.cumulative_distribution(img) return np.interp(img, bin_centers, img_cdf) for col in range(5): for row in range(4): offset = col * 4 + row if init: img = self._zeros.copy() else: if heatmap: img = self.heatmaps[offset].copy() img = img / img.max() img = _hist_eq(img) img = np.uint8(cm.viridis(img) * 255)[..., :3] img = img.copy() else: img = misc.rotate_img(self._xs_np, -(360 / 20) * offset, center=(self.center[1], self.center[0])) img = img.copy() if offset == self._uv[-1]: img[ self._uv[0] - 1 : self._uv[0] + 1, self._uv[1] - 1 : self._uv[1] + 1, ] = [255, 0, 0] self._add_border_clr(img, [255, 0, 0]) if offset == self.best_rot_idx: self._add_border_clr(img, [0, 255, 0]) self._img = QImage( img.data, self._w, self._h, self._c * self._w, QImage.Format_RGB888 ) pixmap = QPixmap.fromImage(self._img) self._grid_widgets[offset].setPixmap(pixmap) self._grid_widgets[offset].setScaledContents(True) def _switch_click(self): if not self._is_switch: self._switch_button.setText("Heatmap View") self._is_switch = True self._draw_rotations(heatmap=False) else: self._switch_button.setText("RGB View") self._is_switch = False self._draw_rotations(heatmap=True) def _next_click(self): if self._pair_idx == len(self._dloader): self.close() else: self._get_next_data() self._draw_target() self._draw_source() self._draw_rotations(init=True) self._advance_progress_bar() def _reset_click(self): self._reset() self._networks_box.setCurrentIndex(0) self._draw_target(init=True) self._draw_source(init=True) self._draw_rotations(init=True) self._advance_progress_bar() def _colorize_click(self): filename = os.path.join( config.rot_stats_dir, self._model_name, self._dtype, str(self._pair_idx - 1), "rot_color.npy", ) pixel_colors = np.load(filename) self._draw_source(pixel_colors) def _set_prediction_text(self, msg): self._is_correct_box.setText(msg) def _sample_click(self): if self._pair_idx > 0: self._forward_network() rand_idx = np.random.choice(np.arange(len(self.target_pixel_idxs))) u_rand, v_rand = self.target_pixel_idxs[rand_idx] self._draw_target([u_rand, v_rand]) u_s, v_s = self.source_pixel_idxs[rand_idx] target_vector = self.out_t[:, :, u_rand, v_rand] outs_flat = self.outs.view(self.outs.shape[0], self.outs.shape[1], -1) target_vector_flat = target_vector.unsqueeze_(2).repeat( (outs_flat.shape[0], 1, outs_flat.shape[2]) ) diff = outs_flat - target_vector_flat dist = diff.pow(2).sum(1).sqrt() self.heatmaps = dist.view(dist.shape[0], self._h, self._w).cpu().numpy() predicted_best_idx = dist.min(dim=1)[0].argmin() is_correct = predicted_best_idx == self.best_rot_idx msg = "Correct!" if is_correct else "Wrong!" self._set_prediction_text(msg) min_val = self.heatmaps[predicted_best_idx].argmin() u_min, v_min = misc.make2d(min_val, self._w) self._uv = [u_min, v_min, predicted_best_idx] self._draw_rotations(heatmap=not self._is_switch) else: print("[!] You must first click next to load a data sample.") def _get_mouse_pos(self, event): v = event.pos().x() u = event.pos().y() u = int(u * (self._h / self._target_widget.height())) v = int(v * (self._w / self._target_widget.width())) uv = [u, v] if self.xs is not None and self.xt is not None: self._forward_network() row_idx = np.where((self.target_pixel_idxs == uv).all(axis=1))[0] if row_idx.size != 0: row_idx = row_idx[0] self._draw_target(uv) u_s, v_s = self.source_pixel_idxs[row_idx] target_vector = self.out_t[:, :, uv[0], uv[1]] outs_flat = self.outs.view(self.outs.shape[0], self.outs.shape[1], -1) target_vector_flat = target_vector.unsqueeze_(2).repeat( (outs_flat.shape[0], 1, outs_flat.shape[2]) ) diff = outs_flat - target_vector_flat dist = diff.pow(2).sum(1).sqrt() self.heatmaps = dist.view(dist.shape[0], self._h, self._w).cpu().numpy() predicted_best_idx = dist.min(dim=1)[0].argmin() is_correct = predicted_best_idx == self.best_rot_idx msg = "Correct!" if is_correct else "Wrong!" self._set_prediction_text(msg) min_val = self.heatmaps[predicted_best_idx].argmin() u_min, v_min = misc.make2d(min_val, self._w) self._uv = [u_min, v_min, predicted_best_idx] self._draw_rotations(heatmap=not self._is_switch) def _get_next_data(self): """Grabs a fresh pair of source and target data points. """ self._pair_idx += 1 self.imgs, labels, center = next(self._dloader) self.center = center[0] label = labels[0] self.xs, self.xt = self.imgs[:, :self._num_channels, :, :], self.imgs[:, self._num_channels:, :, :] if self._num_channels == 4: self._xs_np = ml.tensor2ndarray(self.xs[:, :3], [self._color_mean * 3, self._color_std * 3]) self._xt_np = ml.tensor2ndarray(self.xt[:, :3], [self._color_mean * 3, self._color_std * 3]) else: self._xs_np = ml.tensor2ndarray(self.xs[:, :1], [self._color_mean, self._color_std], False) self._xt_np = ml.tensor2ndarray(self.xt[:, :1], [self._color_mean, self._color_std], False) self._xs_np = np.uint8(cm.viridis(self._xs_np) * 255)[..., :3] self._xt_np = np.uint8(cm.viridis(self._xt_np) * 255)[..., :3] source_idxs = label[:, 0:2] target_idxs = label[:, 2:4] rot_idx = label[:, 4] is_match = label[:, 5] self.best_rot_idx = rot_idx[0].item() mask = (is_match == 1) & (rot_idx == self.best_rot_idx) self.source_pixel_idxs = source_idxs[mask].numpy() self.target_pixel_idxs = target_idxs[mask].numpy() def _forward_network(self): """Forwards the current source-target pair through the network. """ self.imgs = self.imgs.to(self._dev) with torch.no_grad(): self.outs, self.out_t = self._net(self.imgs, *self.center) self.outs = self.outs[0] def _advance_progress_bar(self): """Advances the progress bar. """ curr_val = self._pair_idx max_val = self._progress_bar.maximum() self._progress_bar.setValue(curr_val + (max_val - curr_val) / 100) def _add_border_clr(self, img, color): """Adds a border color to an image. """ img[0 : self._h - 1, 0:10] = color # left img[0:10, 0 : self._w - 1] = color # top img[self._h - 11 : self._h - 1, 0 : self._w - 1] = color img[0 : self._h - 1, self._w - 11 : self._w - 1] = color return img if __name__ == "__main__": def str2bool(s): return s.lower() in ["1", "true"] parser = argparse.ArgumentParser(description="Descriptor Network Visualizer") parser.add_argument("foldername", type=str) parser.add_argument("--dtype", type=str, default="valid") parser.add_argument("--num_desc", type=int, default=64) parser.add_argument("--num_channels", type=int, default=2) parser.add_argument("--background_subtract", type=tuple, default=None) parser.add_argument("--augment", type=str2bool, default=False) args, unparsed = parser.parse_known_args() app = QApplication(sys.argv) window = Debugger(args) window.show() sys.exit(app.exec_())