Python matplotlib.gridspec.GridSpec() Examples

The following are 30 code examples of matplotlib.gridspec.GridSpec(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.gridspec , or try the search function .
Example #1
Source File: cem.py    From visual_dynamics with MIT License 6 votes vote down vote up
def visualization_init(self):
        fig = plt.figure(figsize=(12, 6), frameon=False, tight_layout=True)
        fig.canvas.set_window_title(self.servoing_pol.predictor.name)
        gs = gridspec.GridSpec(1, 2)
        plt.show(block=False)

        return_plotter = LossPlotter(fig, gs[0],
                                     format_dicts=[dict(linewidth=2)] * 2,
                                     labels=['mean returns / 10', 'mean discounted returns'],
                                     ylabel='returns')
        return_major_locator = MultipleLocator(1)
        return_major_formatter = FormatStrFormatter('%d')
        return_minor_locator = MultipleLocator(1)
        return_plotter._ax.xaxis.set_major_locator(return_major_locator)
        return_plotter._ax.xaxis.set_major_formatter(return_major_formatter)
        return_plotter._ax.xaxis.set_minor_locator(return_minor_locator)

        learning_plotter = LossPlotter(fig, gs[1], format_dicts=[dict(linewidth=2)] * 2, ylabel='mean evaluation values')
        return fig, return_plotter, learning_plotter 
Example #2
Source File: plotting.py    From sonata with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_raster_cmp(spike_trains, population=None, time_window=None, node_ids=None, ts_units=None, show_plot=True,
                    save_as=None, with_labels=False):
    spike_trains_l = _get_spike_trains(spike_trains)
    time_window = time_window or _find_time_window(spike_trains_l, population)
    labels = _build_labels_lu(with_labels, spike_trains)

    gs = gridspec.GridSpec(1, 1)
    for i, spikes in enumerate(spike_trains_l):
        spikes_df = spikes.to_dataframe(populations=population, time_window=time_window, node_ids=node_ids)
        ax1 = plt.subplot(gs[0])
        ax1.scatter(spikes_df['timestamps'], spikes_df['node_ids'], lw=0, s=5, label=labels[i])
        ax1.legend(loc=1, prop={'size': 10})
        ax1.set_xlim([time_window[0], time_window[1]])

    if save_as is not None:
        plt.savefig(save_as)

    if show_plot:
        plt.show() 
Example #3
Source File: rbm_binary_cd.py    From generative-models with The Unlicense 6 votes vote down vote up
def plot(samples, size, name):
    size = int(size)
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(size, size), cmap='Greys_r')

    plt.savefig('out/{}.png'.format(name), bbox_inches='tight')
    plt.close(fig) 
Example #4
Source File: rbm_binary_pcd.py    From generative-models with The Unlicense 6 votes vote down vote up
def plot(samples, size, name):
    size = int(size)
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(size, size), cmap='Greys_r')

    plt.savefig('out/{}.png'.format(name), bbox_inches='tight')
    plt.close(fig) 
Example #5
Source File: _utils.py    From scanpy with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def make_grid_spec(
    ax_or_figsize: Union[Tuple[int, int], _AxesSubplot],
    nrows: int,
    ncols: int,
    wspace: Optional[float] = None,
    hspace: Optional[float] = None,
    width_ratios: Optional[Sequence[float]] = None,
    height_ratios: Optional[Sequence[float]] = None,
) -> Tuple[Figure, gridspec.GridSpecBase]:
    kw = dict(
        wspace=wspace,
        hspace=hspace,
        width_ratios=width_ratios,
        height_ratios=height_ratios,
    )
    if isinstance(ax_or_figsize, tuple):
        fig = pl.figure(figsize=ax_or_figsize)
        return fig, gridspec.GridSpec(nrows, ncols, **kw)
    else:
        ax = ax_or_figsize
        ax.axis('off')
        ax.set_frame_on(False)
        ax.set_xticks([])
        ax.set_yticks([])
        return ax.figure, ax.get_subplotspec().subgridspec(nrows, ncols, **kw) 
Example #6
Source File: scatterplots.py    From scanpy with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _panel_grid(hspace, wspace, ncols, num_panels):
    from matplotlib import gridspec

    n_panels_x = min(ncols, num_panels)
    n_panels_y = np.ceil(num_panels / n_panels_x).astype(int)
    # each panel will have the size of rcParams['figure.figsize']
    fig = pl.figure(
        figsize=(
            n_panels_x * rcParams['figure.figsize'][0] * (1 + wspace),
            n_panels_y * rcParams['figure.figsize'][1],
        ),
    )
    left = 0.2 / n_panels_x
    bottom = 0.13 / n_panels_y
    gs = gridspec.GridSpec(
        nrows=n_panels_y,
        ncols=n_panels_x,
        left=left,
        right=1 - (n_panels_x - 1) * left - 0.01 / n_panels_x,
        bottom=bottom,
        top=1 - (n_panels_y - 1) * bottom - 0.1 / n_panels_y,
        hspace=hspace,
        wspace=wspace,
    )
    return fig, gs 
Example #7
Source File: layouts.py    From MDT with GNU Lesser General Public License v3.0 6 votes vote down vote up
def get_gridspec(self, figure, nmr_plots):
        rows = self.rows
        cols = self.cols

        if rows is None and cols is None:
            return AutoGridLayout(spacings=self.spacings).get_gridspec(figure, nmr_plots)

        if rows is None:
            rows = int(np.ceil(nmr_plots / cols))
        if cols is None:
            cols = int(np.ceil(nmr_plots / rows))

        if rows * cols < nmr_plots:
            cols = int(np.ceil(nmr_plots / rows))

        return GridLayoutSpecifier(GridSpec(rows, cols, **self.spacings), figure) 
Example #8
Source File: helmholtz.py    From generative-models with The Unlicense 6 votes vote down vote up
def plot(samples, size, name):
    size = int(size)
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(size, size), cmap='Greys_r')

    plt.savefig('out/{}.png'.format(name), bbox_inches='tight')
    plt.close(fig) 
Example #9
Source File: plotting_utils.py    From QUANTAXIS with MIT License 5 votes vote down vote up
def __init__(self, rows, cols):
        self.rows = rows
        self.cols = cols
        self.fig = plt.figure(figsize=(14, rows * 7))
        self.gs = gridspec.GridSpec(rows, cols, wspace=0.4, hspace=0.3)
        self.curr_row = 0
        self.curr_col = 0 
Example #10
Source File: fqi.py    From visual_dynamics with MIT License 5 votes vote down vote up
def visualization_init(self):
        fig = plt.figure(figsize=(12, 6), frameon=False, tight_layout=True)
        fig.canvas.set_window_title(self.servoing_pol.predictor.name)
        gs = gridspec.GridSpec(1, 2)
        plt.show(block=False)

        return_plotter = LossPlotter(fig, gs[0],
                                     format_dicts=[dict(linewidth=2)] * 2,
                                     labels=['mean returns / 10', 'mean discounted returns'],
                                     ylabel='returns')
        return_major_locator = MultipleLocator(1)
        return_major_formatter = FormatStrFormatter('%d')
        return_minor_locator = MultipleLocator(1)
        return_plotter._ax.xaxis.set_major_locator(return_major_locator)
        return_plotter._ax.xaxis.set_major_formatter(return_major_formatter)
        return_plotter._ax.xaxis.set_minor_locator(return_minor_locator)

        learning_plotter = LossPlotter(fig, gs[1],
                                       format_strings=['', 'r--'],
                                       format_dicts=[dict(linewidth=2)] * 2,
                                       ylabel='Bellman errors', yscale='log')
        # learning_plotter._ax.set_ylim((10.0, 110000))
        learning_major_locator = MultipleLocator(1)
        learning_major_formatter = FormatStrFormatter('%d')
        learning_minor_locator = MultipleLocator(0.2)
        learning_plotter._ax.xaxis.set_major_locator(learning_major_locator)
        learning_plotter._ax.xaxis.set_major_formatter(learning_major_formatter)
        learning_plotter._ax.xaxis.set_minor_locator(learning_minor_locator)
        return fig, return_plotter, learning_plotter 
Example #11
Source File: avb_tensorflow.py    From generative-models with The Unlicense 5 votes vote down vote up
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig 
Example #12
Source File: nupic_anomaly_output.py    From ecg-htm with GNU Affero General Public License v3.0 5 votes vote down vote up
def __init__(self, *args, **kwargs):
        super(NuPICPlotOutput, self).__init__(*args, **kwargs)
        # Turn matplotlib interactive mode on.
        plt.ion()
        self.dates = []
        self.convertedDates = []
        self.value = []
        self.rawValue = []
        self.allValues = []
        self.allRawValues = []
        self.predicted = []
        self.anomalyScore = []
        self.anomalyLikelihood = []
        self.actualLine = None
        self.rawLine = None
        self.predictedLine = None
        self.anomalyScoreLine = None
        self.anomalyLikelihoodLine = None
        self.linesInitialized = False
        self._chartHighlights = []
        fig = plt.figure(figsize=(16, 10))
        gs = gridspec.GridSpec(2, 1, height_ratios=[3,    1])

        self._mainGraph = fig.add_subplot(gs[0, 0])
        plt.title(self.name)
        plt.ylabel('Value')
        plt.xlabel('Date')

        self._anomalyGraph = fig.add_subplot(gs[1])

        plt.ylabel('Percentage')
        plt.xlabel('Date')

        # Maximizes window
        mng = plt.get_current_fig_manager()
        mng.resize(800, 600)

        plt.tight_layout() 
Example #13
Source File: _subplots.py    From neural-network-animation with MIT License 5 votes vote down vote up
def change_geometry(self, numrows, numcols, num):
        """change subplot geometry, e.g., from 1,1,1 to 2,2,3"""
        self._subplotspec = GridSpec(numrows, numcols)[num - 1]
        self.update_params()
        self.set_position(self.figbox) 
Example #14
Source File: test_tightlayout.py    From neural-network-animation with MIT License 5 votes vote down vote up
def test_tight_layout6():
    'Test tight_layout for gridspec'

    # This raises warnings since tight layout cannot
    # do this fully automatically. But the test is
    # correct since the layout is manually edited
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        fig = plt.figure()

        import matplotlib.gridspec as gridspec

        gs1 = gridspec.GridSpec(2, 1)
        ax1 = fig.add_subplot(gs1[0])
        ax2 = fig.add_subplot(gs1[1])

        example_plot(ax1)
        example_plot(ax2)

        gs1.tight_layout(fig, rect=[0, 0, 0.5, 1])

        gs2 = gridspec.GridSpec(3, 1)

        for ss in gs2:
            ax = fig.add_subplot(ss)
            example_plot(ax)
            ax.set_title("")
            ax.set_xlabel("")

        ax.set_xlabel("x-label", fontsize=12)

        gs2.tight_layout(fig, rect=[0.5, 0, 1, 1], h_pad=0.45)

        top = min(gs1.top, gs2.top)
        bottom = max(gs1.bottom, gs2.bottom)

        gs1.tight_layout(fig, rect=[None, 0 + (bottom-gs1.bottom),
                                    0.5, 1 - (gs1.top-top)])
        gs2.tight_layout(fig, rect=[0.5, 0 + (bottom-gs2.bottom),
                                    None, 1 - (gs2.top-top)],
                         h_pad=0.45) 
Example #15
Source File: test_gridspec.py    From neural-network-animation with MIT License 5 votes vote down vote up
def test_equal():
    gs = gridspec.GridSpec(2, 1)
    assert_equal(gs[0, 0], gs[0, 0])
    assert_equal(gs[:, 0], gs[:, 0]) 
Example #16
Source File: pyplot.py    From neural-network-animation with MIT License 5 votes vote down vote up
def subplot2grid(shape, loc, rowspan=1, colspan=1, **kwargs):
    """
    Create a subplot in a grid.  The grid is specified by *shape*, at
    location of *loc*, spanning *rowspan*, *colspan* cells in each
    direction.  The index for loc is 0-based. ::

      subplot2grid(shape, loc, rowspan=1, colspan=1)

    is identical to ::

      gridspec=GridSpec(shape[0], shape[2])
      subplotspec=gridspec.new_subplotspec(loc, rowspan, colspan)
      subplot(subplotspec)
    """

    fig = gcf()
    s1, s2 = shape
    subplotspec = GridSpec(s1, s2).new_subplotspec(loc,
                                                   rowspan=rowspan,
                                                   colspan=colspan)
    a = fig.add_subplot(subplotspec, **kwargs)
    bbox = a.bbox
    byebye = []
    for other in fig.axes:
        if other==a: continue
        if bbox.fully_overlaps(other.bbox):
            byebye.append(other)
    for ax in byebye: delaxes(ax)

    draw_if_interactive()
    return a 
Example #17
Source File: _subplots.py    From Mastering-Elasticsearch-7.0 with MIT License 5 votes vote down vote up
def get_gridspec(self):
        """get the GridSpec instance associated with the subplot"""
        return self._subplotspec.get_gridspec() 
Example #18
Source File: solver.py    From visual_dynamics with MIT License 5 votes vote down vote up
def loss_visualization_init(self, validate=True):
        fig = plt.figure(figsize=(18, 18), frameon=False, tight_layout=True)
        fig.canvas.set_window_title(self.snapshot_prefix)
        gs = gridspec.GridSpec(1, 3)
        plt.show(block=False)

        loss_labels = ['train', 'train']
        if validate:
            loss_labels += ['val']
        loss_plotter = LossPlotter(fig, gs[0], labels=loss_labels)

        output_labels = [str(output_name) for output_name_pair in self.output_names for output_name in output_name_pair]
        image_visualizer = GridImageVisualizer(fig, gs[1:], rows=len(self.output_names), cols=2, labels=output_labels)
        return fig, loss_plotter, image_visualizer 
Example #19
Source File: pyplot.py    From matplotlib-4-abaqus with MIT License 5 votes vote down vote up
def subplot2grid(shape, loc, rowspan=1, colspan=1, **kwargs):
    """
    Create a subplot in a grid.  The grid is specified by *shape*, at
    location of *loc*, spanning *rowspan*, *colspan* cells in each
    direction.  The index for loc is 0-based. ::

      subplot2grid(shape, loc, rowspan=1, colspan=1)

    is identical to ::

      gridspec=GridSpec(shape[0], shape[2])
      subplotspec=gridspec.new_subplotspec(loc, rowspan, colspan)
      subplot(subplotspec)
    """

    fig = gcf()
    s1, s2 = shape
    subplotspec = GridSpec(s1, s2).new_subplotspec(loc,
                                                   rowspan=rowspan,
                                                   colspan=colspan)
    a = fig.add_subplot(subplotspec, **kwargs)
    bbox = a.bbox
    byebye = []
    for other in fig.axes:
        if other==a: continue
        if bbox.fully_overlaps(other.bbox):
            byebye.append(other)
    for ax in byebye: delaxes(ax)

    draw_if_interactive()
    return a 
Example #20
Source File: DyStockBackTestingStrategyResultStatsWidget.py    From DevilYuan with MIT License 5 votes vote down vote up
def _plotStats(self, df, strategyName):
        """
            绘制账户盈亏统计图
        """
        def _dateFormatter(x, pos):
            if not (0 <= int(x) < df.shape[0]):
                return None

            return df.index[int(x)].strftime("%y-%m-%d")

        # create grid spec
        gs = GridSpec(4, 1)
        gs.update(hspace=0)

        # subplot for PnL
        axPnl = plt.subplot(gs[:-1, :])
        axPnl.grid(True)
        axPnl.set_title('{}: 盈亏(%)'.format(strategyName))

        # set x ticks
        x = [x for x in range(df.shape[0])]
        xspace = max((len(x)+9)//10, 1)
        axPnl.xaxis.set_major_locator(FixedLocator(x[:-xspace-1: xspace] + x[-1:]))
        axPnl.xaxis.set_major_formatter(FuncFormatter(_dateFormatter))

        # plot pnl
        for name in df.columns:
            if name not in ['持仓资金(%)', '持仓股票数']:
                axPnl.plot(x, df[name].values, label=name)

        axPnl.legend(loc='upper left', frameon=False)

        # subplot for position
        axPos = plt.subplot(gs[-1, :], sharex=axPnl)
        axPos.grid(True)
        axPos.bar(x, df['持仓资金(%)'].values, label='持仓资金(%)')
        axPos.plot(x, df['持仓资金(%)'].values.cumsum()/np.array(list(range(1, df.shape[0] + 1))), label='平均持仓资金(%)', color='g')
        axPos.plot(x, df['持仓股票数'].values, label='持仓股票数', color='r')

        axPos.legend(loc='upper left', frameon=False) 
Example #21
Source File: layouts.py    From MDT with GNU Lesser General Public License v3.0 5 votes vote down vote up
def __init__(self, gridspec, figure, positions=None):
        """Create a grid layout specifier using the given gridspec and the given figure.

        Args:
            gridspec (GridSpec): the gridspec to use
            figure (Figure): the figure to generate subplots for
            positions (:class:`list`): if given, a list with grid spec indices for every requested axis
                can be logical indices or (x, y) coordinate indices (choose one and stick with it).
        """
        self.gridspec = gridspec
        self.figure = figure
        self.positions = positions 
Example #22
Source File: deeplab.py    From edafa with MIT License 5 votes vote down vote up
def vis_segmentation(image, seg_map):
  """Visualizes input image, segmentation map and overlay view."""
  plt.figure(figsize=(15, 5))
  grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

  plt.subplot(grid_spec[0])
  plt.imshow(image)
  plt.axis('off')
  plt.title('input image')

  plt.subplot(grid_spec[1])
  seg_image = label_to_color_image(seg_map).astype(np.uint8)
  plt.imshow(seg_image)
  plt.axis('off')
  plt.title('segmentation map')

  plt.subplot(grid_spec[2])
  plt.imshow(image)
  plt.imshow(seg_image, alpha=0.7)
  plt.axis('off')
  plt.title('segmentation overlay')

  unique_labels = np.unique(seg_map)
  ax = plt.subplot(grid_spec[3])
  plt.imshow(
      FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
  ax.yaxis.tick_right()
  plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
  plt.xticks([], [])
  ax.tick_params(width=0.0)
  plt.grid('off')
  plt.show() 
Example #23
Source File: axes_divider.py    From Computable with MIT License 5 votes vote down vote up
def change_geometry(self, numrows, numcols, num):
        'change subplot geometry, e.g., from 1,1,1 to 2,2,3'
        self._subplotspec = GridSpec(numrows, numcols)[num-1]
        self.update_params()
        self.set_position(self.figbox) 
Example #24
Source File: pyplot.py    From Computable with MIT License 5 votes vote down vote up
def subplot2grid(shape, loc, rowspan=1, colspan=1, **kwargs):
    """
    Create a subplot in a grid.  The grid is specified by *shape*, at
    location of *loc*, spanning *rowspan*, *colspan* cells in each
    direction.  The index for loc is 0-based. ::

      subplot2grid(shape, loc, rowspan=1, colspan=1)

    is identical to ::

      gridspec=GridSpec(shape[0], shape[2])
      subplotspec=gridspec.new_subplotspec(loc, rowspan, colspan)
      subplot(subplotspec)
    """

    fig = gcf()
    s1, s2 = shape
    subplotspec = GridSpec(s1, s2).new_subplotspec(loc,
                                                   rowspan=rowspan,
                                                   colspan=colspan)
    a = fig.add_subplot(subplotspec, **kwargs)
    bbox = a.bbox
    byebye = []
    for other in fig.axes:
        if other==a: continue
        if bbox.fully_overlaps(other.bbox):
            byebye.append(other)
    for ax in byebye: delaxes(ax)

    draw_if_interactive()
    return a 
Example #25
Source File: layouts.py    From MDT with GNU Lesser General Public License v3.0 5 votes vote down vote up
def get_gridspec(self, figure, nmr_plots):
        rows, cols = self._get_square_size(nmr_plots)
        return GridLayoutSpecifier(GridSpec(rows, cols, **self.spacings), figure) 
Example #26
Source File: disp.py    From hart with GNU General Public License v3.0 5 votes vote down vote up
def _tile_vertical(imgs, glimpses, boxes, n_objects, fig_size, img_size, colors):
    # prepare figure
    yy, xx = imgs.shape[0], 1 + n_objects
    fig_y, fig_x = fig_size
    img_y, img_x = img_size

    sy, sx = yy * img_y, n_objects + img_x
    gs = gridspec.GridSpec(sy, sx)
    fig = plt.figure(figsize=(sx * fig_x, sy * fig_y))

    axes = np.empty((yy, xx), dtype=object)
    ii = 0
    for i in xrange(yy):
        axes[i, 0] = plt.subplot(gs[i * img_y:(i + 1) * img_y, :img_x])

    for i in xrange(yy):
        for j in xrange(1, xx):
            axes[i, j] = plt.subplot(gs[i * img_y:(i + 1) * img_y, j + img_x - 1])

    # plot
    for r in xrange(yy):
        axes[r, 0].imshow(imgs[r], 'gray')

        for n in xrange(n_objects):
            for (k, v), color in izip(boxes.iteritems(), colors):
                y, x, h, w = boxes[k]
                bbox = Rectangle((x[r, n], y[r, n]), w[r, n], h[r, n],
                                 edgecolor=color, facecolor='none', label=k)
                axes[r, 0].add_patch(bbox)

        for c in xrange(1, xx):
            axes[r, c].imshow(glimpses[r, c - 1], 'gray')

    # TODO: improve
    len_bbox = len(boxes)
    if len_bbox > 1:
        x_offset = .25 * len_bbox
        axes[-1, 0].legend(bbox_to_anchor=(x_offset, -.75),
                           ncol=len_bbox, loc='lower center')

    return fig, axes 
Example #27
Source File: 16_basic_kernels.py    From deep-learning-note with MIT License 5 votes vote down vote up
def show_images(images, rgb=True):
    gs = gridspec.GridSpec(1, len(images))
    for i, image in enumerate(images):
        plt.subplot(gs[0, i])
        if rgb:
            plt.imshow(image)
        else:
            image = image.reshape(image.shape[0], image.shape[1])
            plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.show() 
Example #28
Source File: test_frame.py    From vnpy_crypto with MIT License 5 votes vote down vote up
def _generate_4_axes_via_gridspec():
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    import matplotlib.gridspec  # noqa

    gs = mpl.gridspec.GridSpec(2, 2)
    ax_tl = plt.subplot(gs[0, 0])
    ax_ll = plt.subplot(gs[1, 0])
    ax_tr = plt.subplot(gs[0, 1])
    ax_lr = plt.subplot(gs[1, 1])

    return gs, [ax_tl, ax_ll, ax_tr, ax_lr] 
Example #29
Source File: layouts.py    From MDT with GNU Lesser General Public License v3.0 5 votes vote down vote up
def get_gridspec(self, figure, nmr_plots):
        rows, columns, positions = self._get_size_and_position(nmr_plots)
        return GridLayoutSpecifier(GridSpec(rows, columns, **self.spacings), figure, positions=positions) 
Example #30
Source File: cvae_tensorflow.py    From generative-models with The Unlicense 5 votes vote down vote up
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig