#!/usr/bin/env python # -*- coding: utf-8 -*- # snap.py # Copyright (c) 2016-2020, Richard Gerum # # This file is part of Pylustrator. # # Pylustrator is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Pylustrator is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Pylustrator. If not, see <http://www.gnu.org/licenses/> from typing import List, Tuple, Optional import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np from matplotlib.artist import Artist from matplotlib.axes._subplots import Axes from matplotlib.legend import Legend from matplotlib.lines import Line2D from matplotlib.patches import Rectangle, Ellipse, FancyArrowPatch from matplotlib.text import Text DIR_X0 = 1 DIR_Y0 = 2 DIR_X1 = 4 DIR_Y1 = 8 def checkXLabel(target: Artist): """ checks if the target is the xlabel of an axis """ for axes in target.figure.axes: if axes.xaxis.get_label() == target: return axes def checkYLabel(target: Artist): """ checks if the target is the ylabel of an axis """ for axes in target.figure.axes: if axes.yaxis.get_label() == target: return axes class TargetWrapper(object): """ a wrapper to add unified set and get position methods for any matplotlib artist """ target = None def __init__(self, target: Artist): self.target = target self.figure = target.figure self.do_scale = True self.fixed_aspect = False # a patch uses the data_transform if isinstance(self.target, mpl.patches.Patch): self.get_transform = self.target.get_data_transform # axes use the figure_transform elif isinstance(self.target, Axes): # and optionally have a fixed aspect ratio if self.target.get_aspect() != "auto" and self.target.get_adjustable() != "datalim": self.fixed_aspect = True self.get_transform = lambda: self.target.figure.transFigure # texts use get_transform elif isinstance(self.target, Text): if getattr(self.target, "xy", None) is not None: self.do_scale = True else: self.do_scale = False if checkXLabel(self.target): self.label_factor = self.figure.dpi / 72.0 if getattr(self.target, "pad_offset", None) is None: self.target.pad_offset = self.target.get_position()[1] - checkXLabel( self.target).xaxis.labelpad * self.label_factor self.label_y = self.target.get_position()[1] elif checkYLabel(self.target): self.label_factor = self.figure.dpi / 72.0 if getattr(self.target, "pad_offset", None) is None: self.target.pad_offset = self.target.get_position()[0] - checkYLabel( self.target).yaxis.labelpad * self.label_factor self.label_x = self.target.get_position()[0] self.get_transform = self.target.get_transform # the default is to use get_transform else: self.get_transform = self.target.get_transform self.do_scale = False def get_positions(self) -> (int, int, int, int): """ get the current position of the target Artist """ points = [] if isinstance(self.target, Rectangle): points.append(self.target.get_xy()) p2 = (self.target.get_x() + self.target.get_width(), self.target.get_y() + self.target.get_height()) points.append(p2) elif isinstance(self.target, Ellipse): c = self.target.center w = self.target.width h = self.target.height points.append((c[0] - w / 2, c[1] - h / 2)) points.append((c[0] + w / 2, c[1] + h / 2)) elif isinstance(self.target, FancyArrowPatch): points.append(self.target._posA_posB[0]) points.append(self.target._posA_posB[1]) points.extend(self.target.get_path().vertices) elif isinstance(self.target, Text): points.append(self.target.get_position()) if checkXLabel(self.target): points[0] = (points[0][0], self.label_y) elif checkYLabel(self.target): points[0] = (self.label_x, points[0][1]) if getattr(self.target, "xy", None) is not None: points.append(self.target.xy) bbox = self.target.get_bbox_patch() if bbox: points.append(bbox.get_transform().transform((bbox.get_x(), bbox.get_y()))) points.append( bbox.get_transform().transform((bbox.get_x() + bbox.get_width(), bbox.get_y() + bbox.get_height()))) points[-2:] = self.transform_inverted_points(points[-2:]) elif isinstance(self.target, Axes): p1, p2 = np.array(self.target.get_position()) points.append(p1) points.append(p2) elif isinstance(self.target, Legend): bbox = self.target.get_frame().get_bbox() if isinstance(self.target._get_loc(), int): # if the legend doesn't have a location yet, use the left bottom corner of the bounding box self.target._set_loc(tuple(self.target.axes.transAxes.inverted().transform(tuple([bbox.x0, bbox.y0])))) points.append(self.target.axes.transAxes.transform(self.target._get_loc())) # add points to span bounding box around the frame points.append([bbox.x0, bbox.y0]) points.append([bbox.x1, bbox.y1]) return self.transform_points(points) def set_positions(self, points: (int, int)): """ set the position of the target Artist """ points = self.transform_inverted_points(points) if isinstance(self.target, Rectangle): self.target.set_xy(points[0]) self.target.set_width(points[1][0] - points[0][0]) self.target.set_height(points[1][1] - points[0][1]) if self.target.get_label() is None or not self.target.get_label().startswith("_rect"): self.figure.change_tracker.addChange(self.target, ".set_xy([%f, %f])" % tuple(self.target.get_xy())) self.figure.change_tracker.addChange(self.target, ".set_width(%f)" % self.target.get_width()) self.figure.change_tracker.addChange(self.target, ".set_height(%f)" % self.target.get_height()) elif isinstance(self.target, Ellipse): self.target.center = np.mean(points, axis=0) self.target.width = points[1][0] - points[0][0] self.target.height = points[1][1] - points[0][1] self.figure.change_tracker.addChange(self.target, ".center = (%f, %f)" % tuple(self.target.center)) self.figure.change_tracker.addChange(self.target, ".width = %f" % self.target.width) self.figure.change_tracker.addChange(self.target, ".height = %f" % self.target.height) elif isinstance(self.target, FancyArrowPatch): self.target.set_positions(points[0], points[1]) self.figure.change_tracker.addChange(self.target, ".set_positions(%s, %s)" % (tuple(points[0]), tuple(points[1]))) elif isinstance(self.target, Text): if checkXLabel(self.target): axes = checkXLabel(self.target) axes.xaxis.labelpad = -(points[0][1] - self.target.pad_offset) / self.label_factor self.figure.change_tracker.addChange(axes, ".xaxis.labelpad = %f" % axes.xaxis.labelpad) self.target.set_position(points[0]) self.label_y = points[0][1] elif checkYLabel(self.target): axes = checkYLabel(self.target) axes.yaxis.labelpad = -(points[0][0] - self.target.pad_offset) / self.label_factor self.figure.change_tracker.addChange(axes, ".yaxis.labelpad = %f" % axes.yaxis.labelpad) self.target.set_position(points[0]) self.label_x = points[0][0] else: self.target.set_position(points[0]) self.figure.change_tracker.addChange(self.target, ".set_position([%f, %f])" % self.target.get_position()) if getattr(self.target, "xy", None) is not None: self.target.xy = points[1] self.figure.change_tracker.addChange(self.target, ".xy = (%f, %f)" % tuple(self.target.xy)) elif isinstance(self.target, Legend): point = self.target.axes.transAxes.inverted().transform(self.transform_inverted_points(points)[0]) self.target._loc = tuple(point) self.figure.change_tracker.addChange(self.target, "._set_loc((%f, %f))" % tuple(point)) elif isinstance(self.target, Axes): position = np.array([points[0], points[1] - points[0]]).flatten() if self.fixed_aspect: position[3] = position[2] * self.target.get_position().height / self.target.get_position().width self.target.set_position(position) self.figure.change_tracker.addChange(self.target, ".set_position([%f, %f, %f, %f])" % tuple( np.array([points[0], points[1] - points[0]]).flatten())) def get_extent(self) -> (int, int, int, int): """ get the extend of the target """ points = np.array(self.get_positions()) return [np.min(points[:, 0]), np.min(points[:, 1]), np.max(points[:, 0]), np.max(points[:, 1])] def transform_points(self, points: (int, int)) -> (int, int): """ transform points from the targets local coordinate system to the figure coordinate system """ transform = self.get_transform() return [transform.transform(p) for p in points] def transform_inverted_points(self, points: (int, int)) -> (int, int): """ transform points from the figure coordinate system to the targets local coordinate system """ transform = self.get_transform() return [transform.inverted().transform(p) for p in points] class SnapBase(Line2D): """ The base class to implement snaps. """ def __init__(self, ax_source: Artist, ax_target: Artist, edge: int): # wrap both object with a TargetWrapper self.ax_source = TargetWrapper(ax_source) self.ax_target = TargetWrapper(ax_target) self.edge = edge # initialize a line object for the visualisation of the snap Line2D.__init__(self, [], [], transform=None, clip_on=False, lw=1, zorder=100, linestyle="dashed", color="r", marker="o", ms=1, label="_tmp_snap") plt.gca().add_artist(self) def getPosition(self, target: TargetWrapper): """ get the position of a target """ try: return target.get_extent() except AttributeError: return np.array(target.figure.transFigure.transform(target.get_position())).flatten() def getDistance(self, index: int) -> (int, int): """ Calculate the distance of the snap to its target """ return 0, 0 def checkSnap(self, index: int) -> Optional[float]: """ Return the distance to the targets or None """ distance = self.getDistance(index) if abs(distance) < 10: return distance return None def checkSnapActive(self): """ Test if the snap condition is fullfilled """ distance = min([self.getDistance(index) for index in [0, 1]]) # show the snap if the distance to a target is smaller than 1 if abs(distance) < 1: self.show() else: self.hide() def show(self): """ Implements a visualisation of the snap, e.g. lines to indicate what objects are snapped to what """ pass def hide(self): """ Hides the visualisation """ self.set_data((), ()) def remove(self): """ Remove the snap and its visualisation """ self.hide() try: self.axes.artists.remove(self) except ValueError: pass class SnapSameEdge(SnapBase): """ a snap that checks if two objects share an edge """ def getDistance(self, index: int) -> (int, int): """ Calculate the distance of the snap to its target """ # only if the right edge index (x or y) is queried, if not the distance is infinite if self.edge % 2 != index: return np.inf # get the position of both objects p1 = self.getPosition(self.ax_source) p2 = self.getPosition(self.ax_target) # and return the difference in the target dimension return p1[self.edge] - p2[self.edge] def show(self): """ A visualisation of the snap, e.g. lines to indicate what objects are snapped to what """ # get the position of both objects p1 = self.getPosition(self.ax_source) p2 = self.getPosition(self.ax_target) # if the focus edge is x, draw a line along the edge if self.edge % 2 == 0: self.set_data((p1[self.edge], p1[self.edge], p2[self.edge], p2[self.edge]), (p1[self.edge - 1], p1[self.edge + 1], p2[self.edge - 1], p2[self.edge + 1])) # if the focus edge is y else: self.set_data((p1[self.edge - 1], p1[self.edge - 3], p2[self.edge - 1], p2[self.edge - 3]), (p1[self.edge], p1[self.edge], p2[self.edge], p2[self.edge])) class SnapSameDimension(SnapBase): """ a snap that checks if two objects have the same width or height """ def getDistance(self, index: int) -> (int, int): """ Calculate the distance of the snap to its target """ # only if the right edge index (x or y) is queried, if not the distance is infinite if self.edge % 2 != index: return np.inf # get the position of both objects p1 = self.getPosition(self.ax_source) p2 = self.getPosition(self.ax_target) # and the difference of the widths (or heights) of the objects return (p2[self.edge - 2] - p2[self.edge]) - (p1[self.edge - 2] - p1[self.edge]) def show(self): """ A visualisation of the snap, e.g. lines to indicate what objects are snapped to what """ # get the position of both objects p1 = self.getPosition(self.ax_source) p2 = self.getPosition(self.ax_target) # if the focus edge is x, draw a line though the center of each object if self.edge % 2 == 0: self.set_data((p1[0], p1[2], np.nan, p2[0], p2[2]), (p1[1] * 0.5 + p1[3] * 0.5, p1[1] * 0.5 + p1[3] * 0.5, np.nan, p2[1] * 0.5 + p2[3] * 0.5, p2[1] * 0.5 + p2[3] * 0.5)) # if the focus edge is y else: self.set_data((p1[0] * 0.5 + p1[2] * 0.5, p1[0] * 0.5 + p1[2] * 0.5, np.nan, p2[0] * 0.5 + p2[2] * 0.5, p2[0] * 0.5 + p2[2] * 0.5), (p1[1], p1[3], np.nan, p2[1], p2[3])) class SnapSamePos(SnapBase): """ a snap that checks if two objects have the same position """ def getPosition(self, text: TargetWrapper) -> (int, int): # get the position of an object return np.array(text.get_transform().transform(text.target.get_position())) def getDistance(self, index: int) -> int: """ Calculate the distance of the snap to its target """ # only if the right edge index (x or y) is queried, if not the distance is infinite if self.edge % 2 != index: return np.inf # get the position of both objects p1 = self.getPosition(self.ax_source) p2 = self.getPosition(self.ax_target) # get the distance of the two objects in the target dimension return p1[self.edge] - p2[self.edge] def show(self): """ A visualisation of the snap, e.g. lines to indicate what objects are snapped to what """ # get the position of both objects p1 = self.getPosition(self.ax_source) p2 = self.getPosition(self.ax_target) # draw a line connecting the centers of the objects self.set_data((p1[0], p2[0]), (p1[1], p2[1])) class SnapSameBorder(SnapBase): """ A snap that checks if tree axes share the space between them """ def __init__(self, ax_source: Artist, ax_target: Artist, ax_target2: Artist, edge: int): super().__init__(ax_source, ax_target, edge) self.ax_target2 = ax_target2 def overlap(self, p1: list, p2: list, dir: int): """ Test if two objects have an overlapping x or y region """ if p1[dir + 2] < p2[dir] or p1[dir] > p2[dir + 2]: return False return True def getBorders(self, p1: list, p2: list): borders = [] for edge in [0, 1]: if self.overlap(p1, p2, 1 - edge): if p1[edge + 2] < p2[edge]: dist = p2[edge] - p1[edge + 2] borders.append([edge * 2 + 0, dist]) if p1[edge] > p2[edge + 2]: dist = p1[edge] - p2[edge + 2] borders.append([edge * 2 + 1, dist]) return np.array(borders) def getDistance(self, index: int): """ Calculate the distance of the snap to its target """ # get the positions of all three targets p1 = self.getPosition(self.ax_source) p2 = self.getPosition(self.ax_target) p3 = self.getPosition(self.ax_target2) for edge in [index]: if not (self.edge & DIR_X1) and not (self.edge & DIR_Y1): if p1[edge + 2] < p2[edge]: continue if not (self.edge & DIR_X0) and not (self.edge & DIR_Y0): if p1[edge] > p2[edge + 2]: continue if (p1[edge + 2] < p2[edge] or p1[edge] > p2[edge + 2]) and self.overlap(p1, p2, 1 - edge): distances = np.array([p2[edge] - p1[edge + 2], p1[edge] - p2[edge + 2]]) index1 = np.argmax(distances) distance = distances[index1] borders = self.getBorders(p2, p3) if len(borders): deltas = distance - borders[:, 1] index2 = np.argmin(np.abs(deltas)) self.dir2 = borders[index2, 0] self.dir1 = edge * 2 + index1 return deltas[index2] * (-1 + 2 * index1) return np.inf def getConnection(self, p1: list, p2: list, dir: int): """ return the coordinates of a line that spans the space between to axes """ # check which edge (e.g. x, y) and which direction (e.g. if to change the order of p1 and p2) edge, order = dir // 2, dir % 2 # optionally change p1 with p2 if order == 1: p1, p2 = p2, p1 # if edge is x if edge == 0: y = np.mean([max(p1[1], p2[1]), min(p1[3], p2[3])]) return [[p1[2], p2[0], np.nan], [y, y, np.nan]] # if edge is y x = np.mean([max(p1[0], p2[0]), min(p1[2], p2[2])]) return [[x, x, np.nan], [p1[3], p2[1], np.nan]] def show(self): """ A visualisation of the snap, e.g. lines to indicate what objects are snapped to what """ # get the positions of all three axes p1 = self.getPosition(self.ax_source) p2 = self.getPosition(self.ax_target) p3 = self.getPosition(self.ax_target2) # get the x1, y1 = self.getConnection(p1, p2, self.dir1) x2, y2 = self.getConnection(p2, p3, self.dir2) x1.extend(x2) y1.extend(y2) self.set_data((x1, y1)) class SnapCenterWith(SnapBase): """ A snap that checks if a text is centered with an axes """ def getPosition(self, text: TargetWrapper) -> (int, int): """ get the position of the first object """ return np.array(text.get_transform().transform(text.target.get_position())) def getPosition2(self, axes: TargetWrapper) -> int: """ get the position of the second object """ pos = np.array(axes.figure.transFigure.transform(axes.target.get_position())) p = pos[0, :] p[self.edge] = np.mean(pos, axis=0)[self.edge] return p def getDistance(self, index: int) -> int: """ Calculate the distance of the snap to its target """ # only if the right edge index (x or y) is queried, if not the distance is infinite if self.edge % 2 != index: return np.inf # get the position of both objects p1 = self.getPosition(self.ax_source) p2 = self.getPosition2(self.ax_target) # get the distance of the two objects in the target dimension return p1[self.edge] - p2[self.edge] def show(self): """ A visualisation of the snap, e.g. lines to indicate what objects are snapped to what """ # get the position of both objects p1 = self.getPosition(self.ax_source) p2 = self.getPosition2(self.ax_target) # draw a line connecting the centers of the objects self.set_data((p1[0], p2[0]), (p1[1], p2[1])) def checkSnaps(snaps: List[SnapBase]) -> (int, int): """ get the x and y offsets the snaps suggest """ result = [0, 0] # iterate over x and y for index in range(2): # find the best snap best = np.inf for snap in snaps: delta = snap.checkSnap(index) if delta is not None and abs(delta) < abs(best): best = delta # if there is a snap suggestion, store it if best < np.inf: result[index] = best # return the best suggestion return result def checkSnapsActive(snaps: List[SnapBase]): """ check if snaps are active and show them if yes """ for snap in snaps: snap.checkSnapActive() def getSnaps(targets: List[TargetWrapper], dir: int, no_height=False) -> List[SnapBase]: """ get all snap objects for the target and the direction """ snaps = [] targets = [t.target for t in targets] for target in targets: if isinstance(target, Legend): continue if isinstance(target, Text): if checkXLabel(target): snaps.append(SnapCenterWith(target, checkXLabel(target), 0)) elif checkYLabel(target): snaps.append(SnapCenterWith(target, checkYLabel(target), 1)) for ax in target.figure.axes + [target.figure]: for txt in ax.texts: # for other texts if txt in targets or not txt.get_visible(): continue # snap to the x and the y coordinate x, y = txt.get_transform().transform(txt.get_position()) snaps.append(SnapSamePos(target, txt, 0)) snaps.append(SnapSamePos(target, txt, 1)) continue for index, axes in enumerate(target.figure.axes): if axes not in targets and axes.get_visible(): # axes edged if dir & DIR_X0: snaps.append(SnapSameEdge(target, axes, 0)) if dir & DIR_Y0: snaps.append(SnapSameEdge(target, axes, 1)) if dir & DIR_X1: snaps.append(SnapSameEdge(target, axes, 2)) if dir & DIR_Y1: snaps.append(SnapSameEdge(target, axes, 3)) # snap same dimensions if not no_height: if dir & DIR_X0: snaps.append(SnapSameDimension(target, axes, 0)) if dir & DIR_X1: snaps.append(SnapSameDimension(target, axes, 2)) if dir & DIR_Y0: snaps.append(SnapSameDimension(target, axes, 1)) if dir & DIR_Y1: snaps.append(SnapSameDimension(target, axes, 3)) for axes2 in target.figure.axes: if axes2 != axes and axes2 not in targets and axes2.get_visible(): snaps.append(SnapSameBorder(target, axes, axes2, dir)) return snaps