Python seaborn.axes_style() Examples

The following are 30 code examples of seaborn.axes_style(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module seaborn , or try the search function .
Example #1
Source File: plot_functions.py    From idea_relations with MIT License 7 votes vote down vote up
def joint_plot(x, y, xlabel=None,
               ylabel=None, xlim=None, ylim=None,
               loc="best", color='#0485d1',
               size=8, markersize=50, kind="kde",
               scatter_color="r"):
    with sns.axes_style("darkgrid"):
        if xlabel and ylabel:
            g = SubsampleJointGrid(xlabel, ylabel,
                    data=DataFrame(data={xlabel: x, ylabel: y}),
                    space=0.1, ratio=2, size=size, xlim=xlim, ylim=ylim)
        else:
            g = SubsampleJointGrid(x, y, size=size,
                    space=0.1, ratio=2, xlim=xlim, ylim=ylim)
        g.plot_joint(sns.kdeplot, shade=True, cmap="Blues")
        g.plot_sub_joint(plt.scatter, 1000, s=20, c=scatter_color, alpha=0.3)
        g.plot_marginals(sns.distplot, kde=False, rug=False)
        g.annotate(ss.pearsonr, fontsize=25, template="{stat} = {val:.2g}\np = {p:.2g}")
        g.ax_joint.set_yticklabels(g.ax_joint.get_yticks())
        g.ax_joint.set_xticklabels(g.ax_joint.get_xticks())
    return g 
Example #2
Source File: plotting_utils.py    From QUANTAXIS with MIT License 6 votes vote down vote up
def customize(func):
    """
    修饰器,设置输出图像内容与风格
    """

    @wraps(func)
    def call_w_context(*args, **kwargs):
        set_context = kwargs.pop("set_context", True)
        if set_context:
            color_palette = sns.color_palette("colorblind")
            with plotting_context(), axes_style(), color_palette:
                sns.despine(left=True)
                return func(*args, **kwargs)
        else:
            return func(*args, **kwargs)

    return call_w_context 
Example #3
Source File: plotting_utils.py    From QUANTAXIS with MIT License 6 votes vote down vote up
def axes_style(style: str = "darkgrid", rc: dict = None):
    """
    创建默认轴域风格

    参数
    ---
    :param style: seaborn 样式
    :param rc: dict 配置标签
    """
    if rc is None:
        rc = {}

    rc_default = {}

    for name, val in rc_default.items():
        rc.set_default(name, val)

    return sns.axes_style(style=style, rc=rc) 
Example #4
Source File: plot_utils.py    From jqfactor_analyzer with MIT License 6 votes vote down vote up
def customize(func):

    @wraps(func)
    def call_w_context(*args, **kwargs):

        if not PlotConfig.FONT_SETTED:
            _use_chinese(True)

        set_context = kwargs.pop('set_context', True)
        if set_context:
            with plotting_context(), axes_style():
                sns.despine(left=True)
                return func(*args, **kwargs)
        else:
            return func(*args, **kwargs)

    return call_w_context 
Example #5
Source File: interpretation.py    From lumin with Apache License 2.0 6 votes vote down vote up
def plot_embedding(embed:OrderedDict, feat:str, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None:
    r'''
    Visualise weights in provided categorical entity-embedding matrix

    Arguments:
        embed: state_dict of trained nn.Embedding
        feat: name of feature embedded
        savename: Optional name of file to which to save the plot of feature importances
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
    '''

    with sns.axes_style(**settings.style):
        plt.figure(figsize=(settings.w_small, settings.h_small))
        sns.heatmap(to_np(embed['weight']), annot=True, fmt='.1f', linewidths=.5, cmap=settings.div_palette, annot_kws={'fontsize':settings.leg_sz})
        plt.xlabel("Embedding", fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.ylabel(feat, fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc)
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight')
        plt.show() 
Example #6
Source File: plotting.py    From MJHMC with GNU General Public License v2.0 6 votes vote down vote up
def gauss_2d(nsamples=1000):
    """
    Another simple test plot
    1d gaussian sampled from each sampler visualized as a joint 2d gaussian
    """
    gaussian = TestGaussian(ndims=1)
    control = HMCBase(distribution=gaussian)
    experimental = MarkovJumpHMC(distribution=gaussian, resample=False)


    with sns.axes_style("white"):
        sns.jointplot(
            control.sample(nsamples)[0],
            experimental.sample(nsamples)[0],
            kind='hex',
            stat_func=None) 
Example #7
Source File: cyclic_callbacks.py    From lumin with Apache License 2.0 6 votes vote down vote up
def plot(self):
        r'''
        Plots the history of the lr and momentum evolution as a function of iterations
        '''

        with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            fig, axs = plt.subplots(2, 1, figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid))
            axs[1].set_xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            axs[0].set_ylabel("Learning Rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            axs[1].set_ylabel("Momentum", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            axs[0].plot(range(len(self.hist['lr'])), self.hist['lr'])
            axs[1].plot(range(len(self.hist['mom'])), self.hist['mom'])
            for ax in axs:
                ax.tick_params(axis='x', labelsize=self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
                ax.tick_params(axis='y', labelsize=self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
            plt.show() 
Example #8
Source File: plot_utils.py    From jqfactor_analyzer with MIT License 5 votes vote down vote up
def axes_style(style='darkgrid', rc=None):

    if rc is None:
        rc = {}

    rc_default = {}

    for name, val in rc_default.items():
        rc.setdefault(name, val)

    return sns.axes_style(style=style, rc=rc) 
Example #9
Source File: plotfunctions.py    From DataScience-webapp-with-flask with MIT License 5 votes vote down vote up
def plot_boxplot(ds, cat, num):
    sns.set()
    plt.gcf().clear()
    with sns.axes_style(style='ticks'):
        sns.factorplot(cat, num, data=ds, kind="box")
    from io import BytesIO
    plt.xlabel(cat)
    plt.ylabel(num)
    figfile = BytesIO()
    plt.savefig(figfile, format='png')
    figfile.seek(0)  # rewind to beginning of file
    import base64
    figdata_png = base64.b64encode(figfile.getvalue())
    return figdata_png 
Example #10
Source File: rt-heatmap.py    From pyrocore with GNU General Public License v2.0 5 votes vote down vote up
def heatmap(self, df, imagefile):
        """ Create the heat map.
        """
        import seaborn as sns
        import matplotlib.ticker as tkr
        import matplotlib.pyplot as plt
        from  matplotlib.colors import LinearSegmentedColormap

        sns.set()
        with sns.axes_style('whitegrid'):
            fig, ax = plt.subplots(figsize=(5, 11))  # inches

            cmax = max(df[self.args[2]].max(), self.CMAP_MIN_MAX)
            csteps = {
                0.0: 'darkred', 0.3/cmax: 'red', 0.6/cmax: 'orangered', 0.9/cmax: 'coral',
                1.0/cmax: 'skyblue', 1.5/cmax: 'blue', 1.9/cmax: 'darkblue',
                2.0/cmax: 'darkgreen', 3.0/cmax: 'green',
                (self.CMAP_MIN_MAX - .1)/cmax: 'palegreen', 1.0: 'yellow'}
            cmap = LinearSegmentedColormap.from_list('RdGrYl', sorted(csteps.items()), N=256)

            dataset = df.pivot(*self.args)

            sns.heatmap(dataset, mask=dataset.isnull(), annot=False, linewidths=.5, square=True, ax=ax, cmap=cmap,
                        annot_kws=dict(stretch='condensed'))
            ax.tick_params(axis='y', labelrotation=30, labelsize=8)
            # ax.get_yaxis().set_major_formatter(tkr.FuncFormatter(lambda x, p: x))
            plt.savefig(imagefile) 
Example #11
Source File: plotting.py    From MJHMC with GNU General Public License v2.0 5 votes vote down vote up
def hist_2d(distr, nsamples, **kwargs):
    """
    Plots a 2d hexbinned histogram of distribution

    Args:
     distr: Distribution object
     nsamples: number of samples to use to generate plot
    """
    sampler = MarkovJumpHMC(distribution=distr, **kwargs)
    samples = sampler.sample(nsamples)

    with sns.axes_style("white"):
       g =  sns.jointplot(samples[0], samples[1], kind='kde', stat_func=None)
    return g 
Example #12
Source File: budgeted_stream_plot.py    From DARENet with MIT License 5 votes vote down vote up
def main(args):
    distance_confidence_info = pickle.load(open(osp.join(args.result_path, "distance_confidence_info.pkl"), "rb"))
    margin_confidence_info = pickle.load(open(osp.join(args.result_path, "margin_confidence_info.pkl"), "rb"))
    random_info = pickle.load(open(osp.join(args.result_path, "random_info.pkl"), "rb"))
    distance_confidence_info['CMCs'] = [cmc * 100 for cmc in distance_confidence_info['CMCs']]
    margin_confidence_info['CMCs'] = [cmc * 100 for cmc in margin_confidence_info['CMCs']]
    random_info['CMCs'] = [cmc * 100 for cmc in random_info['CMCs']]

    with sns.axes_style("white"):
        fig = plt.figure(figsize=(6, 4.5))
        ax  = fig.add_subplot(111)
        ax.plot(random_info['resulted_budgets'], random_info['CMCs'], marker='.', linewidth=2.5, markersize=0, label="DaRe(R)+RE (random)", color=flatui[0])
        ax.plot(distance_confidence_info['resulted_budgets'], distance_confidence_info['CMCs'], marker='*', linewidth=2.5, markersize=0, label="DaRe(R)+RE (distance)", color=flatui[1])
        ax.plot(margin_confidence_info['resulted_budgets'], margin_confidence_info['CMCs'], marker='*', linewidth=2.5, markersize=0, label="DaRe(R)+RE (margin)", color=flatui[2])

        ax.scatter(SVDNet_R_RE[0], SVDNet_R_RE[1], marker='*', s=150, label="SVDNet(R)+RE", color=flatui[3])
        ax.scatter(IDE_R_KISSME[0], IDE_R_KISSME[1], marker='h', s=100, label="IDE(R)+KISSME", color=flatui[4])
        ax.scatter(IDE_C_KISSME[0], IDE_C_KISSME[1], marker='o', s=100, label="IDE(C)+KISSME", color=flatui[5])
        ax.scatter(TriNet_R[0], TriNet_R[1], marker='D', s=60, label="TriNet(R)", color=flatui[6])
        ax.scatter(SVDNet_C[0], SVDNet_C[1], marker='p', s=100, label="SVDNet(C)", color=flatui[7])
        plt.xlabel("Average Budget (in MUL-ADD)", size=15)
        plt.ylabel("CMC Rank 1 Accuracy (\%)", size=15)
        handles, labels = ax.get_legend_handles_labels()
        label_order = ['TriNet(R)', 'SVDNet(C)', 'SVDNet(R)+RE', 'IDE(R)+KISSME', 'IDE(C)+KISSME', 'DaRe(R)+RE (random)', 'DaRe(R)+RE (distance)', 'DaRe(R)+RE (margin)']
        new_handles = []
        for l in label_order:
            for i in range(len(labels)):
                if labels[i] == l:
                    new_handles.append(handles[i])
        ax.legend(new_handles, label_order, loc='lower right')
        plt.grid(linestyle='dotted')
        plt.tight_layout(pad=1, w_pad=1, h_pad=1)
        plt.xlim(3e8, 4.5e9)
        plt.ylim(55, 95)
        plt.savefig(args.figname + ".pdf", bbox_inches='tight')
        plt.close() 
Example #13
Source File: opt_callbacks.py    From lumin with Apache License 2.0 5 votes vote down vote up
def plot_lr(self) -> None:
        r'''
        Plot the LR as a function of iterations.
        '''

        with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            plt.figure(figsize=(self.plot_settings.h_small, self.plot_settings.h_small))
            plt.plot(range(len(self.history['lr'])), self.history['lr'])
            plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.ylabel("Learning rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.show() 
Example #14
Source File: cyclic_callbacks.py    From lumin with Apache License 2.0 5 votes vote down vote up
def plot(self) -> None:
        r'''
        Plots the history of the parameter evolution as a function of iterations
        '''

        with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            plt.figure(figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid))
            plt.xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.ylabel(self.param_name, fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.plot(range(len(self.hist)), self.hist)
            plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.show() 
Example #15
Source File: metric_logger.py    From lumin with Apache License 2.0 5 votes vote down vote up
def reset(self) -> None:
        r'''
        Resets/initialises the logger's values and plots, and produces a placeholder plot. Should be called prior to `update_vals` or `update_plot`.
        '''

        self.loss_vals, self.vel_vals, self.gen_vals = [[] for _ in self.loss_names], [[] for _ in self.loss_names], [[] for _ in range(len(self.loss_names)-1)]
        self.mean_losses = [None for _ in self.loss_names]
        self.subepochs, self.epochs = [0], [0]
        self.count,self.log = 1,False

        with sns.axes_style(**self.settings.style):
            if self.extra_detail:
                self.fig = plt.figure(figsize=(self.settings.w_mid, self.settings.h_mid), constrained_layout=True)
                gs = self.fig.add_gridspec(2, 3)
                self.loss_ax = self.fig.add_subplot(gs[:,:-1])
                self.vel_ax = self.fig.add_subplot(gs[:1,2:])
                self.gen_ax  = self.fig.add_subplot(gs[1:2,2:])
                for ax in [self.loss_ax, self.vel_ax, self.gen_ax]:
                    ax.tick_params(axis='x', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col)
                    ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col)
                self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
                self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
                self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
                self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
                self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
                self.display = display(self.fig, display_id=True)
            else:
                self.fig, self.loss_ax = plt.subplots(1, figsize=(self.settings.w_mid, self.settings.h_mid))
                self.loss_ax.tick_params(axis='x', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col)
                self.loss_ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col)
                self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
                self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
                self.display = display(self.loss_ax.figure, display_id=True) 
Example #16
Source File: training.py    From lumin with Apache License 2.0 5 votes vote down vote up
def plot_train_history(histories:List[Dict[str,List[float]]], savename:Optional[str]=None, ignore_trn=True, settings:PlotSettings=PlotSettings(),
                       show:bool=True) -> None:
    r'''
    Plot histories object returned by :meth:`~lumin.nn.training.fold_train.fold_train_ensemble` showing the loss evolution over time per model trained.

    Arguments:
        histories: list of dictionaries mapping loss type to values at each (sub)-epoch
        savename: Optional name of file to which to save the plot of feature importances
        ignore_trn: whether to ignore training loss
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
        show: whether or not to show the plot, or just save it
    '''
    with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette:
        plt.figure(figsize=(settings.w_mid, settings.h_mid))
        for i, history in enumerate(histories):
            if i == 0:
                for j, l in enumerate(history):
                    if not('trn' in l and ignore_trn): plt.plot(history[l], color=palette[j], label=_lookup_name(l))
            else:
                for j, l in enumerate(history):
                    if not('trn' in l and ignore_trn): plt.plot(history[l], color=palette[j])

        plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.xlabel("Epoch", fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.ylabel("Loss", fontsize=settings.lbl_sz, color=settings.lbl_col)
        if savename is not None: plt.savefig(f'{savename}{settings.format}', bbox_inches='tight')
        if show: plt.show() 
Example #17
Source File: results.py    From lumin with Apache License 2.0 5 votes vote down vote up
def plot_binary_class_pred(df:pd.DataFrame, pred_name:str='pred', targ_name:str='gen_target', wgt_name:str=None, wgt_scale:float=1,
                           log_y:bool=False, lim_x:Tuple[float,float]=(0,1), density=True, 
                           savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None:
    r'''
    Basic plotter for prediction distribution in a binary classification problem.
    Note that labels are set using the settings.targ2class dictionary, which by default is {0: 'Background', 1: 'Signal'}.

    Arguments:
        df: DataFrame with targets and predictions
        pred_name: name of column to use as predictions
        targ_name: name of column to use as targets
        wgt_name: optional name of column to use as sample weights
        wgt_scale: applies a global multiplicative rescaling to sample weights. Default 1 = no rescaling
        log_y: whether to use a log scale for the y-axis
        lim_x: limit for plotting on the x-axis
        density: whether to normalise each distribution to one, or keep set to sum of weights / datapoints
        savename: Optional name of file to which to save the plot of feature importances
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
    '''

    with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette):
        plt.figure(figsize=(settings.w_mid, settings.h_mid))
        for targ in sorted(set(df[targ_name])):
            cut = df[targ_name] == targ
            hist_kws = {} if wgt_name is None else {'weights': wgt_scale*df.loc[cut, wgt_name]}
            sns.distplot(df.loc[cut, pred_name], label=settings.targ2class[targ], hist_kws=hist_kws, norm_hist=density, kde=False)
        plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz)
        plt.xlabel("Class prediction", fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.xlim(lim_x)
        if density:             plt.ylabel(r"$\frac{1}{N}\ \frac{dN}{dp}$", fontsize=settings.lbl_sz, color=settings.lbl_col)
        elif wgt_scale != 1:    plt.ylabel(str(wgt_scale) + r"$\times\frac{dN}{dp}$", fontsize=settings.lbl_sz, color=settings.lbl_col)
        else:                   plt.ylabel(r"$\frac{dN}{dp}$", fontsize=settings.lbl_sz, color=settings.lbl_col)
        if log_y:
            plt.yscale('log', nonposy='clip')
            plt.grid(True, which="both")
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc)
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight')
        plt.show() 
Example #18
Source File: data_viewing.py    From lumin with Apache License 2.0 5 votes vote down vote up
def plot_rank_order_dendrogram(df:pd.DataFrame, threshold:float=0.8, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) \
        -> Dict[str,Union[List[str],float]]:
    r'''
    Plots a dendrogram of features in df clustered via Spearman's rank correlation coefficient.
    Also returns a sets of features with correlation coefficients greater than the threshold

    Arguments:
        df: Pandas DataFrame containing data
        threshold: Threshold on correlation coefficient
        savename: Optional name of file to which to save the plot of feature importances
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance

    Returns:
        Dict of sets of features with correlation coefficients greater than the threshold and cluster distance
    '''

    corr = np.round(scipy.stats.spearmanr(df).correlation, 4)
    corr_condensed = hc.distance.squareform(1-np.abs(corr))  # Abs because negtaive of a feature is a trvial transformation: information unaffected
    z = hc.linkage(corr_condensed, method='average', optimal_ordering=True)

    with sns.axes_style('white'), sns.color_palette(settings.cat_palette):
        plt.figure(figsize=(settings.w_large, (0.5*len(df.columns))))
        hc.dendrogram(z, labels=df.columns, orientation='left', leaf_font_size=settings.lbl_sz, color_threshold=1-threshold)
        plt.xlabel("Distance (1 - |Spearman's Rank Correlation Coefficient|)", fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight')
        plt.show()

    feats = df.columns
    sets = {}
    for i, merge in enumerate(z):
        if merge[2] > 1-threshold: continue
        if merge[0] <= len(z): a = [feats[int(merge[0])]]
        else:                  a = sets.pop(int(merge[0]))['children']
        if merge[1] <= len(z): b = [feats[int(merge[1])]]
        else:                  b = sets.pop(int(merge[1]))['children']
        sets[1 + i + len(z)] = {'children': [*a, *b], 'distance': merge[2]}
    return sets 
Example #19
Source File: interpretation.py    From lumin with Apache License 2.0 5 votes vote down vote up
def plot_importance(df:pd.DataFrame, feat_name:str='Feature', imp_name:str='Importance',  unc_name:str='Uncertainty', threshold:Optional[float]=None,
                    x_lbl:str='Importance via feature permutation', savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None:
    r'''
    Plot feature importances as computted via `get_nn_feat_importance`, `get_ensemble_feat_importance`, or `rf_rank_features`

    Arguments:
        df: DataFrame containing columns of features, importances and, optionally, uncertainties
        feat_name: column name for features
        imp_name: column name for importances
        unc_name: column name for uncertainties (if present)
        threshold: if set, will draw a line at the threshold hold used for feature importance
        x_lbl: label to put on the x-axis
        savename: Optional name of file to which to save the plot of feature importances
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
    '''

    with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette:
        fig, ax = plt.subplots(figsize=(settings.w_large, (0.75)*settings.lbl_sz))
        xerr = None if unc_name not in df else 'Uncertainty'
        df.plot(feat_name, imp_name, 'barh', ax=ax, legend=False, xerr=xerr, error_kw={'elinewidth': 3}, color=palette[0])
        if threshold is not None:
            ax.axvline(x=threshold, label=f'Threshold {threshold}', color=palette[1], linestyle='--', linewidth=3)
            plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz)
        ax.set_xlabel(x_lbl, fontsize=settings.lbl_sz, color=settings.lbl_col)
        ax.set_ylabel('Feature', fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}')
        plt.show() 
Example #20
Source File: plot_functions.py    From idea_relations with MIT License 5 votes vote down vote up
def start_plotting(fig_size, fig_pos, style="white", rc=None, despine=False):
    with sns.axes_style(style, rc):
        fig = plt.figure(figsize=fig_size)
        if not fig_pos:
            ax = fig.add_subplot(111)
        else:
            ax = fig.add_axes(fig_pos)
    if despine:
        sns.despine(left=True)
    return fig, ax 
Example #21
Source File: threshold.py    From lumin with Apache License 2.0 4 votes vote down vote up
def binary_class_cut_by_ams(df:pd.DataFrame, top_perc:float=5.0, min_pred:float=0.9,
                            wgt_factor:float=1.0, br:float=0.0, syst_unc_b:float=0.0,
                            pred_name:str='pred', targ_name:str='gen_target', wgt_name:str='gen_weight',
                            plot_settings:PlotSettings=PlotSettings()) -> Tuple[float,float,float]:
    r'''
    Optimise a cut on a signal-background classifier prediction by the Approximate Median Significance
    Cut which should generalise better by taking the mean class prediction of the top top_perc percentage of points as ranked by AMS

    Arguments:
        df: Pandas DataFrame containing data
        top_perc: top percentage of events to consider as ranked by AMS
        min_pred: minimum prediction to consider
        wgt_factor: single multiplicative coeficient for rescaling signal and background weights before computing AMS
        br: background offset bias
        syst_unc_b: fractional systemtatic uncertainty on background
        pred_name: column to use as predictions
        targ_name: column to use as truth labels for signal and background
        wgt_name: column to use as weights for signal and background events
        plot_settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance

    Returns:
        Optimised cut
        AMS at cut
        Maximum AMS
    '''

    # TODO: Multithread AMS calculation
    
    sig, bkg = (df.gen_target == 1), (df.gen_target == 0)
    if 'ams' not in df.columns:
        df['ams'] = -1
        df.loc[df[pred_name] >= min_pred, 'ams'] = df[df[pred_name] >= min_pred].apply(
            lambda row: calc_ams(wgt_factor*np.sum(df.loc[(df[pred_name] >= row[pred_name]) & sig, wgt_name]),
                                 wgt_factor*np.sum(df.loc[(df[pred_name] >= row[pred_name]) & bkg, wgt_name]),
                                 br=br, unc_b=syst_unc_b), axis=1)
        
    sort = df.sort_values(by='ams', ascending=False)
    cuts = sort[pred_name].values[0:int(top_perc*len(sort)/100)]

    cut = np.mean(cuts)
    ams = calc_ams(wgt_factor*np.sum(sort.loc[(sort[pred_name] >= cut) & sig, 'gen_weight']),
                   wgt_factor*np.sum(sort.loc[(sort[pred_name] >= cut) & bkg, 'gen_weight']),
                   br=br, unc_b=syst_unc_b)
    
    print(f'Mean cut at {cut} corresponds to AMS of {ams}')
    print(f'Maximum AMS for data is {sort.iloc[0]["ams"]} at cut of {sort.iloc[0][pred_name]}')
    with sns.axes_style(plot_settings.style), sns.color_palette(plot_settings.cat_palette) as palette:
        plt.figure(figsize=(plot_settings.w_small, plot_settings.h_small))
        sns.distplot(cuts, label=f'Top {top_perc}%')
        plt.axvline(x=cut, label='Mean prediction', color=palette[1])
        plt.axvline(x=sort.iloc[0][pred_name], label='Max. AMS', color=palette[2])
        plt.legend(loc=plot_settings.leg_loc, fontsize=plot_settings.leg_sz)
        plt.xticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col)
        plt.yticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col)
        plt.xlabel('Class prediction', fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col)
        plt.ylabel(r"$\frac{1}{N}\ \frac{dN}{dp}$", fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col)
        plt.show()
    return cut, ams, sort.iloc[0]["ams"] 
Example #22
Source File: metric_logger.py    From lumin with Apache License 2.0 4 votes vote down vote up
def update_plot(self, best:Optional[float]=None) -> None:
        r'''
        Updates the plot(s), Optionally showing the user-chose best loss achieved.

        Arguments:
            best: the value of the best loss achieved so far
        '''

        # Loss
        self.loss_ax.clear()
        with sns.axes_style(**self.settings.style), sns.color_palette(self.settings.cat_palette):
            for v,m in zip(self.loss_vals,self.loss_names): self.loss_ax.plot(self.subepochs[1:], v, label=m)
        if best is not None: self.loss_ax.plot(self.subepochs[1:], np.ones_like(self.subepochs[1:])*best, label=f'Best = {best:.3E}', linestyle='--')
        if self.log:
            self.loss_ax.set_yscale('log', nonposy='clip')
            self.loss_ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col, which='both')
        self.loss_ax.grid(True, which="both")
        self.loss_ax.legend(loc='upper right', fontsize=0.8*self.settings.leg_sz)
        self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
        self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)

        if self.extra_detail:
            # Velocity
            self.vel_ax.clear()
            self.vel_ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col, which='both')
            self.vel_ax.grid(True, which="both")
            with sns.color_palette(self.settings.cat_palette):
                for v,m in zip(self.vel_vals,self.loss_names): self.vel_ax.plot(self.epochs[1:], v, label=f'{m} {v[-1]:.2E}')
            self.vel_ax.legend(loc='lower right', fontsize=0.8*self.settings.leg_sz)
            self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)

            # Generalisation
            self.gen_ax.clear()
            self.gen_ax.grid(True, which="both")
            with sns.color_palette(self.settings.cat_palette) as palette:
                for i, (v,m) in enumerate(zip(self.gen_vals,self.loss_names[1:])):
                    self.gen_ax.plot(self.epochs[1:], v, label=f'{m} {v[-1]:.2f}', color=palette[i+1])
            self.gen_ax.legend(loc='upper left', fontsize=0.8*self.settings.leg_sz)
            self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
            self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col)
            if len(self.epochs) > 5:
                self.epochs = self.epochs[1:]
                for i in range(len(self.vel_vals)): self.vel_vals[i] = self.vel_vals[i][1:]
                for i in range(len(self.gen_vals)):  self.gen_vals[i]  = self.gen_vals[i][1:]
            
            self.display.update(self.fig)
        else:
            self.display.update(self.loss_ax.figure) 
Example #23
Source File: data_viewing.py    From lumin with Apache License 2.0 4 votes vote down vote up
def plot_kdes_from_bs(x:np.ndarray, bs_stats:Dict[str,Any], name2args:Dict[str,Dict[str,Any]], 
                      feat:str, units:Optional[str]=None, moments=True,
                      savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None:
    r'''
    Plots KDEs computed via :meth:`~lumin.utils.statistics.bootstrap_stats`

    Arguments:
        bs_stats: (filtered) dictionary retruned by :meth:`~lumin.utils.statistics.bootstrap_stats`
        name2args: Dictionary mapping names of different distributions to arguments to pass to seaborn tsplot
        feat: Name of feature being plotted (for axis lablels)
        units: Optional units to show on axes
        moments: whether to display mean and standard deviation of each distribution
        savename: Optional name of file to which to save the plot of feature importances
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
    '''

    # TODO: update to sns 9

    with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette:
        plt.figure(figsize=(settings.w_mid, settings.h_mid))
        for i, name in enumerate(name2args):
            if 'color' not in name2args[name]: name2args[name]['color'] = palette[i]
            if 'label' in name2args[name]:
                name2args[name]['condition'] = name2args[name]['label']
                name2args[name].pop('label')
            if 'condition' in name2args[name] and moments:
                mean, mean_unc = uncert_round(np.mean(bs_stats[f'{name}_mean']), np.std(bs_stats[f'{name}_mean'], ddof=1))
                std, std_unc = uncert_round(np.mean(bs_stats[f'{name}_std']), np.std(bs_stats[f'{name}_std'], ddof=1))
                name2args[name]['condition'] += r', $\overline{x}=' + r'{}\pm{}\ \sigma= {}\pm{}$'.format(mean, mean_unc, std, std_unc)
            sns.tsplot(data=bs_stats[f'{name}_kde'], time=x, **name2args[name])

        plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz)
        y_lbl = r'$\frac{1}{N}\ \frac{dN}{d' + feat.replace('$','') + r'}$'
        if units is not None:
            x_lbl = feat + r'$\ [' + units + r']$'
            y_lbl += r'$\ [' + units + r'^{-1}]$'
        else:
            x_lbl = feat
        plt.xlabel(x_lbl, fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.ylabel(y_lbl, fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc)
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight')
        plt.show() 
Example #24
Source File: data_viewing.py    From lumin with Apache License 2.0 4 votes vote down vote up
def compare_events(events:list) -> None:
    r'''
    Plots at least two events side by side in their transverse and longitudinal projections

    Arguments:
        events: list of DataFrames containing vector coordinates for 3 momenta
    '''

    # TODO: check typing, why list?
    # TODO: make this work with a single event
    # TODO: add plot settings & saving

    with sns.axes_style('whitegrid'), sns.color_palette('tab10'):
        fig, axs = plt.subplots(3, len(events), figsize=(9*len(events), 18), gridspec_kw={'height_ratios': [1, 0.5, 0.5]})
        for vector in [x[:-3] for x in events[0].columns if '_px' in x.lower()]:
            for i, in_data in enumerate(events):
                x = in_data[vector + '_px'].values[0]
                try: y = in_data[vector + '_py'].values[0]
                except KeyError: y = 0
                try: z = in_data[vector + '_pz'].values[0]
                except KeyError: z = 0
                axs[0, i].plot((0, x), (0, y), label=vector)
                axs[1, i].plot((0, z), (0, x), label=vector)
                axs[2, i].plot((0, z), (0, y), label=vector)
        for ax in axs[0]:
            ax.add_artist(plt.Circle((0, 0), 1, color='grey', fill=False, linewidth=2))
            ax.set_xlim(-1.1, 1.1)
            ax.set_ylim(-1.1, 1.1)
            ax.set_xlabel(r"$p_x$", fontsize=16, color='black')
            ax.set_ylabel(r"$p_y$", fontsize=16, color='black')
            ax.legend(loc='right', fontsize=12)  
        for ax in axs[1]:
            ax.add_artist(plt.Rectangle((-2, -1), 4, 2, color='grey', fill=False, linewidth=2))
            ax.set_xlim(-2.2, 2.2)
            ax.set_ylim(-1.1, 1.1)
            ax.set_xlabel(r"$p_z$", fontsize=16, color='black')
            ax.set_ylabel(r"$p_x$", fontsize=16, color='black')
            ax.legend(loc='right', fontsize=12)
        for ax in axs[2]: 
            ax.add_artist(plt.Rectangle((-2, -1), 4, 2, color='grey', fill=False, linewidth=2))
            ax.set_xlim(-2.2, 2.2)
            ax.set_ylim(-1.1, 1.1)
            ax.set_xlabel(r"$p_z$", fontsize=16, color='black')
            ax.set_ylabel(r"$p_y$", fontsize=16, color='black')
            ax.legend(loc='right', fontsize=12)
        fig.show() 
Example #25
Source File: interpretation.py    From lumin with Apache License 2.0 4 votes vote down vote up
def plot_multibody_weighted_outputs(model:AbsModel, inputs:Union[np.ndarray,Tensor], block_names:Optional[List[str]]=None, use_mean:bool=False,
                                    savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None:
    r'''
    Interpret how a model relies on the outputs of each block in a :class:MultiBlock by plotting the outputs of each block as weighted by the tail block.
    This function currently only supports models whose tail block contains a single neuron in the first dense layer.
    Input data is passed through the model and the absolute sums of the weighted block outputs are computed per datum, and optionally averaged over the number
    of block outputs.

    Arguments:
        model: model to interpret
        inputs: input data to use for interpretation
        block_names: names for each block to use when plotting
        use_mean: if True, will average the weighted outputs over the number of output neurons in each block
        savename: Optional name of file to which to save the plot of feature importances
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
    '''

    assert model.tail[0].weight.shape[0] == 1, 'This function currently only supports models whose tail block contains a single neuron in the first dense layer'
    if block_names is not None:
        assert len(block_names) == len(model.body.blocks), 'block_names passed, but number of names does not match number of blocks'
    else:
        block_names = [f'{i}' for i in range(len(model.body.blocks))]
    
    hook = FowardHook(model.tail[0])
    model.predict(inputs)
    
    y, itr = [], 0
    for b in model.body.blocks:
        o = hook.input[0][:,itr:itr+b.get_out_size()]
        w = model.tail[0].weight[0][itr:itr+b.get_out_size()]
        y.append(to_np(torch.abs(o@w)/b.get_out_size()) if use_mean else to_np(torch.abs(o@w)))
        itr += b.get_out_size()
    
    with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette):
        plt.figure(figsize=(settings.w_mid, settings.h_mid))
        sns.boxplot(x=block_names, y=y)
        plt.xlabel("Block", fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.ylabel(r"Mean $|\bar{w}\cdot\bar{x}|$" if use_mean else r"$|\bar{w}\cdot\bar{x}|$", fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc)
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight')
        plt.show() 
Example #26
Source File: interpretation.py    From lumin with Apache License 2.0 4 votes vote down vote up
def plot_2d_partial_dependence(model:Any, df:pd.DataFrame, feats:Tuple[str,str], train_feats:List[str], ignore_feats:Optional[List[str]]=None,
                               input_pipe:Pipeline=None, sample_sz:Optional[int]=None, wgt_name:Optional[str]=None, n_points:Tuple[int,int]=[20,20],
                               pdp_interact_kargs:Optional[Dict[str,Any]]=None, pdp_interact_plot_kargs:Optional[Dict[str,Any]]=None,
                               savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None:
    r'''
    Wrapper for PDPbox to plot 2D dependence of specified pair of features using provided NN or RF.
    If features have been preprocessed using an SK-Learn Pipeline, then that can be passed in order to rescale them back to their original values.

    Arguments:
        model: any trained model with a .predict method
        df: DataFrame containing training data
        feats: pair of features for which to evaluate the partial dependence of the model
        train_feats: list of all training features including ones which were later ignored, i.e. input features considered when input_pipe was fitted
        ignore_feats: features present in training data which were not used to train the model (necessary to correctly deprocess feature using input_pipe)
        input_pipe: SK-Learn Pipeline which was used to process the training data
        sample_sz: if set, will only compute partial dependence on a random sample with replacement of the training data, sampled according to weights (if set).
            Speeds up computation and allows weighted partial dependencies to computed.
        wgt_name: Optional column name to use as sampling weights
        n_points: pair of numbers of points at which to evaluate the model output, passed to pdp_interact as num_grid_points
        n_clusters: number of clusters in which to group dependency lines. Set to None to show all lines
        pdp_isolate_kargs: optional dictionary of keyword arguments to pass to pdp_isolate
        pdp_plot_kargs: optional dictionary of keyword arguments to pass to pdp_plot
        savename: Optional name of file to which to save the plot of feature importances
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
    '''    
    
    check_pdpbox()
    if pdp_interact_kargs      is None: pdp_interact_kargs      = {}
    if pdp_interact_plot_kargs is None: pdp_interact_plot_kargs = {}

    if sample_sz is not None or wgt_name is not None:
        if wgt_name is None:
            weights = None
        else:
            weights = df[wgt_name].values.astype('float64')
            weights *= 1/np.sum(weights)
        df = df.sample(len(df) if sample_sz is None else sample_sz, weights=weights, replace=True)

    interact = pdp.pdp_interact(model, df, [f for f in train_feats if ignore_feats is None or f not in ignore_feats], feats, num_grid_points=n_points,
                                **pdp_interact_kargs)
    if input_pipe is not None: _deprocess_interact(interact, input_pipe, feats, train_feats)
            
    with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette):
        fig, ax = pdp.pdp_interact_plot(interact, feats, figsize=(settings.h_large, settings.h_large),
                                        plot_params={'title': None, 'subtitle': None, 'cmap':settings.seq_palette}, **pdp_interact_plot_kargs)
        ax['title_ax'].remove()
        ax['pdp_inter_ax'].set_xlabel(feats[0], fontsize=settings.lbl_sz, color=settings.lbl_col)
        ax['pdp_inter_ax'].set_ylabel(feats[1], fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc)
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}')
        plt.show() 
Example #27
Source File: interpretation.py    From lumin with Apache License 2.0 4 votes vote down vote up
def plot_1d_partial_dependence(model:Any, df:pd.DataFrame, feat:str, train_feats:List[str], ignore_feats:Optional[List[str]]=None, input_pipe:Pipeline=None, 
                               sample_sz:Optional[int]=None, wgt_name:Optional[str]=None,  n_clusters:Optional[int]=10, n_points:int=20,
                               pdp_isolate_kargs:Optional[Dict[str,Any]]=None, pdp_plot_kargs:Optional[Dict[str,Any]]=None,
                               y_lim:Optional[Union[Tuple[float,float],List[float]]]=None, 
                               savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None:
    r'''
    Wrapper for PDPbox to plot 1D dependence of specified feature using provided NN or RF.
    If features have been preprocessed using an SK-Learn Pipeline, then that can be passed in order to rescale the x-axis back to its original values.

    Arguments:
        model: any trained model with a .predict method
        df: DataFrame containing training data
        feat: feature for which to evaluate the partial dependence of the model
        train_feats: list of all training features including ones which were later ignored, i.e. input features considered when input_pipe was fitted
        ignore_feats: features present in training data which were not used to train the model (necessary to correctly deprocess feature using input_pipe)
        input_pipe: SK-Learn Pipeline which was used to process the training data
        sample_sz: if set, will only compute partial dependence on a random sample with replacement of the training data, sampled according to weights (if set).
            Speeds up computation and allows weighted partial dependencies to computed.
        wgt_name: Optional column name to use as sampling weights
        n_points: number of points at which to evaluate the model output, passed to pdp_isolate as num_grid_points
        n_clusters: number of clusters in which to group dependency lines. Set to None to show all lines
        pdp_isolate_kargs: optional dictionary of keyword arguments to pass to pdp_isolate
        pdp_plot_kargs: optional dictionary of keyword arguments to pass to pdp_plot
        y_lim: If set, will limit y-axis plot range to tuple
        savename: Optional name of file to which to save the plot of feature importances
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
    '''

    if pdp_isolate_kargs is None: pdp_isolate_kargs = {}
    if pdp_plot_kargs    is None: pdp_plot_kargs    = {}

    if sample_sz is not None or wgt_name is not None:
        if wgt_name is None:
            weights = None
        else:
            weights = df[wgt_name].values.astype('float64')
            weights *= 1/np.sum(weights)
        df = df.sample(len(df) if sample_sz is None else sample_sz, weights=weights, replace=True)

    iso = pdp.pdp_isolate(model, df, [f for f in train_feats if ignore_feats is None or f not in ignore_feats], feat, num_grid_points=n_points,
                          **pdp_isolate_kargs)
    if input_pipe is not None: _deprocess_iso(iso, input_pipe, feat, train_feats)

    with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette):
        fig, ax = pdp.pdp_plot(iso, feat, center=False,  plot_lines=True, cluster=n_clusters is not None, n_cluster_centers=n_clusters,
                               plot_params={'title': None, 'subtitle': None}, figsize=(settings.w_mid, settings.h_mid), **pdp_plot_kargs)
        ax['title_ax'].remove()
        ax['pdp_ax'].set_xlabel(feat, fontsize=settings.lbl_sz, color=settings.lbl_col)
        ax['pdp_ax'].set_ylabel("Partial dependence", fontsize=settings.lbl_sz, color=settings.lbl_col)
        if y_lim is not None: ax['pdp_ax'].set_ylim(y_lim)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc)
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}')
        plt.show() 
Example #28
Source File: general.py    From pyfinance with MIT License 4 votes vote down vote up
def corr_heatmap(
    x,
    mask_half=True,
    cmap="RdYlGn_r",
    vmin=-1,
    vmax=1,
    linewidths=0.5,
    square=True,
    figsize=(10, 10),
    **kwargs
):
    """Wrapper around seaborn.heatmap for visualizing correlation matrix.

    Parameters
    ----------
    x : DataFrame
        Underlying data (not a correlation matrix)
    mask_half : bool, default True
        If True, mask (whiteout) the upper right triangle of the matrix
    All other parameters passed to seaborn.heatmap:
    https://seaborn.pydata.org/generated/seaborn.heatmap.html

    Example
    -------
    # Generate some correlated data
    >>> import numpy as np
    >>> import pandas as pd
    >>> k = 10
    >>> size = 400
    >>> mu = np.random.randint(0, 10, k).astype(float)
    >>> r = np.random.ranf(k ** 2).reshape((k, k)) * 5
    >>> df = pd.DataFrame(np.random.multivariate_normal(mu, r, size=size))
    >>> corr_heatmap(df, figsize=(6, 6))
    """

    if mask_half:
        mask = np.zeros_like(x.corr().values)
        mask[np.triu_indices_from(mask)] = True
    else:
        mask = None
    with sns.axes_style("white"):
        return sns.heatmap(
            x.corr(),
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
            linewidths=linewidths,
            square=square,
            mask=mask,
            **kwargs
        ) 
Example #29
Source File: representation_plot.py    From srl-zoo with MIT License 4 votes vote down vote up
def prettyPlotAgainst(states, rewards, title="Representation", fit_pca=False, cmap='coolwarm'):
    """
    State dimensions are plotted one against the other (it creates a matrix of 2d representation)
    using rewards for coloring, the diagonal is a distribution plot, and the scatter plots have a density outline.
    :param states: (np.ndarray)
    :param rewards: (np.ndarray)
    :param title: (str)
    :param fit_pca: (bool)
    :param cmap: (str)
    """
    with sns.axes_style('white'):
        n = states.shape[1]
        fig, ax_mat = plt.subplots(n, n, figsize=(10, 10), sharex=False, sharey=False)
        fig.subplots_adjust(hspace=0.2, wspace=0.2)

        if fit_pca:
            title += " (PCA)"
            states = PCA(n_components=n).fit_transform(states)

        c_idx = cm.get_cmap(cmap)
        norm = colors.Normalize(vmin=np.min(rewards), vmax=np.max(rewards))

        for i in range(n):
            for j in range(n):
                x, y = states[:, i], states[:, j]
                ax = ax_mat[i, j]
                if i != j:
                    ax.scatter(x, y, c=rewards, cmap=cmap, s=5)
                    sns.kdeplot(x, y, cmap="Greys", ax=ax, shade=True, shade_lowest=False, alpha=0.2)
                    ax.set_xlim([np.min(x), np.max(x)])
                    ax.set_ylim([np.min(y), np.max(y)])
                else:
                    if len(np.unique(rewards)) < 10:
                        for r in np.unique(rewards):
                            sns.distplot(x[rewards == r], color=c_idx(norm(r)), ax=ax)
                    else:
                        sns.distplot(x, ax=ax)

                if i == 0:
                    ax.set_title("Dim {}".format(j), y=1.2)
                if i != j:
                    # Hide ticks
                    if i != 0 and i != n - 1:
                        ax.xaxis.set_visible(False)
                    if j != 0 and j != n - 1:
                        ax.yaxis.set_visible(False)

                    # Set up ticks only on one side for the "edge" subplots...
                    if j == 0:
                        ax.yaxis.set_ticks_position('left')
                    if j == n - 1:
                        ax.yaxis.set_ticks_position('right')
                    if i == 0:
                        ax.xaxis.set_ticks_position('top')
                    if i == n - 1:
                        ax.xaxis.set_ticks_position('bottom')

        plt.suptitle(title, fontsize=16)
        plt.show() 
Example #30
Source File: plots.py    From Comparative-Annotation-Toolkit with Apache License 2.0 4 votes vote down vote up
def improvement_plot(consensus_data, ordered_genomes, improvement_tgt):
    def do_kdeplot(x, y, ax, n_levels=None, bw='scott'):
        try:
            sns.kdeplot(x, y, ax=ax, cut=0, cmap='Purples_d', shade=True, shade_lowest=False, n_levels=n_levels, bw=bw,
                        rasterized=True)
        except:
            logger.warning('Unable to do a KDE fit to AUGUSTUS improvement.')
            pass

    af = luigi.local_target.atomic_file(improvement_tgt.path)
    with PdfPages(af.tmp_path) as pdf, sns.axes_style("whitegrid"):
        for genome in ordered_genomes:
            data = pd.DataFrame(consensus_data[genome]['Evaluation Improvement']['changes'])
            unchanged = consensus_data[genome]['Evaluation Improvement']['unchanged']
            if len(data) == 0:
                continue
            data.columns = ['transMap original introns',
                            'transMap intron annotation support',
                            'transMap intron RNA support',
                            'Original introns',
                            'Intron annotation support',
                            'Intron RNA support',
                            'transMap alignment goodness',
                            'Alignment goodness']
            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2)
            for ax in [ax1, ax2, ax3, ax4]: 
                ax.set_xlim(0, 100)
                ax.set_ylim(0, 100)
            
            do_kdeplot(data['transMap original introns'], data['Original introns'], ax1, n_levels=25, bw=2)
            sns.regplot(x=data['transMap original introns'], y=data['Original introns'], ax=ax1,
                        color='#A9B36F', scatter_kws={"s": 3, 'alpha': 0.7, 'rasterized': True}, fit_reg=False)
            do_kdeplot(data['transMap intron annotation support'], data['Intron annotation support'], ax2,
                       n_levels=25, bw=2)
            sns.regplot(x=data['transMap intron annotation support'], y=data['Intron annotation support'], ax=ax2,
                        color='#A9B36F', scatter_kws={"s": 3, 'alpha': 0.7, 'rasterized': True}, fit_reg=False)          
            do_kdeplot(data['transMap intron RNA support'], data['Intron RNA support'], ax3, n_levels=25, bw=2)
            sns.regplot(x=data['transMap intron RNA support'], y=data['Intron RNA support'], ax=ax3,
                        color='#A9B36F', scatter_kws={"s": 3, 'alpha': 0.7, 'rasterized': True}, fit_reg=False)
            
            do_kdeplot(data['transMap alignment goodness'], data['Alignment goodness'], ax4, n_levels=20, bw=1)
            sns.regplot(x=data['transMap alignment goodness'], y=data['Alignment goodness'], ax=ax4,
                        color='#A9B36F', scatter_kws={"s": 3, 'alpha': 0.7, 'rasterized': True}, fit_reg=False)

            fig.suptitle('AUGUSTUS metric improvements for {:,} transcripts in {}.\n'
                         '{:,} transMap transcripts were chosen.'.format(len(data), genome, unchanged))
            
            for ax in [ax1, ax2, ax3, ax4]:
                ax.set(adjustable='box', aspect='equal')
            fig.subplots_adjust(hspace=0.3)
            multipage_close(pdf, tight_layout=False)
    af.move_to_final_destination()