Python matplotlib.pyplot.ticklabel_format() Examples

The following are 19 code examples of matplotlib.pyplot.ticklabel_format(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: visualizing_data.py    From data-science-from-scratch with MIT License 6 votes vote down vote up
def make_chart_misleading_y_axis(mislead=True):

    mentions = [500, 505]
    years = [2013, 2014]

    plt.bar([2012.6, 2013.6], mentions, 0.8)
    plt.xticks(years)
    plt.ylabel("# of times I heard someone say 'data science'")

    # if you don't do this, matplotlib will label the x-axis 0, 1
    # and then add a +2.013e3 off in the corner (bad matplotlib!)
    plt.ticklabel_format(useOffset=False)

    if mislead:
        # misleading y-axis only shows the part above 500
        plt.axis([2012.5,2014.5,499,506])
        plt.title("Look at the 'Huge' Increase!")
    else:
        plt.axis([2012.5,2014.5,0,550])
        plt.title("Not So Huge Anymore.")
    plt.show() 
Example #2
Source File: genesis_plot.py    From ocelot with GNU General Public License v3.0 6 votes vote down vote up
def subfig_evo_el_energy(ax_energy, g, legend, **kwargs):
    number_ticks = 6

    el_energy = g.el_energy * m_e_MeV
    el_energy_av = int(np.mean(el_energy))
    ax_energy.plot(g.z, np.average(el_energy - el_energy_av, axis=0), 'b-', linewidth=1.5)
    ax_energy.set_ylabel('E + ' + str(el_energy_av) + '[MeV]')
    ax_energy.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3), useOffset=False)
    ax_energy.grid(kwargs.get('grid', True))

    ax_spread = ax_energy.twinx()
    ax_spread.plot(g.z, np.average(g.el_e_spread * m_e_GeV * 1000, weights=g.I, axis=0), 'm--', g.z,
                   np.amax(g.el_e_spread * m_e_GeV * 1000, axis=0), 'r--', linewidth=1.5)
    ax_spread.set_ylabel(r'$\sigma_E$ [MeV]')
    ax_spread.grid(False)
    ax_spread.set_ylim(ymin=0)

    ax_energy.yaxis.major.locator.set_params(nbins=number_ticks)
    ax_spread.yaxis.major.locator.set_params(nbins=number_ticks)

    ax_energy.tick_params(axis='y', which='both', colors='b')
    ax_energy.yaxis.label.set_color('b')
    ax_spread.tick_params(axis='y', which='both', colors='r')
    ax_spread.yaxis.label.set_color('r') 
Example #3
Source File: em_3d_help.py    From mapper-tda with MIT License 6 votes vote down vote up
def plot_clustering_3d(obj, data_local, data_global, filename):


    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    colors = get_colors(len(obj.c_to_ind))

    ax.plot(*zip(*data_global), marker='o', color='k', ls='', ms=4., mew=1.0, alpha=0.4, mec='none')

    for i,c in enumerate(obj.c_to_ind):
        ax.plot(*zip(*data_local[obj.c_to_ind[c]]), marker='o', color=colors[i], ls='', ms=4., mew=1.0, alpha=0.8, mec='none')
        

    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    ax.view_init(*params.ANGLE)
    #plt.grid(True)
    #ax.set_axis_bgcolor('grey')


    
    fig.savefig(filename, format='png')

    plt.close() 
Example #4
Source File: em_help.py    From mapper-tda with MIT License 6 votes vote down vote up
def plot_clustering(obj, data, filename, axis_str=('', ''), tit_str_add='', anot=None):

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)

    colors = get_colors(len(obj.c_to_ind))

    for i,c in enumerate(obj.c_to_ind):
        plt.plot(*zip(*data[obj.c_to_ind[c]]), marker='o', color=colors[i], ls='', ms=4., mew=1.0, alpha=0.8, mec='none')

    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    #plt.grid(True)
    #ax.set_axis_bgcolor('grey')
    plt.xlabel(axis_str[0])
    plt.ylabel(axis_str[1])
    fig.savefig(filename, format='png')

    plt.close() 
Example #5
Source File: visualizing_data.py    From data-science-from-scratch with MIT License 6 votes vote down vote up
def make_chart_misleading_y_axis(plt, mislead=True):

    mentions = [500, 505]
    years = [2013, 2014]

    plt.bar([2012.6, 2013.6], mentions, 0.8)
    plt.xticks(years)
    plt.ylabel("# of times I heard someone say 'data science'")

    # if you don't do this, matplotlib will label the x-axis 0, 1
    # and then add a +2.013e3 off in the corner (bad matplotlib!)
    plt.ticklabel_format(useOffset=False)

    if mislead:
        # misleading y-axis only shows the part above 500
        plt.axis([2012.5,2014.5,499,506])
        plt.title("Look at the 'Huge' Increase!")
    else:
        plt.axis([2012.5,2014.5,0,550])
        plt.title("Not So Huge Anymore.")       
    plt.show() 
Example #6
Source File: plot_output.py    From alphacsc with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_convergence_curve(data, info, dirname):
    # plot the convergence curve
    eps = 1e-6

    # compute the best pobj over all methods
    best_pobj = np.min([np.min(r['pobj']) for _, r in data])

    fig = plt.figure("convergence", figsize=(12, 12))
    plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))

    color_cycle = itertools.cycle(COLORS)
    for (args, res), color in zip(data, color_cycle):
        times = list(np.cumsum(res['times']))
        plt.loglog(
            times, (res['pobj'] - best_pobj) / best_pobj + eps, '.-',
            label=get_label(info['grid_key'], args), color=color,
            linewidth=2)
    plt.xlabel('Time (s)', fontsize=24)
    plt.ylabel('Objective value', fontsize=24)
    ncol = int(np.ceil(len(data) / 10))
    plt.legend(ncol=ncol, fontsize=24)

    plt.gca().tick_params(axis='x', which='both', bottom=False, top=False)
    plt.gca().tick_params(axis='y', which='both', left=False, right=False)
    plt.tight_layout()
    plt.grid(True)
    figname = "{}/convergence.png".format(dirname)
    fig.savefig(figname, dpi=150) 
Example #7
Source File: plot.py    From reaver with MIT License 6 votes vote down vote up
def plot_from_summaries(summaries_path, title=None, samples_per_update=512, updates_per_log=100):
    acc = EventAccumulator(summaries_path)
    acc.Reload()

    rews_mean = np.array([s[2] for s in acc.Scalars('Rewards/Mean')])
    rews_std = np.array([s[2] for s in acc.Scalars('Rewards/Std')])
    x = samples_per_update * updates_per_log * np.arange(0, len(rews_mean))

    if not title:
        title = summaries_path.split('/')[-1].split('_')[0]

    plt.plot(x, rews_mean)
    plt.fill_between(x, rews_mean - rews_std, rews_mean + rews_std, alpha=0.2)
    plt.xlabel('Samples')
    plt.ylabel('Episode Rewards')
    plt.title(title)
    plt.xlim([0, x[-1]+1])
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 
Example #8
Source File: data.py    From osm-python-tools with GNU General Public License v3.0 5 votes vote down vote up
def plot(self, *args, plotTitle=None, showPlot=True, filename=None, **kwargs):
        dataFrame = self.select(*args, **kwargs).getDataFrame()
        if dataFrame.index.nlevels > 1:
            self._raiseException('Please restrict the dataset such that only one index is left.')
        ax = dataFrame.plot()
        plt.ticklabel_format(useOffset=False, style='plain')
        plt.title(plotTitle if plotTitle else kwargs)
        self.showPlot(showPlot=showPlot, filename=filename)
        return self 
Example #9
Source File: genesis_plot.py    From ocelot with GNU General Public License v3.0 5 votes vote down vote up
def subfig_z_energy_espread_bunching(ax_energy, g, zi=None, x_units='um', legend=False, *args, **kwargs):
    ax_energy.clear()
    number_ticks = 6

    if x_units == 'um':
        ax_energy.set_xlabel(r's [$\mu$m]')
        x = g.t * speed_of_light * 1.0e-15 * 1e6
    elif x_units == 'fs':
        ax_energy.set_xlabel(r't [fs]')
        x = g.t
    else:
        raise ValueError('Unknown parameter x_units (should be um or fs)')

    if zi == None:
        zi = -1

    ax_energy.plot(x, g.el_energy[:, zi] * m_e_GeV, 'b-', x, (g.el_energy[:, zi] + g.el_e_spread[:, zi]) * m_e_GeV,
                   'r--', x, (g.el_energy[:, zi] - g.el_e_spread[:, zi]) * m_e_GeV, 'r--')
    ax_energy.set_ylabel(r'$E\pm\sigma_E$ [GeV]')
    # ax_energy.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3), useOffset=False)
    ax_energy.ticklabel_format(useOffset=False, style='plain')
    ax_energy.grid(kwargs.get('grid', True))
    # plt.yticks(plt.yticks()[0][0:-1])

    ax_bunching = ax_energy.twinx()
    ax_bunching.plot(x, g.bunching[:, zi], 'grey', linewidth=0.5)
    ax_bunching.set_ylabel('Bunching')
    ax_bunching.set_ylim(ymin=0)
    ax_bunching.grid(False)

    ax_energy.yaxis.major.locator.set_params(nbins=number_ticks)
    ax_bunching.yaxis.major.locator.set_params(nbins=number_ticks)

    ax_energy.tick_params(axis='y', which='both', colors='b')
    ax_energy.yaxis.label.set_color('b')

    ax_bunching.tick_params(axis='y', which='both', colors='grey')
    ax_bunching.yaxis.label.set_color('grey')

    ax_energy.set_xlim([x[0], x[-1]]) 
Example #10
Source File: genesis_plot.py    From ocelot with GNU General Public License v3.0 5 votes vote down vote up
def subfig_z_energy_espread(ax_energy, g, zi=None, x_units='um', legend=False, *args, **kwargs):
    ax_energy.clear()
    number_ticks = 6

    if x_units == 'um':
        ax_energy.set_xlabel(r's [$\mu$m]')
        x = g.t * speed_of_light * 1.0e-15 * 1e6
    elif x_units == 'fs':
        ax_energy.set_xlabel(r't [fs]')
        x = g.t
    else:
        raise ValueError('Unknown parameter x_units (should be um or fs)')

    if zi == None:
        zi = -1

    ax_energy.plot(x, g.el_energy[:, zi] * m_e_GeV, 'b-', x, (g.el_energy[:, zi] + g.el_e_spread[:, zi]) * m_e_GeV,
                   'r--', x, (g.el_energy[:, zi] - g.el_e_spread[:, zi]) * m_e_GeV, 'r--')
    ax_energy.set_ylabel(r'$E\pm\sigma_E$ [GeV]')
    # ax_energy.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3), useOffset=False)
    ax_energy.ticklabel_format(useOffset=False, style='plain')
    ax_energy.grid(kwargs.get('grid', True))
    # plt.yticks(plt.yticks()[0][0:-1])

    ax_energy.yaxis.major.locator.set_params(nbins=number_ticks)
    ax_energy.tick_params(axis='y', which='both', colors='b')
    ax_energy.yaxis.label.set_color('b')

    ax_energy.set_xlim([x[0], x[-1]]) 
Example #11
Source File: plot_process_info.py    From benchmarks with Apache License 2.0 5 votes vote down vote up
def visualize(file_path):

  entries = []
  with open(file_path) as f:
    entries = [json.loads(line) for line in f.readlines() if line.strip()]

  if not entries:
    print('There is no data in file {}'.format(file_path))
    return

  pdf = backend_pdf.PdfPages("process_info.pdf")
  idx = 0
  names = [name for name in entries[0].keys() if name != 'time']
  times = [entry['time'] for entry in entries]

  for name in names:
    values = [entry[name] for entry in entries]
    fig = plt.figure()
    ax = plt.gca()
    ax.yaxis.set_major_formatter(tick.ScalarFormatter(useMathText=True))
    plt.ticklabel_format(style='sci', axis='y', scilimits=(-2,3))
    plt.plot(times, values, colors[idx % len(colors)], marker='x', label=name)
    plt.xlabel('Time (sec)')
    plt.ylabel(name)
    plt.ylim(ymin=0)
    plt.legend(loc = 'upper left')
    pdf.savefig(fig)
    idx += 1

  plt.show()
  pdf.close()
  print('Generated process_info.pdf from {}'.format(file_path)) 
Example #12
Source File: plot.py    From firedup with MIT License 5 votes vote down vote up
def plot_data(
    data,
    xaxis="Epoch",
    value="AverageEpRet",
    condition="Condition1",
    smooth=1,
    **kwargs
):
    if smooth > 1:
        """
        smooth data with moving window average.
        that is,
            smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])
        where the "smooth" param is width of that window (2k+1)
        """
        y = np.ones(smooth)
        for datum in data:
            x = np.asarray(datum[value])
            z = np.ones(len(x))
            smoothed_x = np.convolve(x, y, "same") / np.convolve(z, y, "same")
            datum[value] = smoothed_x

    if isinstance(data, list):
        data = pd.concat(data, ignore_index=True)
    sns.set(style="darkgrid", font_scale=1.5)
    sns.lineplot(data=data, x=xaxis, y=value, hue=condition, ci="sd", **kwargs)
    plt.legend(
        loc="upper center", ncol=3, handlelength=1, borderaxespad=0.0, prop={"size": 13}
    ).set_draggable(True)

    xscale = np.max(np.asarray(data[xaxis])) > 5e3
    if xscale:
        # Just some formatting niceness: x-axis scale in scientific notation if max x is large
        plt.ticklabel_format(style="sci", axis="x", scilimits=(0, 0))

    plt.tight_layout(pad=0.5) 
Example #13
Source File: wrappers.py    From treetime with MIT License 5 votes vote down vote up
def print_save_plot_skyline(tt, n_std=2.0, screen=True, save='', plot=''):
    if plot:
        import matplotlib.pyplot as plt

    skyline, conf = tt.merger_model.skyline_inferred(gen=50, confidence=n_std)
    if save: fh = open(save, 'w', encoding='utf-8')
    header1 = "Skyline assuming 50 gen/year and approximate confidence bounds (+/- %f standard deviations of the LH)\n"%n_std
    header2 = "date \tN_e \tlower \tupper"
    if screen: print('\t'+header1+'\t'+header2)
    if save: fh.write("#"+ header1+'#'+header2+'\n')
    for (x,y, y1, y2) in zip(skyline.x, skyline.y, conf[0], conf[1]):
        if screen: print("\t%1.1f\t%1.1f\t%1.1f\t%1.1f"%(x,y, y1, y2))
        if save: fh.write("%1.1f\t%1.1f\t%1.1f\t%1.1f\n"%(x,y, y1, y2))

    if save:
        print("\n --- written skyline to %s\n"%save)
        fh.close()

    if plot:
        plt.figure()
        plt.fill_between(skyline.x, conf[0], conf[1], color=(0.8, 0.8, 0.8))
        plt.plot(skyline.x, skyline.y, label='maximum likelihood skyline')
        plt.yscale('log')
        plt.legend()
        plt.ticklabel_format(axis='x',useOffset=False)
        plt.savefig(plot) 
Example #14
Source File: fitVisc.py    From PyLAT with GNU General Public License v3.0 4 votes vote down vote up
def fitvisc(self,time,visc,stddev,plot, popt2):
        #popt2=[1e-3,1.5e-1,1e2,1e3]
        #popt2=[2e-3,5e-2,2e3,2e2]
        #popt2=[1e-4,1e2]
        foundcutoff = False
        foundstart = False
        start = 1
        while not foundstart and start<len(visc):
            if time[start] > 2000:
                foundstart = True
            else:
                start+=1
        cut = 1
        while not foundcutoff and cut<len(visc):
            if stddev[cut] > 0.4*visc[cut]:
                foundcutoff = True
            else:
                cut += 1
        #cut = len(visc)
        #popt2,pcov2 = optimize.curve_fit(self.doubexp, time[start:cut], visc[start:cut],maxfev=1000000,p0=popt2, sigma=stddev[start:cut])
        popt2,pcov2 = optimize.curve_fit(self.doubexp, time[start:cut], visc[start:cut],maxfev=1000000,p0=popt2, sigma=stddev[start:cut],bounds=(0,[np.inf,1,np.inf,np.inf]))
        
        fit = []
        fit1 = []
        fit2 = []
        for t in time:
            fit.append(self.doubexp(t,*popt2))
            fit1.append(self.doubexp1(t,*popt2))
            fit2.append(self.doubexp2(t,*popt2))
        Value = popt2[0]*popt2[1]*popt2[2]+popt2[0]*(1-popt2[1])*popt2[3]
        #Value = popt2[0]
        
        if plot:
            timep = time/1000000
            from matplotlib import pyplot as plt
            from matplotlib import rcParams
            rcParams.update({'font.size':14})
            print('Viscosity estimate is {}'.format(Value))
            print('A={}, alpha={}, tau1={}, tau2={}'.format(popt2[0],popt2[1],popt2[2],popt2[3]))
            print('Time cutoff is {}'.format(time[cut]))
            plt.ticklabel_format(axis='x', style='sci', scilimits=(0,0))
            plt.plot(timep[:len(visc)],visc,label='Viscosity')
            plt.plot(timep[:len(fit)],fit,label='Double Exponential fit')
            plt.plot(timep[:len(fit1)],fit1,label=r'Contribution of $\tau_1$')
            plt.plot(timep[:len(fit2)],fit2,label=r'Contribution of $\tau_2$')
            plt.axvline(timep[cut])
            plt.ylabel('Viscosity (mPa*s)')
            plt.xlabel('Time (ns)')
            plt.legend()
            plt.show()
        
        return(Value) 
Example #15
Source File: _tqdm_gui.py    From Tautulli with GNU General Public License v3.0 4 votes vote down vote up
def __init__(self, *args, **kwargs):
        import matplotlib as mpl
        import matplotlib.pyplot as plt
        from collections import deque
        kwargs['gui'] = True

        super(tqdm_gui, self).__init__(*args, **kwargs)

        # Initialize the GUI display
        if self.disable or not kwargs['gui']:
            return

        warn('GUI is experimental/alpha', TqdmExperimentalWarning)
        self.mpl = mpl
        self.plt = plt
        self.sp = None

        # Remember if external environment uses toolbars
        self.toolbar = self.mpl.rcParams['toolbar']
        self.mpl.rcParams['toolbar'] = 'None'

        self.mininterval = max(self.mininterval, 0.5)
        self.fig, ax = plt.subplots(figsize=(9, 2.2))
        # self.fig.subplots_adjust(bottom=0.2)
        if self.total:
            self.xdata = []
            self.ydata = []
            self.zdata = []
        else:
            self.xdata = deque([])
            self.ydata = deque([])
            self.zdata = deque([])
        self.line1, = ax.plot(self.xdata, self.ydata, color='b')
        self.line2, = ax.plot(self.xdata, self.zdata, color='k')
        ax.set_ylim(0, 0.001)
        if self.total:
            ax.set_xlim(0, 100)
            ax.set_xlabel('percent')
            self.fig.legend((self.line1, self.line2), ('cur', 'est'),
                            loc='center right')
            # progressbar
            self.hspan = plt.axhspan(0, 0.001,
                                     xmin=0, xmax=0, color='g')
        else:
            # ax.set_xlim(-60, 0)
            ax.set_xlim(0, 60)
            ax.invert_xaxis()
            ax.set_xlabel('seconds')
            ax.legend(('cur', 'est'), loc='lower left')
        ax.grid()
        # ax.set_xlabel('seconds')
        ax.set_ylabel((self.unit if self.unit else 'it') + '/s')
        if self.unit_scale:
            plt.ticklabel_format(style='sci', axis='y',
                                 scilimits=(0, 0))
            ax.yaxis.get_offset_text().set_x(-0.15)

        # Remember if external environment is interactive
        self.wasion = plt.isinteractive()
        plt.ion()
        self.ax = ax 
Example #16
Source File: plot.py    From teachDeepRL with MIT License 4 votes vote down vote up
def plot_data(data, xaxis='Epoch', value="AverageEpRet", condition="Condition1", smooth=1, **kwargs):
    if smooth > 1:
        """
        smooth data with moving window average.
        that is,
            smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])
        where the "smooth" param is width of that window (2k+1)
        """
        y = np.ones(smooth)
        for datum in data:
            x = np.asarray(datum[value])
            z = np.ones(len(x))
            smoothed_x = np.convolve(x,y,'same') / np.convolve(z,y,'same')
            datum[value] = smoothed_x

    if isinstance(data, list):
        data = pd.concat(data, ignore_index=True)
    sns.set(style="darkgrid", font_scale=1.5)
    sns.tsplot(data=data, time=xaxis, value=value, unit="Unit", condition=condition, ci='sd', **kwargs)
    """
    If you upgrade to any version of Seaborn greater than 0.8.1, switch from 
    tsplot to lineplot replacing L29 with:

        sns.lineplot(data=data, x=xaxis, y=value, hue=condition, ci='sd', **kwargs)

    Changes the colorscheme and the default legend style, though.
    """
    plt.legend(loc='best').draggable()

    """
    For the version of the legend used in the Spinning Up benchmarking page, 
    swap L38 with:

    plt.legend(loc='upper center', ncol=6, handlelength=1,
               mode="expand", borderaxespad=0., prop={'size': 13})
    """

    xscale = np.max(np.asarray(data[xaxis])) > 5e3
    if xscale:
        # Just some formatting niceness: x-axis scale in scientific notation if max x is large
        plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

    plt.tight_layout(pad=0.5) 
Example #17
Source File: plotting.py    From snn_toolbox with MIT License 4 votes vote down vote up
def plot_spikecount_vs_time(spiketrains_n_b_l_t, duration, dt, path=None):
    """Plot total spikenumber over time.

    Parameters
    ----------

    spiketrains_n_b_l_t:
    duration: int
        Simulation duration.
    dt: float
        Simulation time resolution.
    path: Optional[str]
        Where to save the output.
    """

    # batch time dimensions
    b_t_shape = (spiketrains_n_b_l_t[0][0].shape[0],
                 spiketrains_n_b_l_t[0][0].shape[-1])
    spikecounts_b_t = np.zeros(b_t_shape)
    for n in range(len(spiketrains_n_b_l_t)):  # Loop over layers
        spiketrains_b_l_t = np.not_equal(spiketrains_n_b_l_t[n][0], 0)
        reduction_axes = tuple(np.arange(1, spiketrains_b_l_t.ndim-1))
        spikecounts_b_t += np.sum(spiketrains_b_l_t, reduction_axes)
    cum_spikecounts_b_t = np.cumsum(spikecounts_b_t, 1)

    plt.figure()
    plt.title('SNN spike count')
    time = np.arange(0, duration, dt)
    cum_spikecounts_t = np.mean(cum_spikecounts_b_t, 0)
    std_t = np.std(cum_spikecounts_b_t, 0)
    plt.plot(time, cum_spikecounts_t, '.b')
    plt.fill_between(time, cum_spikecounts_t-std_t, cum_spikecounts_t+std_t,
                     alpha=0.1, color='b')
    plt.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
    plt.ylim(0, None)
    plt.ylabel('# spikes')
    plt.xlabel('Simulation time [ms] in steps of {} ms.'.format(dt))
    if path is not None:
        filename = 'Total_spike_count'
        plt.savefig(os.path.join(path, filename), bbox_inches='tight')
    else:
        plt.show()
    plt.close() 
Example #18
Source File: main.py    From trading-bitcoin-with-reinforcement-learning with MIT License 4 votes vote down vote up
def plot_result(rwd_lst, open_price, close_price):
    plt.subplot(211)

    # Baseline1: BnH
    ret = np.log(close_price / close_price.shift(1))
    ret.fillna(0, inplace=True)

    bnh = np.cumsum(ret.values) * 2
    plt.plot(bnh, label='BnH')

    # Baseline2: Momentum
    log_ret = np.log(close_price / open_price)

    sma = close_price.rolling(30, min_periods=1).mean()
    signal = (close_price > sma).shift(1).astype(float) * 4  # shift by 1 since we trade on the next opening price
    signal.fillna(0, inplace=True)

    mmt = np.cumsum(log_ret.values * signal.values)  # convert to cum. simple return
    plt.plot(mmt, label='MMT')

    # RL agent performance
    rl = np.cumsum(rwd_lst)

    plt.xticks(())
    plt.ylabel('Cumulative Log-Returns')
    plt.plot(rl, label='RL')
    plt.legend()

    def mdd(x):
        max_val = None
        temp = []
        for t in x:
            if max_val is None or t > max_val:
                max_val = t
            temp.append(t - max_val)
        return temp

    plt.subplot(212)
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
    plt.xlabel('Timesteps')
    plt.ylabel('MDD')
    plt.plot(mdd(bnh))
    plt.plot(mdd(mmt))
    plt.plot(mdd(rl))
    plt.show() 
Example #19
Source File: plot.py    From spinningup with MIT License 4 votes vote down vote up
def plot_data(data, xaxis='Epoch', value="AverageEpRet", condition="Condition1", smooth=1, **kwargs):
    if smooth > 1:
        """
        smooth data with moving window average.
        that is,
            smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])
        where the "smooth" param is width of that window (2k+1)
        """
        y = np.ones(smooth)
        for datum in data:
            x = np.asarray(datum[value])
            z = np.ones(len(x))
            smoothed_x = np.convolve(x,y,'same') / np.convolve(z,y,'same')
            datum[value] = smoothed_x

    if isinstance(data, list):
        data = pd.concat(data, ignore_index=True)
    sns.set(style="darkgrid", font_scale=1.5)
    sns.tsplot(data=data, time=xaxis, value=value, unit="Unit", condition=condition, ci='sd', **kwargs)
    """
    If you upgrade to any version of Seaborn greater than 0.8.1, switch from 
    tsplot to lineplot replacing L29 with:

        sns.lineplot(data=data, x=xaxis, y=value, hue=condition, ci='sd', **kwargs)

    Changes the colorscheme and the default legend style, though.
    """
    plt.legend(loc='best').set_draggable(True)
    #plt.legend(loc='upper center', ncol=3, handlelength=1,
    #           borderaxespad=0., prop={'size': 13})

    """
    For the version of the legend used in the Spinning Up benchmarking page, 
    swap L38 with:

    plt.legend(loc='upper center', ncol=6, handlelength=1,
               mode="expand", borderaxespad=0., prop={'size': 13})
    """

    xscale = np.max(np.asarray(data[xaxis])) > 5e3
    if xscale:
        # Just some formatting niceness: x-axis scale in scientific notation if max x is large
        plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

    plt.tight_layout(pad=0.5)