#!/usr/bin/env python3 import matplotlib from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk import matplotlib.pyplot as plt from matplotlib.figure import Figure from mpl_toolkits.mplot3d import Axes3D import tkinter as tk from tkinter import ttk import tkinter.font import argparse import torch from functools import lru_cache import os import numpy as np from pose3d_utils.coords import ensure_homogeneous, ensure_cartesian from margipose.data.get_dataset import get_dataset from margipose.data.skeleton import absolute_to_root_relative, \ VNect_Common_Skeleton, apply_rigid_alignment, CanonicalSkeletonDesc from margipose.utils import plot_skeleton_on_axes3d, plot_skeleton_on_axes, seed_all, init_algorithms from margipose.models import load_model from margipose.eval import mpjpe, pck from margipose.data_specs import DataSpecs, ImageSpecs, JointsSpecs from margipose.cli import Subcommand CPU = torch.device('cpu') def parse_args(argv): """Parse command-line arguments.""" parser = argparse.ArgumentParser(prog='margipose-gui', description='3D human pose browser GUI', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--model', type=str, metavar='FILE', help='path to model file') parser.add_argument('--dataset', type=str, metavar='STR', default='mpi3d-test', help='dataset name') args = parser.parse_args(argv[1:]) return args @lru_cache(maxsize=32) def load_example(dataset, example_index): example = dataset[example_index] input = example['input'] input_image = dataset.input_to_pil_image(input) camera = example['camera_intrinsic'] transform_opts = example['transform_opts'] gt_skel = None if 'target' in example: gt_skel = dict(original=example['original_skel']) gt_skel_norm = ensure_homogeneous(example['target'], d=3) gt_skel_denorm = dataset.denormalise_with_skeleton_height(gt_skel_norm, camera, transform_opts) gt_skel['image_space'] = camera.project_cartesian(gt_skel_denorm) gt_skel['camera_space'] = dataset.untransform_skeleton(gt_skel_denorm, transform_opts) return dict( input=input, input_image=input_image, camera=camera, transform_opts=transform_opts, gt_skel=gt_skel, ) @lru_cache(maxsize=32) def load_and_process_example(dataset, example_index, device, model): example = load_example(dataset, example_index) if model is None: return example in_var = example['input'].unsqueeze(0).to(device, torch.float32) out_var = model(in_var) pred_skel_norm = ensure_homogeneous(out_var.squeeze(0).to(CPU, torch.float64), d=3) pred_skel_denorm = dataset.denormalise_with_skeleton_height( pred_skel_norm, example['camera'], example['transform_opts']) pred_skel_image_space = example['camera'].project_cartesian(pred_skel_denorm) pred_skel_camera_space = dataset.untransform_skeleton(pred_skel_denorm, example['transform_opts']) return dict( pred_skel=dict( normalised=pred_skel_norm, camera_space=pred_skel_camera_space, image_space=pred_skel_image_space, ), xy_heatmaps=[hm.squeeze(0).to(CPU, torch.float32) for hm in model.xy_heatmaps], zy_heatmaps=[hm.squeeze(0).to(CPU, torch.float32) for hm in model.zy_heatmaps], xz_heatmaps=[hm.squeeze(0).to(CPU, torch.float32) for hm in model.xz_heatmaps], **example ) def root_relative(skel): return absolute_to_root_relative( ensure_cartesian(skel, d=3), CanonicalSkeletonDesc.root_joint_id ) class MainGUIApp(tk.Tk): def __init__(self, dataset, device, model): super().__init__() self.dataset = dataset self.device = device self.model = model self.wm_title('3D pose estimation') self.geometry('1280x800') matplotlib.rcParams['savefig.format'] = 'svg' matplotlib.rcParams['savefig.directory'] = os.curdir # Variables self.var_cur_example = tk.StringVar() self.var_pred_visible = tk.IntVar(value=0) self.var_gt_visible = tk.IntVar(value=1) self.var_mpjpe = tk.StringVar(value='??') self.var_pck = tk.StringVar(value='??') self.var_aligned = tk.IntVar(value=0) self.var_joint = tk.StringVar(value='pelvis') if self.model is not None: self.var_pred_visible.set(1) global_toolbar = self._make_global_toolbar(self) global_toolbar.pack(side=tk.TOP, fill=tk.X) self.notebook = ttk.Notebook(self) self.notebook.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True, padx=4, pady=4) def on_change_tab(event): self.update_current_tab() self.notebook.bind('<<NotebookTabChanged>>', on_change_tab) self.tab_update_funcs = [ self._make_overview_tab(self.notebook), self._make_heatmap_tab(self.notebook), ] self.current_example_index = 0 @property def current_example_index(self): return int(self.var_cur_example.get()) @current_example_index.setter def current_example_index(self, value): self.var_cur_example.set(str(value)) self.on_change_example() @property def pred_visible(self): return self.var_pred_visible.get() != 0 @property def gt_visible(self): return self.var_gt_visible.get() != 0 and self.current_example['gt_skel'] is not None @property def is_aligned(self): return self.var_aligned.get() != 0 def update_current_tab(self): cur_tab_index = self.notebook.index('current') if self.model is not None and self.current_example['gt_skel']: actual = root_relative(self.current_example['pred_skel']['camera_space']) expected = root_relative(self.current_example['gt_skel']['original']) if self.is_aligned: actual = apply_rigid_alignment(actual, expected) included_joints = [ CanonicalSkeletonDesc.joint_names.index(joint_name) for joint_name in VNect_Common_Skeleton ] self.var_mpjpe.set('{:0.4f}'.format(mpjpe(actual, expected, included_joints))) self.var_pck.set('{:0.4f}'.format(pck(actual, expected, included_joints))) self.tab_update_funcs[cur_tab_index]() def _make_global_toolbar(self, master): toolbar = tk.Frame(master, bd=1, relief=tk.RAISED) def add_label(text): opts = dict(text=text) if isinstance(text, str) else dict(textvariable=text) label = tk.Label(toolbar, **opts) label.pack(side=tk.LEFT, fill=tk.Y, padx=2, pady=2) return label add_label('Example index:') txt_cur_example = tk.Spinbox( toolbar, textvariable=self.var_cur_example, command=self.on_change_example, wrap=True, from_=0, to=len(self.dataset) - 1, font=tk.font.Font(size=12)) def on_key_cur_example(event): if event.keysym == 'Return': self.on_change_example() txt_cur_example.bind('<Key>', on_key_cur_example) txt_cur_example.pack(side=tk.LEFT, fill=tk.Y, padx=2, pady=2) if self.model is not None: add_label('MPJPE:') add_label(self.var_mpjpe) add_label('PCK@150mm:') add_label(self.var_pck) chk_aligned = tk.Checkbutton( toolbar, text='Procrustes alignment', variable=self.var_aligned, command=lambda: self.update_current_tab()) chk_aligned.pack(side=tk.LEFT, fill=tk.Y, padx=2, pady=2) return toolbar def _make_overview_tab(self, notebook: ttk.Notebook): tab = tk.Frame(notebook) notebook.add(tab, text='Overview') toolbar = tk.Frame(tab, bd=1, relief=tk.RAISED) toolbar.pack(side=tk.TOP, fill=tk.X) chk_pred_visible = tk.Checkbutton( toolbar, text='Show prediction', variable=self.var_pred_visible, command=lambda: self.update_current_tab()) chk_pred_visible.pack(side=tk.LEFT, fill=tk.Y, padx=2, pady=2) if self.model is None: self.var_pred_visible.set(0) chk_pred_visible.configure(state='disabled') chk_gt_visible = tk.Checkbutton( toolbar, text='Show ground truth', variable=self.var_gt_visible, command=lambda: self.update_current_tab()) if hasattr(self.dataset, 'subset') and self.dataset.subset == 'test': self.var_gt_visible.set(0) chk_gt_visible.configure(state='disabled') chk_gt_visible.pack(side=tk.LEFT, fill=tk.Y, padx=2, pady=2) fig = Figure() fig.subplots_adjust(0.05, 0.10, 0.95, 0.95, 0.05, 0.05) canvas = FigureCanvasTkAgg(fig, tab) canvas.draw() canvas.get_tk_widget().pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True) nav_toolbar = NavigationToolbar2Tk(canvas, tab) nav_toolbar.update() canvas._tkcanvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True) prev_ax1: Axes3D = None def update_tab(): fig.clf() skels = [] if self.pred_visible: skels.append(self.current_example['pred_skel']) if self.gt_visible: skels.append(self.current_example['gt_skel']) ax1: Axes3D = fig.add_subplot(1, 2, 1, projection='3d') ax2 = fig.add_subplot(1, 2, 2) ax2.imshow(self.current_example['input_image']) ground_truth = root_relative(self.current_example['gt_skel']['original']) for i, skel in enumerate(skels): alpha = 1 / (3 ** i) skel3d = root_relative(skel['camera_space']) if self.is_aligned: skel3d = apply_rigid_alignment(skel3d, ground_truth) plot_skeleton_on_axes3d(skel3d, CanonicalSkeletonDesc, ax1, invert=True, alpha=alpha) plot_skeleton_on_axes(skel['image_space'], CanonicalSkeletonDesc, ax2, alpha=alpha) # Preserve 3D axes view nonlocal prev_ax1 if prev_ax1 is not None: ax1.view_init(prev_ax1.elev, prev_ax1.azim) prev_ax1 = ax1 canvas.draw() return update_tab def _make_heatmap_tab(self, notebook: ttk.Notebook): tab = tk.Frame(notebook) tab_index = len(notebook.tabs()) notebook.add(tab, text='Heatmaps') if self.model is None: notebook.tab(tab_index, state='disabled') toolbar = tk.Frame(tab, bd=1, relief=tk.RAISED) toolbar.pack(side=tk.TOP, fill=tk.X) joint_names = list(sorted(self.dataset.skeleton_desc.joint_names)) opt_joint = tk.OptionMenu( toolbar, self.var_joint, *joint_names, command=lambda event: self.update_current_tab()) opt_joint.pack(side=tk.LEFT, fill=tk.Y, padx=2, pady=2) var_image_visible = tk.IntVar(value=1) chk_image_visible = tk.Checkbutton( toolbar, text='Show image overlay', variable=var_image_visible, command=lambda: self.update_current_tab()) chk_image_visible.pack(side=tk.LEFT, fill=tk.Y, padx=2, pady=2) var_mean_crosshairs = tk.IntVar(value=1) chk_mean_crosshairs = tk.Checkbutton( toolbar, text='Show mean', variable=var_mean_crosshairs, command=lambda: self.update_current_tab()) chk_mean_crosshairs.pack(side=tk.LEFT, fill=tk.Y, padx=2, pady=2) fig = Figure() canvas = FigureCanvasTkAgg(fig, tab) canvas.draw() canvas.get_tk_widget().pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True) nav_toolbar = NavigationToolbar2Tk(canvas, tab) nav_toolbar.update() canvas._tkcanvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True) prev_ax3d: Axes3D = None def update_tab(): fig.clf() joint_index = self.dataset.skeleton_desc.joint_names.index(self.var_joint.get()) cmap = plt.get_cmap('gist_yarg') img = self.current_example['input_image'] hms = [ (3, self.current_example['xy_heatmaps'][-1][joint_index], ('x', 'y')), (1, self.current_example['xz_heatmaps'][-1][joint_index], ('x', 'z')), (4, self.current_example['zy_heatmaps'][-1][joint_index], ('z', 'y')), ] for subplot_id, hm, (xlabel, ylabel) in hms: ax = fig.add_subplot(2, 2, subplot_id) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) extent = [0, hm.size(-1), hm.size(-2), 0] ax.imshow(hm, cmap=cmap, extent=extent) if subplot_id == 3 and var_image_visible.get() != 0: ax.imshow(img, extent=extent, alpha=0.5) if var_mean_crosshairs.get() != 0: ax.axvline( np.average(np.arange(0, hm.size(-1)), weights=np.array(hm.sum(-2))), ls='dashed', ) ax.axhline( np.average(np.arange(0, hm.size(-2)), weights=np.array(hm.sum(-1))), ls='dashed', ) size = self.current_example['xy_heatmaps'][-1].size(-1) ax: Axes3D = fig.add_subplot(2, 2, 2, projection='3d') plot_skeleton_on_axes3d( (root_relative(self.current_example['pred_skel']['normalised']) + 1) * 0.5 * size, self.dataset.skeleton_desc, ax, invert=True) ax.set_xlim(0, size) ax.set_ylim(0, size) ax.set_zlim(size, 0) # Preserve 3D axes view nonlocal prev_ax3d if prev_ax3d is not None: ax.view_init(prev_ax3d.elev, prev_ax3d.azim) prev_ax3d = ax canvas.draw() return update_tab def on_change_example(self): self.current_example = load_and_process_example( self.dataset, self.current_example_index, self.device, self.model) self.update_current_tab() def main(argv, common_opts): args = parse_args(argv) seed_all(12345) init_algorithms(deterministic=True) torch.set_grad_enabled(False) device = common_opts['device'] if args.model: model = load_model(args.model).to(device).eval() data_specs = model.data_specs else: model = None data_specs = DataSpecs( ImageSpecs(224, mean=ImageSpecs.IMAGENET_MEAN, stddev=ImageSpecs.IMAGENET_STDDEV), JointsSpecs(CanonicalSkeletonDesc, n_dims=3), ) dataset = get_dataset(args.dataset, data_specs, use_aug=False) app = MainGUIApp(dataset, device, model) app.mainloop() GUI_Subcommand = Subcommand(name='gui', func=main, help='browse examples and predictions') if __name__ == '__main__': GUI_Subcommand.run()