""" A more convenient interface to matplotlib. Note ---- It appears that matplotlib's savefig has a bug in certain versions. Update your matplotlib to the latest version and this problem will go away. """ from __future__ import absolute_import from __future__ import division import os import subprocess import sys from os.path import join import numpy as np import matplotlib as mpl; mpl.use('Agg') # For compatibility on cluster import matplotlib.pyplot as plt import mpl_toolkits.mplot3d.art3d as art3d from matplotlib.colors import colorConverter from matplotlib.mlab import PCA from mpl_toolkits.mplot3d import Axes3D from scipy.special import cbrt from .utils import mkdir_p THIS = "pycog.figtools" #========================================================================================= # Font, LaTeX #========================================================================================= mpl.rcParams['font.family'] = 'sans-serif' mpl.rcParams['ps.useafm'] = True mpl.rcParams['pdf.use14corefonts'] = True # Setup LaTeX if available try: FNULL = open(os.devnull, 'w') subprocess.check_call('latex --version', shell=True, stdout=FNULL, stderr=subprocess.STDOUT) except subprocess.CalledProcessError: latex = False print("[ {} ] Warning: Couldn't find LaTeX. Your figures will look ugly." .format(THIS)) else: latex = True mpl.rcParams['text.usetex'] = True mpl.rcParams['text.latex.preamble'] = ( '\usepackage{sfmath}' '\usepackage[T1]{fontenc}' '\usepackage{amsmath}' '\usepackage{amssymb}' ) #========================================================================================= # Global defaults #========================================================================================= mpl.rcParams['xtick.direction'] = 'out' mpl.rcParams['ytick.direction'] = 'out' #========================================================================================= # Simple color map #========================================================================================= def gradient(cmin, cmax): if isinstance(cmin, str): cmin = colorConverter.to_rgb(cmin) if isinstance(cmax, str): cmax = colorConverter.to_rgb(cmax) cdict = { 'red': [(0, 0, cmin[0]), (1, cmax[0], 1)], 'green': [(0, 0, cmin[1]), (1, cmax[1], 1)], 'blue': [(0, 0, cmin[2]), (1, cmax[2], 1)] } return mpl.colors.LinearSegmentedColormap('cmap', cdict, N=1000) #========================================================================================= # Colors #========================================================================================= def apply_alpha(color, alpha=0.7): fg = np.asarray(colorConverter.to_rgb(color)) bg = np.ones(3) return tuple(alpha*fg + (1-alpha)*bg) colors = { 'aquamarine': '#4c968e', 'gold': '#b18955', 'strongblue': '#2171b5', 'strongred': '#cb181d', 'blue': apply_alpha('#447294'), 'green': apply_alpha('#30701e'), 'red': apply_alpha('#bf2121'), 'salmon': '#ee9572', 'lightred': '#bf6d6b', 'darkred': '#8f2a2a', 'lightblue': '#8fbcdb', 'orange': apply_alpha('#e58c2c'), 'magenta': apply_alpha('#c42d95'), 'purple': apply_alpha('#8064a2'), 'lightgreen': '#78cd71', 'darkblue': '#084594',#'#315d7d', 'gray': '0.5', 'darkgray': '0.3', 'lightgray': '0.7', 'lightlightgray': '0.9', 'black': '#000000', 'white': '#ffffff' } #========================================================================================= # Subplot #========================================================================================= class Subplot(object): """ Interface to Axes. You can access any non-overridden Axes attribute directly through Subplot. """ def __getattr__(self, name): if hasattr(self.ax, name): return getattr(self.ax, name) if hasattr(self.ax, 'set_'+name): return getattr(self.ax, 'set_'+name) raise NotImplementedError("Subplot." + name) def __init__(self, fig, p, rect): """ rect : [left, bottom, width, height] """ self.p = p self.ax = fig.add_axes(rect) self.axes = [self.xaxis, self.yaxis] def set_thickness(self, thickness): for v in self.ax.spines.values(): v.set_linewidth(thickness) for ax in self.axes: ax.set_tick_params(width=thickness) def set_tick_params(self, ticksize, ticklabelsize, ticklabelpad): for ax in self.axes: ax.set_tick_params(size=ticksize, labelsize=ticklabelsize, pad=ticklabelpad) def format(self, style, p=None): if style == 'bottomleft': for s in ['top', 'right']: self.spines[s].set_visible(False) self.xaxis.tick_bottom() self.yaxis.tick_left() self.set_thickness(p['thickness']) self.set_tick_params(p['ticksize'], p['ticklabelsize'], p['ticklabelpad']) elif style == 'none': for s in self.spines.values(): s.set_visible(False) self.xticks() self.yticks() def axis_off(self, axis='left'): self.spines[axis].set_visible(False) if axis == 'bottom': self.xticks() self.xticklabels() elif axis == 'left': self.yticks() self.yticklabels() #///////////////////////////////////////////////////////////////////////////////////// def plot(self, *args, **kwargs): kwargs.setdefault('clip_on', False) kwargs.setdefault('zorder', 10) return self.ax.plot(*args, **kwargs) def xlabel(self, *args, **kwargs): kwargs.setdefault('fontsize', self.p['axislabelsize']) kwargs.setdefault('labelpad', self.p['labelpadx']) return self.set_xlabel(*args, **kwargs) def ylabel(self, *args, **kwargs): kwargs.setdefault('fontsize', self.p['axislabelsize']) kwargs.setdefault('labelpad', self.p['labelpady']) return self.set_ylabel(*args, **kwargs) def xticks(self, *args, **kwargs): if len(args) == 0: args = [], return self.set_xticks(*args, **kwargs) def yticks(self, *args, **kwargs): if len(args) == 0: args = [], return self.set_yticks(*args, **kwargs) def xticklabels(self, *args, **kwargs): if len(args) == 0: args = [], return self.set_xticklabels(*args, **kwargs) def yticklabels(self, *args, **kwargs): if len(args) == 0: args = [], return self.set_yticklabels(*args, **kwargs) def equal(self): return self.set_aspect('equal') #///////////////////////////////////////////////////////////////////////////////////// # Annotation def legend(self, *args, **kargs): kargs.setdefault('bbox_transform', self.transAxes) kargs.setdefault('frameon', False) kargs.setdefault('numpoints', 1) return self.ax.legend(*args, **kargs) def text_upper_center(self, s, dx=0, dy=0, fontsize=7.5, color='k', **kwargs): return self.text(0.5+dx, 1+dy, s, ha='center', va='bottom', fontsize=fontsize, color=color, transform=self.transAxes, **kwargs) def text_upper_left(self, s, dx=0, dy=0, fontsize=7.5, color='k', **kwargs): return self.text(0.04+dx, 0.97+dy, s, ha='left', va='top', fontsize=fontsize, color=color, transform=self.transAxes, **kwargs) def text_upper_right(self, s, dx=0, dy=0, fontsize=7.5, color='k', **kwargs): return self.text(0.97+dx, 0.97+dy, s, ha='right', va='top', fontsize=fontsize, color=color, transform=self.transAxes, **kwargs) def text_lower_center(self, s, dx=0, dy=0, fontsize=7.5, color='k', **kwargs): return self.text(0.5+dx, dy, s, ha='center', va='bottom', fontsize=fontsize, color=color, transform=self.transAxes, **kwargs) def text_lower_left(self, s, dx=0, dy=0, fontsize=7.5, color='k', **kwargs): return self.text(0.04+dx, 0.03+dy, s, ha='left', va='bottom', fontsize=fontsize, color=color, transform=self.transAxes, **kwargs) def text_lower_right(self, s, dx=0, dy=0, fontsize=7.5, color='k', **kwargs): return self.text(0.97+dx, 0.03+dy, s, ha='right', va='bottom', fontsize=fontsize, color=color, transform=self.transAxes, **kwargs) #///////////////////////////////////////////////////////////////////////////////////// def hline(self, y, **kwargs): kwargs.setdefault('zorder', 0) kwargs.setdefault('color', '0.2') return self.plot(self.get_xlim(), 2*[y], **kwargs) def vline(self, x, **kwargs): kwargs.setdefault('zorder', 0) kwargs.setdefault('color', '0.2') return self.plot(2*[x], self.get_ylim(), **kwargs) def circle(self, center, r, **kwargs): kwargs.setdefault('zorder', 0) circle = mpl.patches.Circle(center, r, **kwargs) return self.add_patch(circle) def highlight(self, x1, x2, **kwargs): xmin, xmax = self.get_xlim() ymin, ymax = self.get_ylim() kwargs.setdefault('facecolor', '0.9') kwargs.setdefault('edgecolor', 'none') kwargs.setdefault('alpha', 0.7) kwargs.setdefault('zorder', 1) fill = self.fill_between([x1, x2], ymin*np.ones(2), ymax*np.ones(2), **kwargs) # Restore axis limits self.xlim(xmin, xmax) self.ylim(ymin, ymax) return fill #///////////////////////////////////////////////////////////////////////////////////// def lim(self, axis, data, lower=None, upper=None, margin=0.05, relative=True): """ Automatically set axis margins. """ # Flatten try: data = [item for sublist in data for item in sublist] except: pass # Data bounds amin = min(data) amax = max(data) # Add margin if relative: da = margin*(amax - amin) else: da = margin # Adjust bounds amin -= da amax += da # Fixed bounds if lower is not None: amin = lower if upper is not None: amax = upper # Set alim = amin, amax if axis == 'x': self.xlim(*alim) elif axis == 'y': self.ylim(*alim) else: raise ValueError("Invalid axis.") return alim #///////////////////////////////////////////////////////////////////////////////////// @staticmethod def sturges(data): """ Sturges' rule for the number of bins in a histogram. """ return int(np.ceil(np.log2(len(data)) + 1)) @staticmethod def scott(data, ddof=0): """ Scott's rule for the number of bins in a histogram. """ if np.std(data, ddof=ddof) == 0: return sturges_rule(data) h = 3.5*np.std(data, ddof=ddof)/cbrt(len(data)) return int((np.max(data) - np.min(data))/h) def hist(self, data, bins=None, get_binedges=False, lw=0, **kwargs): """ Plot a histogram. """ defaults = { 'color': Figure.colors('blue'), 'normed': True, 'rwidth': 1, 'histtype': 'stepfilled' } # Fill parameters for k in defaults: kwargs.setdefault(k, defaults[k]) # Determine number of bins if bins is None: if len(data) < 200: bins = Subplot.sturges(data) else: bins = Subplot.scott(data) # Plot histogram pdf, binedges, patches = self.ax.hist(data, bins, **kwargs) # Modify appearance plt.setp(patches, 'facecolor', kwargs['color'], 'linewidth', lw) if get_binedges: return pdf, binedges return pdf #========================================================================================= # Subplot (3D) #========================================================================================= class Subplot3D(object): def __getattr__(self, name): if hasattr(self.ax, name): return getattr(self.ax, name) if hasattr(self.ax, 'set_'+name): return getattr(self.ax, 'set_'+name) raise NotImplementedError("Subplot3D." + name) def __init__(self, fig, p, rect): self.p = p self.ax = fig.add_axes(rect, projection='3d') #///////////////////////////////////////////////////////////////////////////////////// def xlabel(self, *args, **kwargs): kwargs.setdefault('fontsize', self.p['axislabelsize']) kwargs.setdefault('labelpad', self.p['labelpadx']) return self.set_xlabel(*args, **kwargs) def ylabel(self, *args, **kwargs): kwargs.setdefault('fontsize', self.p['axislabelsize']) kwargs.setdefault('labelpad', self.p['labelpady']) return self.set_ylabel(*args, **kwargs) def zlabel(self, *args, **kwargs): kwargs.setdefault('fontsize', self.p['axislabelsize']) kwargs.setdefault('labelpad', self.p['labelpadz']) return self.set_zlabel(*args, **kwargs) def xticks(self, *args, **kwargs): if len(args) == 0: args = [], return self.set_xticks(*args, **kwargs) def yticks(self, *args, **kwargs): if len(args) == 0: args = [], return self.set_yticks(*args, **kwargs) def zticks(self, *args, **kwargs): if len(args) == 0: args = [], return self.set_zticks(*args, **kwargs) #///////////////////////////////////////////////////////////////////////////////////// def zcircle(self, center, r, z0, **kwargs): circle = mpl.patches.Circle(center, r, **kwargs) patch = self.add_patch(circle) art3d.pathpatch_2d_to_3d(patch, z=z0) #========================================================================================= # Figure #========================================================================================= class Figure(object): defaults = { 'w': 6.5, 'h': 5.5, 'rect': [0.12, 0.12, 0.8, 0.8], 'thickness': 0.7, 'ticksize': 3.5, 'axislabelsize': 8.5, 'ticklabelsize': 7, 'ticklabelpad': 2.5, 'labelpadx': 7.5, 'labelpady': 8.5, 'labelpadz': 7.5, 'format': 'pdf' } @staticmethod def colors(name): if name in colors: return colors[name] return name #///////////////////////////////////////////////////////////////////////////////////// def __getattr__(self, name): if hasattr(self.fig, name): return getattr(self.fig, name) raise NotImplementedError("Figure." + name) def __init__(self, **kwargs): self.p = kwargs.copy() for k in Figure.defaults: self.p.setdefault(k, Figure.defaults[k]) self.fig = plt.figure(figsize=(self.p['w'], self.p['h'])) self.plots = [] #///////////////////////////////////////////////////////////////////////////////////// def add(self, rect=None, style='bottomleft', projection=None, **kwargs): if rect is None: rect = self.p['rect'] # Override figure defaults for this subplot p = self.p.copy() for k, v in kwargs.items(): if k in p: p[k] = v # 3D subplot if projection == '3d': return Subplot3D(self.fig, p, rect) plot = Subplot(self.fig, p, rect) if style is not None: plot.format(style, p) # List of plots in this figure self.plots.append(plot) return plot #///////////////////////////////////////////////////////////////////////////////////// def plotlabels(self, labels, **kwargs): plot = self.plots[0] for label, (x, y) in labels.items(): plot.text(x, y, label, ha='left', va='bottom', transform=self.transFigure, **kwargs) #///////////////////////////////////////////////////////////////////////////////////// def shared_lim(self, plots, axis, data, **kwargs): """ Make the axis scale the same in all the plots. """ try: data = np.concatenate(data) except: pass data = np.ravel(data) for plot in plots: lim = plot.lim(axis, data, **kwargs) return lim #///////////////////////////////////////////////////////////////////////////////////// def save(self, name=None, path=None, fmt=None, transparent=True, **kwargs): if path is None: path = os.path.dirname(os.path.realpath(sys.argv[0])) path = os.path.join(path, 'work', 'figs') mkdir_p(path) if name is None: name = os.path.splitext(sys.argv[0].split(os.path.sep)[-1])[0] fname = join(path, name + '.' + self.p['format']) plt.figure(self.fig.number) plt.savefig(fname, transparent=transparent, **kwargs) print("[ {}.Figure.save ] ".format(THIS) + fname) def close(self): plt.close(self.fig)