Python matplotlib.pyplot.setp() Examples

The following are 30 code examples for showing how to use matplotlib.pyplot.setp(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .

Example 1
Project: TaskBot   Author: EvilPsyCHo   File: plot.py    License: GNU General Public License v3.0 6 votes vote down vote up
def plot_attention(sentences, attentions, labels, **kwargs):
    fig, ax = plt.subplots(**kwargs)
    im = ax.imshow(attentions, interpolation='nearest',
                   vmin=attentions.min(), vmax=attentions.max())
    plt.colorbar(im, shrink=0.5, ticks=[0, 1])
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontproperties=getChineseFont())
    # Loop over data dimensions and create text annotations.
    for i in range(attentions.shape[0]):
        for j in range(attentions.shape[1]):
            text = ax.text(j, i, sentences[i][j],
                           ha="center", va="center", color="b", size=10,
                           fontproperties=getChineseFont())

    ax.set_title("Attention Visual")
    fig.tight_layout()
    plt.show() 
Example 2
Project: tensortrade   Author: tensortrade-org   File: matplotlib_trading_chart.py    License: Apache License 2.0 6 votes vote down vote up
def render(self, current_step, net_worths, benchmarks, trades, window_size=50):
        net_worth = round(net_worths[-1], 2)
        initial_net_worth = round(net_worths[0], 2)
        profit_percent = round((net_worth - initial_net_worth) / initial_net_worth * 100, 2)

        self.fig.suptitle('Net worth: $' + str(net_worth) +
                          ' | Profit: ' + str(profit_percent) + '%')

        window_start = max(current_step - window_size, 0)
        step_range = slice(window_start, current_step)
        times = self.df.index.values[step_range]

        self._render_net_worth(step_range, times, current_step, net_worths, benchmarks)
        self._render_price(step_range, times, current_step)
        self._render_volume(step_range, times)
        self._render_trades(step_range, trades)

        self.price_ax.set_xticklabels(times, rotation=45, horizontalalignment='right')

        # Hide duplicate net worth date labels
        plt.setp(self.net_worth_ax.get_xticklabels(), visible=False)

        # Necessary to view frames before they are unrendered
        plt.pause(0.001) 
Example 3
Project: uiKLine   Author: rjj510   File: visFunction.py    License: MIT License 6 votes vote down vote up
def plotSigHeats(signals,markets,start=0,step=2,size=1,iters=6):
    """
    打印信号回测盈损热度图,寻找参数稳定岛
    """
    sigMat = pd.DataFrame(index=range(iters),columns=range(iters))
    for i in range(iters):
        for j in range(iters):
            climit = start + i*step
            wlimit = start + j*step
            caps,poss = plotSigCaps(signals,markets,climit=climit,wlimit=wlimit,size=size,op=False)
            sigMat[i][j] = caps[-1]
    sns.heatmap(sigMat.values.astype(np.float64),annot=True,fmt='.2f',annot_kws={"weight": "bold"})
    xTicks   = [i+0.5 for i in range(iters)]
    yTicks   = [iters-i-0.5 for i in range(iters)]
    xyLabels = [str(start+i*step) for i in range(iters)]
    _, labels = plt.yticks(yTicks,xyLabels)
    plt.setp(labels, rotation=0)
    _, labels = plt.xticks(xTicks,xyLabels)
    plt.setp(labels, rotation=90)
    plt.xlabel('Loss Stop @')
    plt.ylabel('Profit Stop @')
    return sigMat 
Example 4
Project: dgl   Author: dmlc   File: viz.py    License: Apache License 2.0 6 votes vote down vote up
def draw_heatmap(array, input_seq, output_seq, dirname, name):
    dirname = os.path.join('log', dirname)
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    fig, axes = plt.subplots(2, 4)
    cnt = 0
    for i in range(2):
        for j in range(4):
            axes[i, j].imshow(array[cnt].transpose(-1, -2))
            axes[i, j].set_yticks(np.arange(len(input_seq)))
            axes[i, j].set_xticks(np.arange(len(output_seq)))
            axes[i, j].set_yticklabels(input_seq, fontsize=4)
            axes[i, j].set_xticklabels(output_seq, fontsize=4)
            axes[i, j].set_title('head_{}'.format(cnt), fontsize=10)
            plt.setp(axes[i, j].get_xticklabels(), rotation=45, ha="right",
                     rotation_mode="anchor")
            cnt += 1

    fig.suptitle(name, fontsize=12)
    plt.tight_layout()
    plt.savefig(os.path.join(dirname, '{}.pdf'.format(name)))
    plt.close() 
Example 5
Project: jMetalPy   Author: jMetal   File: chord_plot.py    License: MIT License 6 votes vote down vote up
def hover_over_bin(event, handle_tickers, handle_plots, colors, fig):
    is_found = False

    for iobj in range(len(handle_tickers)):
        for ibin in range(len(handle_tickers[iobj])):
            cont = False
            if not is_found:
                cont, ind = handle_tickers[iobj][ibin].contains(event)
                if cont:
                    is_found = True
            if cont:
                plt.setp(handle_tickers[iobj][ibin], facecolor=colors[iobj])
                [h.set_visible(True) for h in handle_plots[iobj][ibin]]
                is_found = True
                fig.canvas.draw_idle()
            else:
                plt.setp(handle_tickers[iobj][ibin], facecolor=(1, 1, 1))
                for h in handle_plots[iobj][ibin]:
                    h.set_visible(False)
                fig.canvas.draw_idle() 
Example 6
Project: Computable   Author: ktraunmueller   File: plotting.py    License: MIT License 6 votes vote down vote up
def _label_axis(ax, kind='x', label='', position='top',
    ticks=True, rotate=False):

    from matplotlib.artist import setp
    if kind == 'x':
        ax.set_xlabel(label, visible=True)
        ax.xaxis.set_visible(True)
        ax.xaxis.set_ticks_position(position)
        ax.xaxis.set_label_position(position)
        if rotate:
            setp(ax.get_xticklabels(), rotation=90)
    elif kind == 'y':
        ax.yaxis.set_visible(True)
        ax.set_ylabel(label, visible=True)
        # ax.set_ylabel(a)
        ax.yaxis.set_ticks_position(position)
        ax.yaxis.set_label_position(position)
    return 
Example 7
Project: atis   Author: lil-lab   File: visualize_attention.py    License: MIT License 6 votes vote down vote up
def render(self, filename):
        """
        Renders the attention graph over timesteps.

        Args:
          filename (string): filename to save the figure to.
        """
        figure, axes = plt.subplots()
        graph = np.stack(self.attentions)

        axes.imshow(graph, cmap=plt.cm.Blues, interpolation="nearest")
        axes.xaxis.tick_top()
        axes.set_xticks(range(len(self.keys)))
        axes.set_xticklabels(self.keys)
        plt.setp(axes.get_xticklabels(), rotation=90)
        axes.set_yticks(range(len(self.generated_values)))
        axes.set_yticklabels(self.generated_values)
        axes.set_aspect(1, adjustable='box')
        plt.tick_params(axis='x', which='both', bottom='off', top='off')
        plt.tick_params(axis='y', which='both', left='off', right='off')

        figure.savefig(filename) 
Example 8
Project: DevilYuan   Author: moyuanz   File: DyStockDataViewer.py    License: MIT License 6 votes vote down vote up
def plotTimeShareChart(self, code, date, n):

        date = self._daysEngine.codeTDayOffset(code, date, n)
        if date is None: return

        DyMatplotlib.newFig()

        # plot stock time share chart
        self._plotTimeShareChart(code, date, left=0.05, right=0.95, top=0.95, bottom=0.05)

        # plot index time share chart
        #self._plotTimeShareChart(self._daysEngine.getIndex(code), date, left=0.05, right=0.95, top=0.45, bottom=0.05)

        # layout
        f = plt.gcf()
        plt.setp([a.get_xticklabels() for a in f.axes[::2]], visible=False)
        f.show() 
Example 9
Project: DevilYuan   Author: moyuanz   File: DyStockDataViewer.py    License: MIT License 6 votes vote down vote up
def _plotAckRWExtremas(self, event):
        code = event.data['code']
        df = event.data['df']
        regionalLocals = event.data['regionalLocals']

        DyMatplotlib.newFig()
        f = plt.gcf()

        index = df.index
        startDay = index[0].strftime('%Y-%m-%d')
        endDay = index[-1].strftime('%Y-%m-%d')

        # plot stock
        periods = self._plotCandleStick(code, startDate=startDay, endDate=endDay, baseDate=endDay, left=0.05, right=0.95, top=0.95, bottom=0.5, maIndicator='close')

        self._plotRegionalLocals(f.axes[0], index, regionalLocals)

        plt.setp([a.get_xticklabels() for a in f.axes[:-1]], visible=False)
        f.show() 
Example 10
Project: DevilYuan   Author: moyuanz   File: DyStockDataViewer.py    License: MIT License 6 votes vote down vote up
def _plotAckHSARs(self, event):
        code = event.data['code']
        df = event.data['df']
        hsars = event.data['hsars']

        DyMatplotlib.newFig()
        f = plt.gcf()

        index = df.index
        startDay = index[0].strftime('%Y-%m-%d')
        endDay = index[-1].strftime('%Y-%m-%d')

        # plot stock
        periods = self._plotCandleStick(code, startDate=startDay, endDate=endDay, baseDate=endDay, left=0.05, right=0.95, top=0.95, bottom=0.5, maIndicator='close')

        self._plotHSARs(f.axes[0], hsars)

        plt.setp([a.get_xticklabels() for a in f.axes[:-1]], visible=False)
        f.show() 
Example 11
Project: DevilYuan   Author: moyuanz   File: DyStockDataViewer.py    License: MIT License 6 votes vote down vote up
def plotAckKama(self, event):
        code, startDate, endDate = '002551.SZ', '2015-07-01', '2016-03-01'

        # load
        if not self._daysEngine.load([-200, startDate, endDate], codes=[code]):
            return

        DyMatplotlib.newFig()

        # plot basic stock K-Chart
        periods = self._plotCandleStick(code, startDate=startDate, endDate=endDate, netCapitalFlow=True, left=0.05, right=0.95, top=0.95, bottom=0.5)
        
        # plot customized stock K-Chart
        self._plotKamaCandleStick(code, periods=periods, left=0.05, right=0.95, top=0.45, bottom=0.05)

        # layout
        f = plt.gcf()
        plt.setp([a.get_xticklabels() for a in f.axes[:-1]], visible=False)
        f.show() 
Example 12
Project: SpectralMachine   Author: feranick   File: SpectraKeras_MLP.py    License: GNU General Public License v3.0 6 votes vote down vote up
def plotWeights(En, A, model):
    import matplotlib.pyplot as plt
    plt.figure(tight_layout=True)
    plotInd = 511
    for layer in model.layers:
        try:
            w_layer = layer.get_weights()[0]
            ax = plt.subplot(plotInd)
            newX = np.arange(En[0], En[-1], (En[-1]-En[0])/w_layer.shape[0])
            plt.plot(En, np.interp(En, newX, w_layer[:,0]), label=layer.get_config()['name'])
            plt.legend(loc='upper right')
            plt.setp(ax.get_xticklabels(), visible=False)
            plotInd +=1
        except:
            pass

    ax1 = plt.subplot(plotInd)
    ax1.plot(En, A[0], label='Sample data')

    plt.xlabel('Raman shift [1/cm]')
    plt.legend(loc='upper right')
    plt.savefig('keras_MLP_weights' + '.png', dpi = 160, format = 'png')  # Save plot

#************************************ 
Example 13
Project: SpectralMachine   Author: feranick   File: SpectraKeras_CNN.py    License: GNU General Public License v3.0 6 votes vote down vote up
def plotWeights(En, A, model):
    import matplotlib.pyplot as plt
    plt.figure(tight_layout=True)
    plotInd = 511
    for layer in model.layers:
        try:
            w_layer = layer.get_weights()[0]
            ax = plt.subplot(plotInd)
            newX = np.arange(En[0], En[-1], (En[-1]-En[0])/w_layer.shape[0])
            plt.plot(En, np.interp(En, newX, w_layer[:,0]), label=layer.get_config()['name'])
            plt.legend(loc='upper right')
            plt.setp(ax.get_xticklabels(), visible=False)
            plotInd +=1
        except:
            pass

    ax1 = plt.subplot(plotInd)
    ax1.plot(En, A[0], label='Sample data')

    plt.xlabel('Raman shift [1/cm]')
    plt.legend(loc='upper right')
    plt.savefig('keras_MLP_weights' + '.png', dpi = 160, format = 'png')  # Save plot

#************************************ 
Example 14
Project: seq2seq-summarizer   Author: ymfa   File: utils.py    License: MIT License 6 votes vote down vote up
def show_attention_map(src_words, pred_words, attention, pointer_ratio=None):
  fig, ax = plt.subplots(figsize=(16, 4))
  im = plt.pcolormesh(np.flipud(attention), cmap="GnBu")
  # set ticks and labels
  ax.set_xticks(np.arange(len(src_words)) + 0.5)
  ax.set_xticklabels(src_words, fontsize=14)
  ax.set_yticks(np.arange(len(pred_words)) + 0.5)
  ax.set_yticklabels(reversed(pred_words), fontsize=14)
  if pointer_ratio is not None:
    ax1 = ax.twinx()
    ax1.set_yticks(np.concatenate([np.arange(0.5, len(pred_words)), [len(pred_words)]]))
    ax1.set_yticklabels('%.3f' % v for v in np.flipud(pointer_ratio))
    ax1.set_ylabel('Copy probability', rotation=-90, va="bottom")
  # let the horizontal axes labelling appear on top
  ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
  # rotate the tick labels and set their alignment
  plt.setp(ax.get_xticklabels(), rotation=-45, ha="right", rotation_mode="anchor") 
Example 15
Project: FRETBursts   Author: tritemio   File: burst_plot.py    License: GNU General Public License v2.0 6 votes vote down vote up
def _alex_plot_style(g, colorbar=True):
    """Set plot style and colorbar for an ALEX joint plot.
    """
    g.set_axis_labels(xlabel="E", ylabel="S")
    g.ax_marg_x.grid(True)
    g.ax_marg_y.grid(True)
    g.ax_marg_x.set_xlabel('')
    g.ax_marg_y.set_ylabel('')
    plt.setp(g.ax_marg_y.get_xticklabels(), visible=True)
    plt.setp(g.ax_marg_x.get_yticklabels(), visible=True)
    g.ax_marg_x.locator_params(axis='y', tight=True, nbins=3)
    g.ax_marg_y.locator_params(axis='x', tight=True, nbins=3)
    if colorbar:
        pos = g.ax_joint.get_position().get_points()
        X, Y = pos[:, 0], pos[:, 1]
        cax = plt.axes([1., Y[0], (X[1] - X[0]) * 0.045, Y[1] - Y[0]])
        plt.colorbar(cax=cax) 
Example 16
Project: AE_ts   Author: RobRomijnders   File: AE_ts_model.py    License: MIT License 6 votes vote down vote up
def plot_data(X_train, y_train, plot_row=5):
    counts = dict(Counter(y_train))
    num_classes = len(np.unique(y_train))
    f, axarr = plt.subplots(plot_row, num_classes)
    for c in np.unique(y_train):  # Loops over classes, plot as columns
        c = int(c)
        ind = np.where(y_train == c)
        ind_plot = np.random.choice(ind[0], size=plot_row)
        for n in range(plot_row):  # Loops over rows
            axarr[n, c].plot(X_train[ind_plot[n], :])
            # Only shops axes for bottom row and left column
            if n == 0:
                axarr[n, c].set_title('Class %.0f (%.0f)' % (c, counts[float(c)]))
            if not n == plot_row - 1:
                plt.setp([axarr[n, c].get_xticklabels()], visible=False)
            if not c == 0:
                plt.setp([axarr[n, c].get_yticklabels()], visible=False)
    f.subplots_adjust(hspace=0)  # No horizontal space between subplots
    f.subplots_adjust(wspace=0)  # No vertical space between subplots
    plt.show()
    return 
Example 17
Project: python3_ios   Author: holzschu   File: test_artist.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_setp():
    # Check empty list
    plt.setp([])
    plt.setp([[]])

    # Check arbitrary iterables
    fig, axes = plt.subplots()
    lines1 = axes.plot(range(3))
    lines2 = axes.plot(range(3))
    martist.setp(chain(lines1, lines2), 'lw', 5)
    plt.setp(axes.spines.values(), color='green')

    # Check `file` argument
    sio = io.StringIO()
    plt.setp(lines1, 'zorder', file=sio)
    assert sio.getvalue() == '  zorder: float\n' 
Example 18
Project: uiKLine   Author: moonnejs   File: visFunction.py    License: MIT License 6 votes vote down vote up
def plotSigHeats(signals,markets,start=0,step=2,size=1,iters=6):
    """
    打印信号回测盈损热度图,寻找参数稳定岛
    """
    sigMat = pd.DataFrame(index=range(iters),columns=range(iters))
    for i in range(iters):
        for j in range(iters):
            climit = start + i*step
            wlimit = start + j*step
            caps,poss = plotSigCaps(signals,markets,climit=climit,wlimit=wlimit,size=size,op=False)
            sigMat[i][j] = caps[-1]
    sns.heatmap(sigMat.values.astype(np.float64),annot=True,fmt='.2f',annot_kws={"weight": "bold"})
    xTicks   = [i+0.5 for i in range(iters)]
    yTicks   = [iters-i-0.5 for i in range(iters)]
    xyLabels = [str(start+i*step) for i in range(iters)]
    _, labels = plt.yticks(yTicks,xyLabels)
    plt.setp(labels, rotation=0)
    _, labels = plt.xticks(xTicks,xyLabels)
    plt.setp(labels, rotation=90)
    plt.xlabel('Loss Stop @')
    plt.ylabel('Profit Stop @')
    return sigMat 
Example 19
Project: python_primer   Author: noahwaterfieldprice   File: PiecewiseConstant.py    License: MIT License 6 votes vote down vote up
def _test():
    PC = PiecewiseConstant([(0.4, 1), (0.2, 1.5), (0.1, 3)], xmax=4)
    I, I_s = Indicator(-3, 5), Indicator(-3, 5, eps=1)
    H, H_s = Heaviside(), Heaviside(eps=1)
    ax1 = plt.subplot(311)
    ax2, ax3 = plt.subplot(323),  plt.subplot(324)
    ax4, ax5 = plt.subplot(325),  plt.subplot(326)

    x, y = PC.plot()
    ax1.plot(x, y)
    ax1.set_ylim([0, 0.5])
    ax1.set_title('PiecewiseConstant')
    titles = ['Indicator', 'Indicator (eps=1)',
              'Heaviside', 'Heaviside (eps=1)']
    for f, ax, title in zip([I, I_s, H, H_s], [ax2, ax3, ax4, ax5], titles):
        x, y = f.plot(-6, 8)
        ax.plot(x, y)
        ax.set_ylim([-0.5, 1.5])
        ax.set_title(title)
    for ax in [ax2, ax3]:
        plt.setp(ax.get_xticklabels(), visible=False)
    plt.show() 
Example 20
Project: DETAD   Author: HumamAlwassel   File: false_postive_analysis.py    License: MIT License 5 votes vote down vote up
def subplot_fp_profile(fig, ax, values, labels, colors, xticks, xlabel, ylabel, title,
                       fontsize=14, bottom=0, top=100, bar_width=1, spacing=0.85,
                       grid_color='gray', grid_linestyle=':', grid_lw=1, 
                       ncol=1, legend_loc='best'):

    ax.yaxis.grid(color=grid_color, linestyle=grid_linestyle, lw=grid_lw)
    
    cumsum_values = np.cumsum(np.array(values)*100, axis=1)    
    index = np.linspace(0, spacing*bar_width*len(values),len(values))
    for i in range(cumsum_values.shape[1])[::-1]:
        rects1 = ax.bar(index, cumsum_values[:,i], bar_width,
                         capsize = i,
                         color=colors[i],
                         label=xticks[i], zorder=0)

    lgd = ax.legend(loc=legend_loc, ncol=ncol, fontsize=fontsize/1.2, edgecolor='k')
    
    ax.set_ylabel(ylabel, fontsize=fontsize)
    ax.set_xlabel(xlabel, fontsize=fontsize)
    plt.setp(ax.get_yticklabels(), fontsize=fontsize/1.2)
    plt.xticks(np.array(index), np.array(labels[:len(values)]), fontsize=fontsize/1.2, rotation=90)
    plt.yticks(np.linspace(0,1,11)*100, fontsize=fontsize/1.2 )
    ax.set_ylim(bottom=bottom, top=top)
    ax.set_xlim(left=index[0]-1.25*bar_width, right=index[-1]+1.0*bar_width)
    ax.set_title(title, fontsize=fontsize)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.yaxis.grid(True, linestyle='dotted')
    ax.set_axisbelow(True)
    ax.yaxis.set_tick_params(size=10, direction='in', width=2)
    for axis in ['bottom','left']:
        ax.spines[axis].set_linewidth(2.5)

    return lgd 
Example 21
Project: mac-network   Author: stanfordnlp   File: visualization.py    License: Apache License 2.0 5 votes vote down vote up
def showTableAtt(instance, table, x, y, name):
    # if args.trans:
    #     figureTableDims = (len(y) / 2 + 4, len(x) + 2)
    # else:
    #     figureTableDims = (len(y) / 2, len(x) / 2)
    # xx = np.arange(0, len(x), 1)
    # yy = np.arange(0, len(y), 1)
    # extent2 = np.min(xx), np.max(xx), np.min(yy), np.max(yy)
    
    fig2, bx = plt.subplots(1, 1) # figsize = figureTableDims
    bx.cla()

    sns.set(font_scale = fontScale)

    if args.trans:
        table = np.transpose(table)
        x, y = y, x
    
    tableMap = pandas.DataFrame(data = table, index = x, columns = y)
    
    bx = sns.heatmap(tableMap, cmap = "Purples", cbar = False, linewidths = .5, linecolor = "gray", square = True)
    
    # x ticks
    if args.trans:
        bx.xaxis.tick_top()
    locs, labels = plt.xticks()
    if args.trans:
        plt.setp(labels, rotation = 0)
    else:
        plt.setp(labels, rotation = 60)

    # y ticks
    locs, labels = plt.yticks()
    plt.setp(labels, rotation = 0)

    plt.savefig(outTableAttName(instance, name), dpi = 720) 
Example 22
Project: RLTrader   Author: notadamking   File: TradingChart.py    License: GNU General Public License v3.0 5 votes vote down vote up
def render(self, current_step, net_worths, benchmarks, trades, window_size=200):
        net_worth = round(net_worths[-1], 2)
        initial_net_worth = round(net_worths[0], 2)
        profit_percent = round((net_worth - initial_net_worth) / initial_net_worth * 100, 2)

        self.fig.suptitle('Net worth: $' + str(net_worth) + ' | Profit: ' + str(profit_percent) + '%')

        window_start = max(current_step - window_size, 0)
        step_range = slice(window_start, current_step + 1)
        times = self.df['Date'].values[step_range]

        self._render_net_worth(step_range, times, current_step, net_worths, benchmarks)
        self._render_price(step_range, times, current_step)
        self._render_volume(step_range, times)
        self._render_trades(step_range, trades)

        date_col = pd.to_datetime(self.df['Date'], unit='s').dt.strftime('%m/%d/%Y %H:%M')
        date_labels = date_col.values[step_range]

        self.price_ax.set_xticklabels(date_labels, rotation=45, horizontalalignment='right')

        # Hide duplicate net worth date labels
        plt.setp(self.net_worth_ax.get_xticklabels(), visible=False)

        # Necessary to view frames before they are unrendered
        plt.pause(0.001) 
Example 23
Project: recruit   Author: Frank-qlu   File: _tools.py    License: Apache License 2.0 5 votes vote down vote up
def _set_ticks_props(axes, xlabelsize=None, xrot=None,
                     ylabelsize=None, yrot=None):
    import matplotlib.pyplot as plt

    for ax in _flatten(axes):
        if xlabelsize is not None:
            plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
        if xrot is not None:
            plt.setp(ax.get_xticklabels(), rotation=xrot)
        if ylabelsize is not None:
            plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
        if yrot is not None:
            plt.setp(ax.get_yticklabels(), rotation=yrot)
    return axes 
Example 24
Project: cgpm   Author: probcomp   File: plots.py    License: Apache License 2.0 5 votes vote down vote up
def plot_clustermap(D, xticklabels=None, yticklabels=None):
    import seaborn as sns
    if xticklabels is None: xticklabels = range(D.shape[0])
    if yticklabels is None: yticklabels = range(D.shape[1])
    zmat = sns.clustermap(
        D, yticklabels=yticklabels, xticklabels=xticklabels,
        linewidths=0.2, cmap='BuGn')
    plt.setp(zmat.ax_heatmap.get_yticklabels(), rotation=0)
    plt.setp(zmat.ax_heatmap.get_xticklabels(), rotation=90)
    return zmat 
Example 25
Project: pylops   Author: equinor   File: prestack.py    License: GNU Lesser General Public License v3.0 5 votes vote down vote up
def plotmodel(axs, m, x, z, vmin, vmax,
              params=('VP', 'VS', 'Rho'),
              cmap='gist_rainbow', title=None):
    """Quick visualization of model
    """
    for ip, param in enumerate(params):
        axs[ip].imshow(m[:, ip],
                       extent=(x[0], x[-1], z[-1], z[0]),
                       vmin=vmin, vmax=vmax, cmap=cmap)
        axs[ip].set_title('%s - %s' %(param, title))
        axs[ip].axis('tight')
    plt.setp(axs[1].get_yticklabels(), visible=False)
    plt.setp(axs[2].get_yticklabels(), visible=False)

# data 
Example 26
Project: dgl   Author: dmlc   File: viz.py    License: Apache License 2.0 5 votes vote down vote up
def att_animation(maps_array, mode, src, tgt, head_id):
    weights = [maps[mode2id[mode]][head_id] for maps in maps_array]
    fig, axes = plt.subplots(1, 2)

    def weight_animate(i):
        global colorbar
        if colorbar:
            colorbar.remove()
        plt.cla()
        axes[0].set_title('heatmap')
        axes[0].set_yticks(np.arange(len(src)))
        axes[0].set_xticks(np.arange(len(tgt)))
        axes[0].set_yticklabels(src)
        axes[0].set_xticklabels(tgt)
        plt.setp(axes[0].get_xticklabels(), rotation=45, ha="right",
                 rotation_mode="anchor")

        fig.suptitle('epoch {}'.format(i))
        weight = weights[i].transpose(-1, -2)
        heatmap = axes[0].pcolor(weight, vmin=0, vmax=1, cmap=plt.cm.Blues)
        colorbar = plt.colorbar(heatmap, ax=axes[0], fraction=0.046, pad=0.04)
        axes[0].set_aspect('equal')
        axes[1].axis("off")
        graph_att_head(src, tgt, weight, axes[1], 'graph')


    ani = animation.FuncAnimation(fig, weight_animate, frames=len(weights), interval=500, repeat_delay=2000)
    return ani 
Example 27
Project: DeepDIVA   Author: DIVA-DIA   File: dataset_bidimensional.py    License: GNU Lesser General Public License v3.0 5 votes vote down vote up
def _visualize_distribution(train, val, test, save_path, marker_size=1):
    """
    This routine creates a PDF with three images for train, val and test respectively where
    each image is a visual representation of the split distribution with class colors.

    Parameters
    ----------
    train, val, test : ndarray[float] of size (n,3)
        The three splits. Each row is (x,y,label)
    save_path : String
        Path where to save the PDF
    marker_size : float
        Size of the marker representing each datapoint. For big dataset make this small

    Returns
    -------
        None
    """
    fig, axs = plt.subplots(ncols=3, sharex='all', sharey='all')
    plt.setp(axs.flat, aspect=1.0, adjustable='box-forced')
    axs[0].scatter(train[:, 0], train[:, 1], c=train[:, 2], s=marker_size, cmap=plt.get_cmap('Set1'))
    axs[0].set_title('train')
    axs[1].scatter(val[:, 0], val[:, 1], c=val[:, 2], s=marker_size, cmap=plt.get_cmap('Set1'))
    axs[1].set_title('val')
    axs[2].scatter(test[:, 0], test[:, 1], c=test[:, 2], s=marker_size, cmap=plt.get_cmap('Set1'))
    axs[2].set_title('test')
    fig.canvas.draw()
    fig.savefig(save_path)
    fig.clf()
    plt.close() 
Example 28
Project: TheCannon   Author: annayqho   File: plot_apogee_lamost_cannon.py    License: MIT License 5 votes vote down vote up
def create_grid():
    fig = plt.figure(figsize=(15,20))
    #plt.locator_params(nbins=5)
    #ax = fig.add_subplot(111)
    #plt.setp(ax.get_yticklabels(), visible=False)
    #plt.setp(ax.get_xticklabels(), visible=False)
    ax00 = fig.add_subplot(331)
    ax01 = fig.add_subplot(332, sharex=ax00, sharey=ax00)
    plt.setp(ax01.get_yticklabels(), visible=False)
    xticks = ax01.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    ax02 = fig.add_subplot(333, sharex=ax00, sharey=ax00)
    plt.setp(ax02.get_yticklabels(), visible=False)
    xticks = ax02.xaxis.get_major_ticks()
    xticks[0].set_visible(False) 
    ax10 = fig.add_subplot(334)
    ax11 = fig.add_subplot(335, sharex=ax10, sharey=ax10)
    plt.setp(ax11.get_yticklabels(), visible=False)
    xticks = ax11.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    ax12 = fig.add_subplot(336, sharex=ax10, sharey=ax10)
    plt.setp(ax12.get_yticklabels(), visible=False)
    xticks = ax12.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    ax20 = fig.add_subplot(337)
    ax21 = fig.add_subplot(338, sharex=ax20, sharey=ax20)
    plt.setp(ax21.get_yticklabels(), visible=False)
    xticks = ax21.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    ax22 = fig.add_subplot(339, sharex=ax20, sharey=ax20)
    plt.setp(ax22.get_yticklabels(), visible=False)
    xticks = ax22.xaxis.get_major_ticks()
    xticks[0].set_visible(False)
    fig.subplots_adjust(wspace=0)
    fig.subplots_adjust(hspace=0.2)
    axarr = ((ax00,ax01,ax02), (ax10,ax11,ax12), (ax20,ax21,ax22))
    return fig, axarr 
Example 29
Project: vnpy_crypto   Author: birforce   File: _tools.py    License: MIT License 5 votes vote down vote up
def _set_ticks_props(axes, xlabelsize=None, xrot=None,
                     ylabelsize=None, yrot=None):
    import matplotlib.pyplot as plt

    for ax in _flatten(axes):
        if xlabelsize is not None:
            plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
        if xrot is not None:
            plt.setp(ax.get_xticklabels(), rotation=xrot)
        if ylabelsize is not None:
            plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
        if yrot is not None:
            plt.setp(ax.get_yticklabels(), rotation=yrot)
    return axes 
Example 30
Project: helen   Author: kishwarshafin   File: TestInterface.py    License: MIT License 5 votes vote down vote up
def save_rle_confusion_matrix(stats_dictionary, output_directory):
    # plot confusion matrix
    fig, ax = plt.subplots(figsize=(20, 20))
    cf = np.array(stats_dictionary['rle_confusion_matrix'], dtype=np.int)
    im = ax.imshow(cf)
    rle_labels = [str(i) for i in range(0, ImageSizeOptions.TOTAL_RLE_LABELS)]

    # We want to show all ticks...
    ax.set_xticks(np.arange(len(rle_labels)))
    ax.set_yticks(np.arange(len(rle_labels)))
    # ... and label them with the respective list entries
    ax.set_xticklabels(rle_labels)
    ax.set_yticklabels(rle_labels)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(len(rle_labels)):
        for j in range(len(rle_labels)):
            if cf[i, j] > 0:
                if i == j:
                    text = ax.text(j, i, cf[i, j], ha="center", va="center", color="g")
                else:
                    text = ax.text(j, i, cf[i, j], ha="center", va="center", color="r")

    ax.set_title("RLE Confusion Matrix")
    fig.tight_layout()
    # plt.show()
    plt.savefig(output_directory + "/RLE_CONFUSION_MATRIX.png", dpi=100)