Python matplotlib.pyplot.locator_params() Examples

The following are code examples for showing how to use matplotlib.pyplot.locator_params(). They are from open source Python projects. You can vote up the examples you like or vote down the ones you don't like.

Example 1
Project: razzy-spinner   Author: rafasashi   File: util.py    GNU General Public License v3.0 9 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example 2
Project: OpenBottle   Author: xiaozhuchacha   File: util.py    MIT License 6 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example 3
Project: OpenBottle   Author: xiaozhuchacha   File: util.py    MIT License 6 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example 4
Project: Health-Checker   Author: KriAga   File: util.py    MIT License 6 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example 5
Project: spatial_patterns   Author: sim-web   File: plotting.py    GNU General Public License v3.0 6 votes vote down vote up
def plot_grid_axes_angles_histogram(self, grid_axes_angles, end_frame=-1):
        """
        Plots histogram of grid axes angles

        Note: This is not a stand alone plotting function
        """
        my_bins = np.linspace(-np.pi/1.5, np.pi/1.5, 50)
        hist_kwargs = {'alpha': 0.5, 'bins': my_bins, 'lw': 0.}
        # init_gs = grid_axes_angles[:, 0]
        for axis in [1, 2, 3]:
            angles = grid_axes_angles[:, end_frame, axis-1]
            color = color_cycle_blue3[axis-1]
            plt.hist(
                angles[~np.isnan(angles)], color=color, **hist_kwargs)
        # plt.locator_params(axis='y', tight=True, nbins=4)
        ax = plt.gca()
        tick_angles = np.array([-np.pi/2, -np.pi/3, -np.pi/6,
                                0, np.pi/6, np.pi/2, np.pi/3])
        ax.set_xticks(tick_angles)
        ax.set_xticklabels(tick_angles * 180 / np.pi) 
Example 6
Project: evil   Author: txt   File: demoMatplot.py    The Unlicense 6 votes vote down vote up
def lines(xlabel, ylabel, title, f="lines.png",
          xsize=5,ysize=5,lines=[]): 
  width = len(lines[0][1:])
  xs = [x for x in xrange(1,width+1)] 
  plt.figure(figsize=(xsize,ysize))
  plt.xlabel(xlabel)
  plt.ylabel(ylabel) 
  for line in lines: 
    plt.plot(xs,  line[1:],
                 label = line[0])
   
  plt.locator_params(nbins=len(xs))
  plt.title(title)
  plt.legend()
  plt.tight_layout()
  plt.savefig(f) 
Example 7
Project: Codigo-Network   Author: davinci26   File: comparison_plotter.py    GNU General Public License v3.0 6 votes vote down vote up
def plot(filepath,label,colour_,limit, trendline):
    user_no, delay_avg, delay_std, delay_max, delay_min,_ = parse_file(filepath)
    if trendline:
        y_trendline = linear_reg(user_no,delay_avg)
        plt.plot(user_no, y_trendline,'-', color=colour_, alpha=0.2, label= label + " Trendline")

    plt.plot(user_no[:limit], delay_avg[:limit], 'o--', color=colour_, label=label + " Average delay",ms=3) #, yerr = delay_std, fmt='o' )
    plt.plot(user_no[:limit], delay_max[:limit], '--',  color=colour_, label=label + " Max delay",alpha=0.3) 
    plt.plot(user_no[:limit], delay_min[:limit], '--',  color=colour_, label=label + " Min delay",alpha=0.3) 
    plt.fill_between(user_no[:limit],
                     delay_max[:limit],
                     delay_min[:limit],
                     color =colour_,
                     alpha=0.2 )
    plt.locator_params(nbins=14)
    plt.xlabel('Number of Nodes')
    plt.ylabel('Average delay[sec]')
    print(type(limit))
    plt.xlim(0,limit) 
Example 8
Project: NICERsoft   Author: paulray   File: plotutils.py    MIT License 6 votes vote down vote up
def plot_total_count_hist(etable, ax_rate, ax_counts):
    'Plots event count per ID as a histogram with event count and countrate on y axes'

    num_events, colors = hist_use(etable)

    tc = ax_counts.bar(IDS, num_events, color = colors)

    ax_rate.set_ylabel('c/s')
    cntmin, cntmax = ax_counts.get_ylim()
    ax_rate.set_ylim((cntmin/etable.meta['EXPOSURE'],cntmax/etable.meta['EXPOSURE']))
    #countrate.set_ylim([np.min(rate)-20,np.max(rate)+20])

    ax_counts.set_xlabel('DET_ID')
    ax_counts.set_ylabel('# of Events')
    plot.locator_params(nticks = 20)
    plot.title('Total (Filtered) Event Count by Detector')
    #total_counts.set_ylim([np.min(num_events)-20, np.max(num_events)+20])

    return num_events

#----------------------THIS MAKES THE GRAYSCALE ID/EVENT COUNT CHART--------------------- 
Example 9
Project: honours_project   Author: JFriel   File: util.py    GNU General Public License v3.0 6 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example 10
Project: honours_project   Author: JFriel   File: util.py    GNU General Public License v3.0 6 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example 11
Project: aop-helpFinder   Author: jecarvaill   File: util.py    GNU General Public License v3.0 6 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example 12
Project: actions-for-actions   Author: gsig   File: oraclesplot.py    GNU General Public License v3.0 6 votes vote down vote up
def finalize_plot(allticks,handles):
    plt.locator_params(axis='x', nticks=Noracles,nbins=Noracles)
    plt.yticks([x[0] for x in allticks], [x[1] for x in allticks])
    plt.tick_params(
        axis='y',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        left='off',      # ticks along the bottom edge are off
        right='off'         # ticks along the top edge are off
    )
    if LEGEND:
        plt.legend([h[0] for h in handles],seriesnames,
                   loc='upper right',borderaxespad=0.,
                   ncol=1,fontsize=10,numpoints=1)
    plt.gcf().tight_layout()


######################################################
# Data processing 
Example 13
Project: serverless-chatbots-workshop   Author: datteswararao   File: util.py    Apache License 2.0 6 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example 14
Project: serverless-chatbots-workshop   Author: datteswararao   File: util.py    Apache License 2.0 6 votes vote down vote up
def _show_plot(x_values, y_values, x_labels=None, y_labels=None):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('The plot function requires matplotlib to be installed.'
                         'See http://matplotlib.org/')

    plt.locator_params(axis='y', nbins=3)
    axes = plt.axes()
    axes.yaxis.grid()
    plt.plot(x_values, y_values, 'ro', color='red')
    plt.ylim(ymin=-1.2, ymax=1.2)
    plt.tight_layout(pad=5)
    if x_labels:
        plt.xticks(x_values, x_labels, rotation='vertical')
    if y_labels:
        plt.yticks([-1, 0, 1], y_labels, rotation='horizontal')
    # Pad margins so that markers are not clipped by the axes
    plt.margins(0.2)
    plt.show()

#////////////////////////////////////////////////////////////
#{ Parsing and conversion functions
#//////////////////////////////////////////////////////////// 
Example 15
Project: KerasPlotCallback   Author: prvccy   File: plotcallback.py    MIT License 5 votes vote down vote up
def on_epoch_end(self, epoch, logs={}):
        clear_output(wait=True)
        predicts = dict()
        # TODO handle no testset error
        predicts['test_y'] = self.validation_data[1]
        predicts['test_pred'] = self.model.predict(self.validation_data[0])
        
        # first run data generator
        for dg in self.datagens:
            dg.on_epoch_end(epoch, logs, predicts)
            self.data.update(dg.getData())

        # then plot
        numofsubplot = 0
        for p in self.data:
            if 'plotclass' in self.data[p]:
                numofsubplot += 1

        self.rows = numofsubplot//self.plotcols + int(numofsubplot%self.plotcols>0)
        fig, axes = plt.subplots(nrows=self.rows, ncols=self.plotcols, figsize=(self.subplotsize[0] * self.plotcols, self.subplotsize[1] * self.rows)) #, sharex=True 
        axs = axes.flatten()
        
        plt.xticks(range(1, self.params['epochs'] + 1))
        plt.xlim(1, self.params['epochs'])
        plt.locator_params(axis='x', nbins=10)
        
        subplot_No = 0
        for p in self.data:
            if 'plotclass' in self.data[p]:
                subplot = eval(self.data[p]['plotclass']+'()')
                subplot.plot(axs[subplot_No], self.data[p], p)
                subplot_No += 1
                
        plt.tight_layout()
        plt.show();
        # last run controller
        for c in self.controllers:
            c.on_epoch_end(self.data)
            if epoch == self.params['epochs']:
                c.on_training_end(self.data) 
Example 16
Project: news-popularity-prediction   Author: MKLab-ITI   File: slashdot_results.py    Apache License 2.0 5 votes vote down vote up
def make_slashdot_figures(output_path_prefix, method_name_list, slashdot_mse, slashdot_jaccard, slashdot_k_list):
    sns.set_style("darkgrid")
    sns.set_context("paper")

    translator = get_method_name_to_legend_name_dict()

    slashdot_k_list = list(slashdot_k_list)

    fig, axes = plt.subplots(1, 2, sharex=True)

    axes[0].set_title("SlashDot Comments")
    axes[1].set_title("SlashDot Users")

    plt.locator_params(nbins=8)

    # Comments
    for m, method in enumerate(method_name_list):
        axes[0].set_ylabel("MSE")
        axes[0].set_xlabel("Lifetime (sec)")
        axes[0].plot(slashdot_k_list[1:],
                     handle_nan(slashdot_mse[method]["comments"].mean(axis=1))[1:],
                     label=translator[method])

    # Users
    for m, method in enumerate(method_name_list):
        # axes[1].set_ylabel("MSE")
        axes[1].set_xlabel("Lifetime (sec)")
        axes[1].plot(slashdot_k_list[1:],
                     handle_nan(slashdot_mse[method]["users"].mean(axis=1))[1:],
                     label=translator[method])


    axes[1].legend(loc="upper right")

    # plt.show()
    plt.savefig(output_path_prefix + "_mse_slashdot_SNOW" + ".png", format="png")
    plt.savefig(output_path_prefix + "_mse_slashdot_SNOW" + ".eps", format="eps") 
Example 17
Project: news-popularity-prediction   Author: MKLab-ITI   File: slashdot_results.py    Apache License 2.0 5 votes vote down vote up
def make_barrapunto_figures(output_path_prefix, method_name_list, barrapunto_mse, barrapunto_jaccard, barrapunto_k_list):
    sns.set_style("darkgrid")
    sns.set_context("paper")

    translator = get_method_name_to_legend_name_dict()

    barrapunto_k_list = list(barrapunto_k_list)

    fig, axes = plt.subplots(1, 2, sharex=True)

    axes[0].set_title("BarraPunto Comments")
    axes[1].set_title("BarraPunto Users")

    plt.locator_params(nbins=8)

    # Comments
    for m, method in enumerate(method_name_list):
        axes[0].set_ylabel("MSE")
        axes[0].set_xlabel("Lifetime (sec)")
        axes[0].plot(barrapunto_k_list[1:],
                        handle_nan(barrapunto_mse[method]["comments"].mean(axis=1))[1:],
                        label=translator[method])

    # Users
    for m, method in enumerate(method_name_list):
        # axes[1].set_ylabel("MSE")
        axes[1].set_xlabel("Lifetime (sec)")
        axes[1].plot(barrapunto_k_list[1:],
                        handle_nan(barrapunto_mse[method]["users"].mean(axis=1))[1:],
                        label=translator[method])


    axes[1].legend(loc="upper right")

    # plt.show()
    plt.savefig(output_path_prefix + "_mse_barrapunto_SNOW" + ".png", format="png")
    plt.savefig(output_path_prefix + "_mse_barrapunto_SNOW" + ".eps", format="eps") 
Example 18
Project: spatial_patterns   Author: sim-web   File: plotting.py    GNU General Public License v3.0 5 votes vote down vote up
def plot_list(fig, plot_list, automatic_arrangement=True):
    """
    Takes a list of lambda forms of plot functions and plots them such that
    no more than four rows are used to assure readability

    Note: Instead of lambda forms, you can also use functools.partial to
    pass the not yet evaluated functions.
    """
    n_plots = len(plot_list)
    # A title for the entire figure (super title)
    # fig.suptitle('Time evolution of firing rates', y=1.1)
    if automatic_arrangement:
        for n, p in enumerate(plot_list, start=1):
                # Check if function name contains 'polar'
                # is needed for the sublotting.
                # A silly hack, that only works if every function that should
                # use polar plotting actually has the string 'polar' in its
                # name.
                if 'polar' in str(p.func):
                    polar = True
                else:
                    polar = False
                if n_plots < 4:
                    fig.add_subplot(n_plots, 1, n, polar=polar)
                    # plt.locator_params(axis='y', nbins=4)
                    # plt.ylabel('firing rate')
                else:
                    fig.add_subplot(math.ceil(n_plots/2.), 2, n, polar=polar)
                    # plt.locator_params(axis='y', nbins=4)
                p()

    else:
        # plot_inputs_rates_heatmap(plot_list=plot_list)
        # plot_output_rates_and_gridspacing_vs_parameter(plot_list=plot_list)
        # plot_input_initrate_finalrate_correlogram(plot_list)
        plot_input_initrate_correlogram_finalrate_correlogram(plot_list)
        # plot_input_rate_correlogram_hd_tuning(plot_list) 
Example 19
Project: spatial_patterns   Author: sim-web   File: plotting.py    GNU General Public License v3.0 5 votes vote down vote up
def plot_grid_score_histogram(self, grid_scores, end_frame=-1,
                                  show_initial_fraction=True):
        """
        Plots histogram of grid scores

        Note: This is not a stand alone plotting function

        Parameters
        ----------
        grid_scores : ndarray
            Contains grids scores for all seeds and times
        """
        # my_bins = np.linspace(-1.2, 1.4, 27)
        my_bins = np.linspace(-1.2, 1.4, 14)
        hist_kwargs = {'alpha': 1.0, 'bins': my_bins, 'lw': 1,
                       'histtype': 'step'}
        init_gs = grid_scores[:, 0]
        final_gs = grid_scores[:, end_frame]
        colors = {'init': color_cycle_blue4[2], 'final': color_cycle_blue4[0]}
        # colors = {'init': '0.7', 'final': '0.1'}
        n_init, bins_init, p = plt.hist(
            init_gs[~np.isnan(init_gs)], color=colors['init'], **hist_kwargs)
        n_final, bins_final, p = plt.hist(
            final_gs[~np.isnan(final_gs)], color=colors['final'], **hist_kwargs)
        gc_percentage_init = '{0}%'.format(
            int(100*np.sum(n_init[bins_init[:-1]>=0]) / np.sum(n_init)))
        gc_percentage_final = '{0}%'.format(
            int(100*np.sum(n_final[bins_final[:-1]>=0]) / np.sum(n_final)))
        ax = plt.gca()
        if show_initial_fraction:
            ax.text(0.05, 0.95, gc_percentage_init, horizontalalignment='left',
                    verticalalignment='top', transform=ax.transAxes,
                    color=colors['init'])
        ax.text(0.95, 0.95, gc_percentage_final, horizontalalignment='right',
                verticalalignment='top', transform=ax.transAxes,
                color=colors['final'])
        plt.xlim([-1.2, 1.4])
        plt.xticks([-1.0, 0,  1.0])
        plt.locator_params(axis='y', tight=True, nbins=2)
        # plt.yticks([0, 100]) 
Example 20
Project: SPaT_Prediction   Author: priscillaboyd   File: Plotter.py    Apache License 2.0 5 votes vote down vote up
def plot_phase_vs_dt(df, seconds, analysis_folder):
    """
    Plot phase versus date/time, taking data frame and number of seconds to plot.

    :param object df: dataframe
    :param int seconds: number of seconds to plot
    :param string analysis_folder: analysis folder location
    """

    # define x values and data source (sliced from data frame)
    x = [0, 1, 2, 3]
    df = df[:seconds]

    # define index based on date_time column from the data
    df.set_index(['Date_Time'], inplace=True)

    # plot the data
    df.plot()

    # configure plot
    plt.locator_params(axis='x', nbins=10)
    labels = ['Red', 'Red + Amber', 'Amber', 'Green']
    plt.yticks(x, labels, rotation='vertical')
    plt.xlabel('Time')
    plt.ylabel('Phase')
    plt.yticks(np.arange(0, 4, 1))

    # save plot to file
    plt.savefig(analysis_folder + 'phase_vs_dt.png')

    # display plot
    plt.show() 
Example 21
Project: Monitor   Author: LSSTDESC   File: monitor.py    BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def visualize_seeing_curve(self):

        if self.seeing_curve is None:
            raise ValueError('No lightcurve yet. Use build_lightcurve first.')

        n_col = 2
        n_row = len(self.filter_list)
        fig = plt.figure(figsize=(4. * n_col, 3. * n_row))

        color = ['b', 'g', 'y', 'orange', 'r', 'k']

        plot_num = 1
        for filt in self.filter_list:

            fig.add_subplot(n_row, n_col, plot_num)
            filt_name = str('lsst' + filt)
            plt.title(filt_name)

            sc_bp = self.seeing_curve['bandpass']
            filt_seeing = self.seeing_curve['seeing'][sc_bp == filt_name]
            filt_mjd = self.seeing_curve['mjd'][sc_bp == filt_name]
            plt.scatter(filt_mjd, filt_seeing,
                        c=color[(plot_num - 1)//2], marker='+')
            plt.locator_params(axis='x', nbins=5)
            plt.ylabel('Seeing (arcseconds)')
            plt.xlabel('MJD')
            plot_num += 1

            fig.add_subplot(n_row, n_col, plot_num)
            plt.title(filt_name)
            plt.hist(filt_seeing, color=color[(plot_num - 1)//2])
            plt.xlabel('Seeing (arcseconds)')
            plot_num += 1

        fig.tight_layout()

        return fig 
Example 22
Project: rl-botics   Author: Suman7495   File: manipulation_obstacles.py    MIT License 5 votes vote down vote up
def render(self):
        plt.cla()
        markers = ['o', 'g']
        self.ax.scatter(self.start_loc[0],
                        self.start_loc[1],
                        self.start_loc[2],
                        c='r',
                        marker='o',
                        s=100,
                        label='End effector')
        self.ax.scatter(self.goal_loc[0],
                        self.goal_loc[1],
                        self.goal_loc[2],
                        c='g',
                        marker='^',
                        s=100,
                        label='Goal')

        self.ax.quiver(self.start_loc[0],
                       self.start_loc[1],
                       self.start_loc[2],
                       0.1, 0.1, 0.1)


        self.ax.set_xlim(self.low - 0.1, self.high+0.1)
        self.ax.set_ylim(self.low - 0.1, self.high+0.1)
        self.ax.set_zlim(self.low - 0.1, self.high+0.1)

        self.ax.set_xlabel('X')
        self.ax.set_ylabel('Y')
        self.ax.set_zlabel('Z')

        plt.title("Manipulation with Obstacles")
        plt.locator_params(axis='x', nbins=4)
        plt.locator_params(axis='y', nbins=4)
        plt.locator_params(axis='z', nbins=4)
        plt.tight_layout()
        plt.legend()
        plt.grid(False)
        plt.pause(0.05)
        plt.show(block=False) 
Example 23
Project: optimal_landing   Author: darioizzo   File: vis.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def vis_control(traj, shadow_last=0):

    if 't' in traj:
        t = traj['t']
        xlabel = 't'
    else:
        t = list(range(len(traj)))
        xlabel = ''

    columns = [c for c in traj.columns if c != 't']

    plot_columns = 3
    plot_rows = int(len(columns)/3)

    if len(columns) % plot_columns > 0:
        plot_rows += 1

    plt.rcParams['figure.figsize'] = (10, 3*plot_rows)
    fig, axes = plt.subplots(nrows=plot_rows, ncols=plot_columns)
    fig.tight_layout()

    for i, c in enumerate(columns):
        bg_color = None
        if i > (len(columns)-shadow_last-1):
            bg_color = 'lightgray'
        plt.subplot(plot_rows, plot_columns, i+1, axisbg=bg_color)
        plt.locator_params(nbins=2)
        plt.xlabel(xlabel)
        plt.ylabel(c)
        plt.plot(t, traj[c])
        plt.xlim((t.iloc[0], t.iloc[-1]))
        plt.ylim((min(traj[c]), max(traj[c])))

    for i in range(len(columns),plot_rows*plot_columns):
        plt.subplot(plot_rows,plot_columns,i+1)
        plt.axis('off')

    plt.tight_layout() 
Example 24
Project: hypercl   Author: chrhenning   File: regression1d_data.py    Apache License 2.0 5 votes vote down vote up
def plot_dataset(self, show=True):
        """Plot the whole dataset.

        Args:
            show: Whether the plot should be shown.
        """

        train_x = self.get_train_inputs().squeeze()
        train_y = self.get_train_outputs().squeeze()

        test_x = self.get_test_inputs().squeeze()
        test_y = self.get_test_outputs().squeeze()

        if self.num_val_samples > 0:
            val_x = self.get_val_inputs().squeeze()
            val_y = self.get_val_outputs().squeeze()

        sample_x, sample_y = self._get_function_vals()

        # The default matplotlib setting is usually too high for most plots.
        plt.locator_params(axis='y', nbins=2)
        plt.locator_params(axis='x', nbins=6)

        plt.plot(sample_x, sample_y, color='k', label='f(x)',
                 linestyle='dashed', linewidth=.5)
        plt.scatter(train_x, train_y, color='r', label='Train')
        plt.scatter(test_x, test_y, color='b', label='Test', alpha=0.8)
        if self.num_val_samples > 0:
            plt.scatter(val_x, val_y, color='g', label='Val', alpha=0.5)
        plt.legend()
        plt.title('1D-Regression Dataset')
        plt.xlabel('$x$')
        plt.ylabel('$y$')

        if show:
            plt.show() 
Example 25
Project: hypercl   Author: chrhenning   File: regression1d_data.py    Apache License 2.0 5 votes vote down vote up
def plot_dataset(self, show=True):
        """Plot the whole dataset.

        Args:
            show: Whether the plot should be shown.
        """

        train_x = self.get_train_inputs().squeeze()
        train_y = self.get_train_outputs().squeeze()

        test_x = self.get_test_inputs().squeeze()
        test_y = self.get_test_outputs().squeeze()

        if self.num_val_samples > 0:
            val_x = self.get_val_inputs().squeeze()
            val_y = self.get_val_outputs().squeeze()

        sample_x, sample_y = self._get_function_vals()

        # The default matplotlib setting is usually too high for most plots.
        plt.locator_params(axis='y', nbins=2)
        plt.locator_params(axis='x', nbins=6)

        plt.plot(sample_x, sample_y, color='k', label='f(x)',
                 linestyle='dashed', linewidth=.5)
        plt.scatter(train_x, train_y, color='r', label='Train')
        plt.scatter(test_x, test_y, color='b', label='Test', alpha=0.8)
        if self.num_val_samples > 0:
            plt.scatter(val_x, val_y, color='g', label='Val', alpha=0.5)
        plt.legend()
        plt.title('1D-Regression Dataset')
        plt.xlabel('$x$')
        plt.ylabel('$y$')

        if show:
            plt.show() 
Example 26
Project: denoiser   Author: cdiazbas   File: bayesPrediction.py    MIT License 5 votes vote down vote up
def predict(self):
        print("Predicting validation data...")

        input_validation = np.zeros((1,self.nx,self.ny,1), dtype='float32')
        input_validation[0,:,:,0] = self.image

        # From our tests, the epistemic uncertainty is around one order of magnitude 
        # smaller. Therefore, one could make a single forward pass of the network to 
        # have a rough estimation of the total uncertainty (without the MonteCarlo). 

        start = time.time()
        result = np.array(self.model_prediction.predict(input_validation))
        sigma_total = np.sqrt(np.exp(result[0]))
        prediction = result[1]
        end = time.time()
        print("Prediction took {0:3.2} seconds...".format(end-start))        
        
        medio = 3*2.6e-3
        import matplotlib.pyplot as plt
        plt.figure(figsize=(12,6))
        plt.subplot(131)
        plt.title('Original')
        plt.imshow(imgs,cmap='seismic',origin='lower',interpolation='None',vmin=-medio,vmax=+medio)
        plt.minorticks_on(); plt.locator_params(axis='y', nbins=4); plt.ylabel('Y [pixel]'); plt.xlabel('X [pixel]')
        plt.subplot(132)
        plt.title('Output DNN')
        plt.imshow(prediction[0,:,:,0],cmap='seismic',vmin=-medio,vmax=+medio,origin='lower',interpolation='None')
        plt.minorticks_on(); plt.locator_params(axis='y', nbins=4); plt.xlabel('X [pixel]'); plt.tick_params(axis='y',labelleft=False)
        plt.subplot(133)
        plt.title('Uncertainty')
        plt.imshow(sigma_total[0,:,:,0]*1e3,cmap='gray_r',origin='lower',interpolation='None')        
        plt.minorticks_on(); plt.locator_params(axis='y', nbins=4); plt.xlabel('X [pixel]'); plt.tick_params(axis='y',labelleft=False)
        plt.savefig('docs/prediction'+str(self.number)+'.png',bbox_inches='tight')

        np.save(self.output,result) 
Example 27
Project: denoiser   Author: cdiazbas   File: prediction_sst.py    MIT License 5 votes vote down vote up
def predict(self):
        print("Predicting validation data...")

        input_validation = np.zeros((1,self.nx,self.ny,1), dtype='float32')
        input_validation[0,:,:,0] = self.image


        start = time.time()
        out = self.model.predict(input_validation)
        end = time.time()
        print("Prediction took {0:3.2} seconds...".format(end-start))        
        
        ima = self.image
        medio = 3*2.6e-3
        import matplotlib.pyplot as plt
        plt.figure(figsize=(12,6))
        plt.subplot(131)
        plt.title('Original')
        plt.imshow(ima,cmap='seismic',origin='lower',interpolation='None',vmin=-medio,vmax=+medio)
        plt.minorticks_on(); plt.locator_params(axis='y', nbins=4); plt.ylabel('Y [pixel]'); plt.xlabel('X [pixel]')
        plt.subplot(132)
        plt.title('Output DNN')
        plt.imshow(out[0,:,:,0],cmap='seismic',vmin=-medio,vmax=+medio,origin='lower',interpolation='None')
        plt.minorticks_on(); plt.locator_params(axis='y', nbins=4); plt.xlabel('X [pixel]'); plt.tick_params(axis='y',labelleft=False)
        plt.subplot(133)
        plt.title('Difference')
        plt.imshow(ima-out[0,:,:,0],cmap='seismic',vmin=-medio,vmax=+medio,origin='lower',interpolation='None')        
        plt.minorticks_on(); plt.locator_params(axis='y', nbins=4); plt.xlabel('X [pixel]'); plt.tick_params(axis='y',labelleft=False)
        plt.savefig('docs/prediction'+str(self.number)+'.png',bbox_inches='tight')
        plt.tight_layout()
        np.save(self.output,out[0,:,:,0]) 
Example 28
Project: treebuffers   Author: rgrig   File: make_plots.py    MIT License 5 votes vote down vote up
def main():
  global algorithms
  global datadir
  global history
  global legend_location
  args = argparser.parse_args()
  algorithms = set(algorithms) - set(args.exclude_algorithm)
  datadir = args.datadir
  history = args.history
  legend_location = args.legend_location
  plt.rc('font', size=5)
  plt.rc('axes', color_cycle=['r','g','b','y'])
  plt.figure(figsize=(2,2))
  plt.locator_params(axis='x', nbins=5)
  if args.plot_all or args.steps_frequency:
    plot_steps_frequency()
  if args.plot_all or args.stepssum_vs_opcount:
    plot_stepssum_vs_opcount()
  if args.plot_all or args.nodesmax_vs_opcount:
    plot_nodesmax_vs_opcount()
  if args.plot_all or args.stepssum_vs_history:
    plot_stepssum_vs_history()
  if args.plot_all or args.nodesmax_vs_history:
    plot_nodesmax_vs_history()
  if args.plot_all or args.stepsavg_vs_history:
    plot_stepsavg_vs_history()
  if args.plot_all or args.stepsmed_vs_history:
    plot_stepsmed_vs_history()
  if args.plot_all or args.stepsavgdev_vs_history:
    plot_stepsavgdev_vs_history()
  if args.plot_all or args.stepsmedmax_vs_history:
    plot_stepsmedmax_vs_history() 
Example 29
Project: PyRsw   Author: PyRsw   File: Diagnose.py    MIT License 5 votes vote down vote up
def plot(sim):
    KE = np.array(sim.KEs)
    PE = np.array(sim.PEs)
    EN = np.array(sim.ENs)
    M  = np.array(sim.Ms)
    T  = np.array(sim.diag_times)

    fig = plt.figure()
            
    plt.subplot(2,2,1)
    plt.plot(T,KE - KE[0])
    plt.title('Tot. KE Dev.')
    plt.tight_layout()
    plt.locator_params(nbins=5)

    plt.subplot(2,2,2)
    plt.plot(T,PE - PE[0])
    plt.title('Tot. PE Dev.')
    plt.tight_layout()
    plt.locator_params(nbins=5)

    plt.subplot(2,2,3)
    plt.plot(T,(abs(KE) + abs(PE))/(KE[0]+PE[0]) - 1)
    plt.title('Rel. Energy Dev.')
    plt.tight_layout()
    plt.locator_params(nbins=5)

    plt.subplot(2,2,4)
    plt.plot(T,M/M[0] - 1.0)
    plt.title('Rel. Mass Dev.')
    plt.tight_layout()
    plt.locator_params(nbins=5)

    fig.tight_layout()
    fig.savefig('Outputs/{0:s}/diagnostics.pdf'.format(sim.run_name))
    
    return fig 
Example 30
Project: adversarial-policies   Author: HumanCompatibleAI   File: visualize.py    MIT License 4 votes vote down vote up
def bar_chart(envs, victim_id, n_components, covariance, savefile=None):
    """Bar chart of mean log probability for all opponent types, grouped by environment.
    For unspecified parameters, see get_full_directory.

    :param envs: (list of str) list of environments.
    :param savefile: (None or str) path to save figure to.
    """
    dfs = []
    for env in envs:
        df = load_metadata(env, victim_id, n_components, covariance)
        df["Environment"] = PRETTY_ENVS.get(env, env)
        dfs.append(df)
    longform = pd.concat(dfs)
    longform["opponent_id"] = longform["opponent_id"].apply(PRETTY_OPPONENTS.get)
    longform = longform.reset_index(drop=True)

    width, height = plt.rcParams.get("figure.figsize")
    legend_height = 0.4
    left_margin_in = 0.55
    top_margin_in = legend_height + 0.05
    bottom_margin_in = 0.5
    gridspec_kw = {
        "left": left_margin_in / width,
        "top": 1 - (top_margin_in / height),
        "bottom": bottom_margin_in / height,
    }
    fig, ax = plt.subplots(1, 1, gridspec_kw=gridspec_kw)

    # Make colors consistent with previous figures
    standard_cycle = list(plt.rcParams["axes.prop_cycle"])
    palette = {
        label: standard_cycle[CYCLE_ORDER.index(label)]["color"]
        for label in PRETTY_OPPONENTS.values()
    }

    # Actually plot
    sns.barplot(
        x="Environment",
        y="log_proba",
        hue="opponent_id",
        order=PRETTY_ENVS.values(),
        hue_order=BAR_ORDER,
        data=longform,
        palette=palette,
        errwidth=1,
    )
    ax.set_ylabel("Mean Log Probability Density")
    plt.locator_params(axis="y", nbins=4)
    util.rotate_labels(ax, xrot=0)

    # Plot our own legend
    ax.get_legend().remove()
    legend_entries = ax.get_legend_handles_labels()
    util.outside_legend(
        legend_entries, 3, fig, ax, ax, legend_padding=0.05, legend_height=0.6, handletextpad=0.2
    )

    if savefile is not None:
        fig.savefig(savefile)

    return fig 
Example 31
Project: spatial_patterns   Author: sim-web   File: plotting.py    GNU General Public License v3.0 4 votes vote down vote up
def input_tuning_extrema_distribution(self, populations=['exc', 'inh'],
                                          min_max=['min'],
                                          colors=None):
        """
        Plots histogram of maxima and minima of each input tuning function

        This is only interesting for gaussian process inputs.
        Maxima are plotted with high alpha value.
        Minima are plotted with low alpha value.
        Populations follow the usual color code.

        Parameters
        ----------
        min_max : list
            Either ['min'] or ['max'] or ['min', 'max']
        Returns
        -------
        """
        if not colors:
            colors = self.colors
        label_dict = {'exc': {1: 'L = 2, exc.',
                              500: 'L = 1000, exc.'},
                      'inh': {1: 'L = 2, inh.',
                              500: 'L = 1000, inh.'}
                      }
        for psp in self.psps:
            self.set_params_rawdata_computed(psp, set_sim_params=True)
            # extraticks = []
            alpha_min_max = {'min': 1.0, 'max': 0.8}
            for p in populations:
                for m in min_max:
                    gp_m = self.rawdata[p]['gp_' + m]
                    # label = m + ' {0}'.format(p)
                    label = label_dict[p][self.radius]
                    plt.hist(gp_m, bins=10, alpha=alpha_min_max[m],
                         label=label, color=colors[p], histtype='step', lw=2)
                    ax = plt.gca()
                    mean = np.mean(gp_m)
                    plt.axvline(mean, color=colors[p], linestyle='dotted', lw=2)
                    # trans = mpl.transforms.blended_transform_factory(
                    # 		ax.transData, ax.transAxes)
                    # plt.text(mean, 1, '{0:.2}'.format(mean),
                    # 		 rotation='vertical', transform=trans)
                plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
                plt.locator_params(axis='both', nbins=3)
                plt.setp(ax,
                        xlabel='Minimum of input tuning',
                        ylabel='Frequency') 
Example 32
Project: spatial_patterns   Author: sim-web   File: plotting.py    GNU General Public License v3.0 4 votes vote down vote up
def gridscore_vs_location(self, time=-1, n=1000,
                              colorbar_range='automatic',
                              time_l=None, time_r=None):
        """
        Plot a spikemap from a ratemap

        Parameters
        ----------
        n : int
            Number of spikes

        Returns
        -------
        """
        from .gridscore.spikedata import SpikesFromRatemap
        from .gridscore.plotting import Plot as gsPlot
        from .gridscore.plotting import ARGUMENTS_DEFAULT
        for psp in self.psps:
            self.set_params_rawdata_computed(psp, set_sim_params=True)
            if time_l and time_r:
                ratemap = self.get_ratemap_left_and_right(time_l, time_r)
            else:
                frame = self.time2frame(time, weight=True)
                ratemap = self.get_output_rates(frame=frame, spacing=None,
                                            from_file=True, squeeze=True)
            arena_limits = np.array([[0, 2*self.radius], [0, 2*self.radius]])
            rm = SpikesFromRatemap(ratemap=[ratemap], arena_limits=arena_limits)
            spikepositions = rm.get_spikepositions(n)
            gsplot = gsPlot(spikepositions=spikepositions,
                           arena_limits=arena_limits)
            gsplot.set_nonspecified_attributes_to_default(ARGUMENTS_DEFAULT)
            psi_abs = np.absolute(gsplot._get_psi_n())
            n_bars = 4.
            left_edges = np.arange(0, 1, 1/n_bars)
            gridscore_in_parts = [
                np.mean(
                    psi_abs[gsplot._idx_in_part_of_arena(
                        part='x_range', x_range=[i, i + 1/n_bars])])
                for i in left_edges
            ]
            # n = np.array([0, 1, 2])
            width = 1/n_bars
            # left = n + 1 - width / 2.
            plt.bar(left_edges, gridscore_in_parts, width=width,
                    color=color_cycle_blue3[0])
            ax = plt.gca()
            plt.setp(ax,
                     xlabel='Box side',
                     ylabel=r'$\Psi$ score')
            plt.locator_params(axis='y', tight=True, nbins=5)
            simpleaxis(ax) 
Example 33
Project: transformer_lexical_shortcuts   Author: demelin   File: curve_plotter.py    MIT License 4 votes vote down vote up
def plot_curves(json_logs, exp_ids, mode, factor):
    """ Plots perplexity/ BLEU curves of multiple experiments for comparison against each other """
    # Set up plot
    sns.set(style='whitegrid', context='paper', font_scale=1.3)
    sns.despine()
    plt.locator_params(axis='y', nbins=10)
    plt.locator_params(axis='x', nbins=10)
    # Parse logs
    exp_vals_all = list()
    time_steps_all = list()
    for _, log in enumerate(json_logs):
        bleu_and_ppx = load_from_json(log)
        val_dict = bleu_and_ppx['validation_{:s}'.format(mode)]
        # Extract keys from BLEU / perplexity dicts
        sorted_keys = [int(key) for key in val_dict.keys()]
        sorted_keys.sort()
        exp_vals = [val_dict[str(key)] for key in sorted_keys]
        time_steps = sorted_keys
        # Thin-out log data
        exp_vals_shrunk = list()
        time_steps_shrunk = list()
        step = 0
        if factor > 1:
            while (step + factor) < len(exp_vals):
                exp_vals_shrunk.append(sum(exp_vals[step: (step + factor)]) / factor)
                time_steps_shrunk.append(sum(time_steps[step: (step + factor)]) / factor)
                step += factor
            exp_vals = exp_vals_shrunk
            time_steps = time_steps_shrunk

        exp_vals_all.append(exp_vals)
        time_steps_all.append(time_steps)

    markers = list('+.x')    
    for step, bleu in enumerate(exp_vals_all):
        plt.plot(time_steps_all[step], bleu, label=exp_ids[step], marker=markers[step], markersize=8)

    loc = 'upper right' if mode == 'perplexity' else 'lower right'
    mode = mode if mode == 'perplexity' else mode.upper()

    plt.legend(loc=loc)
    # plt.title('{:s} per number of training updates'.format(mode), fontsize=13)
    plt.xlabel('num updates')
    plt.ylabel('validation {:s}'.format(mode))

    # Adjust plot margins (https://stackoverflow.com/questions/18619880/matplotlib-adjust-figure-margin)
    plot_margin = 0.25
    x0, x1, y0, y1 = plt.axis()
    plt.axis((x0 - plot_margin,
              x1 + plot_margin,
              y0 - plot_margin,
              y1 + plot_margin))

    plt.show() 
Example 34
Project: AmberUtils   Author: williamdlees   File: CalcBounds.py    MIT License 4 votes vote down vote up
def main(argv):
    parser = argparse.ArgumentParser(description='Analyse the distribution of MMPBSA/MMGBSA delta G')
    parser.add_argument('infile', help='input file containing energy totals (CSV format)')
    parser.add_argument('sumfile', help='summary file (text format)')
    parser.add_argument('trendfile', help='trend plot with confidence intervals (.png, .bmp, .pdf)')
    parser.add_argument('distfile', help='distribution plot of energy totals (.png, .bmp, .pdf)')
    parser.add_argument('-c', '--column', help='name of column to use (default TOTAL)')
    args = parser.parse_args()
    column = 'TOTAL' if not args.column else args.column
    
    global mean_results
    
    font = FontProperties()
    font.set_name('Calibri')
    font.set_size(28)

    means = []
    with open(args.infile) as infile:
        reader = csv.DictReader(infile)
        for row in reader:
            if len(row[column]) > 0:
                means.append(float(row[column]))

    means = np.array(means)

    results_m = []
    results_u = []
    results_l = []
    xs = []

    with open(args.sumfile, 'w') as fo:
        for i in range(5, len(means)+5, 5):
            if i >= 10000:
                print('Stopping after 10000 values due to limitations in the bootstrap function.')
                break
            mean_results = []
            bounds = conf_intervals(means[:i])
            fo.write("%d mean %0.2f +%0.2f -%0.2f\n" % (i, bounds[0], bounds[1], bounds[2]))
            results_m.append(bounds[0])
            results_u.append(bounds[1]+bounds[0])
            results_l.append(bounds[2]+bounds[0])
            xs.append(i)

        plt.plot(xs, results_m, color='g')
        plt.plot(xs, results_u, linestyle='--', color='g')
        plt.plot(xs, results_l, linestyle='--', color='g')
        plt.locator_params(nbins=5, axis='y')
        plt.xlabel(u'Samples', fontproperties=font)
        plt.ylabel(u'\u0394G (kcal/mol)', fontproperties=font)  
        plt.tight_layout()
        plt.savefig(args.trendfile)

    lim_l = round(bounds[0] - 2.5, 0)
    
    plt.xlim(lim_l, lim_l+5)
    plt.ylim(0, 900)
    plt.locator_params(nbins=5, axis='y')
    plt.xlabel(u'Bootstrapped mean \u0394G (kcal/mol)', fontproperties=font)
    plt.ylabel(u'Frequency', fontproperties=font)
    plt.hist(mean_results, bins=50)
    plt.savefig(args.distfile) 
Example 35
Project: optimal_landing   Author: darioizzo   File: vis.py    GNU Lesser General Public License v3.0 4 votes vote down vote up
def compare_control(traj, traj_comp, shadow_last=0,plot_columns = 4, order=None):

    xlabel = 't'
    if 't' in traj:
        t = traj['t']
    else:
        print('Time is needed for control comparison')
        return None

    if 't' in traj_comp:
        t_comp = traj_comp['t']
    else:
        print('Time is needed for control comparison')
        return None

    columns = [c for c in traj.columns if c != 't']
    if not order == None:
        columns = [columns[i] for i in order]

    plot_rows = int(len(columns)/plot_columns)

    if len(columns) % plot_columns > 0:
        plot_rows += 1

    sns.set(font_scale=1.8)
    sns.set_style("whitegrid")
    sns.set_style("ticks", {"xtick.major.size": 4, "ytick.major.size": 4})

    plt.rcParams['figure.figsize'] = (12, 5 *plot_rows)

    fig, axes = plt.subplots(nrows=plot_rows, ncols=plot_columns)
    sns.set_context(font_scale=2)


    for i, c in enumerate(columns):
        bg_color = None
        if i > (len(columns)-shadow_last-1):
            bg_color = 'lightgray'
        plt.subplot(plot_rows, plot_columns, i+1, axisbg=bg_color)
        plt.xlabel(xlabel)
        plt.ylabel(c)
        l1, = plt.plot(t, traj[c])
        l2, = plt.plot(t_comp, traj_comp[c], c=sns.color_palette()[2])
        plt.locator_params(nbins=4)
        r = (max(traj[c])) - (min(traj[c]))
        plt.ylim((min(traj[c])-0.1*r, max(traj[c])+0.1*r))
        plt.xlim((t.iloc[0], t.iloc[-1]))

    plt.figlegend([l1,l2],['Optimal control', 'DNN control'],loc = 'upper center', ncol=2,prop={'size':20}, bbox_to_anchor=(0.5, 1.01 ))

    for i in range(len(columns),plot_rows*plot_columns):
        plt.subplot(plot_rows,plot_columns,i+1)
        plt.axis('off')

    plt.tight_layout()
    return fig 
Example 36
Project: radvel   Author: California-Planet-Search   File: orbit_plots.py    MIT License 4 votes vote down vote up
def plot_timeseries(self):
        """
        Make a plot of the RV data and model in the current Axes.
        """

        ax = pl.gca()

        ax.axhline(0, color='0.5', linestyle='--')

        if self.show_rms:
            rms_values = dict()
            for like in self.like_list:
                inst = like.suffix
                rms = np.std(like.residuals())
                rms_values[inst] = rms
        else:
            rms_values = False

        # plot orbit model
        ax.plot(self.mplttimes, self.orbit_model, 'b-', rasterized=False, lw=self.fit_linewidth)

        # plot data
        vels = self.rawresid+self.rvmod
        plot.mtelplot(
            # data = residuals + model
            self.plttimes, vels, self.rverr, self.post.likelihood.telvec, ax, telfmts=self.telfmts,
            rms_values=rms_values
        )

        if self.set_xlim is not None:
            ax.set_xlim(self.set_xlim)
        else:
            ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt)    
        pl.setp(ax.get_xticklabels(), visible=False)

        if self.highlight_last:
            ind = np.argmax(self.plttimes)
            pl.plot(self.plttimes[ind], vels[ind], **plot.highlight_format)

        # legend
        if self.legend:
            ax.legend(numpoints=1, **self.legend_kwargs)

        # years on upper axis
        axyrs = ax.twiny()
        xl = np.array(list(ax.get_xlim())) + self.epoch
        decimalyear = Time(xl, format='jd', scale='utc').decimalyear
#        axyrs.plot(decimalyear, decimalyear)
        axyrs.get_xaxis().get_major_formatter().set_useOffset(False)
        axyrs.set_xlim(*decimalyear)
        axyrs.set_xlabel('Year', fontweight='bold')
        pl.locator_params(axis='x', nbins=5)

        if not self.yscale_auto: 
            scale = np.std(self.rawresid+self.rvmod)
            ax.set_ylim(-self.yscale_sigma * scale, self.yscale_sigma * scale)

        ax.set_ylabel('RV [{ms:}]'.format(**plot.latex), weight='bold')
        ticks = ax.yaxis.get_majorticklocs()
        ax.yaxis.set_ticks(ticks[1:]) 
Example 37
Project: radvel   Author: California-Planet-Search   File: orbit_plots.py    MIT License 4 votes vote down vote up
def plot_timeseries(self):
        """
        Make a plot of the RV data and Gaussian Process + orbit model in the current Axes.
        """

        ax = pl.gca()

        ax.axhline(0, color='0.5', linestyle='--')

        if self.subtract_orbit_model:
            orbit_model4data = np.zeros(self.rvmod.shape)
        else:
            orbit_model4data = self.rvmod

        ci = 0
        for like in self.like_list:
            ci = self.plot_gp_like(like, orbit_model4data, ci)

        # plot data
        plot.mtelplot(
            # data = residuals + model
            self.plttimes, self.rawresid+orbit_model4data, self.rverr,
            self.post.likelihood.telvec, ax, telfmts=self.telfmts
        )

        if self.set_xlim is not None:
            ax.set_xlim(self.set_xlim)
        else:
            ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt)    
        pl.setp(ax.get_xticklabels(), visible=False)

        # legend
        if self.legend:
            ax.legend(numpoints=1, **self.legend_kwargs)

        # years on upper axis
        axyrs = ax.twiny()
        xl = np.array(list(ax.get_xlim())) + self.epoch
        decimalyear = Time(xl, format='jd', scale='utc').decimalyear
        axyrs.plot(decimalyear, decimalyear)
        axyrs.get_xaxis().get_major_formatter().set_useOffset(False)
        axyrs.set_xlim(*decimalyear)
        pl.locator_params(axis='x', nbins=5)
        axyrs.set_xlabel('Year', fontweight='bold')