Python matplotlib.pyplot.locator_params() Examples
The following are 18
code examples of matplotlib.pyplot.locator_params().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
matplotlib.pyplot
, or try the search function
.
Example #1
Source File: plot_apogee_lamost_cannon.py From TheCannon with MIT License | 6 votes |
def plot_one(title, ax, x, y, lim): ax.scatter(x, y-x, marker='x', c='k', alpha=0.5) # ax.set_title(r"%s" %title) #axarr[0].plot([-100,10000],[-100,10000], c='r') ax.axhline(y=0, c='r') scat = np.std(y-x) scat = round_2(scat) bias = np.mean(y-x) bias = round_2(bias) textstr = "RMS: %s \nBias: %s" %(scat, bias) ax.text(0.05,0.95, textstr, ha='left', va='top', transform=ax.transAxes) ax.locator_params(axis='x', nbins=5) ax.locator_params(axis='y', nbins=5) #ymin = -10*scat #ymax = 10*scat ax.set_ylim(-1*lim, lim) #print(ymin, ymax) num_up = sum((y-x)>lim) num_down = sum((y-x)<-1*lim) print("%s above, %s below" %(num_up, num_down))
Example #2
Source File: oraclesplot.py From actions-for-actions with GNU General Public License v3.0 | 6 votes |
def finalize_plot(allticks,handles): plt.locator_params(axis='x', nticks=Noracles,nbins=Noracles) plt.yticks([x[0] for x in allticks], [x[1] for x in allticks]) plt.tick_params( axis='y', # changes apply to the x-axis which='both', # both major and minor ticks are affected left='off', # ticks along the bottom edge are off right='off' # ticks along the top edge are off ) if LEGEND: plt.legend([h[0] for h in handles],seriesnames, loc='upper right',borderaxespad=0., ncol=1,fontsize=10,numpoints=1) plt.gcf().tight_layout() ###################################################### # Data processing
Example #3
Source File: slashdot_results.py From news-popularity-prediction with Apache License 2.0 | 5 votes |
def make_slashdot_figures(output_path_prefix, method_name_list, slashdot_mse, slashdot_jaccard, slashdot_k_list): sns.set_style("darkgrid") sns.set_context("paper") translator = get_method_name_to_legend_name_dict() slashdot_k_list = list(slashdot_k_list) fig, axes = plt.subplots(1, 2, sharex=True) axes[0].set_title("SlashDot Comments") axes[1].set_title("SlashDot Users") plt.locator_params(nbins=8) # Comments for m, method in enumerate(method_name_list): axes[0].set_ylabel("MSE") axes[0].set_xlabel("Lifetime (sec)") axes[0].plot(slashdot_k_list[1:], handle_nan(slashdot_mse[method]["comments"].mean(axis=1))[1:], label=translator[method]) # Users for m, method in enumerate(method_name_list): # axes[1].set_ylabel("MSE") axes[1].set_xlabel("Lifetime (sec)") axes[1].plot(slashdot_k_list[1:], handle_nan(slashdot_mse[method]["users"].mean(axis=1))[1:], label=translator[method]) axes[1].legend(loc="upper right") # plt.show() plt.savefig(output_path_prefix + "_mse_slashdot_SNOW" + ".png", format="png") plt.savefig(output_path_prefix + "_mse_slashdot_SNOW" + ".eps", format="eps")
Example #4
Source File: slashdot_results.py From news-popularity-prediction with Apache License 2.0 | 5 votes |
def make_barrapunto_figures(output_path_prefix, method_name_list, barrapunto_mse, barrapunto_jaccard, barrapunto_k_list): sns.set_style("darkgrid") sns.set_context("paper") translator = get_method_name_to_legend_name_dict() barrapunto_k_list = list(barrapunto_k_list) fig, axes = plt.subplots(1, 2, sharex=True) axes[0].set_title("BarraPunto Comments") axes[1].set_title("BarraPunto Users") plt.locator_params(nbins=8) # Comments for m, method in enumerate(method_name_list): axes[0].set_ylabel("MSE") axes[0].set_xlabel("Lifetime (sec)") axes[0].plot(barrapunto_k_list[1:], handle_nan(barrapunto_mse[method]["comments"].mean(axis=1))[1:], label=translator[method]) # Users for m, method in enumerate(method_name_list): # axes[1].set_ylabel("MSE") axes[1].set_xlabel("Lifetime (sec)") axes[1].plot(barrapunto_k_list[1:], handle_nan(barrapunto_mse[method]["users"].mean(axis=1))[1:], label=translator[method]) axes[1].legend(loc="upper right") # plt.show() plt.savefig(output_path_prefix + "_mse_barrapunto_SNOW" + ".png", format="png") plt.savefig(output_path_prefix + "_mse_barrapunto_SNOW" + ".eps", format="eps")
Example #5
Source File: util.py From razzy-spinner with GNU General Public License v3.0 | 5 votes |
def _show_plot(x_values, y_values, x_labels=None, y_labels=None): try: import matplotlib.pyplot as plt except ImportError: raise ImportError('The plot function requires matplotlib to be installed.' 'See http://matplotlib.org/') plt.locator_params(axis='y', nbins=3) axes = plt.axes() axes.yaxis.grid() plt.plot(x_values, y_values, 'ro', color='red') plt.ylim(ymin=-1.2, ymax=1.2) plt.tight_layout(pad=5) if x_labels: plt.xticks(x_values, x_labels, rotation='vertical') if y_labels: plt.yticks([-1, 0, 1], y_labels, rotation='horizontal') # Pad margins so that markers are not clipped by the axes plt.margins(0.2) plt.show() #//////////////////////////////////////////////////////////// #{ Parsing and conversion functions #////////////////////////////////////////////////////////////
Example #6
Source File: plot_apogee_lamost_cannon.py From TheCannon with MIT License | 5 votes |
def create_grid(): fig = plt.figure(figsize=(15,20)) #plt.locator_params(nbins=5) #ax = fig.add_subplot(111) #plt.setp(ax.get_yticklabels(), visible=False) #plt.setp(ax.get_xticklabels(), visible=False) ax00 = fig.add_subplot(331) ax01 = fig.add_subplot(332, sharex=ax00, sharey=ax00) plt.setp(ax01.get_yticklabels(), visible=False) xticks = ax01.xaxis.get_major_ticks() xticks[0].set_visible(False) ax02 = fig.add_subplot(333, sharex=ax00, sharey=ax00) plt.setp(ax02.get_yticklabels(), visible=False) xticks = ax02.xaxis.get_major_ticks() xticks[0].set_visible(False) ax10 = fig.add_subplot(334) ax11 = fig.add_subplot(335, sharex=ax10, sharey=ax10) plt.setp(ax11.get_yticklabels(), visible=False) xticks = ax11.xaxis.get_major_ticks() xticks[0].set_visible(False) ax12 = fig.add_subplot(336, sharex=ax10, sharey=ax10) plt.setp(ax12.get_yticklabels(), visible=False) xticks = ax12.xaxis.get_major_ticks() xticks[0].set_visible(False) ax20 = fig.add_subplot(337) ax21 = fig.add_subplot(338, sharex=ax20, sharey=ax20) plt.setp(ax21.get_yticklabels(), visible=False) xticks = ax21.xaxis.get_major_ticks() xticks[0].set_visible(False) ax22 = fig.add_subplot(339, sharex=ax20, sharey=ax20) plt.setp(ax22.get_yticklabels(), visible=False) xticks = ax22.xaxis.get_major_ticks() xticks[0].set_visible(False) fig.subplots_adjust(wspace=0) fig.subplots_adjust(hspace=0.2) axarr = ((ax00,ax01,ax02), (ax10,ax11,ax12), (ax20,ax21,ax22)) return fig, axarr
Example #7
Source File: plot_sweep.py From gpkit with MIT License | 5 votes |
def format_and_label_axes(var, posys, axes, ylabel=True): "Formats and labels axes" for posy, ax in zip(posys, axes): if ylabel: if hasattr(posy, "key"): ylabel = (posy.key.descr.get("label", posy.key.name) + " [%s]" % posy.key.unitstr(dimless="-")) else: ylabel = str(posy) ax.set_ylabel(ylabel) ax.grid(color="0.6") # ax.set_frame_on(False) for item in [ax.xaxis.label, ax.yaxis.label]: item.set_fontsize(12) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(9) ax.tick_params(length=0) ax.spines['left'].set_visible(False) ax.spines['top'].set_visible(False) for i in ax.spines.values(): i.set_linewidth(0.6) i.set_color("0.6") i.set_linestyle("dotted") xlabel = (var.key.descr.get("label", var.key.name) + " [%s]" % var.key.unitstr(dimless="-")) ax.set_xlabel(xlabel) # pylint: disable=undefined-loop-variable plt.locator_params(nbins=4) plt.subplots_adjust(wspace=0.15) # pylint: disable=too-many-locals,too-many-branches,too-many-statements
Example #8
Source File: make_figure.py From pyhawkes with MIT License | 5 votes |
def make_figure_a(S, F, C): """ Plot fluorescence traces, filtered fluorescence, and spike times for three neurons """ col = harvard_colors() dt = 0.02 T_start = 0 T_stop = 1 * 50 * 60 t = dt * np.arange(T_start, T_stop) ks = [0,1] nk = len(ks) fig = create_figure((3,3)) for ind,k in enumerate(ks): ax = fig.add_subplot(nk,1,ind+1) ax.plot(t, F[T_start:T_stop, k], color=col[1], label="$F$") # Plot the raw flourescence in blue ax.plot(t, C[T_start:T_stop, k], color=col[0], lw=1.5, label="$\widehat{F}$") # Plot the filtered flourescence in red spks = np.where(S[T_start:T_stop, k])[0] ax.plot(t[spks], C[spks,k], 'ko', label="S") # Plot the spike times in black # Make a legend if ind == 0: # Put a legend above plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=3, mode="expand", borderaxespad=0., prop={'size':9}) # Add labels ax.set_ylabel("$F_%d(t)$" % (k+1)) if ind == nk-1: ax.set_xlabel("Time $t$ [sec]") # Format the ticks ax.set_ylim([-0.1,1.0]) plt.locator_params(nbins=5, axis="y") plt.subplots_adjust(left=0.2, bottom=0.2) fig.savefig("figure3a.pdf") plt.show()
Example #9
Source File: util.py From V1EngineeringInc-Docs with Creative Commons Attribution Share Alike 4.0 International | 5 votes |
def _show_plot(x_values, y_values, x_labels=None, y_labels=None): try: import matplotlib.pyplot as plt except ImportError: raise ImportError( 'The plot function requires matplotlib to be installed.' 'See http://matplotlib.org/' ) plt.locator_params(axis='y', nbins=3) axes = plt.axes() axes.yaxis.grid() plt.plot(x_values, y_values, 'ro', color='red') plt.ylim(ymin=-1.2, ymax=1.2) plt.tight_layout(pad=5) if x_labels: plt.xticks(x_values, x_labels, rotation='vertical') if y_labels: plt.yticks([-1, 0, 1], y_labels, rotation='horizontal') # Pad margins so that markers are not clipped by the axes plt.margins(0.2) plt.show() # //////////////////////////////////////////////////////////// # { Parsing and conversion functions # ////////////////////////////////////////////////////////////
Example #10
Source File: ab_exp.py From abyes with Apache License 2.0 | 4 votes |
def expected_loss_decision(self, posterior, var): """ Calculate expected loss and apply decision rule """ dl = posterior[var][1] dl = 0.5 * (dl[0:-1] + dl[1:]) fdl = posterior[var][0] inta = np.maximum(dl, 0) * fdl intb = np.maximum(-dl, 0) * fdl ela = np.trapz(inta, dl) elb = np.trapz(intb, dl) if self.plot: plt.subplot(1, 2, 1) b = posterior['muA'][1] plt.plot(0.5*(b[0:-1]+b[1:]), posterior['muA'][0], lw=2, label=r'$f(\mu_A)$') b = posterior['muB'][1] plt.plot(0.5*(b[0:-1]+b[1:]), posterior['muB'][0], lw=2, label=r'$f(\mu_B)$') plt.xlabel('$\mu_A,\ \mu_B$') plt.xlim([0, 1]) plt.title('Conversion Rate') plt.locator_params(nticks=6) plt.gca().set_ylim(bottom=0) plt.legend() plt.subplot(1, 2, 2) plt.plot(dl, fdl, 'b', lw=3, label=r'f$(\mu_B - \mu_A)$') plt.plot([ela, ela], [0, 0.3*np.max(fdl)], 'r', lw=3, label='A: Expected Loss') plt.plot([elb, elb], [0, 0.3*np.max(fdl)], 'c', lw=3, label='B: Expected Loss') plt.plot([self.toc, self.toc], [0, 0.3*np.max(fdl)], 'k--', lw=3, label='Threshold of Caring') plt.xlabel(r'$\mu_B-\mu_A$') plt.title('Expected Loss') plt.gca().set_ylim(bottom=0) plt.gca().locator_params(axis='x', numticks=6) plt.legend() if ela <= self.toc and elb <= self.toc: result = 0 elif elb < self.toc: result = 1 elif ela < self.toc: result = -1 else: result = np.nan return result
Example #11
Source File: orbit_plots.py From radvel with MIT License | 4 votes |
def plot_timeseries(self): """ Make a plot of the RV data and Gaussian Process + orbit model in the current Axes. """ ax = pl.gca() ax.axhline(0, color='0.5', linestyle='--') if self.subtract_orbit_model: orbit_model4data = np.zeros(self.rvmod.shape) else: orbit_model4data = self.rvmod ci = 0 for like in self.like_list: ci = self.plot_gp_like(like, orbit_model4data, ci) # plot data plot.mtelplot( # data = residuals + model self.plttimes, self.rawresid+orbit_model4data, self.rverr, self.post.likelihood.telvec, ax, telfmts=self.telfmts ) if self.set_xlim is not None: ax.set_xlim(self.set_xlim) else: ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt) pl.setp(ax.get_xticklabels(), visible=False) # legend if self.legend: ax.legend(numpoints=1, **self.legend_kwargs) # years on upper axis axyrs = ax.twiny() xl = np.array(list(ax.get_xlim())) + self.epoch decimalyear = Time(xl, format='jd', scale='utc').decimalyear axyrs.plot(decimalyear, decimalyear) axyrs.get_xaxis().get_major_formatter().set_useOffset(False) axyrs.set_xlim(*decimalyear) pl.locator_params(axis='x', nbins=5) axyrs.set_xlabel('Year', fontweight='bold')
Example #12
Source File: orbit_plots.py From radvel with MIT License | 4 votes |
def plot_timeseries(self): """ Make a plot of the RV data and model in the current Axes. """ ax = pl.gca() ax.axhline(0, color='0.5', linestyle='--') if self.show_rms: rms_values = dict() for like in self.like_list: inst = like.suffix rms = np.std(like.residuals()) rms_values[inst] = rms else: rms_values = False # plot orbit model ax.plot(self.mplttimes, self.orbit_model, 'b-', rasterized=False, lw=self.fit_linewidth) # plot data vels = self.rawresid+self.rvmod plot.mtelplot( # data = residuals + model self.plttimes, vels, self.rverr, self.post.likelihood.telvec, ax, telfmts=self.telfmts, rms_values=rms_values ) if self.set_xlim is not None: ax.set_xlim(self.set_xlim) else: ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt) pl.setp(ax.get_xticklabels(), visible=False) if self.highlight_last: ind = np.argmax(self.plttimes) pl.plot(self.plttimes[ind], vels[ind], **plot.highlight_format) # legend if self.legend: ax.legend(numpoints=1, **self.legend_kwargs) # years on upper axis axyrs = ax.twiny() xl = np.array(list(ax.get_xlim())) + self.epoch decimalyear = Time(xl, format='jd', scale='utc').decimalyear # axyrs.plot(decimalyear, decimalyear) axyrs.get_xaxis().get_major_formatter().set_useOffset(False) axyrs.set_xlim(*decimalyear) axyrs.set_xlabel('Year', fontweight='bold') pl.locator_params(axis='x', nbins=5) if not self.yscale_auto: scale = np.std(self.rawresid+self.rvmod) ax.set_ylim(-self.yscale_sigma * scale, self.yscale_sigma * scale) ax.set_ylabel('RV [{ms:}]'.format(**plot.latex), weight='bold') ticks = ax.yaxis.get_majorticklocs() ax.yaxis.set_ticks(ticks[1:])
Example #13
Source File: plotting.py From snn_toolbox with MIT License | 4 votes |
def plot_max_activ_hist(h, title=None, layer_label=None, path=None, scale_fac=None): """Plot a histogram over the maximum activations. Parameters ---------- h: dict Dictionary of datasets to plot in histogram. title: string, optional Title of histogram. layer_label: string, optional Label of layer from which data was taken. path: string, optional If not ``None``, specifies where to save the resulting image. Else, display plots without saving. scale_fac: float, optional The value with which parameters are normalized (maximum of activations or parameter value of a layer). If given, will be insterted into plot title. """ keys = sorted(h.keys()) plt.hist([h[key] for key in keys], label=keys, bins=1000, edgecolor='none', histtype='stepfilled') plt.xlabel('Maximum ANN activations') plt.ylabel('Sample count') if scale_fac: plt.axvline(scale_fac, color='red', linestyle='dashed', linewidth=2, label='scale factor') plt.legend() plt.locator_params(axis='x', nbins=5) if title and layer_label: filename = layer_label + '_' + 'maximum_activity_distribution' facs = "Applied divisor: {:.2f}".format(scale_fac) if scale_fac else '' plt.title('{} distribution \n of layer {} \n {}'.format( title, layer_label, facs)) else: plt.title('Distribution') filename = 'Maximum_activity_distribution' if path: plt.savefig(os.path.join(path, filename), bbox_inches='tight') else: plt.show() plt.close()
Example #14
Source File: plotting.py From snn_toolbox with MIT License | 4 votes |
def plot_activ_hist(h, title=None, layer_label=None, path=None, scale_fac=None): """Plot a histogram over all activities of a network. Parameters ---------- h: dict Dictionary of datasets to plot in histogram. title: string, optional Title of histogram. layer_label: string, optional Label of layer from which data was taken. path: string, optional If not ``None``, specifies where to save the resulting image. Else, display plots without saving. scale_fac: float, optional The value with which parameters are normalized (maximum of activations or parameter value of a layer). If given, will be insterted into plot title. """ keys = sorted(h.keys()) plt.hist([h[key] for key in keys], label=keys, bins=1000, edgecolor='none', histtype='stepfilled', log=True, bottom=1) if scale_fac: plt.axvline(scale_fac, color='red', linestyle='dashed', linewidth=2, label='scale factor') plt.legend() plt.locator_params(axis='x', nbins=5) plt.xlabel('ANN activations') plt.ylabel('Count') plt.xlim(xmin=0) if title and layer_label: filename = layer_label + '_' + 'activ_distribution' facs = "Applied divisor: {:.2f}".format(scale_fac) if scale_fac else '' plt.title('{} distribution \n of layer {} \n {}'.format( title, layer_label, facs)) else: plt.title('Distribution') filename = 'Activity_distribution' if path: plt.savefig(os.path.join(path, filename), bbox_inches='tight') else: plt.show() plt.close()
Example #15
Source File: plotting.py From snn_toolbox with MIT License | 4 votes |
def plot_hist(h, title=None, layer_label=None, path=None, scale_fac=None): """Plot a histogram over two datasets. Parameters ---------- h: dict Dictionary of datasets to plot in histogram. title: string, optional Title of histogram. layer_label: string, optional Label of layer from which data was taken. path: string, optional If not ``None``, specifies where to save the resulting image. Else, display plots without saving. scale_fac: float, optional The value with which parameters are normalized (maximum of activations or parameter value of a layer). If given, will be insterted into plot title. """ keys = sorted(h.keys()) plt.hist([h[key] for key in keys], label=keys, log=True, bottom=1, bins=1000, histtype='stepfilled', alpha=0.5, edgecolor='none') if scale_fac: plt.axvline(scale_fac, color='red', linestyle='dashed', linewidth=2, label='scale factor') plt.legend() plt.locator_params(axis='x', nbins=5) if title and layer_label: if 'Spikerates' in title: filename = '4' + title + '_distribution' unit = '[Hz]' else: filename = layer_label + '_' + title + '_distribution' unit = '' facs = "Applied divisor: {:.2f}".format(scale_fac) if scale_fac else '' plt.title('{} distribution {} \n of layer {} \n {}'.format( title, unit, layer_label, facs)) else: plt.title('Distribution') filename = 'Activity_distribution' if path: plt.savefig(os.path.join(path, filename), bbox_inches='tight') else: plt.show() plt.close()
Example #16
Source File: plotting.py From snn_toolbox with MIT License | 4 votes |
def plot_network_correlations(spikerates, layer_activations): """Plot the correlation between SNN spiketrains and ANN activations. For each layer, the method draws a scatter plot, showing the correlation between the average firing rate of neurons in the SNN layer and the activation of the corresponding neurons in the ANN layer. Parameters ---------- spikerates: list of tuples ``(spikerate, label)``. ``spikerate`` is a 1D array containing the mean firing rates of the neurons in a specific layer. ``label`` is a string specifying both the layer type and the index, e.g. ``'3Dense'``. layer_activations: list of tuples ``(activations, label)`` Each entry represents a layer in the ANN for which an activation can be calculated (e.g. ``Dense``, ``Conv2D``). ``activations`` is an array of the same dimension as the corresponding layer, containing the activations of Dense or Convolution layers. ``label`` is a string specifying the layer type, e.g. ``'Dense'``. """ num_layers = len(layer_activations) # Determine optimal shape for rectangular arrangement of plots num_rows = int(np.ceil(np.sqrt(num_layers))) num_cols = int(np.ceil(num_layers / num_rows)) f, ax = plt.subplots(num_rows, num_cols, squeeze=False, figsize=(8, 1 + num_rows * 4)) for i in range(num_rows): for j in range(num_cols): layer_num = j + i * num_cols if layer_num >= num_layers: break ax[i, j].plot(layer_activations[layer_num][0].flatten(), spikerates[layer_num][0], '.') ax[i, j].set_title(spikerates[layer_num][1], fontsize='medium') ax[i, j].locator_params(nbins=4) ax[i, j].set_xlim([None, np.max(layer_activations[layer_num][0]) * 1.1]) ax[i, j].set_ylim([None, max(spikerates[layer_num][0]) * 1.1]) f.suptitle('ANN-SNN correlations', fontsize=20) f.subplots_adjust(wspace=0.3, hspace=0.3) f.text(0.5, 0.04, 'SNN spikerates (Hz)', ha='center', fontsize=16) f.text(0.04, 0.5, 'ANN activations', va='center', rotation='vertical', fontsize=16)
Example #17
Source File: plotting.py From snn_toolbox with MIT License | 4 votes |
def plot_layer_correlation(rates, activations, title, config, path=None, same_xylim=True): """ Plot correlation between spikerates and activations of a specific layer, as 2D-dot-plot. Parameters ---------- rates: np.array The spikerates of a layer, flattened to 1D. activations: Union[ndarray, Iterable] The activations of a layer, flattened to 1D. title: str Plot title. config: configparser.ConfigParser Settings. path: Optional[str] If not ``None``, specifies where to save the resulting image. Else, display plots without saving. same_xylim: Optional[bool] Whether to use the same axis limit on the ``rates`` and ``activations``. If ``True``, the maximum is chosen. Default: ``True``. """ # Determine percentage of saturated neurons. Need to subtract one time step dt = config.getfloat('simulation', 'dt') duration = config.getint('simulation', 'duration') p = np.mean(np.greater_equal(rates, 1000 / dt - 1000 / duration / dt)) plt.figure() plt.plot(activations, rates, '.') plt.annotate("{:.2%} units saturated.".format(p), xy=(1, 1), xycoords='axes fraction', xytext=(-200, -20), textcoords='offset points') plt.title(title, fontsize=20) plt.locator_params(nbins=4) lim = max([1.1, max(activations), max(rates)]) if same_xylim else None plt.xlim([0, lim]) plt.ylim([0, lim]) plt.xlabel('ANN activations', fontsize=16) plt.ylabel('SNN spikerates [Hz]', fontsize=16) if path is not None: filename = '5Correlation' plt.savefig(os.path.join(path, filename), bbox_inches='tight') else: plt.show() plt.close()
Example #18
Source File: visualize.py From adversarial-policies with MIT License | 4 votes |
def bar_chart(envs, victim_id, n_components, covariance, savefile=None): """Bar chart of mean log probability for all opponent types, grouped by environment. For unspecified parameters, see get_full_directory. :param envs: (list of str) list of environments. :param savefile: (None or str) path to save figure to. """ dfs = [] for env in envs: df = load_metadata(env, victim_id, n_components, covariance) df["Environment"] = PRETTY_ENVS.get(env, env) dfs.append(df) longform = pd.concat(dfs) longform["opponent_id"] = longform["opponent_id"].apply(PRETTY_OPPONENTS.get) longform = longform.reset_index(drop=True) width, height = plt.rcParams.get("figure.figsize") legend_height = 0.4 left_margin_in = 0.55 top_margin_in = legend_height + 0.05 bottom_margin_in = 0.5 gridspec_kw = { "left": left_margin_in / width, "top": 1 - (top_margin_in / height), "bottom": bottom_margin_in / height, } fig, ax = plt.subplots(1, 1, gridspec_kw=gridspec_kw) # Make colors consistent with previous figures standard_cycle = list(plt.rcParams["axes.prop_cycle"]) palette = { label: standard_cycle[CYCLE_ORDER.index(label)]["color"] for label in PRETTY_OPPONENTS.values() } # Actually plot sns.barplot( x="Environment", y="log_proba", hue="opponent_id", order=PRETTY_ENVS.values(), hue_order=BAR_ORDER, data=longform, palette=palette, errwidth=1, ) ax.set_ylabel("Mean Log Probability Density") plt.locator_params(axis="y", nbins=4) util.rotate_labels(ax, xrot=0) # Plot our own legend ax.get_legend().remove() legend_entries = ax.get_legend_handles_labels() util.outside_legend( legend_entries, 3, fig, ax, ax, legend_padding=0.05, legend_height=0.6, handletextpad=0.2 ) if savefile is not None: fig.savefig(savefile) return fig