Python matplotlib.pyplot.locator_params() Examples

The following are 18 code examples of matplotlib.pyplot.locator_params(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: plot_apogee_lamost_cannon.py    From TheCannon with MIT License 6 votes vote down vote up
def plot_one(title, ax, x, y, lim):
    ax.scatter(x, y-x, marker='x', c='k', alpha=0.5)
    # ax.set_title(r"%s" %title)
    #axarr[0].plot([-100,10000],[-100,10000], c='r')
    ax.axhline(y=0, c='r')
    scat = np.std(y-x)
    scat = round_2(scat)
    bias = np.mean(y-x)
    bias = round_2(bias)
    textstr = "RMS: %s \nBias: %s" %(scat, bias)
    ax.text(0.05,0.95, textstr, ha='left', va='top', transform=ax.transAxes)
    ax.locator_params(axis='x', nbins=5)
    ax.locator_params(axis='y', nbins=5)
    #ymin = -10*scat
    #ymax = 10*scat
    ax.set_ylim(-1*lim, lim)
    #print(ymin, ymax)
    num_up = sum((y-x)>lim)
    num_down = sum((y-x)<-1*lim)
    print("%s above, %s below" %(num_up, num_down)) 
Example #2
Source File: oraclesplot.py    From actions-for-actions with GNU General Public License v3.0 6 votes vote down vote up
def finalize_plot(allticks,handles):
    plt.locator_params(axis='x', nticks=Noracles,nbins=Noracles)
    plt.yticks([x[0] for x in allticks], [x[1] for x in allticks])
    plt.tick_params(
        axis='y',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        left='off',      # ticks along the bottom edge are off
        right='off'         # ticks along the top edge are off
    )
    if LEGEND:
        plt.legend([h[0] for h in handles],seriesnames,
                   loc='upper right',borderaxespad=0.,
                   ncol=1,fontsize=10,numpoints=1)
    plt.gcf().tight_layout()


######################################################
# Data processing 
Example #3
Source File: slashdot_results.py    From news-popularity-prediction with Apache License 2.0 5 votes vote down vote up
def make_slashdot_figures(output_path_prefix, method_name_list, slashdot_mse, slashdot_jaccard, slashdot_k_list):
    sns.set_style("darkgrid")
    sns.set_context("paper")

    translator = get_method_name_to_legend_name_dict()

    slashdot_k_list = list(slashdot_k_list)

    fig, axes = plt.subplots(1, 2, sharex=True)

    axes[0].set_title("SlashDot Comments")
    axes[1].set_title("SlashDot Users")

    plt.locator_params(nbins=8)

    # Comments
    for m, method in enumerate(method_name_list):
        axes[0].set_ylabel("MSE")
        axes[0].set_xlabel("Lifetime (sec)")
        axes[0].plot(slashdot_k_list[1:],
                     handle_nan(slashdot_mse[method]["comments"].mean(axis=1))[1:],
                     label=translator[method])

    # Users
    for m, method in enumerate(method_name_list):
        # axes[1].set_ylabel("MSE")
        axes[1].set_xlabel("Lifetime (sec)")
        axes[1].plot(slashdot_k_list[1:],
                     handle_nan(slashdot_mse[method]["users"].mean(axis=1))[1:],
                     label=translator[method])


    axes[1].legend(loc="upper right")

    # plt.show()
    plt.savefig(output_path_prefix + "_mse_slashdot_SNOW" + ".png", format="png")
    plt.savefig(output_path_prefix + "_mse_slashdot_SNOW" + ".eps", format="eps") 
Example #4
Source File: slashdot_results.py    From news-popularity-prediction with Apache License 2.0 5 votes vote down vote up
def make_barrapunto_figures(output_path_prefix, method_name_list, barrapunto_mse, barrapunto_jaccard, barrapunto_k_list):
    sns.set_style("darkgrid")
    sns.set_context("paper")

    translator = get_method_name_to_legend_name_dict()

    barrapunto_k_list = list(barrapunto_k_list)

    fig, axes = plt.subplots(1, 2, sharex=True)

    axes[0].set_title("BarraPunto Comments")
    axes[1].set_title("BarraPunto Users")

    plt.locator_params(nbins=8)

    # Comments
    for m, method in enumerate(method_name_list):
        axes[0].set_ylabel("MSE")
        axes[0].set_xlabel("Lifetime (sec)")
        axes[0].plot(barrapunto_k_list[1:],
                        handle_nan(barrapunto_mse[method]["comments"].mean(axis=1))[1:],
                        label=translator[method])

    # Users
    for m, method in enumerate(method_name_list):
        # axes[1].set_ylabel("MSE")
        axes[1].set_xlabel("Lifetime (sec)")
        axes[1].plot(barrapunto_k_list[1:],
                        handle_nan(barrapunto_mse[method]["users"].mean(axis=1))[1:],
                        label=translator[method])


    axes[1].legend(loc="upper right")

    # plt.show()
    plt.savefig(output_path_prefix + "_mse_barrapunto_SNOW" + ".png", format="png")
    plt.savefig(output_path_prefix + "_mse_barrapunto_SNOW" + ".eps", format="eps") 
Example #5
Source File: util.py    From razzy-spinner with GNU General Public License v3.0 5 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example #6
Source File: plot_apogee_lamost_cannon.py    From TheCannon with MIT License 5 votes vote down vote up
def create_grid():
    fig = plt.figure(figsize=(15,20))
    #plt.locator_params(nbins=5)
    #ax = fig.add_subplot(111)
    #plt.setp(ax.get_yticklabels(), visible=False)
    #plt.setp(ax.get_xticklabels(), visible=False)
    ax00 = fig.add_subplot(331)
    ax01 = fig.add_subplot(332, sharex=ax00, sharey=ax00)
    plt.setp(ax01.get_yticklabels(), visible=False)
    xticks = ax01.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    ax02 = fig.add_subplot(333, sharex=ax00, sharey=ax00)
    plt.setp(ax02.get_yticklabels(), visible=False)
    xticks = ax02.xaxis.get_major_ticks()
    xticks[0].set_visible(False) 
    ax10 = fig.add_subplot(334)
    ax11 = fig.add_subplot(335, sharex=ax10, sharey=ax10)
    plt.setp(ax11.get_yticklabels(), visible=False)
    xticks = ax11.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    ax12 = fig.add_subplot(336, sharex=ax10, sharey=ax10)
    plt.setp(ax12.get_yticklabels(), visible=False)
    xticks = ax12.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    ax20 = fig.add_subplot(337)
    ax21 = fig.add_subplot(338, sharex=ax20, sharey=ax20)
    plt.setp(ax21.get_yticklabels(), visible=False)
    xticks = ax21.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    ax22 = fig.add_subplot(339, sharex=ax20, sharey=ax20)
    plt.setp(ax22.get_yticklabels(), visible=False)
    xticks = ax22.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    fig.subplots_adjust(wspace=0)
    fig.subplots_adjust(hspace=0.2)
    axarr = ((ax00,ax01,ax02), (ax10,ax11,ax12), (ax20,ax21,ax22))
    return fig, axarr 
Example #7
Source File: plot_sweep.py    From gpkit with MIT License 5 votes vote down vote up
def format_and_label_axes(var, posys, axes, ylabel=True):
    "Formats and labels axes"
    for posy, ax in zip(posys, axes):
        if ylabel:
            if hasattr(posy, "key"):
                ylabel = (posy.key.descr.get("label", posy.key.name)
                          + " [%s]" % posy.key.unitstr(dimless="-"))
            else:
                ylabel = str(posy)
            ax.set_ylabel(ylabel)
        ax.grid(color="0.6")
        # ax.set_frame_on(False)
        for item in [ax.xaxis.label, ax.yaxis.label]:
            item.set_fontsize(12)
        for item in ax.get_xticklabels() + ax.get_yticklabels():
            item.set_fontsize(9)
        ax.tick_params(length=0)
        ax.spines['left'].set_visible(False)
        ax.spines['top'].set_visible(False)
        for i in ax.spines.values():
            i.set_linewidth(0.6)
            i.set_color("0.6")
            i.set_linestyle("dotted")
    xlabel = (var.key.descr.get("label", var.key.name)
              + " [%s]" % var.key.unitstr(dimless="-"))
    ax.set_xlabel(xlabel)  # pylint: disable=undefined-loop-variable
    plt.locator_params(nbins=4)
    plt.subplots_adjust(wspace=0.15)


# pylint: disable=too-many-locals,too-many-branches,too-many-statements 
Example #8
Source File: make_figure.py    From pyhawkes with MIT License 5 votes vote down vote up
def make_figure_a(S, F, C):
    """
    Plot fluorescence traces, filtered fluorescence, and spike times
    for three neurons
    """
    col = harvard_colors()
    dt = 0.02
    T_start = 0
    T_stop = 1 * 50 * 60
    t = dt * np.arange(T_start, T_stop)

    ks = [0,1]
    nk = len(ks)
    fig = create_figure((3,3))
    for ind,k in enumerate(ks):
        ax = fig.add_subplot(nk,1,ind+1)
        ax.plot(t, F[T_start:T_stop, k], color=col[1], label="$F$")    # Plot the raw flourescence in blue
        ax.plot(t, C[T_start:T_stop, k], color=col[0], lw=1.5, label="$\widehat{F}$")    # Plot the filtered flourescence in red
        spks  = np.where(S[T_start:T_stop, k])[0]
        ax.plot(t[spks], C[spks,k], 'ko', label="S")            # Plot the spike times in black

        # Make a legend
        if ind == 0:
            # Put a legend above
            plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
                       ncol=3, mode="expand", borderaxespad=0.,
                       prop={'size':9})

        # Add labels
        ax.set_ylabel("$F_%d(t)$" % (k+1))
        if ind == nk-1:
            ax.set_xlabel("Time $t$ [sec]")

        # Format the ticks
        ax.set_ylim([-0.1,1.0])
        plt.locator_params(nbins=5, axis="y")


    plt.subplots_adjust(left=0.2, bottom=0.2)
    fig.savefig("figure3a.pdf")
    plt.show() 
Example #9
Source File: util.py    From V1EngineeringInc-Docs with Creative Commons Attribution Share Alike 4.0 International 5 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError(
            'The plot function requires matplotlib to be installed.'
            'See http://matplotlib.org/'
        )

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()


# ////////////////////////////////////////////////////////////
# { Parsing and conversion functions
# //////////////////////////////////////////////////////////// 
Example #10
Source File: ab_exp.py    From abyes with Apache License 2.0 4 votes vote down vote up
def expected_loss_decision(self, posterior, var):
        """
        Calculate expected loss and apply decision rule
        """
        dl = posterior[var][1]
        dl = 0.5 * (dl[0:-1] + dl[1:])
        fdl = posterior[var][0]
        inta = np.maximum(dl, 0) * fdl
        intb = np.maximum(-dl, 0) * fdl

        ela = np.trapz(inta, dl)
        elb = np.trapz(intb, dl)

        if self.plot:
            plt.subplot(1, 2, 1)
            b = posterior['muA'][1]
            plt.plot(0.5*(b[0:-1]+b[1:]), posterior['muA'][0], lw=2, label=r'$f(\mu_A)$')
            b = posterior['muB'][1]
            plt.plot(0.5*(b[0:-1]+b[1:]), posterior['muB'][0], lw=2, label=r'$f(\mu_B)$')
            plt.xlabel('$\mu_A,\  \mu_B$')
            plt.xlim([0, 1])
            plt.title('Conversion Rate')
            plt.locator_params(nticks=6)
            plt.gca().set_ylim(bottom=0)
            plt.legend()

            plt.subplot(1, 2, 2)
            plt.plot(dl, fdl, 'b', lw=3, label=r'f$(\mu_B - \mu_A)$')
            plt.plot([ela, ela], [0, 0.3*np.max(fdl)], 'r', lw=3, label='A: Expected Loss')
            plt.plot([elb, elb], [0, 0.3*np.max(fdl)], 'c', lw=3, label='B: Expected Loss')
            plt.plot([self.toc, self.toc], [0, 0.3*np.max(fdl)], 'k--', lw=3, label='Threshold of Caring')
            plt.xlabel(r'$\mu_B-\mu_A$')
            plt.title('Expected Loss')
            plt.gca().set_ylim(bottom=0)
            plt.gca().locator_params(axis='x', numticks=6)
            plt.legend()

        if ela <= self.toc and elb <= self.toc:
            result = 0
        elif elb < self.toc:
            result = 1
        elif ela < self.toc:
            result = -1
        else:
            result = np.nan

        return result 
Example #11
Source File: orbit_plots.py    From radvel with MIT License 4 votes vote down vote up
def plot_timeseries(self):
        """
        Make a plot of the RV data and Gaussian Process + orbit model in the current Axes.
        """

        ax = pl.gca()

        ax.axhline(0, color='0.5', linestyle='--')

        if self.subtract_orbit_model:
            orbit_model4data = np.zeros(self.rvmod.shape)
        else:
            orbit_model4data = self.rvmod

        ci = 0
        for like in self.like_list:
            ci = self.plot_gp_like(like, orbit_model4data, ci)

        # plot data
        plot.mtelplot(
            # data = residuals + model
            self.plttimes, self.rawresid+orbit_model4data, self.rverr,
            self.post.likelihood.telvec, ax, telfmts=self.telfmts
        )

        if self.set_xlim is not None:
            ax.set_xlim(self.set_xlim)
        else:
            ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt)    
        pl.setp(ax.get_xticklabels(), visible=False)

        # legend
        if self.legend:
            ax.legend(numpoints=1, **self.legend_kwargs)

        # years on upper axis
        axyrs = ax.twiny()
        xl = np.array(list(ax.get_xlim())) + self.epoch
        decimalyear = Time(xl, format='jd', scale='utc').decimalyear
        axyrs.plot(decimalyear, decimalyear)
        axyrs.get_xaxis().get_major_formatter().set_useOffset(False)
        axyrs.set_xlim(*decimalyear)
        pl.locator_params(axis='x', nbins=5)
        axyrs.set_xlabel('Year', fontweight='bold') 
Example #12
Source File: orbit_plots.py    From radvel with MIT License 4 votes vote down vote up
def plot_timeseries(self):
        """
        Make a plot of the RV data and model in the current Axes.
        """

        ax = pl.gca()

        ax.axhline(0, color='0.5', linestyle='--')

        if self.show_rms:
            rms_values = dict()
            for like in self.like_list:
                inst = like.suffix
                rms = np.std(like.residuals())
                rms_values[inst] = rms
        else:
            rms_values = False

        # plot orbit model
        ax.plot(self.mplttimes, self.orbit_model, 'b-', rasterized=False, lw=self.fit_linewidth)

        # plot data
        vels = self.rawresid+self.rvmod
        plot.mtelplot(
            # data = residuals + model
            self.plttimes, vels, self.rverr, self.post.likelihood.telvec, ax, telfmts=self.telfmts,
            rms_values=rms_values
        )

        if self.set_xlim is not None:
            ax.set_xlim(self.set_xlim)
        else:
            ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt)    
        pl.setp(ax.get_xticklabels(), visible=False)

        if self.highlight_last:
            ind = np.argmax(self.plttimes)
            pl.plot(self.plttimes[ind], vels[ind], **plot.highlight_format)

        # legend
        if self.legend:
            ax.legend(numpoints=1, **self.legend_kwargs)

        # years on upper axis
        axyrs = ax.twiny()
        xl = np.array(list(ax.get_xlim())) + self.epoch
        decimalyear = Time(xl, format='jd', scale='utc').decimalyear
#        axyrs.plot(decimalyear, decimalyear)
        axyrs.get_xaxis().get_major_formatter().set_useOffset(False)
        axyrs.set_xlim(*decimalyear)
        axyrs.set_xlabel('Year', fontweight='bold')
        pl.locator_params(axis='x', nbins=5)

        if not self.yscale_auto: 
            scale = np.std(self.rawresid+self.rvmod)
            ax.set_ylim(-self.yscale_sigma * scale, self.yscale_sigma * scale)

        ax.set_ylabel('RV [{ms:}]'.format(**plot.latex), weight='bold')
        ticks = ax.yaxis.get_majorticklocs()
        ax.yaxis.set_ticks(ticks[1:]) 
Example #13
Source File: plotting.py    From snn_toolbox with MIT License 4 votes vote down vote up
def plot_max_activ_hist(h, title=None, layer_label=None, path=None,
                        scale_fac=None):
    """Plot a histogram over the maximum activations.

    Parameters
    ----------

    h: dict
        Dictionary of datasets to plot in histogram.
    title: string, optional
        Title of histogram.
    layer_label: string, optional
        Label of layer from which data was taken.
    path: string, optional
        If not ``None``, specifies where to save the resulting image. Else,
        display plots without saving.
    scale_fac: float, optional
        The value with which parameters are normalized (maximum of activations
        or parameter value of a layer). If given, will be insterted into plot
        title.
    """

    keys = sorted(h.keys())
    plt.hist([h[key] for key in keys], label=keys, bins=1000, edgecolor='none',
             histtype='stepfilled')
    plt.xlabel('Maximum ANN activations')
    plt.ylabel('Sample count')
    if scale_fac:
        plt.axvline(scale_fac, color='red', linestyle='dashed', linewidth=2,
                    label='scale factor')
    plt.legend()
    plt.locator_params(axis='x', nbins=5)
    if title and layer_label:
        filename = layer_label + '_' + 'maximum_activity_distribution'
        facs = "Applied divisor: {:.2f}".format(scale_fac) if scale_fac else ''
        plt.title('{} distribution \n of layer {} \n {}'.format(
            title, layer_label, facs))
    else:
        plt.title('Distribution')
        filename = 'Maximum_activity_distribution'
    if path:
        plt.savefig(os.path.join(path, filename), bbox_inches='tight')
    else:
        plt.show()
    plt.close() 
Example #14
Source File: plotting.py    From snn_toolbox with MIT License 4 votes vote down vote up
def plot_activ_hist(h, title=None, layer_label=None, path=None,
                    scale_fac=None):
    """Plot a histogram over all activities of a network.

    Parameters
    ----------

    h: dict
        Dictionary of datasets to plot in histogram.
    title: string, optional
        Title of histogram.
    layer_label: string, optional
        Label of layer from which data was taken.
    path: string, optional
        If not ``None``, specifies where to save the resulting image. Else,
        display plots without saving.
    scale_fac: float, optional
        The value with which parameters are normalized (maximum of activations
        or parameter value of a layer). If given, will be insterted into plot
        title.
    """

    keys = sorted(h.keys())
    plt.hist([h[key] for key in keys], label=keys, bins=1000, edgecolor='none',
             histtype='stepfilled', log=True, bottom=1)
    if scale_fac:
        plt.axvline(scale_fac, color='red', linestyle='dashed', linewidth=2,
                    label='scale factor')
    plt.legend()
    plt.locator_params(axis='x', nbins=5)
    plt.xlabel('ANN activations')
    plt.ylabel('Count')
    plt.xlim(xmin=0)
    if title and layer_label:
        filename = layer_label + '_' + 'activ_distribution'
        facs = "Applied divisor: {:.2f}".format(scale_fac) if scale_fac else ''
        plt.title('{} distribution \n of layer {} \n {}'.format(
            title, layer_label, facs))
    else:
        plt.title('Distribution')
        filename = 'Activity_distribution'
    if path:
        plt.savefig(os.path.join(path, filename), bbox_inches='tight')
    else:
        plt.show()
    plt.close() 
Example #15
Source File: plotting.py    From snn_toolbox with MIT License 4 votes vote down vote up
def plot_hist(h, title=None, layer_label=None, path=None, scale_fac=None):
    """Plot a histogram over two datasets.

    Parameters
    ----------

    h: dict
        Dictionary of datasets to plot in histogram.
    title: string, optional
        Title of histogram.
    layer_label: string, optional
        Label of layer from which data was taken.
    path: string, optional
        If not ``None``, specifies where to save the resulting image. Else,
        display plots without saving.
    scale_fac: float, optional
        The value with which parameters are normalized (maximum of activations
        or parameter value of a layer). If given, will be insterted into plot
        title.
    """

    keys = sorted(h.keys())
    plt.hist([h[key] for key in keys], label=keys, log=True, bottom=1,
             bins=1000, histtype='stepfilled', alpha=0.5, edgecolor='none')
    if scale_fac:
        plt.axvline(scale_fac, color='red', linestyle='dashed', linewidth=2,
                    label='scale factor')
    plt.legend()
    plt.locator_params(axis='x', nbins=5)
    if title and layer_label:
        if 'Spikerates' in title:
            filename = '4' + title + '_distribution'
            unit = '[Hz]'
        else:
            filename = layer_label + '_' + title + '_distribution'
            unit = ''
        facs = "Applied divisor: {:.2f}".format(scale_fac) if scale_fac else ''
        plt.title('{} distribution {} \n of layer {} \n {}'.format(
            title, unit, layer_label, facs))
    else:
        plt.title('Distribution')
        filename = 'Activity_distribution'
    if path:
        plt.savefig(os.path.join(path, filename), bbox_inches='tight')
    else:
        plt.show()
    plt.close() 
Example #16
Source File: plotting.py    From snn_toolbox with MIT License 4 votes vote down vote up
def plot_network_correlations(spikerates, layer_activations):
    """Plot the correlation between SNN spiketrains and ANN activations.

    For each layer, the method draws a scatter plot, showing the correlation
    between the average firing rate of neurons in the SNN layer and the
    activation of the corresponding neurons in the ANN layer.

    Parameters
    ----------

    spikerates: list of tuples ``(spikerate, label)``.

        ``spikerate`` is a 1D array containing the mean firing rates of the
        neurons in a specific layer.

        ``label`` is a string specifying both the layer type and the index,
        e.g. ``'3Dense'``.

    layer_activations: list of tuples ``(activations, label)``
        Each entry represents a layer in the ANN for which an activation can be
        calculated (e.g. ``Dense``, ``Conv2D``).

        ``activations`` is an array of the same dimension as the corresponding
        layer, containing the activations of Dense or Convolution layers.

        ``label`` is a string specifying the layer type, e.g. ``'Dense'``.
    """

    num_layers = len(layer_activations)
    # Determine optimal shape for rectangular arrangement of plots
    num_rows = int(np.ceil(np.sqrt(num_layers)))
    num_cols = int(np.ceil(num_layers / num_rows))
    f, ax = plt.subplots(num_rows, num_cols, squeeze=False,
                         figsize=(8, 1 + num_rows * 4))
    for i in range(num_rows):
        for j in range(num_cols):
            layer_num = j + i * num_cols
            if layer_num >= num_layers:
                break
            ax[i, j].plot(layer_activations[layer_num][0].flatten(),
                          spikerates[layer_num][0], '.')
            ax[i, j].set_title(spikerates[layer_num][1], fontsize='medium')
            ax[i, j].locator_params(nbins=4)
            ax[i, j].set_xlim([None,
                               np.max(layer_activations[layer_num][0]) * 1.1])
            ax[i, j].set_ylim([None, max(spikerates[layer_num][0]) * 1.1])
    f.suptitle('ANN-SNN correlations', fontsize=20)
    f.subplots_adjust(wspace=0.3, hspace=0.3)
    f.text(0.5, 0.04, 'SNN spikerates (Hz)', ha='center', fontsize=16)
    f.text(0.04, 0.5, 'ANN activations', va='center', rotation='vertical',
           fontsize=16) 
Example #17
Source File: plotting.py    From snn_toolbox with MIT License 4 votes vote down vote up
def plot_layer_correlation(rates, activations, title, config, path=None,
                           same_xylim=True):
    """
    Plot correlation between spikerates and activations of a specific layer,
    as 2D-dot-plot.

    Parameters
    ----------

    rates: np.array
        The spikerates of a layer, flattened to 1D.
    activations: Union[ndarray, Iterable]
        The activations of a layer, flattened to 1D.
    title: str
        Plot title.
    config: configparser.ConfigParser
        Settings.
    path: Optional[str]
        If not ``None``, specifies where to save the resulting image. Else,
        display plots without saving.
    same_xylim: Optional[bool]
        Whether to use the same axis limit on the ``rates`` and
        ``activations``. If ``True``, the maximum is chosen. Default: ``True``.
    """

    # Determine percentage of saturated neurons. Need to subtract one time step
    dt = config.getfloat('simulation', 'dt')
    duration = config.getint('simulation', 'duration')
    p = np.mean(np.greater_equal(rates, 1000 / dt - 1000 / duration / dt))

    plt.figure()
    plt.plot(activations, rates, '.')
    plt.annotate("{:.2%} units saturated.".format(p), xy=(1, 1),
                 xycoords='axes fraction', xytext=(-200, -20),
                 textcoords='offset points')
    plt.title(title, fontsize=20)
    plt.locator_params(nbins=4)
    lim = max([1.1, max(activations), max(rates)]) if same_xylim else None
    plt.xlim([0, lim])
    plt.ylim([0, lim])
    plt.xlabel('ANN activations', fontsize=16)
    plt.ylabel('SNN spikerates [Hz]', fontsize=16)
    if path is not None:
        filename = '5Correlation'
        plt.savefig(os.path.join(path, filename), bbox_inches='tight')
    else:
        plt.show()
    plt.close() 
Example #18
Source File: visualize.py    From adversarial-policies with MIT License 4 votes vote down vote up
def bar_chart(envs, victim_id, n_components, covariance, savefile=None):
    """Bar chart of mean log probability for all opponent types, grouped by environment.
    For unspecified parameters, see get_full_directory.

    :param envs: (list of str) list of environments.
    :param savefile: (None or str) path to save figure to.
    """
    dfs = []
    for env in envs:
        df = load_metadata(env, victim_id, n_components, covariance)
        df["Environment"] = PRETTY_ENVS.get(env, env)
        dfs.append(df)
    longform = pd.concat(dfs)
    longform["opponent_id"] = longform["opponent_id"].apply(PRETTY_OPPONENTS.get)
    longform = longform.reset_index(drop=True)

    width, height = plt.rcParams.get("figure.figsize")
    legend_height = 0.4
    left_margin_in = 0.55
    top_margin_in = legend_height + 0.05
    bottom_margin_in = 0.5
    gridspec_kw = {
        "left": left_margin_in / width,
        "top": 1 - (top_margin_in / height),
        "bottom": bottom_margin_in / height,
    }
    fig, ax = plt.subplots(1, 1, gridspec_kw=gridspec_kw)

    # Make colors consistent with previous figures
    standard_cycle = list(plt.rcParams["axes.prop_cycle"])
    palette = {
        label: standard_cycle[CYCLE_ORDER.index(label)]["color"]
        for label in PRETTY_OPPONENTS.values()
    }

    # Actually plot
    sns.barplot(
        x="Environment",
        y="log_proba",
        hue="opponent_id",
        order=PRETTY_ENVS.values(),
        hue_order=BAR_ORDER,
        data=longform,
        palette=palette,
        errwidth=1,
    )
    ax.set_ylabel("Mean Log Probability Density")
    plt.locator_params(axis="y", nbins=4)
    util.rotate_labels(ax, xrot=0)

    # Plot our own legend
    ax.get_legend().remove()
    legend_entries = ax.get_legend_handles_labels()
    util.outside_legend(
        legend_entries, 3, fig, ax, ax, legend_padding=0.05, legend_height=0.6, handletextpad=0.2
    )

    if savefile is not None:
        fig.savefig(savefile)

    return fig