Python matplotlib.pyplot.subplots_adjust() Examples

The following are 30 code examples for showing how to use matplotlib.pyplot.subplots_adjust(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .

Example 1
Project: tensortrade   Author: tensortrade-org   File: matplotlib_trading_chart.py    License: Apache License 2.0 6 votes vote down vote up
def __init__(self, df):
        self.df = df

        # Create a figure on screen and set the title
        self.fig = plt.figure()

        # Create top subplot for net worth axis
        self.net_worth_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1)

        # Create bottom subplot for shared price/volume axis
        self.price_ax = plt.subplot2grid((6, 1), (2, 0), rowspan=8,
                                         colspan=1, sharex=self.net_worth_ax)

        # Create a new axis for volume which shares its x-axis with price
        self.volume_ax = self.price_ax.twinx()

        # Add padding to make graph easier to view
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)

        # Show the graph without blocking the rest of the program
        plt.show(block=False) 
Example 2
Project: Attention-Gated-Networks   Author: ozan-oktay   File: visualise_att_maps_epoch.py    License: MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Epochs 
Example 3
Project: Attention-Gated-Networks   Author: ozan-oktay   File: visualise_fmaps.py    License: MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Load options 
Example 4
Project: Attention-Gated-Networks   Author: ozan-oktay   File: visualise_attention.py    License: MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()
    plt.suptitle(title) 
Example 5
Project: Attention-Gated-Networks   Author: ozan-oktay   File: visualise_attention.py    License: MIT License 6 votes vote down vote up
def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear',
                        colormap=cm.jet, colormap_lim=None, title='', alpha=0.8):
    plt.ion()
    filters = units.shape[2]
    fig = plt.figure(figure_id, figsize=(5,5))
    fig.clf()

    for i in range(filters):
        plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray')
        plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha)
        plt.axis('off')
        plt.colorbar()
        plt.title(title, fontsize='small')
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

    # plt.savefig('{}/{}.png'.format(dir_name,time.time()))




## Load options 
Example 6
Project: RLTrader   Author: notadamking   File: TradingChart.py    License: GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, df):
        self.df = df

        # Create a figure on screen and set the title
        self.fig = plt.figure()

        # Create top subplot for net worth axis
        self.net_worth_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1)

        # Create bottom subplot for shared price/volume axis
        self.price_ax = plt.subplot2grid((6, 1), (2, 0), rowspan=8, colspan=1, sharex=self.net_worth_ax)

        # Create a new axis for volume which shares its x-axis with price
        self.volume_ax = self.price_ax.twinx()

        # Add padding to make graph easier to view
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)

        # Show the graph without blocking the rest of the program
        plt.show(block=False) 
Example 7
Project: AMLSim   Author: IBM   File: plot_alert_pattern_subgraphs.py    License: Apache License 2.0 6 votes vote down vote up
def plot_alerts(_g, _bank_accts, _output_png):
    bank_ids = _bank_accts.keys()
    cmap = plt.get_cmap("tab10")
    pos = nx.nx_agraph.graphviz_layout(_g)

    plt.figure(figsize=(12.0, 8.0))
    plt.axis('off')

    for i, bank_id in enumerate(bank_ids):
        color = cmap(i)
        members = _bank_accts[bank_id]
        nx.draw_networkx_nodes(_g, pos, members, node_size=300, node_color=color, label=bank_id)
        nx.draw_networkx_labels(_g, pos, {n: n for n in members}, font_size=10)

    edge_labels = nx.get_edge_attributes(_g, "label")
    nx.draw_networkx_edges(_g, pos)
    nx.draw_networkx_edge_labels(_g, pos, edge_labels, font_size=6)

    plt.legend(numpoints=1)
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    plt.savefig(_output_png, dpi=120) 
Example 8
Project: supair   Author: stelzner   File: visualize.py    License: MIT License 6 votes vote down vote up
def show_images(images, nrow=10, text=None, overlay=None, matplot=False):
    images = np.squeeze(images)
    if not matplot:
        images = np.stack([draw_overlay(images[j], None if text is None else text[j])
                           for j in range(len(images))], axis=0)
        vis.images(images, padding=4, nrow=nrow)
    else:
        fig, axes = plt.subplots(2, 4, figsize=(8, 4))
        plt.subplots_adjust(top=1.0, bottom=0.0,
                            left=0.1, right=0.9,
                            wspace=0.1, hspace=-0.15)
        for i, image in enumerate(images):
            cur_axes = axes[i // 4, i % 4]
            setup_axis(cur_axes)
            cur_axes.imshow(image, cmap='gray', interpolation='none')
            if overlay is not None:
                cur_overlay = scipy.misc.imresize(overlay[i], image.shape)
                cur_axes.imshow(cur_overlay, cmap='RdYlGn', alpha=0.5)
        vis.matplot(plt)
        plt.close(fig) 
Example 9
Project: bert-as-service   Author: hanxiao   File: example7.py    License: MIT License 6 votes vote down vote up
def vis(embed, vis_alg='PCA', pool_alg='REDUCE_MEAN'):
    plt.close()
    fig = plt.figure()
    plt.rcParams['figure.figsize'] = [21, 7]
    for idx, ebd in enumerate(embed):
        ax = plt.subplot(2, 6, idx + 1)
        vis_x = ebd[:, 0]
        vis_y = ebd[:, 1]
        plt.scatter(vis_x, vis_y, c=subset_label, cmap=ListedColormap(["blue", "green", "yellow", "red"]), marker='.',
                    alpha=0.7, s=2)
        ax.set_title('pool_layer=-%d' % (idx + 1))
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.1, right=0.95, top=0.9)
    cax = plt.axes([0.96, 0.1, 0.01, 0.3])
    cbar = plt.colorbar(cax=cax, ticks=range(num_label))
    cbar.ax.get_yaxis().set_ticks([])
    for j, lab in enumerate(['ent.', 'bus.', 'sci.', 'heal.']):
        cbar.ax.text(.5, (2 * j + 1) / 8.0, lab, ha='center', va='center', rotation=270)
    fig.suptitle('%s visualization of BERT layers using "bert-as-service" (-pool_strategy=%s)' % (vis_alg, pool_alg),
                 fontsize=14)
    plt.show() 
Example 10
Project: WannaPark   Author: dalmia   File: mnist.py    License: GNU General Public License v3.0 6 votes vote down vote up
def plot_bad_images(images):
    """This takes a list of images misclassified by a pretty good
    neural network --- one achieving over 93 percent accuracy --- and
    turns them into a figure."""
    bad_image_indices = [8, 18, 33, 92, 119, 124, 149, 151, 193, 233, 241, 247, 259, 300, 313, 321, 324, 341, 349, 352, 359, 362, 381, 412, 435, 445, 449, 478, 479, 495, 502, 511, 528, 531, 547, 571, 578, 582, 597, 610, 619, 628, 629, 659, 667, 691, 707, 717, 726, 740, 791, 810, 844, 846, 898, 938, 939, 947, 956, 959, 965, 982, 1014, 1033, 1039, 1044, 1050, 1055, 1107, 1112, 1124, 1147, 1181, 1191, 1192, 1198, 1202, 1204, 1206, 1224, 1226, 1232, 1242, 1243, 1247, 1256, 1260, 1263, 1283, 1289, 1299, 1310, 1319, 1326, 1328, 1357, 1378, 1393, 1413, 1422, 1435, 1467, 1469, 1494, 1500, 1522, 1523, 1525, 1527, 1530, 1549, 1553, 1609, 1611, 1634, 1641, 1676, 1678, 1681, 1709, 1717, 1722, 1730, 1732, 1737, 1741, 1754, 1759, 1772, 1773, 1790, 1808, 1813, 1823, 1843, 1850, 1857, 1868, 1878, 1880, 1883, 1901, 1913, 1930, 1938, 1940, 1952, 1969, 1970, 1984, 2001, 2009, 2016, 2018, 2035, 2040, 2043, 2044, 2053, 2063, 2098, 2105, 2109, 2118, 2129, 2130, 2135, 2148, 2161, 2168, 2174, 2182, 2185, 2186, 2189, 2224, 2229, 2237, 2266, 2272, 2293, 2299, 2319, 2325, 2326, 2334, 2369, 2371, 2380, 2381, 2387, 2393, 2395, 2406, 2408, 2414, 2422, 2433, 2450, 2488, 2514, 2526, 2548, 2574, 2589, 2598, 2607, 2610, 2631, 2648, 2654, 2695, 2713, 2720, 2721, 2730, 2770, 2771, 2780, 2863, 2866, 2896, 2907, 2925, 2927, 2939, 2995, 3005, 3023, 3030, 3060, 3073, 3102, 3108, 3110, 3114, 3115, 3117, 3130, 3132, 3157, 3160, 3167, 3183, 3189, 3206, 3240, 3254, 3260, 3280, 3329, 3330, 3333, 3383, 3384, 3475, 3490, 3503, 3520, 3525, 3559, 3567, 3573, 3597, 3598, 3604, 3629, 3664, 3702, 3716, 3718, 3725, 3726, 3727, 3751, 3752, 3757, 3763, 3766, 3767, 3769, 3776, 3780, 3798, 3806, 3808, 3811, 3817, 3821, 3838, 3848, 3853, 3855, 3869, 3876, 3902, 3906, 3926, 3941, 3943, 3951, 3954, 3962, 3976, 3985, 3995, 4000, 4002, 4007, 4017, 4018, 4065, 4075, 4078, 4093, 4102, 4139, 4140, 4152, 4154, 4163, 4165, 4176, 4199, 4201, 4205, 4207, 4212, 4224, 4238, 4248, 4256, 4284, 4289, 4297, 4300, 4306, 4344, 4355, 4356, 4359, 4360, 4369, 4405, 4425, 4433, 4435, 4449, 4487, 4497, 4498, 4500, 4521, 4536, 4548, 4563, 4571, 4575, 4601, 4615, 4620, 4633, 4639, 4662, 4690, 4722, 4731, 4735, 4737, 4739, 4740, 4761, 4798, 4807, 4814, 4823, 4833, 4837, 4874, 4876, 4879, 4880, 4886, 4890, 4910, 4950, 4951, 4952, 4956, 4963, 4966, 4968, 4978, 4990, 5001, 5020, 5054, 5067, 5068, 5078, 5135, 5140, 5143, 5176, 5183, 5201, 5210, 5331, 5409, 5457, 5495, 5600, 5601, 5617, 5623, 5634, 5642, 5677, 5678, 5718, 5734, 5735, 5749, 5752, 5771, 5787, 5835, 5842, 5845, 5858, 5887, 5888, 5891, 5906, 5913, 5936, 5937, 5945, 5955, 5957, 5972, 5973, 5985, 5987, 5997, 6035, 6042, 6043, 6045, 6053, 6059, 6065, 6071, 6081, 6091, 6112, 6124, 6157, 6166, 6168, 6172, 6173, 6347, 6370, 6386, 6390, 6391, 6392, 6421, 6426, 6428, 6505, 6542, 6555, 6556, 6560, 6564, 6568, 6571, 6572, 6597, 6598, 6603, 6608, 6625, 6651, 6694, 6706, 6721, 6725, 6740, 6746, 6768, 6783, 6785, 6796, 6817, 6827, 6847, 6870, 6872, 6926, 6945, 7002, 7035, 7043, 7089, 7121, 7130, 7198, 7216, 7233, 7248, 7265, 7426, 7432, 7434, 7494, 7498, 7691, 7777, 7779, 7797, 7800, 7809, 7812, 7821, 7849, 7876, 7886, 7897, 7902, 7905, 7917, 7921, 7945, 7999, 8020, 8059, 8081, 8094, 8095, 8115, 8246, 8256, 8262, 8272, 8273, 8278, 8279, 8293, 8322, 8339, 8353, 8408, 8453, 8456, 8502, 8520, 8522, 8607, 9009, 9010, 9013, 9015, 9019, 9022, 9024, 9026, 9036, 9045, 9046, 9128, 9214, 9280, 9316, 9342, 9382, 9433, 9446, 9506, 9540, 9544, 9587, 9614, 9634, 9642, 9645, 9700, 9716, 9719, 9729, 9732, 9738, 9740, 9741, 9742, 9744, 9745, 9749, 9752, 9768, 9770, 9777, 9779, 9792, 9808, 9831, 9839, 9856, 9858, 9867, 9879, 9883, 9888, 9890, 9893, 9905, 9944, 9970, 9982]
    n = len(bad_image_indices)
    bad_images = [images[j] for j in bad_image_indices]
    fig = plt.figure(figsize=(10, 15))
    for j in xrange(1, n+1):
        ax = fig.add_subplot(25, 125, j)
        ax.matshow(bad_images[j-1], cmap = matplotlib.cm.binary)
        ax.set_title(str(bad_image_indices[j-1]))
        plt.xticks(np.array([]))
        plt.yticks(np.array([]))
    plt.subplots_adjust(hspace = 1.2)
    plt.show() 
Example 11
Project: neural-network-animation   Author: miloharper   File: test_skew.py    License: MIT License 6 votes vote down vote up
def test_skew_rectange():

    fix, axes = plt.subplots(5, 5, sharex=True, sharey=True, figsize=(16, 12))
    axes = axes.flat

    rotations = list(itertools.product([-3, -1, 0, 1, 3], repeat=2))

    axes[0].set_xlim([-4, 4])
    axes[0].set_ylim([-4, 4])
    axes[0].set_aspect('equal')

    for ax, (xrots, yrots) in zip(axes, rotations):
        xdeg, ydeg = 45 * xrots, 45 * yrots
        t = transforms.Affine2D().skew_deg(xdeg, ydeg)

        ax.set_title('Skew of {0} in X and {1} in Y'.format(xdeg, ydeg))
        ax.add_patch(mpatch.Rectangle([-1, -1], 2, 2,
                                      transform=t + ax.transData,
                                      alpha=0.5, facecolor='coral'))

    plt.subplots_adjust(wspace=0, left=0, right=1, bottom=0) 
Example 12
Project: dal   Author: montrealrobotics   File: dal.py    License: MIT License 6 votes vote down vote up
def init_figure(self):
        self.init_fig = True
        if self.args.figure == True:# and self.obj_fig==None:
            self.obj_fig = plt.figure(figsize=(16,12))
            plt.set_cmap('viridis')

            self.gridspec = gridspec.GridSpec(3,5)
            self.ax_map = plt.subplot(self.gridspec[0,0])
            self.ax_scan = plt.subplot(self.gridspec[1,0])
            self.ax_pose =  plt.subplot(self.gridspec[2,0])

            self.ax_bel =  plt.subplot(self.gridspec[0,1])
            self.ax_lik =  plt.subplot(self.gridspec[1,1])
            self.ax_gtl =  plt.subplot(self.gridspec[2,1])


            self.ax_pbel =  plt.subplot(self.gridspec[0,2:4])
            self.ax_plik =  plt.subplot(self.gridspec[1,2:4])
            self.ax_pgtl =  plt.subplot(self.gridspec[2,2:4])

            self.ax_act = plt.subplot(self.gridspec[0,4])
            self.ax_rew = plt.subplot(self.gridspec[1,4])
            self.ax_err = plt.subplot(self.gridspec[2,4])

            plt.subplots_adjust(hspace = 0.4, wspace=0.4, top=0.95, bottom=0.05) 
Example 13
Project: dal   Author: montrealrobotics   File: dal_ros_aml.py    License: MIT License 6 votes vote down vote up
def init_figure(self):
        self.init_fig = True
        if self.args.figure == True:# and self.obj_fig==None:
            self.obj_fig = plt.figure(figsize=(16,12))
            plt.set_cmap('viridis')

            self.gridspec = gridspec.GridSpec(3,5)
            self.ax_map = plt.subplot(self.gridspec[0,0])
            self.ax_scan = plt.subplot(self.gridspec[1,0])
            self.ax_pose =  plt.subplot(self.gridspec[2,0])

            self.ax_bel =  plt.subplot(self.gridspec[0,1])
            self.ax_lik =  plt.subplot(self.gridspec[1,1])
            self.ax_gtl =  plt.subplot(self.gridspec[2,1])


            self.ax_pbel =  plt.subplot(self.gridspec[0,2:4])
            self.ax_plik =  plt.subplot(self.gridspec[1,2:4])
            self.ax_pgtl =  plt.subplot(self.gridspec[2,2:4])

            self.ax_act = plt.subplot(self.gridspec[0,4])
            self.ax_rew = plt.subplot(self.gridspec[1,4])
            self.ax_err = plt.subplot(self.gridspec[2,4])

            plt.subplots_adjust(hspace = 0.4, wspace=0.4, top=0.95, bottom=0.05) 
Example 14
Project: python3_ios   Author: holzschu   File: test_skew.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_skew_rectangle():

    fix, axes = plt.subplots(5, 5, sharex=True, sharey=True, figsize=(8, 8))
    axes = axes.flat

    rotations = list(itertools.product([-3, -1, 0, 1, 3], repeat=2))

    axes[0].set_xlim([-3, 3])
    axes[0].set_ylim([-3, 3])
    axes[0].set_aspect('equal', share=True)

    for ax, (xrots, yrots) in zip(axes, rotations):
        xdeg, ydeg = 45 * xrots, 45 * yrots
        t = transforms.Affine2D().skew_deg(xdeg, ydeg)

        ax.set_title('Skew of {0} in X and {1} in Y'.format(xdeg, ydeg))
        ax.add_patch(mpatch.Rectangle([-1, -1], 2, 2,
                                      transform=t + ax.transData,
                                      alpha=0.5, facecolor='coral'))

    plt.subplots_adjust(wspace=0, left=0.01, right=0.99, bottom=0.01, top=0.99) 
Example 15
Project: mac-network   Author: stanfordnlp   File: visualization.py    License: Apache License 2.0 5 votes vote down vote up
def showImgAtts(instance):
    img = imread(inImgName(instance["imageId"]))

    length = len(instance["attentions"]["kb"])
    
    # show images
    for j in range(length):
        fig, ax = plt.subplots()
        fig.set_figheight(figureImageDims[0])
        fig.set_figwidth(figureImageDims[1])              
        
        showImgAtt(img, instance, j, ax)
        
        plt.subplots_adjust(bottom = 0, top = 1, left = 0, right = 1)
        savePlot(fig, outImgAttName(instance, j)) 
Example 16
Project: recaptcha-cracker   Author: nocturnaltortoise   File: main.py    License: GNU General Public License v3.0 5 votes vote down vote up
def show_checkbox_predictions(checkboxes, rows, cols, captcha_query, correct):

    fig, axes = plt.subplots(rows, cols)
    for checkbox in checkboxes:
        position = checkbox['position']
        path = checkbox['path']
        predictions = checkbox['predictions']
        matching = checkbox['matching']

        if path:
            path = os.path.join('../', path)
            x = position[0]-1
            y = position[1]-1
            axes[x, y].imshow(skimage.io.imread(path))
            if matching:
                axes[x, y].set_title("Picked \n {0}".format(predictions))
            else:
                axes[x, y].set_title(predictions)

            axes[x, y].set_xticks([])
            axes[x, y].set_yticks([])

    fig.suptitle("{0}, Correct".format(captcha_query)) if correct else fig.suptitle("{0}, Incorrect".format(captcha_query))
    plt.subplots_adjust(hspace=0.5)

    plt.show() 
Example 17
Project: CapsLayer   Author: naturomics   File: figure.py    License: Apache License 2.0 5 votes vote down vote up
def plot_activation(matrix, step, save_to=None):
    save_to = os.path.join(".", "activations") if save_to is None else save_to
    os.makedirs(save_to, exist_ok=True)
    if len(matrix.shape) != 2:
        raise ValueError('Input "matrix" should have 2 rank, but it is',str(len(matrix.shape)))
    num_label = matrix.shape[1] - 1
    matrix = matrix[matrix[:, num_label].argsort()]
    fig, axes = plt.subplots(ncols=1, nrows=num_label, figsize=(15,12))
    fig.suptitle("The probability of entity presence (step %s)"%str(step), fontsize=20)
    fig.tight_layout()
    for i, ax in enumerate(axes.flatten()):
        idx = num_label - (i + 1)
        ax.spines['top'].set_color('none')
        ax.spines['bottom'].set_color('none')
        ax.set_ylim(0, 1.05)
        ax.set_ylabel("Capsule " + str(idx))
        ax.yaxis.set_major_locator(ticker.NullLocator())
        if idx > 0:
            ax.xaxis.set_major_locator(ticker.NullLocator())
        else:
            ax.xaxis.set_major_locator(ticker.IndexLocator(base=500,offset=0))
            ax.set_xlabel("Sample index ")
        ax.plot(matrix[:,idx])
        ax_prime = ax.twinx()
        ax_prime.spines['top'].set_color('none')
        ax_prime.spines['bottom'].set_color('none')
    plt.subplots_adjust(hspace=0.2, left=0.05, right=0.95, bottom=0.05, top=.95)
    plt.savefig(os.path.join(save_to, "activation_%s.png" % str(step)))
    plt.close() 
Example 18
Project: qb   Author: Pinafore   File: performance.py    License: MIT License 5 votes vote down vote up
def plot_summary(summary_only, stats_dir, output):
    import seaborn as sns
    rows = parse_data(stats_dir)
    g = sns.factorplot(y='result', x='score', col='experiment',
                       data=rows, kind='bar', ci=None,
                       order=ANSWER_PLOT_ORDER, size=4, col_wrap=4, sharex=False)
    for ax in g.axes.flat:
        for label in ax.get_xticklabels():
            label.set_rotation(30)
    plt.subplots_adjust(top=0.93)
    g.fig.suptitle('Feature Ablation Study')
    g.savefig(output, format='png', dpi=200) 
Example 19
Project: reinforcement-learning-an-introduction   Author: ShangtongZhang   File: blackjack.py    License: MIT License 5 votes vote down vote up
def figure_5_1():
    states_usable_ace_1, states_no_usable_ace_1 = monte_carlo_on_policy(10000)
    states_usable_ace_2, states_no_usable_ace_2 = monte_carlo_on_policy(500000)

    states = [states_usable_ace_1,
              states_usable_ace_2,
              states_no_usable_ace_1,
              states_no_usable_ace_2]

    titles = ['Usable Ace, 10000 Episodes',
              'Usable Ace, 500000 Episodes',
              'No Usable Ace, 10000 Episodes',
              'No Usable Ace, 500000 Episodes']

    _, axes = plt.subplots(2, 2, figsize=(40, 30))
    plt.subplots_adjust(wspace=0.1, hspace=0.2)
    axes = axes.flatten()

    for state, title, axis in zip(states, titles, axes):
        fig = sns.heatmap(np.flipud(state), cmap="YlGnBu", ax=axis, xticklabels=range(1, 11),
                          yticklabels=list(reversed(range(12, 22))))
        fig.set_ylabel('player sum', fontsize=30)
        fig.set_xlabel('dealer showing', fontsize=30)
        fig.set_title(title, fontsize=30)

    plt.savefig('../images/figure_5_1.png')
    plt.close() 
Example 20
Project: reinforcement-learning-an-introduction   Author: ShangtongZhang   File: blackjack.py    License: MIT License 5 votes vote down vote up
def figure_5_2():
    state_action_values = monte_carlo_es(500000)

    state_value_no_usable_ace = np.max(state_action_values[:, :, 0, :], axis=-1)
    state_value_usable_ace = np.max(state_action_values[:, :, 1, :], axis=-1)

    # get the optimal policy
    action_no_usable_ace = np.argmax(state_action_values[:, :, 0, :], axis=-1)
    action_usable_ace = np.argmax(state_action_values[:, :, 1, :], axis=-1)

    images = [action_usable_ace,
              state_value_usable_ace,
              action_no_usable_ace,
              state_value_no_usable_ace]

    titles = ['Optimal policy with usable Ace',
              'Optimal value with usable Ace',
              'Optimal policy without usable Ace',
              'Optimal value without usable Ace']

    _, axes = plt.subplots(2, 2, figsize=(40, 30))
    plt.subplots_adjust(wspace=0.1, hspace=0.2)
    axes = axes.flatten()

    for image, title, axis in zip(images, titles, axes):
        fig = sns.heatmap(np.flipud(image), cmap="YlGnBu", ax=axis, xticklabels=range(1, 11),
                          yticklabels=list(reversed(range(12, 22))))
        fig.set_ylabel('player sum', fontsize=30)
        fig.set_xlabel('dealer showing', fontsize=30)
        fig.set_title(title, fontsize=30)

    plt.savefig('../images/figure_5_2.png')
    plt.close() 
Example 21
Project: AIX360   Author: IBM   File: dipvae_utils.py    License: Apache License 2.0 5 votes vote down vote up
def plot_latent_traversal(explainer, input_images, args, dataset_obj, image_id_to_plot=0, num_sweeps=15,
                          max_abs_edit_value=10.0, epoch=0, batch_id = 0, save_dir="results"):
    edit_dim_values = np.linspace(-1.0 *max_abs_edit_value, max_abs_edit_value, num_sweeps)

    f, axarr = plt.subplots(args.latent_dim, len(edit_dim_values), sharex=True, sharey=True)
    f.set_size_inches(10, 10* args.latent_dim / len(edit_dim_values))

    for i in range(args.latent_dim):
        for j in range(len(edit_dim_values)):

            edited_images = convert_and_reshape(explainer.explain(input_images=input_images,
                             edit_dim_id = i,
                             edit_dim_value = edit_dim_values[j],edit_z_sample=False), dataset_obj)
            if edited_images.shape[2] == 1:
                axarr[i][j].imshow(edited_images[image_id_to_plot,:,:,0], cmap="gray", aspect='auto')
            else:
                axarr[i][j].imshow(edited_images[image_id_to_plot]*0.5 + 0.5, aspect='auto')
            #axarr[j][i].axis('off')
            if i == len(axarr) - 1:
                axarr[i][j].set_xlabel("z:" + str(np.round(edit_dim_values[j], 1)))
            if j == 0:
                axarr[i][j].set_ylabel("l:" + str(i))
            axarr[i][j].set_yticks([])
            axarr[i][j].set_xticks([])
    plt.subplots_adjust(hspace=0, wspace=0)

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    f.savefig(os.path.join(save_dir, 'traversal_epoch_{}_batch_id_{}.png'.format(epoch, batch_id)))
    plt.close(fig=f) 
Example 22
Project: supair   Author: stelzner   File: make_plots.py    License: MIT License 5 votes vote down vote up
def make_perf(spair_file, air_file, output_file, max_t,
              add_annotation=None, directory='./plots', no_bg_file=None):
    spair_file = os.path.join(directory, spair_file)
    air_file = os.path.join(directory, air_file)
    output_file = os.path.join(directory, output_file)

    spn_data = rows_to_matrix(read_csv(spair_file))
    air_data = rows_to_matrix(read_csv(air_file))

    air_bins, air_means, air_stdev = process_perf(air_data, max_t + 20)
    spn_bins, spn_means, spn_stdev = process_perf(spn_data, max_t + 20)

    print(spair_file, 'AIR:', air_means[-2:], 'SuPAIR', spn_means[-2:])

    fig, plot = plt.subplots(figsize=[3.8, 2.44])
    plt.subplots_adjust(top=0.97, bottom=0.21,
                        left=0.19, right=0.95,
                        wspace=0.1, hspace=0.1)
    plot.fill_between(spn_bins, spn_means + spn_stdev, spn_means - spn_stdev, alpha=0.5)
    plot.plot(spn_bins, spn_means, linewidth=2.0, label='SuPAIR')

    if no_bg_file is not None:
        no_bg_file = os.path.join(directory, no_bg_file)
        nobg_data = rows_to_matrix(read_csv(no_bg_file))
        nobg_bins, nobg_means, nobg_stdev = process_perf(nobg_data, max_t + 20)
        plot.fill_between(nobg_bins, nobg_means + nobg_stdev, nobg_means - nobg_stdev, alpha=0.5)
        plot.plot(nobg_bins, nobg_means, linewidth=2.0, label='SuPAIR w/o bg')

    plot.fill_between(air_bins, air_means + air_stdev, air_means - air_stdev, alpha=0.5)
    plot.plot(air_bins, air_means, linewidth=2.0, label='AIR')
    plot.set_xlim(min(air_bins), max_t)
    plot.set_ylim(-0.002, 1.0)
    plot.set_xlabel('time (s)', fontsize=14)
    plot.set_ylabel('count accuracy', fontsize=14)
    plot.tick_params(labelsize=12)
    plot.legend(loc='lower right')
    # if add_annotation is not None:
    #     add_annotation(plot)

    vis.matplot(plt)
    plt.savefig(output_file) 
Example 23
Project: verb-attributes   Author: uwnlp   File: fig_4.py    License: MIT License 5 votes vote down vote up
def att_plot(top_labels, gt_ind, probs, fn):
    # plt.figure(figsize=(5, 5))
    #
    # color_dict = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
    # colors = [color_dict[c] for c in
    #           ['lightcoral', 'steelblue', 'forestgreen', 'darkviolet', 'sienna', 'dimgrey',
    #            'darkorange', 'gold']]
    # colors[gt_ind] = color_dict['crimson']
    # w = 0.9
    # plt.bar(np.arange(len(top_labels)), probs, w, color=colors, alpha=.9, label='data')
    # plt.axhline(0, color='black')
    # plt.ylim([0, 1])
    # plt.xticks(np.arange(len(top_labels)), top_labels, fontsize=6)
    # plt.subplots_adjust(bottom=.15)
    # plt.tight_layout()
    # plt.savefig(fn)
    lab = deepcopy(top_labels)
    lab[gt_ind] += ' (gt)'
    d = pd.DataFrame(data={'probs': probs, 'labels':lab})
    fig, ax = plt.subplots(figsize=(4,5))
    ax.tick_params(labelsize=15)

    sns.barplot(y='labels', x='probs', ax=ax, data=d, orient='h', ci=None)
    ax.set(xlim=(0,1))

    for rect, label in zip(ax.patches,lab):
        w = rect.get_width()
        ax.text(w+.02, rect.get_y() + rect.get_height()*4/5, label, ha='left', va='bottom',
                fontsize=25)

    # ax.yaxis.set_label_coords(0.5, 0.5)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().label.set_visible(False)
    fig.savefig(fn, bbox_inches='tight', transparent=True)
    plt.close('all') 
Example 24
Project: gpkit   Author: convexengineering   File: plot_sweep.py    License: MIT License 5 votes vote down vote up
def format_and_label_axes(var, posys, axes, ylabel=True):
    "Formats and labels axes"
    for posy, ax in zip(posys, axes):
        if ylabel:
            if hasattr(posy, "key"):
                ylabel = (posy.key.descr.get("label", posy.key.name)
                          + " [%s]" % posy.key.unitstr(dimless="-"))
            else:
                ylabel = str(posy)
            ax.set_ylabel(ylabel)
        ax.grid(color="0.6")
        # ax.set_frame_on(False)
        for item in [ax.xaxis.label, ax.yaxis.label]:
            item.set_fontsize(12)
        for item in ax.get_xticklabels() + ax.get_yticklabels():
            item.set_fontsize(9)
        ax.tick_params(length=0)
        ax.spines['left'].set_visible(False)
        ax.spines['top'].set_visible(False)
        for i in ax.spines.values():
            i.set_linewidth(0.6)
            i.set_color("0.6")
            i.set_linestyle("dotted")
    xlabel = (var.key.descr.get("label", var.key.name)
              + " [%s]" % var.key.unitstr(dimless="-"))
    ax.set_xlabel(xlabel)  # pylint: disable=undefined-loop-variable
    plt.locator_params(nbins=4)
    plt.subplots_adjust(wspace=0.15)


# pylint: disable=too-many-locals,too-many-branches,too-many-statements 
Example 25
Project: vnpy_crypto   Author: birforce   File: plotting.py    License: MIT License 5 votes vote down vote up
def adjust_subplots(**kwds):
    import matplotlib.pyplot as plt

    passed_kwds = dict(bottom=0.05, top=0.925,
                       left=0.05, right=0.95,
                       hspace=0.2)
    passed_kwds.update(kwds)
    plt.subplots_adjust(**passed_kwds)

#-------------------------------------------------------------------------------
# Multiple impulse response (cum_effects, etc.) cplots 
Example 26
Project: scikit-multiflow   Author: scikit-multiflow   File: evaluation_visualizer.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def on_new_train_step(self, sample_id, data_buffer):
        """ This is the listener main function, which gives it the ability to
        'listen' for the caller. Whenever the EvaluationVisualiser should 
        be aware of some new data, the caller will invoke this function,
        passing the new data buffer.
        
        Parameters
        ----------
        sample_id: int
            The current sample id.

        data_buffer: EvaluationDataBuffer
            A buffer containing evaluation data for a single training / visualization step.
            
        Raises
        ------
        ValueError: If an exception is raised during the draw operation.
         
        """

        try:
            current_time = time.time()
            self._clear_annotations()
            self._update_plots(sample_id, data_buffer)

            # To mitigate re-drawing overhead for fast models use frame counter (default = 5 frames).
            # To avoid slow refresh rate in slow models use a time limit (default = 1 sec).
            if (self._frame_cnt == 5) or (current_time - self._last_draw_timestamp > 1):
                plt.subplots_adjust(right=0.72, bottom=0.22)  # Adjust subplots to include metrics annotations
                if get_backend() == 'nbAgg':
                    self.fig.canvas.draw()    # Force draw in'notebook' backend
                plt.pause(1e-9)
                self._frame_cnt = 0
                self._last_draw_timestamp = current_time
            else:
                self._frame_cnt += 1
        except BaseException as exception:
            raise ValueError('Failed when trying to draw plot. Exception: {} | Type: {}'.
                             format(exception, type(exception).__name__)) 
Example 27
Project: sdwan-harvester   Author: sdnewhop   File: core.py    License: GNU General Public License v2.0 5 votes vote down vote up
def create_pie_chart(elements, suptitle, png, figure_id):
    """
    Create pie chart

    :param elements: dict with elements (dict)
    :param suptitle: name of chart (str)
    :param png: name of output file (str)
    :param figure_id: id of current plot (started with 1) (int)
    :return: None
    """
    values = [value for value in elements.values()]
    keys = [key for key in elements.keys()]
    plt.figure(figure_id)
    plt.subplots_adjust(bottom=.05, left=.01, right=.99, top=.90, hspace=.35)

    explode = [0 for x in range(len(keys))]
    max_value = max(values)
    explode[list(values).index(max_value)] = 0.1

    plt.pie(values, labels=keys,
            autopct=make_autopct(values), explode=explode,
            textprops={'fontsize': PIE_LABEL_FONT_SIZE})
    plt.axis("equal")
    plt.suptitle(suptitle, fontsize=PIE_SUPTITLE_FONT_SIZE)

    plt.gcf().set_dpi(PIE_DPI)
    plt.savefig("{dest}/{png}/{result_file}".format(dest=RESULTS_DIR,
                                                    png=PNG_DIR,
                                                    result_file=png)) 
Example 28
Project: pbt   Author: MattKleinsmith   File: utils.py    License: MIT License 5 votes vote down vote up
def plots(imgs, figsize=(12, 12), rows=None, cols=None,
          interp=None, titles=None, cmap='gray',
          fig=None):
    if not isinstance(imgs, list):
        imgs = [imgs]
    imgs = [np.array(img) for img in imgs]
    if not isinstance(cmap, list):
        if imgs[0].ndim == 2:
            cmap = 'gray'
        cmap = [cmap] * len(imgs)
    if not isinstance(interp, list):
        interp = [interp] * len(imgs)
    n = len(imgs)
    if not rows and not cols:
        cols = n
        rows = 1
    elif not rows:
        rows = cols
    elif not cols:
        cols = rows
    if not fig:
        rows = int(np.ceil(len(imgs) / cols))
        w = 12
        h = rows * (w / cols + 1)
        figsize = (w, h)
        fig = plt.figure(figsize=figsize)
    fontsize = 13 if cols == 5 else 16
    fig.set_figheight(figsize[1], forward=True)
    fig.clear()
    for i in range(len(imgs)):
        sp = fig.add_subplot(rows, cols, i+1)
        if titles:
            sp.set_title(titles[i], fontsize=fontsize)
        plt.imshow(imgs[i], interpolation=interp[i], cmap=cmap[i])
        plt.axis('off')
        plt.subplots_adjust(0, 0, 1, 1, .1, 0)
        #  plt.tight_layout()
    if fig:
        fig.canvas.draw() 
Example 29
Project: lightnet   Author: jing-vision   File: tsne.py    License: MIT License 5 votes vote down vote up
def tsne_plot(labels, tokens):
    "Creates and TSNE model and plots it"
    
    tsne_model = TSNE(perplexity=40, n_components=2, init='pca', n_iter=2500, random_state=23)
    X_2d = tsne_model.fit_transform(tokens)
    X_2d -= X_2d.min(axis=0)
    X_2d /= X_2d.max(axis=0)

    width = 1200
    grid, to_plot = tsne_to_grid(X_2d)
    out_dim = int(width / np.sqrt(to_plot))
   
    fig, ax = plt.subplots(figsize=(width/100, width/100))
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
    
    for pos, label in zip(grid, labels[0:to_plot]):
        ax.scatter(pos[0], pos[1])
        if False:
            ax.annotate(label,
                     xy=(pos[0], pos[1]),
                     xytext=(5, 2),
                     fontsize=9,
                     textcoords='offset points',
                     ha='right',
                     va='bottom')
        ab = AnnotationBbox(getImage(label, new_size = out_dim / 2), (pos[0], pos[1]), frameon=False)
        ax.add_artist(ab)

    plt.show() 
Example 30
Project: pyhawkes   Author: slinderman   File: make_figure.py    License: MIT License 5 votes vote down vote up
def make_figure_a(S, F, C):
    """
    Plot fluorescence traces, filtered fluorescence, and spike times
    for three neurons
    """
    col = harvard_colors()
    dt = 0.02
    T_start = 0
    T_stop = 1 * 50 * 60
    t = dt * np.arange(T_start, T_stop)

    ks = [0,1]
    nk = len(ks)
    fig = create_figure((3,3))
    for ind,k in enumerate(ks):
        ax = fig.add_subplot(nk,1,ind+1)
        ax.plot(t, F[T_start:T_stop, k], color=col[1], label="$F$")    # Plot the raw flourescence in blue
        ax.plot(t, C[T_start:T_stop, k], color=col[0], lw=1.5, label="$\widehat{F}$")    # Plot the filtered flourescence in red
        spks  = np.where(S[T_start:T_stop, k])[0]
        ax.plot(t[spks], C[spks,k], 'ko', label="S")            # Plot the spike times in black

        # Make a legend
        if ind == 0:
            # Put a legend above
            plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
                       ncol=3, mode="expand", borderaxespad=0.,
                       prop={'size':9})

        # Add labels
        ax.set_ylabel("$F_%d(t)$" % (k+1))
        if ind == nk-1:
            ax.set_xlabel("Time $t$ [sec]")

        # Format the ticks
        ax.set_ylim([-0.1,1.0])
        plt.locator_params(nbins=5, axis="y")


    plt.subplots_adjust(left=0.2, bottom=0.2)
    fig.savefig("figure3a.pdf")
    plt.show()