# -*- coding: utf-8 -*- """ Mostly deprecated in favor of kwplot """ from __future__ import absolute_import, division, print_function, unicode_literals import cv2 import itertools as it import numpy as np import six import ubelt as ub from os.path import join, dirname def pandas_plot_matrix(df, rot=90, ax=None, grid=True, label=None, zerodiag=False, cmap='viridis', showvals=False, logscale=True): import matplotlib as mpl import copy from matplotlib import pyplot as plt import matplotlib.cm # NOQA import kwplot if ax is None: fig = kwplot.figure(fnum=1, pnum=(1, 1, 1)) fig.clear() ax = plt.gca() ax = plt.gca() values = df.values if zerodiag: values = values.copy() values = values - np.diag(np.diag(values)) # aximg = ax.imshow(values, interpolation='none', cmap='viridis') if logscale: from matplotlib.colors import LogNorm vmin = df[df > 0].min().min() norm = LogNorm(vmin=vmin, vmax=values.max()) else: norm = None cmap = copy.copy(mpl.cm.get_cmap(cmap)) # copy the default cmap cmap.set_bad((0, 0, 0)) aximg = ax.matshow(values, interpolation='none', cmap=cmap, norm=norm) # aximg = ax.imshow(values, interpolation='none', cmap='viridis', norm=norm) # ax.imshow(values, interpolation='none', cmap='viridis') ax.grid(False) cax = plt.colorbar(aximg, ax=ax) if label is not None: cax.set_label(label) ax.set_xticks(list(range(len(df.index)))) ax.set_xticklabels([lbl[0:100] for lbl in df.index]) for lbl in ax.get_xticklabels(): lbl.set_rotation(rot) for lbl in ax.get_xticklabels(): lbl.set_horizontalalignment('center') ax.set_yticks(list(range(len(df.columns)))) ax.set_yticklabels([lbl[0:100] for lbl in df.columns]) for lbl in ax.get_yticklabels(): lbl.set_horizontalalignment('right') for lbl in ax.get_yticklabels(): lbl.set_verticalalignment('center') # Grid lines around the pixels if grid: offset = -.5 xlim = [-.5, len(df.columns)] ylim = [-.5, len(df.index)] segments = [] for x in range(ylim[1]): xdata = [x + offset, x + offset] ydata = ylim segment = list(zip(xdata, ydata)) segments.append(segment) for y in range(xlim[1]): xdata = xlim ydata = [y + offset, y + offset] segment = list(zip(xdata, ydata)) segments.append(segment) bingrid = mpl.collections.LineCollection(segments, color='w', linewidths=1) ax.add_collection(bingrid) if showvals: x_basis = np.arange(len(df.columns)) y_basis = np.arange(len(df.index)) x, y = np.meshgrid(x_basis, y_basis) for c, r in zip(x.flatten(), y.flatten()): val = df.iloc[r, c] ax.text(c, r, val, va='center', ha='center', color='white') return ax def axes_extent(axs, pad=0.0): """ Get the full extent of a group of axes, including axes labels, tick labels, and titles. """ import matplotlib as mpl def axes_parts(ax): yield ax for label in ax.get_xticklabels(): if label.get_text(): yield label for label in ax.get_yticklabels(): if label.get_text(): yield label xlabel = ax.get_xaxis().get_label() ylabel = ax.get_yaxis().get_label() for label in (xlabel, ylabel, ax.title): if label.get_text(): yield label items = it.chain.from_iterable(axes_parts(ax) for ax in axs) extents = [item.get_window_extent() for item in items] #mpl.transforms.Affine2D().scale(1.1) extent = mpl.transforms.Bbox.union(extents) extent = extent.expanded(1.0 + pad, 1.0 + pad) return extent def extract_axes_extents(fig, combine=False, pad=0.0): """ Extracts the extent of each axes item in inches. The main purpose of this is to set `bbox_inches` in `fig.savefig`, such that only the important data is visualized. Args: fig (Figure): the figure combine (bool): if True returns the union of each extent pad (float): additional padding around each axes Returns: matplotlib.transforms.Bbox or list of matplotlib.transforms.Bbox """ # Make sure we draw the axes first so we can # extract positions from the text objects import matplotlib as mpl fig.canvas.draw() # Group axes that belong together atomic_axes = [] seen_ = set([]) for ax in fig.axes: if ax not in seen_: atomic_axes.append([ax]) seen_.add(ax) dpi_scale_trans_inv = fig.dpi_scale_trans.inverted() axes_bboxes_ = [axes_extent(axs, pad) for axs in atomic_axes] axes_extents_ = [extent.transformed(dpi_scale_trans_inv) for extent in axes_bboxes_] # axes_extents_ = axes_bboxes_ if combine: # Grab include extents of figure text as well # FIXME: This might break on OSX # http://stackoverflow.com/questions/22667224/bbox-backend renderer = fig.canvas.get_renderer() for mpl_text in fig.texts: bbox = mpl_text.get_window_extent(renderer=renderer) extent_ = bbox.expanded(1.0 + pad, 1.0 + pad) extent = extent_.transformed(dpi_scale_trans_inv) # extent = extent_ axes_extents_.append(extent) axes_extents = mpl.transforms.Bbox.union(axes_extents_) # if True: # axes_extents.x0 = 0 # # axes_extents.y1 = 0 else: axes_extents = axes_extents_ return axes_extents def adjust_subplots(left=None, right=None, bottom=None, top=None, wspace=None, hspace=None, fig=None): """ Kwargs: left (float): left side of the subplots of the figure right (float): right side of the subplots of the figure bottom (float): bottom of the subplots of the figure top (float): top of the subplots of the figure wspace (float): width reserved for blank space between subplots hspace (float): height reserved for blank space between subplots """ from matplotlib import pyplot as plt kwargs = dict(left=left, right=right, bottom=bottom, top=top, wspace=wspace, hspace=hspace) kwargs = {k: v for k, v in kwargs.items() if v is not None} if fig is None: fig = plt.gcf() subplotpars = fig.subplotpars adjust_dict = subplotpars.__dict__.copy() del adjust_dict['validate'] adjust_dict.update(kwargs) fig.subplots_adjust(**adjust_dict) def render_figure_to_image(fig, dpi=None, transparent=None, **savekw): """ Saves a figure as an image in memory. Args: fig (matplotlib.figure.Figure): figure to save dpi (int or str, Optional): The resolution in dots per inch. If *None* it will default to the value ``savefig.dpi`` in the matplotlibrc file. If 'figure' it will set the dpi to be the value of the figure. transparent (bool): If *True*, the axes patches will all be transparent; the figure patch will also be transparent unless facecolor and/or edgecolor are specified via kwargs. **savekw: other keywords passed to `fig.savefig`. Valid keywords include: facecolor, edgecolor, orientation, papertype, format, pad_inches, frameon. Returns: np.ndarray: an image in BGR or BGRA format. Notes: Be sure to use `fig.set_size_inches` to an appropriate size before calling this function. """ import io import cv2 # import matplotlib as mpl # axes_extents = extract_axes_extents(fig) # extent = mpl.transforms.Bbox.union(axes_extents) extent = 'tight' # mpl might do this correctly these days with io.BytesIO() as stream: # This call takes 23% - 15% of the time depending on settings fig.savefig(stream, bbox_inches=extent, dpi=dpi, transparent=transparent, **savekw) # fig.savefig(stream, **savekw) stream.seek(0) data = np.fromstring(stream.getvalue(), dtype=np.uint8) im_bgra = cv2.imdecode(data, cv2.IMREAD_UNCHANGED) return im_bgra def savefig2(fig, fpath, **kwargs): """ Does a tight layout and saves the figure with transparency DEPRICATE """ import matplotlib as mpl if 'transparent' not in kwargs: kwargs['transparent'] = True if 'extent' not in kwargs: axes_extents = extract_axes_extents(fig) extent = mpl.transforms.Bbox.union(axes_extents) kwargs['extent'] = extent fig.savefig(fpath, **kwargs) def copy_figure_to_clipboard(fig): """ References: https://stackoverflow.com/questions/17676373/python-matplotlib-pyqt-copy-image-to-clipboard """ print('Copying figure %d to the clipboard' % fig.number) import matplotlib as mpl app = mpl.backends.backend_qt5.qApp QtGui = mpl.backends.backend_qt5.QtGui im_bgra = render_figure_to_image(fig, transparent=True) im_rgba = cv2.cvtColor(im_bgra, cv2.COLOR_BGRA2RGBA) im = im_rgba QImage = QtGui.QImage qim = QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QImage.Format_RGBA8888) clipboard = app.clipboard() clipboard.setImage(qim) # size = fig.canvas.size() # width, height = size.width(), size.height() # qim = QtGui.QImage(fig.canvas.buffer_rgba(), width, height, QtGui.QImage.Format_ARGB32) # QtWidgets = mpl.backends.backend_qt5.QtWidgets # pixmap = QtWidgets.QWidget.grab(fig.canvas) # clipboard.setPixmap(pixmap) def _get_axis_xy_width_height(ax=None, xaug=0, yaug=0, waug=0, haug=0): """ gets geometry of a subplot """ from matplotlib import pyplot as plt if ax is None: ax = plt.gca() autoAxis = ax.axis() xy = (autoAxis[0] + xaug, autoAxis[2] + yaug) width = (autoAxis[1] - autoAxis[0]) + waug height = (autoAxis[3] - autoAxis[2]) + haug return xy, width, height _LEGEND_LOCATION = { 'upper right': 1, 'upper left': 2, 'lower left': 3, 'lower right': 4, 'right': 5, 'center left': 6, 'center right': 7, 'lower center': 8, 'upper center': 9, 'center': 10, } _BASE_FNUM = 9001 def _save_requested(fpath_, save_parts): raise NotImplementedError('havent done this yet') # dpi = ub.argval('--dpi', type_=int, default=200) from os.path import expanduser from matplotlib import pyplot as plt dpi = 200 fpath_ = expanduser(fpath_) print('Figure save was requested') # arg_dict = ut.get_arg_dict(prefix_list=['--', '-'], # type_hints={'t': list, 'a': list}) arg_dict = {} # HACK arg_dict = { key: (val[0] if len(val) == 1 else '[' + ']['.join(val) + ']') if isinstance(val, list) else val for key, val in arg_dict.items() } fpath_ = fpath_.format(**arg_dict) fpath_ = fpath_.replace(' ', '').replace('\'', '').replace('"', '') dpath = ub.argval('--dpath', type_=str, default=None) if dpath is None: gotdpath = False dpath = '.' else: gotdpath = True fpath = join(dpath, fpath_) if not gotdpath: dpath = dirname(fpath_) print('dpath = %r' % (dpath,)) fig = plt.gcf() fig.dpi = dpi fpath_strict = ub.expandpath(fpath) CLIP_WHITE = ub.argflag('--clipwhite') from netharn import util if save_parts: # TODO: call save_parts instead, but we still need to do the # special grouping. # Group axes that belong together atomic_axes = [] seen_ = set([]) for ax in fig.axes: div = _get_plotdat(ax, _DF2_DIVIDER_KEY, None) if div is not None: df2_div_axes = _get_plotdat_dict(ax).get('df2_div_axes', []) seen_.add(ax) seen_.update(set(df2_div_axes)) atomic_axes.append([ax] + df2_div_axes) # TODO: pad these a bit else: if ax not in seen_: atomic_axes.append([ax]) seen_.add(ax) hack_axes_group_row = ub.argflag('--grouprows') if hack_axes_group_row: groupid_list = [] for axs in atomic_axes: for ax in axs: groupid = ax.colNum groupid_list.append(groupid) groups = ub.group_items(atomic_axes, groupid_list) new_groups = list(map(ub.flatten, groups.values())) atomic_axes = new_groups #[[(ax.rowNum, ax.colNum) for ax in axs] for axs in atomic_axes] # save all rows of each column subpath_list = save_parts(fig=fig, fpath=fpath_strict, grouped_axes=atomic_axes, dpi=dpi) absfpath_ = subpath_list[-1] if CLIP_WHITE: for subpath in subpath_list: # remove white borders util.clipwhite_ondisk(subpath, subpath) else: savekw = {} # savekw['transparent'] = fpath.endswith('.png') and not noalpha savekw['transparent'] = ub.argflag('--alpha') savekw['dpi'] = dpi savekw['edgecolor'] = 'none' savekw['bbox_inches'] = extract_axes_extents(fig, combine=True) # replaces need for clipwhite absfpath_ = ub.expandpath(fpath) fig.savefig(absfpath_, **savekw) if CLIP_WHITE: # remove white borders fpath_in = fpath_out = absfpath_ util.clipwhite_ondisk(fpath_in, fpath_out) if ub.argflag(('--diskshow', '--ds')): # show what we wrote ub.startfile(absfpath_) def save_parts(fig, fpath, grouped_axes=None, dpi=None): """ FIXME: this works in mpl 2.0.0, but not 2.0.2 Args: fig (?): fpath (str): file path string dpi (None): (default = None) Returns: list: subpaths CommandLine: python -m draw_func2 save_parts Ignore: >>> # DISABLE_DOCTEST >>> import kwplot >>> kwplot.autompl() >>> import matplotlib as mpl >>> import matplotlib.pyplot as plt >>> def testimg(fname): >>> return plt.imread(mpl.cbook.get_sample_data(fname)) >>> fnames = ['grace_hopper.png', 'ada.png'] * 4 >>> fig = plt.figure(1) >>> for c, fname in enumerate(fnames, start=1): >>> ax = fig.add_subplot(3, 4, c) >>> ax.imshow(testimg(fname)) >>> ax.set_title(fname[0:3] + str(c)) >>> ax.set_xticks([]) >>> ax.set_yticks([]) >>> ax = fig.add_subplot(3, 1, 3) >>> ax.plot(np.sin(np.linspace(0, np.pi * 2))) >>> ax.set_xlabel('xlabel') >>> ax.set_ylabel('ylabel') >>> ax.set_title('title') >>> fpath = 'test_save_parts.png' >>> adjust_subplots(fig=fig, wspace=.3, hspace=.3, top=.9) >>> subpaths = save_parts(fig, fpath, dpi=300) >>> fig.savefig(fpath) >>> ub.startfile(subpaths[0]) >>> ub.startfile(fpath) """ if dpi: # Need to set figure dpi before we draw fig.dpi = dpi # We need to draw the figure before calling get_window_extent # (or we can figure out how to set the renderer object) # if getattr(fig.canvas, 'renderer', None) is None: fig.canvas.draw() # Group axes that belong together if grouped_axes is None: grouped_axes = [] for ax in fig.axes: grouped_axes.append([ax]) subpaths = [] _iter = enumerate(grouped_axes, start=0) _iter = ub.ProgIter(list(_iter), label='save subfig') for count, axs in _iter: subpath = ub.augpath(fpath, suffix=chr(count + 65)) extent = axes_extent(axs).transformed(fig.dpi_scale_trans.inverted()) savekw = {} savekw['transparent'] = ub.argflag('--alpha') if dpi is not None: savekw['dpi'] = dpi savekw['edgecolor'] = 'none' fig.savefig(subpath, bbox_inches=extent, **savekw) subpaths.append(subpath) return subpaths _qtensured = False def _current_ipython_session(): """ Returns a reference to the current IPython session, if one is running """ try: __IPYTHON__ except NameError: return None else: import IPython ipython = IPython.get_ipython() # if ipython is None we must have exited ipython at some point return ipython def qtensure(): """ If you are in an IPython session, ensures that your backend is Qt. """ global _qtensured if not _qtensured: ipython = _current_ipython_session() if ipython: import sys if 'PyQt4' in sys.modules: ipython.magic('pylab qt4 --no-import-all') _qtensured = True else: ipython.magic('pylab qt5 --no-import-all') _qtensured = True def aggensure(): """ Ensures that you are in agg mode as long as IPython is not running This might help prevent errors in tmux like: qt.qpa.screen: QXcbConnection: Could not connect to display localhost:10.0 Could not connect to any X display. """ import matplotlib as mpl current_backend = mpl.get_backend() if current_backend != 'agg': ipython = _current_ipython_session() if not ipython: import kwplot kwplot.set_mpl_backend('agg') def colorbar(scalars, colors, custom=False, lbl=None, ticklabels=None, float_format='%.2f', **kwargs): """ adds a color bar next to the axes based on specific scalars Args: scalars (ndarray): colors (ndarray): custom (bool): use custom ticks Kwargs: See plt.colorbar Returns: cb : matplotlib colorbar object Ignore: >>> import kwplot >>> kwplot.autompl() >>> scalars = np.array([-1, -2, 1, 1, 2, 7, 10]) >>> cmap_ = 'plasma' >>> logscale = False >>> custom = True >>> reverse_cmap = True >>> val2_customcolor = { ... -1: UNKNOWN_PURP, ... -2: LIGHT_BLUE, ... } >>> colors = scores_to_color(scalars, cmap_=cmap_, logscale=logscale, reverse_cmap=reverse_cmap, val2_customcolor=val2_customcolor) >>> colorbar(scalars, colors, custom=custom) >>> df2.present() >>> show_if_requested() Ignore: >>> # ENABLE_DOCTEST >>> scalars = np.linspace(0, 1, 100) >>> cmap_ = 'plasma' >>> logscale = False >>> custom = False >>> reverse_cmap = False >>> colors = scores_to_color(scalars, cmap_=cmap_, logscale=logscale, >>> reverse_cmap=reverse_cmap) >>> colors = [lighten_rgb(c, .3) for c in colors] >>> colorbar(scalars, colors, custom=custom) >>> df2.present() >>> show_if_requested() """ import matplotlib as mpl import matplotlib.pyplot as plt import matplotlib.cm # NOQA assert len(scalars) == len(colors), 'scalars and colors must be corresponding' if len(scalars) == 0: return None # Parameters ax = plt.gca() divider = _ensure_divider(ax) cax = divider.append_axes('right', size='5%', pad=0.05) xy, width, height = _get_axis_xy_width_height(ax) #orientation = ['vertical', 'horizontal'][0] TICK_FONTSIZE = 8 # # Create scalar mappable with cmap if custom: # FIXME: clean this code up and change the name custom # to be meaningful. It is more like: display unique colors unique_scalars, unique_idx = np.unique(scalars, return_index=True) unique_colors = np.array(colors)[unique_idx] #max_, min_ = unique_scalars.max(), unique_scalars.min() #extent_ = max_ - min_ #bounds = np.linspace(min_, max_ + 1, extent_ + 2) listed_cmap = mpl.colors.ListedColormap(unique_colors) #norm = mpl.colors.BoundaryNorm(bounds, listed_cmap.N) #sm = mpl.cm.ScalarMappable(cmap=listed_cmap, norm=norm) sm = mpl.cm.ScalarMappable(cmap=listed_cmap) sm.set_array(np.linspace(0, 1, len(unique_scalars) + 1)) else: sorted_scalars = sorted(scalars) listed_cmap = scores_to_cmap(scalars, colors) sm = plt.cm.ScalarMappable(cmap=listed_cmap) sm.set_array(sorted_scalars) # Use mapable object to create the colorbar #COLORBAR_SHRINK = .42 # 1 #COLORBAR_PAD = .01 # 1 #COLORBAR_ASPECT = np.abs(20 * height / (width)) # 1 cb = plt.colorbar(sm, cax=cax, **kwargs) ## Add the colorbar to the correct label #axis = cb.ax.yaxis # if orientation == 'horizontal' else cb.ax.yaxis #position = 'bottom' if orientation == 'horizontal' else 'right' #axis.set_ticks_position(position) # This line alone removes data # axis.set_ticks([0, .5, 1]) if custom: ticks = np.linspace(0, 1, len(unique_scalars) + 1) if len(ticks) < 2: ticks += .5 else: # SO HACKY ticks += (ticks[1] - ticks[0]) / 2 if isinstance(unique_scalars, np.ndarray) and unique_scalars.dtype.kind == 'f': ticklabels = [float_format % scalar for scalar in unique_scalars] else: ticklabels = unique_scalars cb.set_ticks(ticks) # tick locations cb.set_ticklabels(ticklabels) # tick labels elif ticklabels is not None: ticks_ = cb.ax.get_yticks() mx = ticks_.max() mn = ticks_.min() ticks = np.linspace(mn, mx, len(ticklabels)) cb.set_ticks(ticks) # tick locations cb.set_ticklabels(ticklabels) #cb.ax.get_yticks() #cb.set_ticks(ticks) # tick locations #cb.set_ticklabels(ticklabels) # tick labels # _set_plotdat(cb.ax, 'viztype', 'colorbar-%s' % (lbl,)) # _set_plotdat(cb.ax, 'sm', sm) # FIXME: Figure out how to make a maximum number of ticks # and to enforce them to be inside the data bounds cb.ax.tick_params(labelsize=TICK_FONTSIZE) # Sets current axis plt.sca(ax) if lbl is not None: cb.set_label(lbl) return cb _DF2_DIVIDER_KEY = '_df2_divider' def _get_plotdat(ax, key, default=None): """ returns internal property from a matplotlib axis """ _plotdat = _get_plotdat_dict(ax) val = _plotdat.get(key, default) return val def _set_plotdat(ax, key, val): """ sets internal property to a matplotlib axis """ _plotdat = _get_plotdat_dict(ax) _plotdat[key] = val def _del_plotdat(ax, key): """ sets internal property to a matplotlib axis """ _plotdat = _get_plotdat_dict(ax) if key in _plotdat: del _plotdat[key] def _get_plotdat_dict(ax): """ sets internal property to a matplotlib axis """ if '_plotdat' not in ax.__dict__: ax.__dict__['_plotdat'] = {} plotdat_dict = ax.__dict__['_plotdat'] return plotdat_dict def _ensure_divider(ax): """ Returns previously constructed divider or creates one """ from mpl_toolkits.axes_grid1 import make_axes_locatable divider = _get_plotdat(ax, _DF2_DIVIDER_KEY, None) if divider is None: divider = make_axes_locatable(ax) _set_plotdat(ax, _DF2_DIVIDER_KEY, divider) orig_append_axes = divider.append_axes def df2_append_axes(divider, position, size, pad=None, add_to_figure=True, **kwargs): """ override divider add axes to register the divided axes """ div_axes = _get_plotdat(ax, 'df2_div_axes', []) new_ax = orig_append_axes(position, size, pad=pad, add_to_figure=add_to_figure, **kwargs) div_axes.append(new_ax) _set_plotdat(ax, 'df2_div_axes', div_axes) return new_ax new_method = df2_append_axes.__get__(divider, divider.__class__) setattr(divider, 'append_axes', new_method) # ut.inject_func_as_method(divider, df2_append_axes, 'append_axes', allow_override=True) return divider def scores_to_cmap(scores, colors=None, cmap_='hot'): import matplotlib as mpl if colors is None: colors = scores_to_color(scores, cmap_=cmap_) scores = np.array(scores) colors = np.array(colors) sortx = scores.argsort() sorted_colors = colors[sortx] # Make a listed colormap and mappable object listed_cmap = mpl.colors.ListedColormap(sorted_colors) return listed_cmap def scores_to_color(score_list, cmap_='hot', logscale=False, reverse_cmap=False, custom=False, val2_customcolor=None, score_range=None, cmap_range=(.1, .9)): """ Other good colormaps are 'spectral', 'gist_rainbow', 'gist_ncar', 'Set1', 'Set2', 'Accent' # TODO: plasma Args: score_list (list): cmap_ (str): defaults to hot logscale (bool): cmap_range (tuple): restricts to only a portion of the cmap to avoid extremes Returns: <class '_ast.ListComp'> Ignore: >>> ut.exec_funckw(scores_to_color, globals()) >>> score_list = np.array([-1, -2, 1, 1, 2, 10]) >>> # score_list = np.array([0, .1, .11, .12, .13, .8]) >>> # score_list = np.linspace(0, 1, 100) >>> cmap_ = 'plasma' >>> colors = scores_to_color(score_list, cmap_) >>> imgRGB = kwarray.atleast_nd(np.array(colors)[:, 0:3], 3, tofront=True) >>> imgRGB = imgRGB.astype(np.float32) >>> imgBGR = kwimage.convert_colorspace(imgRGB, 'BGR', 'RGB') >>> imshow(imgBGR) >>> show_if_requested() Ignore: >>> score_list = np.array([-1, -2, 1, 1, 2, 10]) >>> cmap_ = 'hot' >>> logscale = False >>> reverse_cmap = True >>> custom = True >>> val2_customcolor = { ... -1: UNKNOWN_PURP, ... -2: LIGHT_BLUE, ... } """ import matplotlib.pyplot as plt assert len(score_list.shape) == 1, 'score must be 1d' if len(score_list) == 0: return [] def apply_logscale(scores): scores = np.array(scores) above_zero = scores >= 0 scores_ = scores.copy() scores_[above_zero] = scores_[above_zero] + 1 scores_[~above_zero] = scores_[~above_zero] - 1 scores_ = np.log2(scores_) return scores_ if logscale: # Hack score_list = apply_logscale(score_list) #if loglogscale #score_list = np.log2(np.log2(score_list + 2) + 1) #if isinstance(cmap_, six.string_types): cmap = plt.get_cmap(cmap_) #else: # cmap = cmap_ if reverse_cmap: cmap = reverse_colormap(cmap) #if custom: # base_colormap = cmap # data = score_list # cmap = customize_colormap(score_list, base_colormap) if score_range is None: min_ = score_list.min() max_ = score_list.max() else: min_ = score_range[0] max_ = score_range[1] if logscale: min_, max_ = apply_logscale([min_, max_]) if cmap_range is None: cmap_scale_min, cmap_scale_max = 0., 1. else: cmap_scale_min, cmap_scale_max = cmap_range extent_ = max_ - min_ if extent_ == 0: colors = [cmap(.5) for fx in range(len(score_list))] else: if False and logscale: # hack def score2_01(score): return np.log2( 1 + cmap_scale_min + cmap_scale_max * (float(score) - min_) / (extent_)) score_list = np.array(score_list) #rank_multiplier = score_list.argsort() / len(score_list) #normscore = np.array(list(map(score2_01, score_list))) * rank_multiplier normscore = np.array(list(map(score2_01, score_list))) colors = list(map(cmap, normscore)) else: def score2_01(score): return cmap_scale_min + cmap_scale_max * (float(score) - min_) / (extent_) colors = [cmap(score2_01(score)) for score in score_list] if val2_customcolor is not None: colors = [ np.array(val2_customcolor.get(score, color)) for color, score in zip(colors, score_list)] return colors def interpolated_colormap(colors, resolution=64, space='lch-ab'): """ Interpolates between colors in `space` to create a smooth listed colormap Args: colors (list or dict): list of colors or color objects and where in the map they should appear. resolution (int): number of discrete items in the colormap space (str): colorspace to interpolate in, using a CIE-LAB space will result in a perceptually uniform interpolation. HSV also works well. References: http://stackoverflow.com/questions/12073306/customize-colorbar-in-matplotlib CommandLine: python -m netharn.util.mplutil interpolated_colormap Example: >>> # DISABLE_DOCTEST >>> import kwplot >>> colors = [ >>> (0.0, kwplot.Color('green')), >>> (0.5, kwplot.Color('gray')), >>> (1.0, kwplot.Color('red')), >>> ] >>> space = 'lab' >>> #resolution = 16 + 1 >>> resolution = 256 + 1 >>> cmap = interpolated_colormap(colors, resolution, space) >>> # xdoc: +REQUIRES(--show) >>> import pylab >>> from matplotlib import pyplot as plt >>> a = np.linspace(0, 1, resolution).reshape(1, -1) >>> pylab.imshow(a, aspect='auto', cmap=cmap, interpolation='nearest') # , origin="lower") >>> plt.grid(False) >>> show_if_requested() """ import colorsys import matplotlib as mpl colors_inputs = colors if isinstance(colors_inputs, dict): colors_inputs = [(f, c) for f, c in sorted(colors_inputs.items())] else: if len(colors_inputs[0]) != 2: fracs = np.linspace(0, 1, len(colors_inputs)) colors_inputs = list(zip(fracs, colors_inputs)) # print('colors_inputs = {!r}'.format(colors_inputs)) import kwplot colors = [kwplot.Color(c) for f, c in colors_inputs] fracs = [f for f, c in colors_inputs] basis = np.linspace(0, 1, resolution) fracs = np.array(fracs) indices = np.searchsorted(fracs, basis) indices = np.maximum(indices, 1) cpool = [] from colormath import color_conversions # FIXME: need to ensure monkeypatch for networkx 2.0 in colormath # color_conversions._conversion_manager = color_conversions.GraphConversionManager() from colormath import color_objects def new_convertor(target_obj): source_obj = color_objects.sRGBColor def to_target(src_tup): src_tup = src_tup[0:3] src_co = source_obj(*src_tup) target_co = color_conversions.convert_color(src_co, target_obj) target_tup = target_co.get_value_tuple() return target_tup def from_target(target_tup): target_co = target_obj(*target_tup) src_co = color_conversions.convert_color(target_co, source_obj) src_tup = src_co.get_value_tuple() return src_tup return to_target, from_target def from_hsv(rgb): return colorsys.rgb_to_hsv(*rgb[0:3]) def to_hsv(hsv): return colorsys.hsv_to_rgb(*hsv[0:3].tolist()) classnames = { # 'AdobeRGBColor', # 'BaseRGBColor', 'cmk': 'CMYColor', 'cmyk': 'CMYKColor', 'hsl': 'HSLColor', 'hsv': 'HSVColor', 'ipt': 'IPTColor', 'lch-ab': 'LCHabColor', 'lch-uv': 'LCHuvColor', 'lab': 'LabColor', 'luv': 'LuvColor', # 'SpectralColor', 'xyz': 'XYZColor', # 'sRGBColor', 'xyy': 'xyYColor' } conversions = {k: new_convertor(getattr(color_objects, v)) for k, v in classnames.items()} from_rgb, to_rgb = conversions['hsv'] from_rgb, to_rgb = conversions['xyz'] from_rgb, to_rgb = conversions['lch-uv'] from_rgb, to_rgb = conversions['lch-ab'] from_rgb, to_rgb = conversions[space] # from_rgb, to_rgb = conversions['lch'] # from_rgb, to_rgb = conversions['lab'] # from_rgb, to_rgb = conversions['lch-uv'] for idx2, b in zip(indices, basis): idx1 = idx2 - 1 f1 = fracs[idx1] f2 = fracs[idx2] c1 = colors[idx1].as01('rgb') c2 = colors[idx2].as01('rgb') # from_rgb, to_rgb = conversions['lch'] h1 = np.array(from_rgb(c1)) h2 = np.array(from_rgb(c2)) alpha = (b - f1) / (f2 - f1) new_h = h1 * (1 - alpha) + h2 * (alpha) new_c = np.clip(to_rgb(new_h), 0, 1) # print('new_c = %r' % (new_c,)) cpool.append(new_c) cpool = np.array(cpool) cmap = mpl.colors.ListedColormap(cpool, 'indexed') return cmap def reverse_colormap(cmap): """ References: http://nbviewer.ipython.org/github/kwinkunks/notebooks/blob/master/Matteo_colourmaps.ipynb """ import matplotlib as mpl if isinstance(cmap, mpl.colors.ListedColormap): return mpl.colors.ListedColormap(cmap.colors[::-1]) else: reverse = [] k = [] for key, channel in six.iteritems(cmap._segmentdata): data = [] for t in channel: data.append((1 - t[0], t[1], t[2])) k.append(key) reverse.append(sorted(data)) cmap_reversed = mpl.colors.LinearSegmentedColormap( cmap.name + '_reversed', dict(zip(k, reverse))) return cmap_reversed def draw_border(ax, color, lw=2, offset=None, adjust=True): 'draws rectangle border around a subplot' if adjust: xy, width, height = _get_axis_xy_width_height(ax, -.7, -.2, 1, .4) else: xy, width, height = _get_axis_xy_width_height(ax) if offset is not None: xoff, yoff = offset xy = [xoff, yoff] height = - height - yoff width = width - xoff import matplotlib as mpl rect = mpl.patches.Rectangle(xy, width, height, lw=lw) rect = ax.add_patch(rect) rect.set_clip_on(False) rect.set_fill(False) import kwplot rect.set_edgecolor(kwplot.Color(color).as01('rgb')) return rect def colorbar_image(domain, cmap='plasma', dpi=96, shape=(200, 20), transparent=False): """ Notes: shape is approximate Ignore: domain = np.linspace(-30, 200) cmap='plasma' dpi = 80 dsize = (20, 200) util.imwrite('foo.png', util.colorbar_image(np.arange(0, 1)), shape=(400, 80)) import plottool as pt pt.qtensure() import matplotlib as mpl mpl.style.use('ggplot') util.imwrite('foo.png', util.colorbar_image(np.linspace(0, 1, 100), dpi=200, shape=(1000, 40), transparent=1)) ub.startfile('foo.png') """ import kwplot plt = kwplot.autoplt() fig = plt.figure(dpi=dpi) w, h = shape[1] / dpi, shape[0] / dpi # w, h = 1, 10 fig.set_size_inches(w, h) ax = fig.add_subplot('111') sm = plt.cm.ScalarMappable(cmap=plt.get_cmap(cmap)) sm.set_array(domain) plt.colorbar(sm, cax=ax) cb_img = render_figure_to_image(fig, dpi=dpi, transparent=transparent) plt.close(fig) return cb_img def make_legend_img(classname_to_rgb, dpi=96, shape=(200, 200), mode='line', transparent=False): """ Makes an image of a categorical legend CommandLine: python -m netharn.util.mplutil make_legend_img Example: >>> # xdoctest: +REQUIRES(module:kwplot) >>> import kwplot >>> classname_to_rgb = { >>> 'blue': kwplot.Color('blue').as01(), >>> 'red': kwplot.Color('red').as01(), >>> } >>> img = make_legend_img(classname_to_rgb) >>> # xdoctest: +REQUIRES(--show) >>> kwplot.autompl() >>> kwplot.imshow(img) >>> kwplot.show_if_requested() """ import kwplot plt = kwplot.autoplt() def append_phantom_legend_label(label, color, type_='line', alpha=1.0, ax=None): if ax is None: ax = plt.gca() _phantom_legend_list = getattr(ax, '_phantom_legend_list', None) if _phantom_legend_list is None: _phantom_legend_list = [] setattr(ax, '_phantom_legend_list', _phantom_legend_list) if type_ == 'line': phantom_actor = plt.Line2D((0, 0), (1, 1), color=color, label=label, alpha=alpha) else: phantom_actor = plt.Circle((0, 0), 1, fc=color, label=label, alpha=alpha) _phantom_legend_list.append(phantom_actor) fig = plt.figure(dpi=dpi) w, h = shape[1] / dpi, shape[0] / dpi fig.set_size_inches(w, h) ax = fig.add_subplot('111') for label, color in classname_to_rgb.items(): append_phantom_legend_label(label, color, type_=mode, ax=ax) _phantom_legend_list = getattr(ax, '_phantom_legend_list', None) if _phantom_legend_list is None: _phantom_legend_list = [] setattr(ax, '_phantom_legend_list', _phantom_legend_list) ax.legend(handles=_phantom_legend_list) ax.grid(False) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.axis('off') legend_img = render_figure_to_image(fig, dpi=dpi, transparent=transparent) plt.close(fig) return legend_img if __name__ == '__main__': r""" CommandLine: python -m netharn.util.mplutil """ import xdoctest xdoctest.doctest_module(__file__)