import matplotlib.pyplot as plt import numpy as np import scipy.stats as st def patch_plot(ax, x, y, yl, yu, label): if label.startswith('random'): h = ax.plot(x, y, '--', label=label, linewidth=1, alpha=0.6) else: h = ax.plot(x, y, '-', label=label, linewidth=1) ax.fill_between(x, yl, yu, label='', alpha=0.15, facecolor=h[0].get_color()) def plot_comparison(n_evals, regrets, PLOT_TYPE='mean', pct_values=[25,50,75]): plt.xlabel('Iterations') ax = plt.gca() for k in regrets.keys(): r = regrets[k] n_test = r.shape[1] if PLOT_TYPE == 'mean': regret = r.mean(axis=1) regret_sd = r.std(axis=1) regret_low = regret - regret_sd/np.sqrt(n_test) regret_high = regret + regret_sd/np.sqrt(n_test) else: pcts = np.percentile(r, pct_values, axis=1) regret = pcts[1,:] regret_low = pcts[0,:] regret_high = pcts[2,:] patch_plot(ax, range(n_evals), regret, regret_low, regret_high, label=k) plt.legend(loc='upper right') if PLOT_TYPE == 'mean': plt.ylabel('Regret (mean $\pm$ SE)') else: plt.ylabel('Regret (quartiles)') def compare_regrets(regrets): n_evals = regrets[list(regrets.keys())[0]].shape[0] plot_comparison(n_evals, regrets) plt.show() def compare_ranks(regrets, num_decimals=10): all_results = np.stack([regrets[k] for k in regrets.keys()], axis=2) all_results = np.around(all_results, num_decimals) ranks0 = np.apply_along_axis(st.rankdata, 2, all_results) ranks = {} for i in range(len(regrets.keys())): ranks[list(regrets.keys())[i]] = ranks0[:,:,i].squeeze() n_evals = regrets[list(regrets.keys())[0]].shape[0] plot_comparison(n_evals, ranks) plt.ylabel('Rank (mean $\pm$ SE)') plt.show()