Python seaborn.heatmap() Examples

The following are 30 code examples of seaborn.heatmap(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module seaborn , or try the search function .
Example #1
Source File: time_align.py    From scanorama with MIT License 7 votes vote down vote up
def time_align_visualize(alignments, time, y, namespace='time_align'):
    plt.figure()
    heat = np.flip(alignments + alignments.T +
                   np.eye(alignments.shape[0]), axis=0)
    sns.heatmap(heat, cmap="YlGnBu", vmin=0, vmax=1)
    plt.savefig(namespace + '_heatmap.svg')

    G = nx.from_numpy_matrix(alignments)
    G = nx.maximum_spanning_tree(G)

    pos = {}
    for i in range(len(G.nodes)):
        pos[i] = np.array([time[i], y[i]])

    mst_edges = set(nx.maximum_spanning_tree(G).edges())
    
    weights = [ G[u][v]['weight'] if (not (u, v) in mst_edges) else 8
                for u, v in G.edges() ]
    
    plt.figure()
    nx.draw(G, pos, edges=G.edges(), width=10)
    plt.ylim([-1, 1])
    plt.savefig(namespace + '.svg') 
Example #2
Source File: heatmap.py    From Attention-on-Attention-for-VQA with MIT License 6 votes vote down vote up
def plot_heatmap(a, b, title='title', saveLoc='temp'):

    a = a.reshape((6,6))
    b = b.reshape((6,6))

    fig, (ax1, ax2) = plt.subplots(1, 2)

    h1 = sns.heatmap(a,cmap="magma",cbar=False,ax=ax1)
    h1.set_title("Attention 1")
    h1.invert_yaxis()
    h1.set_xlabel('')
    h1.set_ylabel('')

    h2 = sns.heatmap(b,cmap="magma",ax=ax2)
    h2.set_title("Attention 1")
    h2.invert_yaxis()
    h2.set_xlabel('')
    h2.set_ylabel('')

    plt.show() 
Example #3
Source File: metrics.py    From axcell with Apache License 2.0 6 votes vote down vote up
def plot_confusion_matrix(self, name):
        cm, target_names = self.confusion_matrix(name)
        # cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        df_cm = pd.DataFrame(cm, index=[i for i in target_names],
                             columns=[i for i in target_names])
        plt.figure(figsize=(20, 20))
        ax = sn.heatmap(df_cm,
                        annot=True,
                        square=True,
                        fmt="d",
                        cmap="YlGnBu",
                        mask=cm == 0,
                        linecolor="black",
                        linewidths=0.01)
        ax.set_ylabel("True")
        ax.set_xlabel("Predicted") 
Example #4
Source File: design_matrix.py    From nltools with MIT License 6 votes vote down vote up
def heatmap(self, figsize=(8, 6), **kwargs):
        """Visualize Design Matrix spm style. Use .plot() for typical pandas
            plotting functionality. Can pass optional keyword args to seaborn
            heatmap.

        """
        cmap = kwargs.pop('cmap', 'gray')
        fig, ax = plt.subplots(1, figsize=figsize)
        ax = sns.heatmap(self, cmap=cmap, cbar=False, ax=ax, **kwargs)
        for _, spine in ax.spines.items():
            spine.set_visible(True)
        for i, label in enumerate(ax.get_yticklabels()):
            if i in [0, self.shape[0] - 1]:
                label.set_visible(True)
            else:
                label.set_visible(False)
        ax.axhline(linewidth=4, color="k")
        ax.axvline(linewidth=4, color="k")
        ax.axhline(y=self.shape[0], color='k', linewidth=4)
        ax.axvline(x=self.shape[1], color='k', linewidth=4)
        plt.yticks(rotation=0) 
Example #5
Source File: visFunction.py    From uiKLine with MIT License 6 votes vote down vote up
def plotSigHeats(signals,markets,start=0,step=2,size=1,iters=6):
    """
    打印信号回测盈损热度图,寻找参数稳定岛
    """
    sigMat = pd.DataFrame(index=range(iters),columns=range(iters))
    for i in range(iters):
        for j in range(iters):
            climit = start + i*step
            wlimit = start + j*step
            caps,poss = plotSigCaps(signals,markets,climit=climit,wlimit=wlimit,size=size,op=False)
            sigMat[i][j] = caps[-1]
    sns.heatmap(sigMat.values.astype(np.float64),annot=True,fmt='.2f',annot_kws={"weight": "bold"})
    xTicks   = [i+0.5 for i in range(iters)]
    yTicks   = [iters-i-0.5 for i in range(iters)]
    xyLabels = [str(start+i*step) for i in range(iters)]
    _, labels = plt.yticks(yTicks,xyLabels)
    plt.setp(labels, rotation=0)
    _, labels = plt.xticks(xTicks,xyLabels)
    plt.setp(labels, rotation=90)
    plt.xlabel('Loss Stop @')
    plt.ylabel('Profit Stop @')
    return sigMat 
Example #6
Source File: common.py    From typhon with MIT License 6 votes vote down vote up
def _plot_weights(self, title, file, layer_index=0, vmin=-5, vmax=5):
        import seaborn as sns
        sns.set_context("paper")

        layers = self.iwp.estimator.steps[-1][1].coefs_
        layer = layers[layer_index]
        f, ax = plt.subplots(figsize=(18, 12))
        weights = pd.DataFrame(layer)
        weights.index = self.iwp.inputs

        sns.set(font_scale=1.1)

        # Draw a heatmap with the numeric values in each cell
        sns.heatmap(
            weights, annot=True, fmt=".1f", linewidths=.5, ax=ax,
            cmap="difference", center=0, vmin=vmin, vmax=vmax,
            # annot_kws={"size":14},
        )
        ax.tick_params(labelsize=18)
        f.tight_layout()
        f.savefig(file) 
Example #7
Source File: visFunction.py    From uiKLine with MIT License 6 votes vote down vote up
def plotSigHeats(signals,markets,start=0,step=2,size=1,iters=6):
    """
    打印信号回测盈损热度图,寻找参数稳定岛
    """
    sigMat = pd.DataFrame(index=range(iters),columns=range(iters))
    for i in range(iters):
        for j in range(iters):
            climit = start + i*step
            wlimit = start + j*step
            caps,poss = plotSigCaps(signals,markets,climit=climit,wlimit=wlimit,size=size,op=False)
            sigMat[i][j] = caps[-1]
    sns.heatmap(sigMat.values.astype(np.float64),annot=True,fmt='.2f',annot_kws={"weight": "bold"})
    xTicks   = [i+0.5 for i in range(iters)]
    yTicks   = [iters-i-0.5 for i in range(iters)]
    xyLabels = [str(start+i*step) for i in range(iters)]
    _, labels = plt.yticks(yTicks,xyLabels)
    plt.setp(labels, rotation=0)
    _, labels = plt.xticks(xTicks,xyLabels)
    plt.setp(labels, rotation=90)
    plt.xlabel('Loss Stop @')
    plt.ylabel('Profit Stop @')
    return sigMat 
Example #8
Source File: plots.py    From cgpm with Apache License 2.0 6 votes vote down vote up
def plot_heatmap(
        D, xordering=None, yordering=None, xticklabels=None,
        yticklabels=None, vmin=None, vmax=None, ax=None):
    import seaborn as sns
    D = np.copy(D)

    if ax is None:
        _, ax = plt.subplots()
    if xticklabels is None:
        xticklabels = np.arange(D.shape[0])
    if yticklabels is None:
        yticklabels = np.arange(D.shape[1])
    if xordering is not None:
        xticklabels = xticklabels[xordering]
        D = D[:,xordering]
    if yordering is not None:
        yticklabels = yticklabels[yordering]
        D = D[yordering,:]

    sns.heatmap(
        D, yticklabels=yticklabels, xticklabels=xticklabels,
        linewidths=0.2, cmap='BuGn', ax=ax, vmin=vmin, vmax=vmax)
    ax.set_xticklabels(xticklabels, rotation=90)
    ax.set_yticklabels(yticklabels, rotation=0)
    return ax 
Example #9
Source File: test_ext_signature.py    From feets with MIT License 6 votes vote down vote up
def test_plot_SignaturePhMag(fig_test, fig_ref):

    # fig test
    ext = extractors.Signature()
    kwargs = ext.get_default_params()
    kwargs.update(
        feature="SignaturePhMag",
        value=[[1, 2, 3, 4]],
        ax=fig_test.subplots(),
        plot_kws={},
        time=[1, 2, 3, 4],
        magnitude=[1, 2, 3, 4],
        error=[1, 2, 3, 4],
        features={"PeriodLS": 1, "Amplitude": 10},
    )
    ext.plot(**kwargs)

    # expected
    eax = fig_ref.subplots()
    eax.set_title(
        f"SignaturePhMag - {kwargs['phase_bins']}x{kwargs['mag_bins']}"
    )
    eax.set_xlabel("Phase")
    eax.set_ylabel("Magnitude")
    sns.heatmap(kwargs["value"], ax=eax, **kwargs["plot_kws"]) 
Example #10
Source File: spatial_heatmap.py    From NanoPlot with GNU General Public License v3.0 6 votes vote down vote up
def spatial_heatmap(array, path, title=None, color="Greens", figformat="png"):
    """Taking channel information and creating post run channel activity plots."""
    logging.info("Nanoplotter: Creating heatmap of reads per channel using {} reads."
                 .format(array.size))
    activity_map = Plot(
        path=path + "." + figformat,
        title="Number of reads generated per channel")
    layout = make_layout(maxval=np.amax(array))
    valueCounts = pd.value_counts(pd.Series(array))
    for entry in valueCounts.keys():
        layout.template[np.where(layout.structure == entry)] = valueCounts[entry]
    plt.figure()
    ax = sns.heatmap(
        data=pd.DataFrame(layout.template, index=layout.yticks, columns=layout.xticks),
        xticklabels="auto",
        yticklabels="auto",
        square=True,
        cbar_kws={"orientation": "horizontal"},
        cmap=color,
        linewidths=0.20)
    ax.set_title(title or activity_map.title)
    activity_map.fig = ax.get_figure()
    activity_map.save(format=figformat)
    plt.close("all")
    return [activity_map] 
Example #11
Source File: base_backend.py    From delira with GNU Affero General Public License v3.0 6 votes vote down vote up
def _heatmap(self, plot_kwargs=None, figure_kwargs=None, **kwargs):
        """
        Function to create a heatmap plot and push it

        Parameters
        ----------
        plot_kwargs : dict
            the arguments for plotting
        figure_kwargs : dict
            the arguments to actually create the figure
        **kwargs :
            additional keyword arguments for pushing the created figure to the
            logging writer

        """
        if figure_kwargs is None:
            figure_kwargs = {}
        if plot_kwargs is None:
            plot_kwargs = {}
        with self.FigureManager(self._figure, figure_kwargs, kwargs):
            from seaborn import heatmap
            heatmap(**plot_kwargs) 
Example #12
Source File: heat_map.py    From NAS-Benchmark with GNU General Public License v3.0 6 votes vote down vote up
def draw(self):
        f, ax1= plt.subplots(figsize=(15, 9))
        sns.heatmap(self.df1, annot=True, ax=ax1,
                    annot_kws={'size': 13, 'weight': 'bold'})
        ax1.set_xlabel('Ops without none operation', labelpad=14, fontsize='medium')
        ax1.set_ylabel('Possiable Input Index', labelpad=14, fontsize='medium')
        # ax1.set_title('The weights for Ops without none operation in normal cell', pad = 18, fontsize='x-large')

        # f, ax2= plt.subplots(figsize=(15, 9))
        # sns.heatmap(self.df2, annot=True, ax=ax2,
        #             annot_kws={'size': 13, 'weight': 'bold'})
        # ax2.set_xlabel('Ops without none operation', labelpad=14, fontsize='medium')
        # ax2.set_ylabel('Possible predecessors id for each intermediate node', labelpad=14, fontsize='medium')
        # #ax2.set_title('The weights for Ops without none operation in reduction cell', pad = 18, fontsize='x-large')
        plt.savefig(self.store_path+'/normal_hm.pdf', bbox_inches = 'tight', dpi=600)
        # plt.show() 
Example #13
Source File: experiment.py    From axcell with Apache License 2.0 6 votes vote down vote up
def _plot_confusion_matrix(self, cm, normalize, fmt=None):
        if normalize:
            s = cm.sum(axis=1)[:, None]
            s[s == 0] = 1
            cm = cm / s
        if fmt is None:
            fmt = "0.2f" if normalize else "d"

        target_names = self.get_cm_labels(cm)
        df_cm = pd.DataFrame(cm, index=[i for i in target_names],
                             columns=[i for i in target_names])
        plt.figure(figsize=(10, 10))
        ax = sn.heatmap(df_cm,
                        annot=True,
                        square=True,
                        fmt=fmt,
                        cmap="YlGnBu",
                        mask=cm == 0,
                        linecolor="black",
                        linewidths=0.01)
        ax.set_ylabel("True")
        ax.set_xlabel("Predicted") 
Example #14
Source File: basenji_sat_h5.py    From basenji with Apache License 2.0 6 votes vote down vote up
def plot_heat(ax, sat_delta_ti, min_limit):
  """ Plot satmut deltas.

    Args:
        ax (Axis): matplotlib axis to plot to.
        sat_delta_ti (4 x L_sm array): Single target delta matrix for saturated mutagenesis region,
        min_limit (float): Minimum heatmap limit.
    """
  vlim = max(min_limit, abs(sat_delta_ti).max())
  sns.heatmap(
      sat_delta_ti,
      linewidths=0,
      cmap='RdBu_r',
      vmin=-vlim,
      vmax=vlim,
      xticklabels=False,
      ax=ax)
  ax.yaxis.set_ticklabels('ACGT', rotation='horizontal')  # , size=10) 
Example #15
Source File: basenji_motifs_denovo.py    From basenji with Apache License 2.0 6 votes vote down vote up
def plot_kernel(kernel_weights, out_pdf):
    depth, width = kernel_weights.shape
    fig_width = 2 + 1.5*np.log2(width)

    # normalize
    kernel_weights -= kernel_weights.mean(axis=0)

    # plot
    sns.set(font_scale=1.5)
    plt.figure(figsize=(fig_width, depth))
    sns.heatmap(kernel_weights, cmap='PRGn', linewidths=0.2, center=0)
    ax = plt.gca()
    ax.set_xticklabels(range(1,width+1))

    if depth == 4:
        ax.set_yticklabels('ACGT', rotation='horizontal')
    else:
        ax.set_yticklabels(range(1,depth+1), rotation='horizontal')

    plt.savefig(out_pdf)
    plt.close() 
Example #16
Source File: functions.py    From Match-LSTM with MIT License 6 votes vote down vote up
def draw_heatmap_sea(x, xlabels, ylabels, answer, save_path, inches=(11, 3), bottom=0.45, linewidths=0.2):
    """
    draw matrix heatmap with seaborn
    :param x:
    :param xlabels:
    :param ylabels:
    :param answer:
    :param save_path:
    :param inches:
    :param bottom:
    :param linewidths:
    :return:
    """
    fig, ax = plt.subplots()
    plt.subplots_adjust(bottom=bottom)
    plt.title('Answer: ' + answer)
    sns.heatmap(x, linewidths=linewidths, ax=ax, cmap='Blues', xticklabels=xlabels, yticklabels=ylabels)
    fig.set_size_inches(inches)
    fig.savefig(save_path) 
Example #17
Source File: QARisk.py    From QUANTAXIS with MIT License 6 votes vote down vote up
def plot_signal(self, start=None, end=None):
        """
        使用热力图画出买卖信号
        """
        start = self.account.start_date if start is None else start
        end = self.account.end_date if end is None else end
        _, ax = plt.subplots(figsize=(20, 18))
        sns.heatmap(self.account.trade.reset_index().drop('account_cookie',
                axis=1).set_index('datetime').loc[start:end],
            cmap="YlGnBu",
            linewidths=0.05,
            ax=ax)
        ax.set_title('SIGNAL TABLE --ACCOUNT: {}'.format(self.account.account_cookie))
        ax.set_xlabel('Code')
        ax.set_ylabel('DATETIME')
        return plt 
Example #18
Source File: QARisk.py    From QUANTAXIS with MIT License 6 votes vote down vote up
def plot_dailyhold(self, start=None, end=None):
        """
        使用热力图画出每日持仓
        """
        start = self.account.start_date if start is None else start
        end = self.account.end_date if end is None else end
        _, ax = plt.subplots(figsize=(20, 8))
        sns.heatmap(self.account.daily_hold.reset_index().set_index('date').loc[start:end],
            cmap="YlGnBu",
            linewidths=0.05,
            ax=ax)
        ax.set_title('HOLD TABLE --ACCOUNT: {}'.format(self.account.account_cookie))
        ax.set_xlabel('Code')
        ax.set_ylabel('DATETIME')

        return plt 
Example #19
Source File: basenji_sat_plot.py    From basenji with Apache License 2.0 6 votes vote down vote up
def plot_heat(ax, sat_delta_ti, min_limit):
  """ Plot satmut deltas.

    Args:
        ax (Axis): matplotlib axis to plot to.
        sat_delta_ti (4 x L_sm array): Single target delta matrix for saturated mutagenesis region,
        min_limit (float): Minimum heatmap limit.
    """

  vlim = max(min_limit, np.nanmax(np.abs(sat_delta_ti)))
  sns.heatmap(
      sat_delta_ti,
      linewidths=0,
      cmap='RdBu_r',
      vmin=-vlim,
      vmax=vlim,
      xticklabels=False,
      ax=ax)
  ax.yaxis.set_ticklabels('ACGT', rotation='horizontal')  # , size=10) 
Example #20
Source File: lda_plots.py    From numpy-ml with GNU General Public License v3.0 6 votes vote down vote up
def plot_unsmoothed():
    corpus, T = generate_corpus()
    L = LDA(T)
    L.train(corpus, verbose=False)

    fig, axes = plt.subplots(1, 2)
    ax1 = sns.heatmap(L.beta, xticklabels=[], yticklabels=[], ax=axes[0])
    ax1.set_xlabel("Topics")
    ax1.set_ylabel("Words")
    ax1.set_title("Recovered topic-word distribution")

    ax2 = sns.heatmap(L.gamma, xticklabels=[], yticklabels=[], ax=axes[1])
    ax2.set_xlabel("Topics")
    ax2.set_ylabel("Documents")
    ax2.set_title("Recovered document-topic distribution")

    plt.savefig("img/plot_unsmoothed.png", dpi=300)
    plt.close("all") 
Example #21
Source File: interpretation.py    From lumin with Apache License 2.0 6 votes vote down vote up
def plot_embedding(embed:OrderedDict, feat:str, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None:
    r'''
    Visualise weights in provided categorical entity-embedding matrix

    Arguments:
        embed: state_dict of trained nn.Embedding
        feat: name of feature embedded
        savename: Optional name of file to which to save the plot of feature importances
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
    '''

    with sns.axes_style(**settings.style):
        plt.figure(figsize=(settings.w_small, settings.h_small))
        sns.heatmap(to_np(embed['weight']), annot=True, fmt='.1f', linewidths=.5, cmap=settings.div_palette, annot_kws={'fontsize':settings.leg_sz})
        plt.xlabel("Embedding", fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.ylabel(feat, fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc)
        if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight')
        plt.show() 
Example #22
Source File: akita_sat_vcf.py    From basenji with Apache License 2.0 6 votes vote down vote up
def plot_heat(ax, sat_score_ti, min_limit=None):
  """ Plot satmut deltas.

    Args:
        ax (Axis): matplotlib axis to plot to.
        sat_delta_ti (L_sm x 4 array): Single target delta matrix for saturated mutagenesis region,
    """

  if np.max(sat_score_ti) < min_limit:
    vmax = min_limit
  else:
    vmax = None

  sns.heatmap(
      sat_score_ti.T,
      linewidths=0,
      xticklabels=False,
      yticklabels=False,
      cmap='Blues',
      vmax=vmax,
      ax=ax)

  # yticklabels break the plot for some reason
  # ax.yaxis.set_ticklabels('ACGT', rotation='horizontal') 
Example #23
Source File: stock_visualizer.py    From stock-analysis with MIT License 6 votes vote down vote up
def heatmap(self, pct_change=False, **kwargs):
        """
        Generate a seaborn heatmap for correlations between assets.

        Parameters:
            - pct_change: Whether or not to show the correlations of the
                          daily percent change in price or just use
                          the closing price.
            - kwargs: Keyword arguments to pass down to `sns.heatmap()`

        Returns:
            A seaborn heatmap
        """
        pivot = self.data.pivot_table(
            values='close', index=self.data.index, columns='name'
        )
        if pct_change:
            pivot = pivot.pct_change()
        return sns.heatmap(pivot.corr(), annot=True, center=0, **kwargs) 
Example #24
Source File: plot.py    From retentioneering-tools with Mozilla Public License 2.0 6 votes vote down vote up
def altair_step_matrix(diff, plot_name=None, title='', vmin=None, vmax=None, font_size=12, **kwargs):
    heatmap_data = diff.reset_index().melt('index')
    heatmap_data.columns = ['y', 'x', 'z']
    table = alt.Chart(heatmap_data).encode(
        x=alt.X('x:O', sort=None),
        y=alt.Y('y:O', sort=None)
    )
    heatmap = table.mark_rect().encode(
        color=alt.Color(
            'z:Q',
            scale=alt.Scale(scheme='blues'),
        )
    )
    text = table.mark_text(
        align='center', fontSize=font_size
    ).encode(
        text='z',
        color=alt.condition(
            abs(alt.datum.z) < 0.8,
            alt.value('black'),
            alt.value('white'))
    )
    heatmap_object = (heatmap + text).properties(
        width=3 * font_size * len(diff.columns),
        height=2 * font_size * diff.shape[0]
    )
    return heatmap_object, plot_name, None, diff.retention.retention_config 
Example #25
Source File: Auto_NLP.py    From Auto_ViML with Apache License 2.0 5 votes vote down vote up
def plot_confusion_matrix(y_test,y_pred, model_name='Model'):
    """
    This plots a beautiful confusion matrix based on input: ground truths and predictions
    """
    #Confusion Matrix
    '''Plotting CONFUSION MATRIX'''
    import matplotlib.pyplot as plt
    import seaborn as sns
    sns.set_style('darkgrid')

    '''Display'''
    from IPython.core.display import display, HTML
    display(HTML("<style>.container { width:95% !important; }</style>"))
    pd.options.display.float_format = '{:,.2f}'.format

    #Get the confusion matrix and put it into a df
    from sklearn.metrics import confusion_matrix, f1_score

    cm = confusion_matrix(y_test, y_pred)

    cm_df = pd.DataFrame(cm,
                         index = np.unique(y_test).tolist(),
                         columns = np.unique(y_test).tolist(),
                        )

    #Plot the heatmap
    plt.figure(figsize=(12, 8))

    sns.heatmap(cm_df,
                center=0,
                cmap=sns.diverging_palette(220, 15, as_cmap=True),
                annot=True,
                fmt='g')

    plt.title(' %s \nF1 Score(avg = micro): %0.2f \nF1 Score(avg = macro): %0.2f' %(
        model_name,f1_score(y_test, y_pred, average='micro'),f1_score(y_test, y_pred, average='macro')),
              fontsize = 13)
    plt.ylabel('True label', fontsize = 13)
    plt.xlabel('Predicted label', fontsize = 13)
    plt.show();
############################################################################################## 
Example #26
Source File: Auto_NLP.py    From Auto_ViML with Apache License 2.0 5 votes vote down vote up
def plot_confusion_matrix(y_test,y_pred, model_name='Model'):
    """
    This plots a beautiful confusion matrix based on input: ground truths and predictions
    """
    #Confusion Matrix
    '''Plotting CONFUSION MATRIX'''
    import matplotlib.pyplot as plt
    import seaborn as sns
    sns.set_style('darkgrid')

    '''Display'''
    from IPython.core.display import display, HTML
    display(HTML("<style>.container { width:95% !important; }</style>"))
    pd.options.display.float_format = '{:,.2f}'.format

    #Get the confusion matrix and put it into a df
    from sklearn.metrics import confusion_matrix, f1_score

    cm = confusion_matrix(y_test, y_pred)

    cm_df = pd.DataFrame(cm,
                         index = np.unique(y_test).tolist(),
                         columns = np.unique(y_test).tolist(),
                        )

    #Plot the heatmap
    plt.figure(figsize=(12, 8))

    sns.heatmap(cm_df,
                center=0,
                cmap=sns.diverging_palette(220, 15, as_cmap=True),
                annot=True,
                fmt='g')

    plt.title(' %s \nF1 Score(avg = micro): %0.2f \nF1 Score(avg = macro): %0.2f' %(
        model_name,f1_score(y_test, y_pred, average='micro'),f1_score(y_test, y_pred, average='macro')),
              fontsize = 13)
    plt.ylabel('True label', fontsize = 13)
    plt.xlabel('Predicted label', fontsize = 13)
    plt.show();
############################################################################################## 
Example #27
Source File: multiple_linear_regression.py    From deep-learning-samples with The Unlicense 5 votes vote down vote up
def plot_correlation_heatmap(X, header):
    """Plot a heatmap of the correlation matrix for X.

    This requires the seaborn package to be installed.
    """
    import seaborn
    cm = np.corrcoef(X.T)
    hm = seaborn.heatmap(cm,
            cbar=True,
            annot=True,
            square=True,
            yticklabels=header,
            xticklabels=header)
    plt.show() 
Example #28
Source File: visuals.py    From B-SOID with GNU General Public License v3.0 5 votes vote down vote up
def plot_tmat(tm: object):
    """
    :param tm: object, transition matrix data frame
    :param fps: scalar, camera frame-rate
    """
    fig = plt.figure()
    fig.suptitle("Transition matrix of {} behaviors".format(tm.shape[0]))
    sn.heatmap(tm, annot=True)
    plt.xlabel("Next frame behavior")
    plt.ylabel("Current frame behavior")
    plt.show()
    return fig 
Example #29
Source File: visuals.py    From B-SOID with GNU General Public License v3.0 5 votes vote down vote up
def plot_tmat(tm: object):
    """
    :param tm: object, transition matrix data frame
    :param fps: scalar, camera frame-rate
    """
    fig = plt.figure()
    fig.suptitle("Transition matrix of {} behaviors".format(tm.shape[0]))
    sn.heatmap(tm, annot=True)
    plt.xlabel("Next frame behavior")
    plt.ylabel("Current frame behavior")
    # plt.show()
    return fig 
Example #30
Source File: viz.py    From focus with GNU General Public License v3.0 5 votes vote down vote up
def heatmap(wcor):
    """
    Make a scatterplot of zscore values with gene names as xtick labels.

    :param wcor: numpy.ndarray matrix of sample correlation structure for predicted expression

    :return: numpy.ndarray (RGB) formatted heatmap of correlation structure
    """
    mpl.rcParams["figure.figsize"] = [6.4, 6.4]
    fig = plt.figure()
    fig.subplots_adjust(bottom=0.20, left=0.28)
    mask = np.zeros_like(wcor, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True
    ax = sns.heatmap(wcor, mask=mask, cmap="RdBu_r", square=True,
                     linewidths=0, cbar=False, xticklabels=False, yticklabels=False, ax=None,
                     vmin=-1, vmax=1)
    ax.margins(2)
    ax.set_aspect("equal", "box")
    fig.canvas.draw()

    # save image as numpy array
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
    img = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    # rotate heatmap to make upside-down triangle shape
    rows, cols, ch = img.shape
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 45, 1)
    dst = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_CONSTANT,
                         borderValue=(255, 255, 255))

    # trim extra whitespace
    crop_img = dst[int(dst.shape[0] / 2.5):int(dst.shape[0] / 1.1)]

    return crop_img