Python matplotlib.pyplot.margins() Examples

The following are 30 code examples of matplotlib.pyplot.margins(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: plotter.py    From message-analyser with MIT License 7 votes vote down vote up
def distplot_messages_per_month(msgs, path_to_save):
    sns.set(style="whitegrid")

    start_date = msgs[0].date.date()
    (xticks, xticks_labels, xlabel) = _get_xticks(msgs)

    ax = sns.distplot([(msg.date.date() - start_date).days for msg in msgs],
                      bins=xticks + [(msgs[-1].date.date() - start_date).days], color="m", kde=False)
    ax.set_xticklabels(xticks_labels)
    ax.set(xlabel=xlabel, ylabel="messages")
    ax.margins(x=0)

    plt.xticks(xticks, rotation=65)
    plt.tight_layout()
    fig = plt.gcf()
    fig.set_size_inches(11, 8)

    fig.savefig(os.path.join(path_to_save, distplot_messages_per_month.__name__ + ".png"), dpi=500)
    # plt.show()
    log_line(f"{distplot_messages_per_month.__name__} was created.")
    plt.close("all") 
Example #2
Source File: plotter.py    From message-analyser with MIT License 7 votes vote down vote up
def distplot_messages_per_day(msgs, path_to_save):
    sns.set(style="whitegrid")

    data = stools.get_messages_per_day(msgs)

    max_day_len = len(max(data.values(), key=len))
    ax = sns.distplot([len(day) for day in data.values()], bins=list(range(0, max_day_len, 50)) + [max_day_len],
                      color="m", kde=False)
    ax.set(xlabel="messages", ylabel="days")
    ax.margins(x=0)

    fig = plt.gcf()
    fig.set_size_inches(11, 8)

    fig.savefig(os.path.join(path_to_save, distplot_messages_per_day.__name__ + ".png"), dpi=500)
    # plt.show()
    log_line(f"{distplot_messages_per_day.__name__} was created.")
    plt.close("all") 
Example #3
Source File: plotter.py    From message-analyser with MIT License 7 votes vote down vote up
def distplot_messages_per_hour(msgs, path_to_save):
    sns.set(style="whitegrid")

    ax = sns.distplot([msg.date.hour for msg in msgs], bins=range(25), color="m", kde=False)
    ax.set_xticklabels(stools.get_hours())
    ax.set(xlabel="hour", ylabel="messages")
    ax.margins(x=0)

    plt.xticks(range(24), rotation=65)
    plt.tight_layout()
    fig = plt.gcf()
    fig.set_size_inches(11, 8)

    fig.savefig(os.path.join(path_to_save, distplot_messages_per_hour.__name__ + ".png"), dpi=500)
    # plt.show()
    log_line(f"{distplot_messages_per_hour.__name__} was created.")
    plt.close("all") 
Example #4
Source File: visualization.py    From Deep-QLearning-Agent-for-Traffic-Signal-Control with MIT License 6 votes vote down vote up
def save_data_and_plot(self, data, filename, xlabel, ylabel):
        """
        Produce a plot of performance of the agent over the session and save the relative data to txt
        """
        min_val = min(data)
        max_val = max(data)

        plt.rcParams.update({'font.size': 24})  # set bigger font size

        plt.plot(data)
        plt.ylabel(ylabel)
        plt.xlabel(xlabel)
        plt.margins(0)
        plt.ylim(min_val - 0.05 * abs(min_val), max_val + 0.05 * abs(max_val))
        fig = plt.gcf()
        fig.set_size_inches(20, 11.25)
        fig.savefig(os.path.join(self._path, 'plot_'+filename+'.png'), dpi=self._dpi)
        plt.close("all")

        with open(os.path.join(self._path, 'plot_'+filename + '_data.txt'), "w") as file:
            for value in data:
                    file.write("%s\n" % value) 
Example #5
Source File: 2019-07-01 abf1 sample rate test.py    From pyABF with MIT License 6 votes vote down vote up
def plotAdriansFile():
    abf = pyabf.ABF(PATH_DATA+"/190619B_0003.abf")

    print(abf)
    # OUTPUT:
    #   ABF (version 1.8.3.0) with 2 channels (mV, pA),
    #   sampled at 20.0 kHz, containing 10 sweeps,
    #   having no tags, with a total length of 0.28 minutes,
    #   recorded with protocol "IV_FI_IN0_saray".

    plt.figure(figsize=(10, 4))
    plt.grid(alpha=.2, ls='--')
    for sweepNumber in abf.sweepList:
        abf.setSweep(sweepNumber)
        plt.plot(abf.sweepX, abf.sweepY, label=f"sweep {sweepNumber+1}")
    plt.margins(0, .1)
    plt.legend(fontsize=8)
    plt.title(abf.abfID+".abf")
    plt.ylabel(abf.sweepLabelY)
    plt.xlabel(abf.sweepLabelX)
    plt.tight_layout()
    plt.show() 
Example #6
Source File: go.py    From pyABF with MIT License 6 votes vote down vote up
def makeFigure1(colormap):
    abf=pyabf.ABF("../../data/17o05028_ic_steps.abf")
    plt.figure(figsize=(6,3))
    for sweep in abf.sweepList[::3]:
        color = plt.cm.get_cmap(colormap)(sweep/abf.sweepCount)
        abf.setSweep(sweep)
        plt.plot(abf.dataX,abf.dataY,color=color)
    plt.margins(0,.1)
    plt.axis([0,1,None,None])
    plt.gca().axis('off') # remove square around edges
    plt.xticks([]) # remove x labels
    plt.yticks([]) # remove y labels
    plt.tight_layout()
    plt.savefig("_output/1_%s.png"%colormap,dpi=150)
    plt.show()
    plt.close()
    return 
Example #7
Source File: go.py    From pyABF with MIT License 6 votes vote down vote up
def makeFigure2(colormap):
    abf=pyabf.ABF("../../data/17o05026_vc_stim.abf")
    plt.figure(figsize=(6,3))
    for sweep in abf.sweepList[::-1]:
        color = plt.cm.get_cmap(colormap)(sweep/abf.sweepCount)
        abf.setSweep(sweep)
        abf.dataY[:-int(abf.pointsPerSec*1)]=np.nan
        abf.dataY+=4*sweep
        plt.plot(abf.dataX+.05*sweep,abf.dataY,color=color,alpha=.7)
    plt.margins(0,0)
    plt.gca().axis('off') # remove square around edges
    plt.xticks([]) # remove x labels
    plt.yticks([]) # remove y labels
    plt.tight_layout()
    plt.savefig("_output/2_%s.png"%colormap,dpi=150)
    plt.show()
    plt.close()
    return 
Example #8
Source File: gettingStarted.py    From pyABF with MIT License 6 votes vote down vote up
def demo_02a_plot_matplotlib_sweep(self):
        """
        ## Plot a Sweep with Matplotlib

        Matplotlib is a fantastic plotting library for Python. This example
        shows how to plot an ABF sweep using matplotlib.
        ABF `setSweep()` is used to tell the ABF class what sweep to load
        into memory. After that you can just plot `sweepX` and `sweepY`.
        """

        import pyabf
        abf = pyabf.ABF("data/abfs/17o05028_ic_steps.abf")
        abf.setSweep(14)
        plt.figure(figsize=self.figsize)
        plt.plot(abf.sweepX, abf.sweepY)
        plt.grid(alpha=.2)  # ignore
        plt.margins(0, .1)  # ignore
        plt.tight_layout()  # ignore
        self.saveAndClose() 
Example #9
Source File: gettingStarted.py    From pyABF with MIT License 6 votes vote down vote up
def demo_03a_decorate_matplotlib_plot(self):
        """
        ## Decorate Plots with ABF Information

        The ABF class provides easy access to lots of information about the ABF.
        This example shows how to use these class methods to create a prettier
        plot of several sweeps from the same file.
        """

        import pyabf
        abf = pyabf.ABF("data/abfs/17o05028_ic_steps.abf")
        plt.figure(figsize=self.figsize)
        plt.title("pyABF and Matplotlib are a great pair!")
        plt.ylabel(abf.sweepLabelY)
        plt.xlabel(abf.sweepLabelX)
        for i in [0, 5, 10, 15]:
            abf.setSweep(i)
            plt.plot(abf.sweepX, abf.sweepY, alpha=.5, label="sweep %d" % (i))
        plt.margins(0, .1)  # ignore
        plt.tight_layout()  # ignore
        plt.grid(alpha=.2)  # ignore
        plt.legend()
        self.saveAndClose() 
Example #10
Source File: gettingStarted.py    From pyABF with MIT License 6 votes vote down vote up
def demo_08a_xy_offset(self):
        """
        ## Plot Sweeps in 3D

        The previous example how to plot stacked sweeps by adding a Y offset
        to each sweep. If you add an X and Y offset to each sweep, you can
        create a 3D effect.
        """

        import pyabf
        abf = pyabf.ABF("data/abfs/171116sh_0018.abf")

        plt.figure(figsize=self.figsize)
        for sweepNumber in abf.sweepList:
            abf.setSweep(sweepNumber)
            i1, i2 = 0, int(abf.dataRate * 1)  # plot part of the sweep
            dataX = abf.sweepX[i1:i2] + .025 * sweepNumber
            dataY = abf.sweepY[i1:i2] + 15 * sweepNumber
            plt.plot(dataX, dataY, color='C0', alpha=.5)

        plt.gca().axis('off')  # hide axes to enhance floating effect
        plt.margins(.02, .02)  # ignore
        plt.tight_layout()  # ignore
        self.saveAndClose() 
Example #11
Source File: gettingStarted.py    From pyABF with MIT License 6 votes vote down vote up
def demo_11a_gap_free(self):
        """
        ## Plotting Gap-Free ABFs

        The pyABF treats every ABF like it's episodic (with sweeps). As such,
        gap free ABF files are loaded as if they were episodic files with
        a single sweep. When an ABF is loaded, `setSweep(0)` is called
        automatically, so the entire gap-free set of data is already available
        by plotting `sweepX` and `sweepY`.
        """

        import pyabf
        abf = pyabf.ABF("data/abfs/abf1_with_tags.abf")
        plt.figure(figsize=self.figsize)
        plt.plot(abf.sweepX, abf.sweepY, lw=.5)
        plt.axis([725, 825, -150, -15])
        plt.ylabel(abf.sweepLabelY)
        plt.xlabel(abf.sweepLabelX)
        plt.title("Example Gap Free File")
        plt.margins(0, .1)  # ignore
        plt.grid(alpha=.2)  # ignore
        self.saveAndClose() 
Example #12
Source File: generate_is_plot.py    From big-discriminator-batch-spoofing-gan with MIT License 6 votes vote down vote up
def generate_plot(x, y, title, save_path):
    """
    generates the plot given the indices and is values
    :param x: the indices (epochs)
    :param y: IS values
    :param title: title of the generated plot
    :param save_path: path to save the file
    :return: None (saves file)
    """
    font = {'family': 'normal', 'size': 20}
    matplotlib.rc('font', **font)
    plt.figure(figsize=(10, 6))
    annot_max(x, y)
    plt.margins(.05, .05)
    plt.title(title)
    plt.xlabel("Epochs")
    plt.ylabel("Inception scores")
    plt.ylim(0, max(y) + 2)
    plt.plot(x, y, linewidth=4)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight') 
Example #13
Source File: generate_fid_plot.py    From big-discriminator-batch-spoofing-gan with MIT License 6 votes vote down vote up
def generate_plot(x, y, title, save_path):
    """
    generates the plot given the indices and fid values
    :param x: the indices (epochs)
    :param y: fid values
    :param title: title of the generated plot
    :param save_path: path to save the file
    :return: None (saves file)
    """
    font = {'family': 'normal', 'size': 20}
    matplotlib.rc('font', **font)
    plt.figure(figsize=(10, 6))
    annot_min(x, y)
    plt.margins(.05, .05)
    plt.title(title)
    plt.xlabel("Epochs")
    plt.ylabel("FID scores")
    plt.plot(x, y, linewidth=4)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight') 
Example #14
Source File: gettingStarted.py    From pyABF with MIT License 5 votes vote down vote up
def demo_07a_stacked_sweeps(self):
        """
        ## Plot Stacked Sweeps

        I often like to view sweeps stacked one on top of another. In ClampFit
        this is done with "distribute traces". Here we can add a bit of offset
        when plotting sweeps and achieve the same effect. This example makes
        use of `abf.sweepList`, which is the same as `range(abf.sweepCount)`
        """

        import pyabf
        abf = pyabf.ABF("data/abfs/171116sh_0018.abf")
        plt.figure(figsize=self.figsize)

        # plot every sweep (with vertical offset)
        for sweepNumber in abf.sweepList:
            abf.setSweep(sweepNumber)
            offset = 140*sweepNumber
            plt.plot(abf.sweepX, abf.sweepY+offset, color='C0')

        # decorate the plot
        plt.gca().get_yaxis().set_visible(False)  # hide Y axis
        plt.xlabel(abf.sweepLabelX)
        plt.margins(0, .02)  # ignore
        plt.tight_layout()  # ignore
        self.saveAndClose() 
Example #15
Source File: visualize.py    From Pixel2MeshPlusPlus with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def plot_scatter(pt, data_name, plt_path):
    fig = plt.figure()
    fig.set_size_inches(20.0 / 3, 20.0 / 3)
    ax = fig.gca(projection='3d')
    ax.set_aspect('equal')
    ax.grid(color='r', linestyle='-',)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_zticklabels([])
    X = pt[:, 0]
    Y = pt[:, 1]
    Z = pt[:, 2]

    scat = ax.scatter(X, Y, Z, depthshade=True, marker='.')

    max_range = np.array([X.max() - X.min(), Y.max() - Y.min(), Z.max() - Z.min()]).max() / 2.0

    mid_x = (X.max() + X.min()) * 0.5
    mid_y = (Y.max() + Y.min()) * 0.5
    mid_z = (Z.max() + Z.min()) * 0.5
    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)

    plt.margins(0, 0)
    fig.savefig(os.path.join(plt_path, data_name.replace('.dat', '.png')), format='png', transparent=True, dpi=300, pad_inches=0, bbox_inches='tight') 
Example #16
Source File: gettingStarted.py    From pyABF with MIT License 5 votes vote down vote up
def demo_08b_custom_colormap(self):
        """
        ## Custom Colormaps

        Matplotlib's colormap tools can be used to add an extra dimension to
        graphs. All matplotlib colormaps are [listed here](https://matplotlib.org/examples/color/colormaps_reference.html).
        For an interesting discussion on choosing ideal colormaps for scientific
        data visit [bids.github.io/colormap/](https://bids.github.io/colormap/).
        Good colors for e-phys are "winter", "rainbow", and "viridis".
        """

        import pyabf
        abf = pyabf.ABF("data/abfs/171116sh_0018.abf")

        # use a custom colormap to create a different color for every sweep
        cm = plt.get_cmap("winter")
        colors = [cm(x/abf.sweepCount) for x in abf.sweepList]
        # colors.reverse()

        plt.figure(figsize=self.figsize)
        for sweepNumber in abf.sweepList:
            abf.setSweep(sweepNumber)
            i1, i2 = 0, int(abf.dataRate * 1)
            dataX = abf.sweepX[i1:i2] + .025 * sweepNumber
            dataY = abf.sweepY[i1:i2] + 15 * sweepNumber
            plt.plot(dataX, dataY, color=colors[sweepNumber], alpha=.5)

        plt.gca().axis('off')
        plt.margins(.02, .02)  # ignore
        plt.tight_layout()  # ignore
        self.saveAndClose() 
Example #17
Source File: gettingStarted.py    From pyABF with MIT License 5 votes vote down vote up
def demo_12a_tags(self):
        """
        ## Accessing Comments (Tags) in ABF Files

        While recording an ABF the user can insert a comment at a certain
        time point. pClamp calls these "tags", and they can be a useful
        way to mark when a drug was applied during an experiment. For this
        to work, `sweepX` needs to be a list of times in the ABF recording
        (not times which always start at 0 for every new sweep). Set this
        behavior by setting `absoluteTime=True` when calling `setSweep()`.

        A list of comments (the text of tags) is stored in a list 
        `abf.tagComments`. The sweep for each tag is in `abf.tagSweeps`, while
        the time of each tag is in `abf.tagTimesSec` and `abf.tagTimesMin`
        """

        import pyabf
        abf = pyabf.ABF("data/abfs/16d05007_vc_tags.abf")

        # create a plot with time on the horizontal axis
        plt.figure(figsize=self.figsize)
        for sweep in abf.sweepList:
            abf.setSweep(sweep, absoluteTime=True)  # <-- relates to sweepX
            abf.sweepY[:int(abf.dataRate*1.0)] = np.nan  # ignore
            plt.plot(abf.sweepX, abf.sweepY, lw=.5, alpha=.5, color='C0')
        plt.margins(0, .5)
        plt.grid(alpha=.2)  # ignore
        plt.ylabel(abf.sweepLabelY)
        plt.xlabel(abf.sweepLabelX)

        # now add the tags as vertical lines
        for i, tagTimeSec in enumerate(abf.tagTimesSec):
            posX = abf.tagTimesSec[i]
            comment = abf.tagComments[i]
            color = "C%d" % (i+1)
            plt.axvline(posX, label=comment, color=color, ls='--')
        plt.legend()

        plt.title("ABF File with Comments (Tags)")
        self.saveAndClose() 
Example #18
Source File: generate.py    From pyABF with MIT License 5 votes vote down vote up
def plot(self, show=False):
        """Display the current sweep."""
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(10, 4))
        plt.grid(alpha=.2, ls='--')
        ax.plot(self.sweepX, self.sweepY, color='r', lw=.5)
        plt.margins(0, .1)
        plt.title("Simulated %s Sweep" % self.clampMode)
        plt.xlabel("Time (seconds)")
        plt.ylabel(self.unitsLong)
        plt.tight_layout()
        if show:
            plt.show()
        return 
Example #19
Source File: plotting_callbacks.py    From neon with Apache License 2.0 5 votes vote down vote up
def on_epoch_end(self, callback_data, model, epoch):
        # convert to numpy arrays
        data_batch = model.data_batch.get()
        noise_batch = model.noise_batch.get()
        # value transform
        data_batch = self._value_transform(data_batch)
        noise_batch = self._value_transform(noise_batch)
        # shape transform
        data_canvas = self._shape_transform(data_batch)
        noise_canvas = self._shape_transform(noise_batch)
        # plotting options
        im_args = dict(interpolation="nearest", vmin=0., vmax=1.)
        if self.nchan == 1:
            im_args['cmap'] = plt.get_cmap("gray")
        fname = self.filename+'_data_'+'{:03d}'.format(epoch)+'.png'
        Image.fromarray(np.uint8(data_canvas*255)).convert('RGB').save(fname)
        fname = self.filename+'_noise_'+'{:03d}'.format(epoch)+'.png'
        Image.fromarray(np.uint8(noise_canvas*255)).convert('RGB').save(fname)

        # plot logged WGAN costs if logged
        if model.cost.costfunc.func == 'wasserstein':
            giter = callback_data['gan/gen_iter'][:]
            nonzeros = np.where(giter)
            giter = giter[nonzeros]
            cost_dis = callback_data['gan/cost_dis'][:][nonzeros]
            w_dist = medfilt(np.array(-cost_dis, dtype='float64'), kernel_size=101)
            plt.figure(figsize=(400/self.dpi, 300/self.dpi), dpi=self.dpi)
            plt.plot(giter, -cost_dis, 'k-', lw=0.25)
            plt.plot(giter, w_dist, 'r-', lw=2.)
            plt.title(self.filename, fontsize=self.font_size)
            plt.xlabel("Generator Iterations", fontsize=self.font_size)
            plt.ylabel("Wasserstein estimate", fontsize=self.font_size)
            plt.margins(0, 0, tight=True)
            plt.savefig(self.filename+'_training.png', bbox_inches='tight')
            plt.close() 
Example #20
Source File: util.py    From V1EngineeringInc-Docs with Creative Commons Attribution Share Alike 4.0 International 5 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError(
            'The plot function requires matplotlib to be installed.'
            'See http://matplotlib.org/'
        )

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()


# ////////////////////////////////////////////////////////////
# { Parsing and conversion functions
# //////////////////////////////////////////////////////////// 
Example #21
Source File: synth-atf.py    From pyABF with MIT License 5 votes vote down vote up
def display(data, rate=20000):
    """Display a stimulus waveform array."""
    Xs=np.arange(len(data))/rate
    plt.figure(figsize=(8,2))
    plt.plot(Xs,data)
    plt.margins(0,.1)
    plt.title("Stimulus Waveform")
    plt.ylabel("mV or pA")
    plt.xlabel("Stimulus Time (seconds)")
    plt.savefig("stimulus-waveform.png",dpi=100)
    plt.tight_layout()
    plt.show() 
Example #22
Source File: 2018-11-24 simulated data.py    From pyABF with MIT License 5 votes vote down vote up
def load_simulated_data():
    """load the NPY file and display it."""
    data = np.load(FOLDER_HERE+"/2018-11-24 simulated data.npy")
    sweepX = np.arange(len(data[0]))/20_000
    plt.figure(figsize=(8, 4))
    plt.plot(sweepX*1000, data[0], color='r', lw=.5)
    plt.margins(0, .1)
    plt.ylabel("current (pA)")
    plt.ylabel("time (ms)")
    plt.show() 
Example #23
Source File: observe_input.py    From Self-Supervised-Speech-Pretraining-and-Representation-Learning with MIT License 5 votes vote down vote up
def plot_x(x, name='x', xlabel='Frames'):
    x = x.transpose(1, 0)
    fig, ax = plt.subplots(figsize=(10, 3))
    im = ax.imshow(x, aspect='auto', origin='lower',
                   interpolation='none')
    plt.colorbar(im, ax=ax)
    plt.xlabel(xlabel)
    plt.ylabel('Channels')
    plt.tight_layout()
    plt.margins(0,0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())

    fig.canvas.draw()
    fig.savefig(os.path.join(out_dir, name + '.png'), bbox_inches='tight', pad_inches = 0) 
Example #24
Source File: util.py    From razzy-spinner with GNU General Public License v3.0 5 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example #25
Source File: generate_multiple_is_plots.py    From big-discriminator-batch-spoofing-gan with MIT License 5 votes vote down vote up
def generate_plot(xs, ys, titles, save_path):
    """
    generates the plot given the indices and is values
    :param xs: the indices (epochs)
    :param ys: IS values
    :param titles: title of the generated plot
    :param save_path: path to save the file
    :return: None (saves file)
    """
    font = {'family': 'normal', 'size': 20}
    matplotlib.rc('font', **font)
    plt.figure(figsize=(10, 6))

    plt.xlabel("Epochs")
    plt.ylabel("Inception scores")

    # set the y limit to 4 + max of everything
    plt.ylim(0, max(map(max, ys)) + 5)

    for cnt, x, y, title in zip(range(len(xs)), xs, ys, titles):
        annot_max(x, y, y_offset=0.96 - (0.07 * cnt))
        plt.margins(.05, .05)
        plt.plot(x, y, linewidth=4, label=title)

    plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight') 
Example #26
Source File: generate_multiple_fid_plots.py    From big-discriminator-batch-spoofing-gan with MIT License 5 votes vote down vote up
def generate_plot(xs, ys, titles, save_path):
    """
    generates the plot given the indices and is values
    :param xs: the indices (epochs)
    :param ys: FID values
    :param titles: title of the generated plot
    :param save_path: path to save the file
    :return: None (saves file)
    """
    font = {'family': 'normal', 'size': 20}
    matplotlib.rc('font', **font)
    plt.figure(figsize=(10, 6))

    plt.xlabel("Epochs")
    plt.ylabel("FID scores")

    # set the y limit to 4 + max of everything
    plt.ylim(0, max(map(max, ys)) + 50)

    for cnt, x, y, title in zip(range(len(xs)), xs, ys, titles):
        annot_min(x, y, y_offset=0.96 - (0.07 * cnt))
        plt.margins(.05, .05)
        plt.plot(x, y, linewidth=4, label=title)

    plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight') 
Example #27
Source File: wall_time_comparison_sample_factory_seed_rl.py    From sample-factory with MIT License 5 votes vote down vote up
def main():
    sample_factory_runs = '/home/alex/all/projects/sample-factory/train_dir/paper_doom_wall_time_v97_fs4'
    sample_factory_runs_path = Path(sample_factory_runs)

    seed_rl_runs = '/home/alex/all/projects/sample-factory/train_dir/seedrl/seed_rl_csv'
    seed_rl_runs_path = Path(seed_rl_runs)

    fig, (top_ax, bottom_ax) = plt.subplots(2, 2)

    interpolated_keys_by_env = extract_data_tensorboard_events(sample_factory_runs_path, SAMPLE_FACTORY)
    plot_envs(interpolated_keys_by_env, top_ax, bottom_ax, SAMPLE_FACTORY)

    interpolated_keys_by_env = extract_data_csv(seed_rl_runs_path, SEED_RL)
    plot_envs(interpolated_keys_by_env, top_ax, bottom_ax, SEED_RL)

    # plt.show()
    # plot_name = f'{env}_{key.replace("/", " ")}'
    plt.tight_layout()
    # plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=1, wspace=0)
    plt.subplots_adjust(wspace=0.12, hspace=0.4)

    plt.margins(0, 0)
    plot_name = f'wall_time'
    plt.savefig(os.path.join(os.getcwd(), f'../final_plots/reward_{plot_name}.pdf'), format='pdf', bbox_inches='tight', pad_inches=0)

    return 0 
Example #28
Source File: throughput_plot.py    From sample-factory with MIT License 5 votes vote down vote up
def main():
    # requirements
    # 1) dark background
    # 2) both axis should start at 0
    # 3) Legend should be on background
    # 4) Legend should not obstruct data
    # 5) Export in eps
    # 6) Markers. Little circles for every data point
    # 7) Dashed lines for missing data
    fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3)
    count = 0

    ax = (ax1, ax2, ax3, ax4, ax5, ax6)
    for name, measurement in measurements.items():
        build_plot(name, measurement, ax[count], count)
        count += 1

    handles, labels = ax[-1].get_legend_handles_labels()
    # fig.legend(handles, labels, loc='upper center')
    # lgd = fig.legend(handles, labels, bbox_to_anchor=(0., 1.02, 1., .102), loc='lower left', ncol=4, mode="expand")
    lgd = fig.legend(handles, labels, bbox_to_anchor=(0.05, 0.88, 0.9, 0.5), loc='lower left', ncol=5, mode="expand")
    lgd.set_in_layout(True)

    # plt.show()
    plot_name = 'throughput'
    # plt.subplots_adjust(wspace=0.05, hspace=0.15)
    # plt.margins(0, 0)
    # plt.tight_layout(rect=(0, 0, 1, 1.2))
    # plt.subplots_adjust(bottom=0.2)

    plt.tight_layout(rect=(0, 0, 1.0, 0.9))
    # plt.show()

    plot_dir = ensure_dir_exists(os.path.join(os.getcwd(), '../final_plots'))

    plt.savefig(os.path.join(plot_dir, f'../final_plots/{plot_name}.pdf'), format='pdf', bbox_extra_artists=(lgd,))
    # plt.savefig(os.path.join(os.getcwd(), f'../final_plots/{plot_name}.pdf'), format='pdf', bbox_inches='tight', pad_inches=0, bbox_extra_artists=(lgd,)) 
Example #29
Source File: utils.py    From labelKeypoint with GNU General Public License v3.0 5 votes vote down vote up
def draw_label(label, img, label_names, colormap=None):
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
                        wspace=0, hspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())

    if colormap is None:
        colormap = label_colormap(len(label_names))

    label_viz = label2rgb(label, img, n_labels=len(label_names))
    plt.imshow(label_viz)
    plt.axis('off')

    plt_handlers = []
    plt_titles = []
    for label_value, label_name in enumerate(label_names):
        fc = colormap[label_value]
        p = plt.Rectangle((0, 0), 1, 1, fc=fc)
        plt_handlers.append(p)
        plt_titles.append(label_name)
    plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)

    f = io.BytesIO()
    plt.savefig(f, bbox_inches='tight', pad_inches=0)
    plt.cla()
    plt.close()

    out_size = (img.shape[1], img.shape[0])
    out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
    out = np.asarray(out)
    return out 
Example #30
Source File: plotter.py    From message-analyser with MIT License 5 votes vote down vote up
def lineplot_messages(msgs, your_name, target_name, path_to_save):
    sns.set(style="whitegrid")

    (x, y_total), (xticks, xticks_labels, xlabel) = _get_plot_data(msgs), _get_xticks(msgs)

    y_your = [len([msg for msg in period if msg.author == your_name]) for period in y_total]
    y_target = [len([msg for msg in period if msg.author == target_name]) for period in y_total]

    plt.fill_between(x, y_your, alpha=0.3)
    ax = sns.lineplot(x=x, y=y_your, palette="denim blue", linewidth=2.5, label=your_name)
    plt.fill_between(x, y_target, alpha=0.3)
    sns.lineplot(x=x, y=y_target, linewidth=2.5, label=target_name)

    ax.set(xlabel=xlabel, ylabel="messages")
    ax.set_xticklabels(xticks_labels)

    ax.tick_params(axis='x', bottom=True, color="#A9A9A9")
    plt.xticks(xticks, rotation=65)
    ax.margins(x=0, y=0)

    # plt.tight_layout()
    fig = plt.gcf()
    fig.set_size_inches(13, 7)

    fig.savefig(os.path.join(path_to_save, lineplot_messages.__name__ + ".png"), dpi=500)
    # plt.show()
    plt.close("all")
    log_line(f"{lineplot_messages.__name__} was created.")