import mpl_toolkits.mplot3d as a3 import matplotlib.colors as colors import pylab as pl import numpy as np V = np.array r2h = lambda x: colors.rgb2hex(tuple(map(lambda y: y / 255., x))) surface_color = r2h((255, 230, 205)) edge_color = r2h((90, 90, 90)) edge_colors = (r2h((15, 167, 175)), r2h((230, 81, 81)), r2h((142, 105, 252)), r2h((248, 235, 57)), r2h((51, 159, 255)), r2h((225, 117, 231)), r2h((97, 243, 185)), r2h((161, 183, 196))) def init_plot(): ax = pl.figure().add_subplot(111, projection='3d') # hide axis, thank to # https://stackoverflow.com/questions/29041326/3d-plot-with-matplotlib-hide-axes-but-keep-axis-labels/ ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) # Get rid of the spines ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) # Get rid of the ticks ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) return (ax, [np.inf, -np.inf, np.inf, -np.inf, np.inf, -np.inf]) def update_lim(mesh, plot): vs = mesh[0] for i in range(3): plot[1][2 * i] = min(plot[1][2 * i], vs[:, i].min()) plot[1][2 * i + 1] = max(plot[1][2 * i], vs[:, i].max()) return plot def update_plot(mesh, plot): if plot is None: plot = init_plot() return update_lim(mesh, plot) def surfaces(mesh, plot): vs, faces, edges = mesh vtx = vs[faces] edgecolor = edge_color if not len(edges) else 'none' tri = a3.art3d.Poly3DCollection(vtx, facecolors=surface_color +'55', edgecolors=edgecolor, linewidths=.5, linestyles='dashdot') plot[0].add_collection3d(tri) return plot def segments(mesh, plot): vs, _, edges = mesh for edge_c, edge_group in enumerate(edges): for edge_idx in edge_group: edge = vs[edge_idx] line = a3.art3d.Line3DCollection([edge], linewidths=.5, linestyles='dashdot') line.set_color(edge_colors[edge_c % len(edge_colors)]) plot[0].add_collection3d(line) return plot def plot_mesh(mesh, *whats, show=True, plot=None): for what in [update_plot] + list(whats): plot = what(mesh, plot) if show: li = max(plot[1][1], plot[1][3], plot[1][5]) plot[0].auto_scale_xyz([0, li], [0, li], [0, li]) pl.tight_layout() pl.show() return plot def parse_obje(obj_file, scale_by): vs = [] faces = [] edges = [] def add_to_edges(): if edge_c >= len(edges): for _ in range(len(edges), edge_c + 1): edges.append([]) edges[edge_c].append(edge_v) def fix_vertices(): nonlocal vs, scale_by vs = V(vs) z = vs[:, 2].copy() vs[:, 2] = vs[:, 1] vs[:, 1] = z max_range = 0 for i in range(3): min_value = np.min(vs[:, i]) max_value = np.max(vs[:, i]) max_range = max(max_range, max_value - min_value) vs[:, i] -= min_value if not scale_by: scale_by = max_range vs /= scale_by with open(obj_file) as f: for line in f: line = line.strip() splitted_line = line.split() if not splitted_line: continue elif splitted_line[0] == 'v': vs.append([float(v) for v in splitted_line[1:]]) elif splitted_line[0] == 'f': faces.append([int(c) - 1 for c in splitted_line[1:]]) elif splitted_line[0] == 'e': if len(splitted_line) >= 4: edge_v = [int(c) - 1 for c in splitted_line[1:-1]] edge_c = int(splitted_line[-1]) add_to_edges() vs = V(vs) fix_vertices() faces = V(faces, dtype=int) edges = [V(c, dtype=int) for c in edges] return (vs, faces, edges), scale_by def view_meshes(*files, offset=.2): plot = None max_x = 0 scale = 0 for file in files: mesh, scale = parse_obje(file, scale) max_x_current = mesh[0][:, 0].max() mesh[0][:, 0] += max_x + offset plot = plot_mesh(mesh, surfaces, segments, plot=plot, show=file == files[-1]) max_x += max_x_current + offset if __name__=='__main__': import argparse parser = argparse.ArgumentParser("view meshes") parser.add_argument('--files', nargs='+', default=['checkpoints/human_seg/meshes/shrec__14_0.obj', 'checkpoints/human_seg/meshes/shrec__14_3.obj'], type=str, help="list of 1 or more .obj files") args = parser.parse_args() # view meshes view_meshes(*args.files)