import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.path import Path from matplotlib.patches import PathPatch import numpy as np try: numeric_types = (int, float, long) except NameError: numeric_types = (int, float) class SimpleVectorPlotter(object): """Plots vector data represented as lists of coordinates.""" # _graphics = {} def __init__(self, interactive, ticks=False, figsize=None, limits=None): """Construct a new SimpleVectorPlotter. interactive - boolean flag denoting interactive mode ticks - boolean flag denoting whether to show axis tickmarks figsize - optional figure size limits - optional geographic limits (x_min, x_max, y_min, y_max) """ # if figsize: # plt.figure(num=1, figsize=figsize) plt.figure(num=1, figsize=figsize) self.interactive = interactive self.ticks = ticks if interactive: plt.ion() else: plt.ioff() if limits is not None: self.set_limits(*limits) if not ticks: self.no_ticks() plt.axis('equal') self._graphics = {} self._init_colors() def adjust_markers(self): figsize = plt.gcf().get_size_inches() r = min(figsize[0] / 8, figsize[1] / 6) mpl.rcParams['lines.markersize'] = 6 * r mpl.rcParams['lines.markeredgewidth'] = 0.5 * r mpl.rcParams['lines.linewidth'] = r mpl.rcParams['patch.linewidth'] = r def adjust_markersize(self, size): figsize = plt.gcf().get_size_inches() r = min(figsize[0] / 8, figsize[1] / 6) return 6 * r def axis_on(self, on): """Turn the axes and labels on or off.""" if on: plt.axis('on') else: plt.axis('off') def clear(self): """Clear the plot area.""" plt.cla() self._graphics = {} if not self.ticks: self.no_ticks() def close(self): """Close the plot.""" self.clear() plt.close() def draw(self): """Draw a non-interactive plot.""" plt.show() def hide(self, name): """Hide the layer with the given name.""" try: graphics = self._graphics[name] graphic_type = type(graphics[0]) if graphic_type is mpl.lines.Line2D: for graphic in graphics: plt.axes().lines.remove(graphic) elif graphic_type is mpl.patches.Polygon or graphic_type is mpl.patches.PathPatch: for graphic in graphics: plt.axes().patches.remove(graphic) else: raise RuntimeError('{} not supported'.format(graphic_type)) except (KeyError, ValueError): pass def plot_line(self, data, symbol='', name='', **kwargs): """Plot a line. data - list of (x, y) tuples symbol - optional pyplot symbol to draw the line with name - optional name to assign to layer so can access it later kwargs - optional pyplot drawing parameters """ graphics = self._plot_line(data, symbol, **kwargs) self._set_graphics(graphics, name, symbol or kwargs) def plot_multiline(self, data, symbol='', name='', **kwargs): """Plot a multiline. data - list of lines, each of which is a list of (x, y) tuples symbol - optional pyplot symbol to draw the lines with name - optional name to assign to layer so can access it later kwargs - optional pyplot drawing parameters """ has_symbol = symbol or kwargs symbol = symbol or self._line_symbol() graphics = self._plot_multiline(data, symbol, **kwargs) self._set_graphics(graphics, name, has_symbol) def plot_multipoint(self, data, symbol='', name='', **kwargs): """Plot a multipoint. data - list of (x, y) tuples symbol - optional pyplot symbol to draw the points with name - optional name to assign to layer so can access it later kwargs - optional pyplot drawing parameters """ has_symbol = symbol or kwargs symbol = symbol or self._point_symbol() graphics = self._plot_multipoint(data, **kwargs) self._set_graphics(graphics, name, has_symbol) def plot_multipolygon(self, data, color='', name='', **kwargs): """Plot a multipolygon. data - list of polygons, each of which is a list of rings, each of which is a list of (x, y) tuples color - optional pyplot color to draw the polygons with name - optional name to assign to layer so can access it later kwargs - optional pyplot drawing parameters """ has_symbol = bool(color or kwargs) if not ('facecolor' in kwargs or 'fc' in kwargs): kwargs['fc'] = color or self._next_color() graphics = self._plot_multipolygon(data, **kwargs) self._set_graphics(graphics, name, has_symbol) def plot_point(self, data, symbol='', name='', **kwargs): """Plot a point. data - (x, y) tuple symbol - optional pyplot symbol to draw the point with name - optional name to assign to layer so can access it later kwargs - optional pyplot drawing parameters """ has_symbol = symbol or kwargs symbol = symbol or self._point_symbol() graphics = self._plot_point(data, symbol, **kwargs) self._set_graphics(graphics, name, has_symbol) def plot_polygon(self, data, color='', name='', **kwargs): """Plot a polygon. data - list of rings, each of which is a list of (x, y) tuples color - optional pyplot color to draw the polygon with name - optional name to assign to layer so can access it later kwargs - optional pyplot drawing parameters """ has_symbol = bool(color or kwargs) if not ('facecolor' in kwargs or 'fc' in kwargs): kwargs['fc'] = color or self._next_color() graphics = self._plot_polygon(data, **kwargs) self._set_graphics(graphics, name, has_symbol) def remove(self, name): """Remove a layer with the given name.""" try: self.hide(name) del self._graphics[name] except KeyError: pass def save(self, fn, dpi=300): plt.savefig(fn, dpi=dpi, bbox_inches='tight', pad_inches=0.02) def set_limits(self, x_min, x_max, y_min, y_max): """Set geographic limits for plotting.""" self.x_lim = x_min, x_max self.y_lim = y_min, y_max self._set_limits() def show(self, name): """Show the layer with the given name.""" try: graphics = self._graphics[name] graphic_type = type(graphics[0]) if graphic_type is mpl.lines.Line2D: for graphic in graphics: plt.axes().add_line(graphic) elif graphic_type is mpl.patches.Polygon or graphic_type is mpl.patches.PathPatch: for graphic in graphics: plt.axes().add_patch(graphic) else: raise RuntimeError('{} not supported'.format(graphic_type)) except KeyError: pass def no_ticks(self): plt.gca().get_xaxis().set_ticks([]) plt.gca().get_yaxis().set_ticks([]) def zoom(self, factor): """Zoom in or out by a percentage; negative is out.""" x_min, x_max, y_min, y_max = plt.axis() x_delta = (x_max - x_min) * factor / 100 y_delta = (y_max - y_min) * factor / 100 plt.axis((x_min + x_delta, x_max - x_delta, y_min + y_delta, y_max - y_delta)) def _clockwise(self, data): """Determine if points are in clockwise order.""" total = 0 x1, y1 = data[0] for x, y in data[1:]: total += (x - x1) * (y + y1) x1, y1 = x, y x, y = data[0] total += (x - x1) * (y + y1) return total > 0 def _codes(self, data): """Get a list of codes for creating a new PathPatch.""" codes = np.ones(len(data), dtype=np.int) * Path.LINETO codes[0] = Path.MOVETO return codes def _init_colors(self): if mpl.__version__ >= '1.5': self.colors = list(mpl.rcParams['axes.prop_cycle']) self.current_color = -1 self._next_color = self._next_color_new else: self._next_color = self._next_color_old def _line_symbol(self): """Get a default line symbol.""" return self._next_color() + '-' def _next_color_old(self): """Get the next color in the rotation.""" return next(plt.gca()._get_lines.color_cycle) def _next_color_new(self): """Get the next color in the rotation.""" self.current_color += 1 if self.current_color >= len(self.colors): self.current_color = 0 return self.colors[self.current_color]['color'] def _order_vertices(self, data, clockwise=True): """Order vertices in clockwise or counter-clockwise order.""" self._clockwise(data) != clockwise or data.reverse() if data[0] != data[-1]: data.append(data[0]) return data def _plot_line(self, data, symbol, **kwargs): x, y = zip(*data) return plt.plot(x, y, symbol, **kwargs) def _plot_multiline(self, data, symbol, **kwargs): """Plot a multiline.""" graphics = [] for line in data: graphics += self._plot_line(line, symbol, **kwargs) return graphics def _plot_multipoint(self, data, symbol, **kwargs): """Plot a multipoint.""" graphics = [] for point in data: graphics += self._plot_point(point, symbol, **kwargs) return graphics def _plot_multipolygon(self, data, **kwargs): """Plot a multipolygon.""" graphics = [] for poly in data: graphics += self._plot_polygon(poly, **kwargs) return graphics def _plot_point(self, data, symbol, **kwargs): """Plot a point.""" return plt.plot(data[0], data[1], symbol, **kwargs) def _plot_polygon(self, data, **kwargs): """Plot a polygon.""" outer = self._order_vertices(data[0], True) inner = [self._order_vertices(d, False) for d in data[1:]] vertices = np.concatenate( [np.asarray(outer)] + [np.asarray(i) for i in inner]) codes = np.concatenate( [self._codes(outer)] + [self._codes(i) for i in inner]) patch = PathPatch(Path(vertices, codes), **kwargs) plt.axes().add_patch(patch) return [patch] def _point_symbol(self): """Get a default point symbol.""" return self._next_color() + 'o' def _same_type(self, graphic1, graphic2): """Determine if two graphics are of the same type.""" if type(graphic1) is not type(graphic2): return False if type(graphic1) is mpl.patches.Polygon: ## huh? return True if len(graphic1.get_xdata()) == len(graphic2.get_xdata()): return True return len(graphic1.get_xdata()) > 1 and len(graphic2.get_xdata()) > 1 def _set_graphics(self, graphics, name, has_symbol): """Add graphics to plot.""" name = name or len(self._graphics) if name in self._graphics: self.hide(name) if not has_symbol and self._same_type(graphics[0], self._graphics[name][0]): styled_graphic = self._graphics[name][0] for graphic in graphics: graphic.update_from(styled_graphic) self._graphics[name] = graphics plt.axis('equal') def _set_limits(self): """Set axis limits.""" plt.xlim(*self.x_lim) plt.ylim(*self.y_lim) plt.axes().set_aspect('equal')