Python matplotlib.pyplot.sca() Examples

The following are code examples for showing how to use matplotlib.pyplot.sca(). They are from open source Python projects. You can vote up the examples you like or vote down the ones you don't like.

Example 1
Project: heppyplotlib   Author: ebothmann   File: plot.py    MIT License 6 votes vote down vote up
def gridplot(file_name, uses_rivet_plot_info=True):
    """Convenience function to plot all :py:mod:`yoda` data objects
    from a :py:mod:`yoda` file into a subplots grid.

    :param str file_name: The path to the :py:mod:`yoda` file.
    :return: fig, axes_list
    """
    all_rivet_paths = rivet_paths(file_name)

    # setup axes
    if len(all_rivet_paths) == 1:
        fig, axes_list = plt.subplots()
    else:
        ncols = 2
        nrows = (len(all_rivet_paths) - 1) / ncols + 1
        fig, axes_list = plt.subplots(nrows, ncols, squeeze=False)

    # plot into axes
    for rivet_path, axes in zip(all_rivet_paths, np.ravel(axes_list)):
        plt.sca(axes)
        plot(file_name, rivet_path, uses_rivet_plot_info=uses_rivet_plot_info)

    return fig, axes_list 
Example 2
Project: Generative-adversarial-Nets-in-NLP   Author: EternalFeather   File: adversarial_real_corpus.py    Apache License 2.0 5 votes vote down vote up
def matplotformat(self, ax, plot_y, plot_name, x_max):
		plt.sca(ax)
		plot_x = [i * 5 for i in range(len(plot_y))]
		plt.xticks(np.linspace(0, x_max, (x_max // 50) + 1, dtype=np.int32))
		plt.xlabel('Epochs', fontsize=16)
		plt.ylabel('NLL by oracle', fontsize=16)
		plt.title(plot_name)
		plt.plot(plot_x, plot_y) 
Example 3
Project: Generative-adversarial-Nets-in-NLP   Author: EternalFeather   File: adversarial_poem.py    Apache License 2.0 5 votes vote down vote up
def matplotformat(self, ax, plot_y, plot_name, x_max):
		plt.sca(ax)
		plot_x = [i * 5 for i in range(len(plot_y))]
		plt.xticks(np.linspace(0, x_max, (x_max // 50) + 1, dtype=np.int32))
		plt.xlabel('Epochs', fontsize=16)
		plt.ylabel('NLL by oracle', fontsize=16)
		plt.title(plot_name)
		plt.plot(plot_x, plot_y) 
Example 4
Project: Generative-adversarial-Nets-in-NLP   Author: EternalFeather   File: adversarial_ori.py    Apache License 2.0 5 votes vote down vote up
def matplotformat(self, ax, plot_y, plot_name, x_max):
		plt.sca(ax)
		plot_x = [i * 5 for i in range(len(plot_y))]
		plt.xticks(np.linspace(0, x_max, (x_max // 50) + 1, dtype=np.int32))
		plt.xlabel('Epochs', fontsize=16)
		plt.ylabel('NLL by oracle', fontsize=16)
		plt.title(plot_name)
		plt.plot(plot_x, plot_y) 
Example 5
Project: Generative-adversarial-Nets-in-NLP   Author: EternalFeather   File: adversarial.py    Apache License 2.0 5 votes vote down vote up
def matplotformat(self, ax, plot_y, plot_name, x_max):
		plt.sca(ax)
		plot_x = [i * 5 for i in range(len(plot_y))]
		plt.xticks(np.linspace(0, x_max, (x_max // 100) + 1, dtype=np.int32))
		plt.xlabel('Epochs', fontsize=16)
		plt.ylabel('NLL by oracle', fontsize=16)
		plt.title(plot_name)
		plt.plot(plot_x, plot_y) 
Example 6
Project: Generative-adversarial-Nets-in-NLP   Author: EternalFeather   File: adversarial_obama.py    Apache License 2.0 5 votes vote down vote up
def matplotformat(self, ax, plot_y, plot_name, x_max):
		plt.sca(ax)
		plot_x = [i * 5 for i in range(len(plot_y))]
		plt.xticks(np.linspace(0, x_max, (x_max // 50) + 1, dtype=np.int32))
		plt.xlabel('Epochs', fontsize=16)
		plt.ylabel('NLL by oracle', fontsize=16)
		plt.title(plot_name)
		plt.plot(plot_x, plot_y) 
Example 7
Project: afplot   Author: sndrtj   File: region.py    MIT License 5 votes vote down vote up
def plot_single_histogram(dataframe, output, dpi=300,
                          kde_only=False, label=None):
    g = sns.FacetGrid(dataframe, col="chrom", hue="label", col_wrap=2)
    if kde_only:
        try:
            g = (g.map(sns.distplot, "af", hist=False).
                 add_legend().
                 set_titles(""))
        except LinAlgError:
            warn("Cannot create KDE for this data set."
                 "Defaulting to histogram")
            g = (g.map(sns.distplot, "af", hist=True, kde=False).
                 add_legend().
                 set_titles(""))
    else:
        try:
            g = (g.map(sns.distplot, "af").
                 add_legend().
                 set_titles(""))
        except LinAlgError:
            warn("Cannot create KDE for this data set."
                 "Defaulting to histogram")
            g = (g.map(sns.distplot, "af", hist=True, kde=False).
                 add_legend().
                 set_titles(""))
    for x in g.axes:
        if x.get_ylim()[1] > 10:
            x.set_ylim(0, 10)
        x.set_xlim(-0.5, 1.5)
    if label is not None:
        plt.title(label)
    plt.savefig(output, dpi=dpi)
    for x in g.axes:
        plt.sca(x)
    plt.close(g.fig) 
Example 8
Project: afplot   Author: sndrtj   File: region.py    MIT License 5 votes vote down vote up
def plot_single_scatter(dataframe, output, category="af", dpi=300, label=None):
    f = sns.lmplot("pos", category, dataframe, col="chrom",
                   col_wrap=1.2, fit_reg=False,
                   hue="label", scatter_kws={"alpha": 0.3}, aspect=3)
    f.add_legend()
    f.set_titles("")
    for x in f.axes:
        x.set_ylim(0, 1.0)
    if label is not None:
        plt.title(label)
    plt.savefig(output, dpi=dpi)
    for x in f.axes:
        plt.sca(x)
    plt.close(f.fig) 
Example 9
Project: miccai-2016-surgical-activity-rec   Author: rdipietro   File: data.py    Apache License 2.0 5 votes vote down vote up
def visualize_predictions(prediction_seqs, label_seqs, num_classes,
                          fig_width=6.5, fig_height_per_seq=0.5):
    """ Visualize predictions vs. ground truth.

    Args:
        prediction_seqs: A list of int NumPy arrays, each with shape
            `[duration, 1]`.
        label_seqs: A list of int NumPy arrays, each with shape `[duration, 1]`.
        num_classes: An integer.
        fig_width: A float. Figure width (inches).
        fig_height_per_seq: A float. Figure height per sequence (inches).

    Returns:
        A tuple of the created figure, axes.
    """

    num_seqs = len(label_seqs)
    max_seq_length = max([seq.shape[0] for seq in label_seqs])
    figsize = (fig_width, num_seqs*fig_height_per_seq)
    fig, axes = plt.subplots(nrows=num_seqs, ncols=1,
                             sharex=True, figsize=figsize)

    for pred_seq, label_seq, ax in zip(prediction_seqs, label_seqs, axes):
        plt.sca(ax)
        plot_label_seq(label_seq, num_classes, 1)
        plot_label_seq(pred_seq, num_classes, -1)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        plt.xlim(0, max_seq_length)
        plt.ylim(-2.75, 2.75)
        plt.tight_layout()

    return fig, axes 
Example 10
Project: DSMnet   Author: wyf2017   File: test_ssim.py    Apache License 2.0 5 votes vote down vote up
def implot(im1, im2, im3, im4, im5, im6, im7, im8):
    m = 4
    n = 2
    ims = [im1, im2, im3, im4, im5, im6, im7, im8]
    for i in range(m*n):
        ax = plt.subplot(m, n, i+1)
        plt.sca(ax)
        plt.imshow(ims[i]) 
Example 11
Project: dougu   Author: bheinzerling   File: plot.py    MIT License 5 votes vote down vote up
def add_colorbar(im, aspect=20, pad_fraction=0.5, **kwargs):
    """Add a vertical color bar to an image plot."""
    from mpl_toolkits import axes_grid1
    divider = axes_grid1.make_axes_locatable(im.axes)
    width = axes_grid1.axes_size.AxesY(im.axes, aspect=1./aspect)
    pad = axes_grid1.axes_size.Fraction(pad_fraction, width)
    current_ax = plt.gca()
    cax = divider.append_axes("right", size=width, pad=pad)
    plt.sca(current_ax)
    return im.axes.figure.colorbar(im, cax=cax, **kwargs) 
Example 12
Project: neural-network-animation   Author: miloharper   File: test_contour.py    MIT License 5 votes vote down vote up
def test_given_colors_levels_and_extends():
    _, axes = plt.subplots(2, 4)

    data = np.arange(12).reshape(3, 4)

    colors = ['red', 'yellow', 'pink', 'blue', 'black']
    levels = [2, 4, 8, 10]

    for i, ax in enumerate(axes.flatten()):
        plt.sca(ax)

        filled = i % 2 == 0.
        extend = ['neither', 'min', 'max', 'both'][i // 2]

        if filled:
            last_color = -1 if extend in ['min', 'max'] else None
            plt.contourf(data, colors=colors[:last_color], levels=levels,
                         extend=extend)
        else:
            last_level = -1 if extend == 'both' else None
            plt.contour(data, colors=colors, levels=levels[:last_level],
                        extend=extend)

        plt.colorbar() 
Example 13
Project: skymap_statistics   Author: reedessick   File: mollweide.py    GNU General Public License v2.0 5 votes vote down vote up
def heatmap( post, ax, color_map='OrRd', colorbar=False, colorbar_label='' ):
    '''
    generate mollweide projection of heatmap with requested annotations
    '''
    plt.sca( ax )
    lalinf_plot.healpix_heatmap( post, cmap=plt.get_cmap(color_map) ) ### is this buggy when projection=="mollweide"?
    if colorbar:
        cb = plt.colorbar(orientation='horizontal', fraction=0.15, pad=0.03, shrink=0.8) ### FIXME: hard-coded options are a bit fragile...
        cb.set_label(colorbar_label) 
Example 14
Project: skymap_statistics   Author: reedessick   File: mollweide.py    GNU General Public License v2.0 5 votes vote down vote up
def contour( post, ax, levels=[0.1, 0.5, 0.9], alpha=1.0, colors='b', linewidths=1 ):
    '''
    generate mollweide projection of contours with requested annotations
    '''
    cpost = np.empty(post.shape)
    indecies = np.argsort(post)[::-1]
    cpost[indecies] = np.cumsum(post[indecies])

    plt.sca( ax )
    lalinf_plot.healpix_contour( cpost,
                                 levels=levels,
                                 alpha=alpha,
                                 colors=colors,
                                 linewidths=linewidths )

#-------------------------------------------------

### data preparation 
Example 15
Project: Sequential-Generation   Author: Philip-Bachman   File: utils.py    MIT License 5 votes vote down vote up
def plot_scatter(x, y, f_name, x_label=None, y_label=None):
    """
    Plot a scatter plot.
    """
    import matplotlib.pyplot as plt
    if x_label is None:
        x_label = 'Posterior KLd'
    if y_label is None:
        y_label = 'Expected Log-likelihood'
    fig = plt.figure()
    ax = fig.add_subplot(111)
    box = ax.get_position()
    ax.set_position([box.x0+(0.05*box.width), box.y0+(0.05*box.height), 0.96*box.width, 0.96*box.height])
    ax.set_xlabel(x_label, fontsize=22)
    ax.set_ylabel(y_label, fontsize=22)
    ax.hold(True)
    ax.scatter(x, y, s=24, alpha=0.5, c=u'b', marker=u'o')
    plt.sca(ax)
    x_locs, x_labels = plt.xticks()
    plt.xticks(x_locs, fontsize=18)
    y_locs, y_labels = plt.yticks()
    plt.yticks(y_locs, fontsize=18)
    fig.savefig(f_name, dpi=None, facecolor='w', edgecolor='w', \
        orientation='portrait', papertype=None, format='png', \
        transparent=False, bbox_inches=None, pad_inches=0.1, \
        frameon=None)
    plt.close(fig)
    return 
Example 16
Project: monsoon-onset   Author: jenfly   File: pub-figs-grl.py    MIT License 5 votes vote down vote up
def plot_mfc_budget(mfc_budget, index, year, legend=True,
                    legend_kw={'fontsize' : 9, 'loc' : 'upper left',
                               'handlelength' : 2.5},
                    dashes=[6, 2], netprecip=False, labelpad=1.5):
    ts = mfc_budget.sel(year=year)
    ind = index.sel(year=year)
    days = ts['day'].values
    styles = {'PRECTOT' : {'color' : 'k', 'linestyle' : '--', 'dashes' : dashes},
              'EVAP' : {'color' : 'k'},
              'MFC' : {'color' : 'k', 'linewidth' : 2},
              'dw/dt' : {'color' : '0.7', 'linewidth' : 2}}
    if netprecip:
        styles['P-E'] = {'color' : 'b', 'linewidth' : 2}
    for nm in styles:
        plt.plot(days, ts[nm], label=nm, **styles[nm])
    plt.axvline(ind['onset'], color='k')
    plt.axvline(ind['retreat'], color='k')
    plt.xlabel('Day of Year')
    plt.ylabel('mm day$^{-1}$', labelpad=labelpad)
    ax1 = plt.gca()
    ax2 = plt.twinx()
    plt.sca(ax2)
    plt.plot(days, ind['tseries'], 'r', alpha=0.6, linewidth=2, label='CMFC')
    atm.fmt_axlabels('y', 'mm', color='r', alpha=0.6)
    plt.gca().set_ylabel('mm', labelpad=labelpad)
    if legend:
        atm.legend_2ax(ax1, ax2, **legend_kw)
    return ax1, ax2 
Example 17
Project: monsoon-onset   Author: jenfly   File: summarize-tseries.py    MIT License 5 votes vote down vote up
def lineplots(data1, data2=None, data1_style=None, xlims=None, xticks=None,
              ylims=None, yticks=None, length=None, legend=False,
              legend_kw={'fontsize' : 9, 'handlelength' : 2.5},
              y2_lims=None, y2_opts={'color' : 'r', 'alpha' : 0.6},
              y1_label='', y2_label='', grp=None):

    data1, data2 = to_dataset(data1), to_dataset(data2)

    for nm in data1.data_vars:
        if data1_style is None:
            plt.plot(data1['dayrel'], data1[nm], label=nm)
        else:
            plt.plot(data1['dayrel'], data1[nm], data1_style[nm], label=nm)
    fmt_axes(xlims, xticks, ylims, yticks)
    plt.grid(True)
    plt.axvline(0, color='k')
    if length is not None:
        plt.axvline(length, color='k')
    if grp is not None and grp.row == grp.ncol - 1:
        plt.xlabel('Rel Day')
    plt.ylabel(y1_label)
    axes = [plt.gca()]

    if data2 is not None:
        plt.sca(plt.gca().twinx())
        for nm in data2.data_vars:
            plt.plot(data2['dayrel'], data2[nm], label=nm, linewidth=2,
                     **y2_opts)
        if y2_lims is not None:
            plt.ylim(y2_lims)
        atm.fmt_axlabels('y', y2_label, **y2_opts)
    axes = axes + [plt.gca()]

    if legend:
        if data2 is None:
            plt.legend(**legend_kw)
        else:
            atm.legend_2ax(axes[0], axes[1], **legend_kw)

    return axes 
Example 18
Project: monsoon-onset   Author: jenfly   File: thesis-figs-jclim.py    MIT License 5 votes vote down vote up
def plot_mfc_budget(mfc_budget, index, year, legend=True,
                    legend_kw={'fontsize' : 9, 'loc' : 'upper left',
                               'handlelength' : 2.5},
                    dashes=[6, 2], netprecip=False, labelpad=1.5):
    ts = mfc_budget.sel(year=year)
    ind = index.sel(year=year)
    days = ts['day'].values
    styles = {'PRECTOT' : {'color' : 'k', 'linestyle' : '--', 'dashes' : dashes},
              'EVAP' : {'color' : 'k'},
              'MFC' : {'color' : 'k', 'linewidth' : 2},
              'dw/dt' : {'color' : '0.7', 'linewidth' : 2}}
    if netprecip:
        styles['P-E'] = {'color' : 'b', 'linewidth' : 2}
    for nm in styles:
        plt.plot(days, ts[nm], label=nm, **styles[nm])
    plt.axvline(ind['onset'], color='k')
    plt.axvline(ind['retreat'], color='k')
    plt.xlabel('Day of Year')
    plt.ylabel('mm day$^{-1}$', labelpad=labelpad)
    ax1 = plt.gca()
    ax2 = plt.twinx()
    plt.sca(ax2)
    plt.plot(days, ind['tseries'], 'r', alpha=0.6, linewidth=2, label='CMFC')
    atm.fmt_axlabels('y', 'mm', color='r', alpha=0.6)
    plt.gca().set_ylabel('mm', labelpad=labelpad)
    if legend:
        atm.legend_2ax(ax1, ax2, **legend_kw)
    return ax1, ax2 
Example 19
Project: statistical-learning-methods-note   Author: ysh329   File: Perceptron.py    Apache License 2.0 5 votes vote down vote up
def plotChart(self, costList, misRateList, saveFigPath):
        '''
        绘制错分率和损失函数值随 epoch 变化的曲线。
        :param costList: 训练过程中每个epoch的损失函数列表
        :param misRateList: 训练过程中每个epoch的错分率列表
        :return:
        '''
        # 导入绘图库
        import matplotlib.pyplot as plt
        # 新建画布
        plt.figure('Perceptron Cost and Mis-classification Rate',figsize=(8, 9))
        # 设定两个子图和位置关系
        ax1 = plt.subplot(211)
        ax2 = plt.subplot(212)

        # 选择子图1并绘制损失函数值折线图及相关坐标轴
        plt.sca(ax1)
        plt.plot(xrange(1, len(costList)+1), costList, '--b*')
        plt.xlabel('Epoch No.')
        plt.ylabel('Cost')
        plt.title('Plot of Cost Function')
        plt.grid()
        ax1.legend(u"Cost", loc='best')

        # 选择子图2并绘制错分率折线图及相关坐标轴
        plt.sca(ax2)
        plt.plot(xrange(1, len(misRateList)+1), misRateList, '-r*')
        plt.xlabel('Epoch No.')
        plt.ylabel('Mis-classification Rate')
        plt.title('Plot of Mis-classification Rate')
        plt.grid()
        ax2.legend(u'Mis-classification Rate', loc='best')

        # 显示图像并打印和保存
        # 需要先保存再绘图否则相当于新建了一张新空白图像然后保存
        plt.savefig(saveFigPath)
        plt.show()

################################### PART3 TEST ########################################
# 例子 
Example 20
Project: statistical-learning-methods-note   Author: ysh329   File: Dual-form_Perceptron.py    Apache License 2.0 5 votes vote down vote up
def plotChart(self, costList, misRateList, saveFigPath):
        '''
        绘制错分率和损失函数值随 epoch 变化的曲线。
        :param costList: 训练过程中每个epoch的损失函数列表
        :param misRateList: 训练过程中每个epoch的错分率列表
        :return:
        '''
        # 导入绘图库
        import matplotlib.pyplot as plt
        # 新建画布
        plt.figure('Perceptron Cost and Mis-classification Rate', figsize=(8, 9))
        # 设定两个子图和位置关系
        ax1 = plt.subplot(211)
        ax2 = plt.subplot(212)

        # 选择子图1并绘制损失函数值折线图及相关坐标轴
        plt.sca(ax1)
        plt.plot(xrange(1, len(costList) + 1), costList, '--b*')
        plt.xlabel('Epoch No.')
        plt.ylabel('Cost')
        plt.title('Plot of Cost Function')
        plt.grid()
        ax1.legend(u"Cost", loc='best')

        # 选择子图2并绘制错分率折线图及相关坐标轴
        plt.sca(ax2)
        plt.plot(xrange(1, len(misRateList) + 1), misRateList, '-r*')
        plt.xlabel('Epoch No.')
        plt.ylabel('Mis-classification Rate')
        plt.title('Plot of Mis-classification Rate')
        plt.grid()
        ax2.legend(u'Mis-classification Rate', loc='best')

        # 显示图像并打印和保存
        # 需要先保存再绘图否则相当于新建了一张新空白图像然后保存
        plt.savefig(saveFigPath)
        plt.show()

################################### PART3 TEST ########################################
# 例子 
Example 21
Project: safe-exploration   Author: befelix   File: utils_visualization.py    MIT License 5 votes vote down vote up
def plot_ellipsoid_2D(p, q, ax, n_points=100, color="r"):
    """ Plot an ellipsoid in 2D

    TODO: Untested!

    Parameters
    ----------
    p: 3x1 array[float]
        Center of the ellipsoid
    q: 3x3 array[float]
        Shape matrix of the ellipsoid
    ax: matplotlib.Axes object
        Ax on which to plot the ellipsoid

    Returns
    -------
    ax: matplotlib.Axes object
        The Ax containing the ellipsoid
    """
    plt.sca(ax)
    r = nLa.cholesky(q).T;  # checks spd inside the function
    t = np.linspace(0, 2 * np.pi, n_points);
    z = [np.cos(t), np.sin(t)];
    ellipse = np.dot(r, z) + p;
    handle, = ax.plot(ellipse[0, :], ellipse[1, :], color)

    return ax, handle 
Example 22
Project: safe-exploration   Author: befelix   File: environments.py    MIT License 5 votes vote down vote up
def plot_state(self, ax, x=None, color="b", normalize=True):
        """ Plot the current state or a given state vector

        Parameters:
        -----------
        ax: Axes Object
            The axes to plot the state on
        x: 2x0 array_like[float], optional
            A state vector of the dynamics
        Returns
        -------
        ax: Axes Object
            The axes with the state plotted
        """
        if x is None:
            x = self.current_state
            if normalize:
                x, _ = self.normalize(x)
        assert len(
            x) == self.n_s, "x needs to have the same number of states as the dynamics"
        plt.sca(ax)
        ax.plot(x[0], x[1], color=color, marker="o", mew=1.2)
        return ax 
Example 23
Project: safe-exploration   Author: befelix   File: environments.py    MIT License 5 votes vote down vote up
def plot_ellipsoid_trajectory(self, p, q, vis_safety_bounds=True, ax=None,
                                  color="r"):
        """ Plot the reachability ellipsoids given in observation space

        TODO: Need more principled way to transform ellipsoid to internal states

        Parameters
        ----------
        p: n x n_s array[float]
            The ellipsoid centers of the trajectory
        q: n x n_s x n_s  ndarray[float]
            The shape matrices of the trajectory
        vis_safety_bounds: bool, optional
            Visualize the safety bounds of the system

        """
        new_ax = False

        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
            new_ax = True

        plt.sca(ax)
        n, n_s = np.shape(p)
        handles = [None] * n
        for i in range(n):
            p_i = cas_reshape(p[i, :], (n_s, 1)) + self.p_origin.reshape((n_s, 1))
            q_i = cas_reshape(q[i, :], (self.n_s, self.n_s))
            ax, handles[i] = plot_ellipsoid_2D(p_i, q_i, ax, color=color)
            # ax = plot_ellipsoid_2D(p_i,q_i,ax,color = color)

        if vis_safety_bounds:
            ax = self.plot_safety_bounds(ax)

        if new_ax:
            plt.show()

        return ax, handles 
Example 24
Project: safe-exploration   Author: befelix   File: utils_visualization.py    MIT License 5 votes vote down vote up
def plot_ellipsoid_2D(p, q, ax, n_points = 100, color = "r"):
    """ Plot an ellipsoid in 2D

    TODO: Untested!

    Parameters
    ----------
    p: 3x1 array[float]
        Center of the ellipsoid
    q: 3x3 array[float]
        Shape matrix of the ellipsoid
    ax: matplotlib.Axes object
        Ax on which to plot the ellipsoid

    Returns
    -------
    ax: matplotlib.Axes object
        The Ax containing the ellipsoid
    """
    plt.sca(ax)
    r = nLa.cholesky(q).T; #checks spd inside the function
    t = np.linspace(0, 2*np.pi, n_points);
    z = [np.cos(t), np.sin(t)];
    ellipse = np.dot(r,z) + p;
    handle, = ax.plot(ellipse[0,:], ellipse[1,:],color)

    return ax, handle 
Example 25
Project: mlcv-tutorial   Author: johny-c   File: clustering.py    GNU General Public License v3.0 5 votes vote down vote up
def scatter_iteration(ax, data, title='', colors='b', marker='o',
                      pause_time=0.3):
    """Scatter points as estimated in a single iteration of an algorithm.

    Parameters
    ----------
    ax : matplotlib.axes.Axes instance
        The axes to draw in.

    data : array, shape (n_samples, n_features)
        The data to scatter plot.

    title : str (optional)
        Title of the plot.

    colors : array-like, shape (n_samples,) or str(optional
        RGBA color per sample or strings or single string.

    marker : str, (optional)
        The representation of the points.

    pause_time : float (optional)
        How long to wait so the drawing can be rendered and observed.

    """

    plt.sca(ax)
    plt.scatter(data[:, 0], data[:, 1], c=colors, marker=marker, lw=0, s=50)
    plt.title('{}'.format(title), fontweight='bold')
    plt.draw()
    plt.pause(pause_time)


########################################################################### 
Example 26
Project: igscout   Author: Immunotools   File: igscout.py    GNU General Public License v3.0 5 votes vote down vote up
def _OutputMaxFreqs(self, ax):
        ax.plot(range(0, len(self.max_freq_list)), self.max_freq_list, marker = 'o', linestyle = '-', color = 'blue', label = 'Max frequency')
        ax.plot(range(0, len(self.max2_freq_list)), self.max2_freq_list, marker = 'o', linestyle = '-', color = 'red', label = 'Second max frequency')
        plt.sca(ax)
        plt.ylim(-0.05, 1.15)
        plt.legend(loc = 'upper center', ncol = 2)
        plt.xticks(range(0, len(self.max_freq_list)), self.nucls, fontsize = 10)
        plt.ylabel('Frequency', fontsize = 12) 
Example 27
Project: igscout   Author: Immunotools   File: igscout.py    GNU General Public License v3.0 5 votes vote down vote up
def _OutputConservation(self, ax):
        ax.plot(range(0, len(self.max_freq_list)), self.conservations, marker = 'o', linestyle = '-')
        plt.sca(ax)
        plt.ylim(-0.05, 2.05)
        plt.xticks(range(0, len(self.max_freq_list)), self.nucls, fontsize = 10)
        plt.ylabel('Conservation', fontsize = 12) 
Example 28
Project: igscout   Author: Immunotools   File: igscout.py    GNU General Public License v3.0 5 votes vote down vote up
def _OutputNuclDist(self, ax):
        ax.bar(range(0, len(self.a_c_g_t)), self.a_c_g_t, color = 'orange', label = 'A')
        ax.bar(range(0, len(self.a_c_g_t)), self.c_g_t, color = 'green', label = 'C')
        ax.bar(range(0, len(self.a_c_g_t)), self.g_t, color = 'blue', label = 'G')
        ax.bar(range(0, len(self.a_c_g_t)), self.t, color = 'red', label = 'T')
        plt.sca(ax)
        plt.legend(loc = 'upper center', ncol = 4)
        plt.xticks(range(0, len(self.nucls)), self.nucls, fontsize = 10, rotation = 0)
        plt.ylim(-0.05, 1.1) 
Example 29
Project: psychrometric-chart-makeover   Author: buds-lab   File: axisgrid.py    MIT License 5 votes vote down vote up
def facet_axis(self, row_i, col_j):
        """Make the axis identified by these indices active and return it."""

        # Calculate the actual indices of the axes to plot on
        if self._col_wrap is not None:
            ax = self.axes.flat[col_j]
        else:
            ax = self.axes[row_i, col_j]

        # Get a reference to the axes object we want, and make it active
        plt.sca(ax)
        return ax 
Example 30
Project: psychrometric-chart-makeover   Author: buds-lab   File: axisgrid.py    MIT License 5 votes vote down vote up
def plot_marginals(self, func, **kwargs):
        """Draw univariate plots for `x` and `y` separately.

        Parameters
        ----------
        func : plotting callable
            This must take a 1d array of data as the first positional
            argument, it must plot on the "current" axes, and it must
            accept a "vertical" keyword argument to orient the measure
            dimension of the plot vertically.
        kwargs : key, value mappings
            Keyword argument are passed to the plotting function.

        Returns
        -------
        self : JointGrid instance
            Returns `self`.

        """
        kwargs["vertical"] = False
        plt.sca(self.ax_marg_x)
        func(self.x, **kwargs)

        kwargs["vertical"] = True
        plt.sca(self.ax_marg_y)
        func(self.y, **kwargs)

        return self 
Example 31
Project: psychrometric-chart-makeover   Author: buds-lab   File: test_axes.py    MIT License 5 votes vote down vote up
def test_pyplot_axes():
    # test focusing of Axes in other Figure
    fig1, ax1 = plt.subplots()
    fig2, ax2 = plt.subplots()
    plt.sca(ax1)
    assert ax1 is plt.gca()
    assert fig1 is plt.gcf()
    plt.close(fig1)
    plt.close(fig2) 
Example 32
Project: SignLanguage_ML   Author: mareep-raljodid   File: test_axes.py    MIT License 5 votes vote down vote up
def test_pyplot_axes():
    # test focusing of Axes in other Figure
    fig1, ax1 = plt.subplots()
    fig2, ax2 = plt.subplots()
    plt.sca(ax1)
    assert ax1 is plt.gca()
    assert fig1 is plt.gcf()
    plt.close(fig1)
    plt.close(fig2) 
Example 33
Project: NeuralTuringMachine   Author: MarkPKCollier   File: produce_heat_maps.py    GNU Lesser General Public License v3.0 5 votes vote down vote up
def plot_figures(figures, nrows=1, ncols=1, width_ratios=None):
    fig, axeslist = plt.subplots(ncols=ncols, nrows=nrows, gridspec_kw={'width_ratios': width_ratios})

    for ind, (title, fig) in enumerate(figures):
        axeslist.ravel()[ind].imshow(fig, cmap='gray', interpolation='nearest')
        axeslist.ravel()[ind].set_title(title)
        if TASK != 'Associative Recall' or ind == 0:
            axeslist.ravel()[ind].set_xlabel('Time ------->')
    
    if TASK == 'Associative Recall':
        plt.sca(axeslist[1])
        plt.xticks([0, 1, 2])
        plt.sca(axeslist[2])
        plt.xticks([0, 1, 2])

    if TASK == 'Copy':
        plt.sca(axeslist[1])
        plt.yticks([])

    plt.tight_layout() 
Example 34
Project: brainpipe   Author: EtienneCmb   File: cmon_plt.py    GNU General Public License v3.0 5 votes vote down vote up
def _BorderPlot(time, x, color, kind, alpha, legend, linewidth, axes):
    npts, dev = x.shape
    # Get the deviation/sem :
    xStd = np.std(x, axis=1)
    if kind is 'sem':
        xStd = xStd/np.sqrt(npts-1)
    xMean = np.mean(x, 1)
    xLow, xHigh = xMean-xStd, xMean+xStd

    # Plot :
    if axes is None:
        axes = plt.gca()
    plt.sca(axes)
    ax = plt.plot(time, xMean, color=color, label=legend, linewidth=linewidth)
    plt.fill_between(time, xLow, xHigh, alpha=alpha, color=ax[0].get_color()) 
Example 35
Project: pyebm   Author: 88vikram   File: visualize.py    GNU General Public License v3.0 5 votes vote down vote up
def Ordering(labels, pi0_all,pi0_mean, plotorder):

    columns = ['Features', 'Event Position', 'Count']
    datapivot = pd.DataFrame(columns = columns)
    for i in range(len(labels)):
        bb = [item.index(i) for item in pi0_all]
        for j in range(len(labels)):
            cc = pd.DataFrame([[bb.count(j),j, labels[i]]], index = [j], columns = ['Count','Event Position','Features'])
            datapivot = datapivot.append(cc)
    datapivot = datapivot.pivot("Features", "Event Position", "Count")
    if plotorder == True:
        newindex = []
        for i in range(len(list(pi0_mean))):
            aa = labels[pi0_mean[i]]
            newindex.append(aa)
        datapivot = datapivot.reindex(newindex)        
    xticks = np.arange(len(labels)) + 1
    datapivot = datapivot[datapivot.columns].astype(float)
    fig, ax = plt.subplots(1,1,figsize=(7, 7))
    heatmap = sns.heatmap(datapivot, cmap = 'binary', xticklabels=xticks, vmin=0, vmax=len(pi0_all),ax=ax)
    plt.sca(ax)
    plt.title('Positional variance diagram of the central ordering')
    plt.yticks(rotation=0) 
    plt.show() 
Example 36
Project: visualqc   Author: raamana   File: freesurfer.py    Apache License 2.0 5 votes vote down vote up
def plot_contours_in_slice(self, slice_seg, target_axis):
        """Plots contour around the data in slice (after binarization)"""

        plt.sca(target_axis)
        contour_handles = list()
        for index, label in enumerate(self.unique_labels_display):
            binary_slice_seg = slice_seg == index
            if not binary_slice_seg.any():
                continue
            ctr_h = plt.contour(binary_slice_seg,
                                levels=[cfg.contour_level, ],
                                colors=(self.color_for_label[index],),
                                linewidths=cfg.contour_line_width,
                                alpha=self.alpha_seg,
                                zorder=cfg.seg_zorder_freesurfer)
            contour_handles.append(ctr_h)

        return contour_handles 
Example 37
Project: ICML-2015   Author: Philip-Bachman   File: utils.py    MIT License 5 votes vote down vote up
def plot_scatter(x, y, f_name, x_label=None, y_label=None):
    """
    Plot a scatter plot.
    """
    import matplotlib.pyplot as plt
    if x_label is None:
        x_label = 'Posterior KLd'
    if y_label is None:
        y_label = 'Expected Log-likelihood'
    fig = plt.figure()
    ax = fig.add_subplot(111)
    box = ax.get_position()
    ax.set_position([box.x0+(0.05*box.width), box.y0+(0.05*box.height), 0.96*box.width, 0.96*box.height])
    ax.set_xlabel(x_label, fontsize=22)
    ax.set_ylabel(y_label, fontsize=22)
    ax.hold(True)
    ax.scatter(x, y, s=24, alpha=0.5, c=u'b', marker=u'o')
    plt.sca(ax)
    x_locs, x_labels = plt.xticks()
    plt.xticks(x_locs, fontsize=18)
    y_locs, y_labels = plt.yticks()
    plt.yticks(y_locs, fontsize=18)
    fig.savefig(f_name, dpi=None, facecolor='w', edgecolor='w', \
        orientation='portrait', papertype=None, format='png', \
        transparent=False, bbox_inches=None, pad_inches=0.1, \
        frameon=None)
    plt.close(fig)
    return 
Example 38
Project: lambda-tensorflow-object-detection   Author: mikylucky   File: test_axes.py    GNU General Public License v3.0 5 votes vote down vote up
def test_pyplot_axes():
    # test focusing of Axes in other Figure
    fig1, ax1 = plt.subplots()
    fig2, ax2 = plt.subplots()
    plt.sca(ax1)
    assert ax1 is plt.gca()
    assert fig1 is plt.gcf()
    plt.close(fig1)
    plt.close(fig2) 
Example 39
Project: DIAG-NRE   Author: thunlp   File: utils.py    MIT License 5 votes vote down vote up
def show_word_scores_heatmap(score_tensor_tup, x_ticks, y_ticks, nrows=1, ncols=1, titles=None, figsize=(8, 8), fontsize=14):
    def colorbar(mappable):
        ax = mappable.axes
        fig = ax.figure
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="1%", pad=0.1)
        return fig.colorbar(mappable, cax=cax)
    if not isinstance(score_tensor_tup, tuple):
        score_tensor_tup = (score_tensor_tup, )

    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)

    for idx, ax in enumerate(axs):
        score_tensor = score_tensor_tup[idx]
        img = ax.matshow(score_tensor.numpy())
        plt.sca(ax)
        plt.xticks(range(score_tensor.size(1)), x_ticks, fontsize=fontsize)
        plt.yticks(range(score_tensor.size(0)), y_ticks, fontsize=fontsize)
        if titles is not None:
            plt.title(titles[idx], fontsize=fontsize + 2)
        colorbar(img)

    for ax in axs:
        ax.set_aspect('auto')
    plt.tight_layout(h_pad=1)

    plt.show() 
Example 40
Project: atmos-tools   Author: jenfly   File: utils.py    MIT License 5 votes vote down vote up
def ax_lims_ticks(xlims=None, xticks=None, ylims=None, yticks=None, ax=None):
    """Assign limits and ticks to x and y axes."""
    if ax is not None:
        plt.sca(ax)
    if xticks is not None:
        plt.xticks(xticks)
    if xlims is not None:
        plt.xlim(xlims)
    if yticks is not None:
        plt.yticks(yticks)
    if ylims is not None:
        plt.ylim(ylims)
    return None


# ---------------------------------------------------------------------- 
Example 41
Project: atmos-tools   Author: jenfly   File: utils.py    MIT License 5 votes vote down vote up
def subplot(self, row, col):
        """ Set the subplot axes to the specified row and column."""
        self.row, self.col = row, col
        self.ax = self.axes[row, col]
        plt.sca(self.ax)
        return None 
Example 42
Project: SPTM   Author: nsavinov   File: process_log.py    MIT License 5 votes vote down vote up
def add_to_plots(plots, input):
  FAIL_STEPS = MAX_NUMBER_OF_STEPS_NAVIGATION + 1
  environment, mode, result = input
  steps = []
  success_rate = float(sum([value for value, _, _ in result])) / float(len(result))
  print environment, mode, success_rate
  for success, length, _ in result:
    if success:
      steps.append(length)
    else:
      steps.append(FAIL_STEPS)
  steps.sort()
  cumulative = {}
  for index, step in enumerate(steps):
    if step < FAIL_STEPS:
      cumulative[step] = float(index + 1) / float(len(steps))
    else:
      cumulative[step] = success_rate
  if environment in plots:
    figure, axes = plots[environment]
    plt.sca(axes)
  else:
    figure, axes = plt.subplots()
    plots[environment] = figure, axes
  sorted_cumulative = sorted(cumulative.items())
  # print sorted_cumulative
  x = [0] + [value for value, _ in sorted_cumulative] + [FAIL_STEPS]
  y = [0] + [value for _, value in sorted_cumulative] + [success_rate]
  y = [SUCCESS_SCALING * value for value in y]
  plt.plot(x, y, METHOD_TO_COLOR[mode], linewidth=LINEWIDTH, label=METHOD_TO_LEGEND[mode])
  plt.title(ENVIRONMENT_TO_PAPER_TITLE[environment], fontsize=TITLE_FONT)
  plt.xlabel('Steps', fontsize=AXIS_LABEL_FONT)
  if ENVIRONMENT_TO_PAPER_TITLE[environment] in ['Test-1', 'Test-5', 'Val-1']:
    plt.ylabel('Success rate', fontsize=AXIS_LABEL_FONT)
  plt.axis([0, FAIL_STEPS, 0, 1.0 * SUCCESS_SCALING])
  plt.grid(linestyle='dotted')
  print ENVIRONMENT_TO_PAPER_TITLE[environment]
  if ENVIRONMENT_TO_PAPER_TITLE[environment] in ['Val-3']:
    leg = plt.legend(shadow=True, fontsize=LEGEND_FONT, loc='upper left', fancybox=True, framealpha=1.0)
    for legobj in leg.legendHandles:
      legobj.set_linewidth(LEGEND_LINE_WIDTH) 
Example 43
Project: matplotlib-hep   Author: ibab   File: __init__.py    MIT License 5 votes vote down vote up
def plot_pull(data, func):

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator

    ax, bx = make_split(0.8)

    plt.sca(ax)

    x, y, norm = histpoints(data)

    lower, upper = ax.get_xlim()

    xs = np.linspace(lower, upper, 200)
    plt.plot(xs, norm * func(xs), 'b-')

    #plt.gca().yaxis.set_major_locator(MaxNLocator(prune='lower'))

    plt.sca(bx)

    resid = y[1] - norm * func(x)
    err = np.zeros_like(resid)
    err[resid >= 0] = y[0][resid >= 0]
    err[resid < 0] = y[2][resid < 0]

    pull = resid / err

    plt.errorbar(x, pull, yerr=1, color='k', fmt='o')
    plt.ylim(-5, 5)
    plt.axhline(0, color='b')

    plt.sca(ax)

    return ax, bx 
Example 44
Project: miccai-2016-surgical-activity-rec   Author: rdipietro   File: standardize_jigsaws.py    Apache License 2.0 4 votes vote down vote up
def main():
    """ Create a standardized data file from raw data. """

    args = define_and_process_args()

    print('Standardizing JIGSAWS..')
    print()

    print('%d classes:' % NUM_CLASSES)
    print(CLASSES)
    print()

    user_to_trial_names = {}
    for user, trials in USER_TO_TRIALS.items():
        user_to_trial_names[user] = [get_trial_name(user, trial)
                                     for trial in trials]
    print('Users and corresponding trial names:')
    for user in ALL_USERS:
        print(user, '   ', user_to_trial_names[user])
    print()


    all_trial_names = sorted(list(
        itertools.chain(*user_to_trial_names.values())
    ))
    print('All trial names, sorted:')
    print(all_trial_names)
    print()

    # Original data is at 30 Hz.
    all_data = {trial_name: downsample(
                    load_kinematics_and_new_labels(args.data_dir, trial_name),
                    factor=6)
                for trial_name in all_trial_names}
    print('Downsampled to 5 Hz.')
    print()

    fig, ax_list = plt.subplots(nrows=len(all_data), ncols=1,
                                sharex=True, figsize=(15, 50))
    for ax, (trial_name, data_mat) in zip(ax_list, sorted(all_data.items())):
        plt.sca(ax)
        data.plot_label_seq(data_mat[:, -1:], NUM_CLASSES, 0)
        plt.title(trial_name)
    plt.tight_layout()
    vis_path = os.path.join(args.data_dir, 'standardized_data_labels.png')
    plt.savefig(vis_path)
    plt.close(fig)
    print('Saved label visualization to %s.' % vis_path)
    print()

    export_dict = dict(
        dataset_name=DATASET_NAME, classes=CLASSES, num_classes=NUM_CLASSES,
        col_names=STANDARDIZED_COL_NAMES, all_users=ALL_USERS,
        user_to_trial_names=user_to_trial_names,
        all_trial_names=all_trial_names, all_data=all_data)
    standardized_data_path = os.path.join(args.data_dir, args.data_filename)
    with open(standardized_data_path, 'w') as f:
        cPickle.dump(export_dict, f)
    print('Saved standardized data file %s.' % standardized_data_path)
    print() 
Example 45
Project: DSMnet   Author: wyf2017   File: stereo.py    Apache License 2.0 4 votes vote down vote up
def start(self):
        args = self.args
        if args.mode == 'test':
            self.validate()
            return
    
        losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = [], [], [], [], [], [], []
        path_val = os.path.join(self.dirpath, "loss.pkl")
        if(os.path.exists(path_val)):
            state_val = torch.load(path_val)
            losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = state_val
        # 开始训练模型
        plt.figure(figsize=(18, 5))
        time_start = time.time()
        epoch0 = self.epoch
        for epoch in range(epoch0, args.epochs):
            self.epoch = epoch
            self.lr_adjust(self.optim, args.lr_epoch0, args.lr_stride, args.lr, epoch) # 自定义的lr_adjust函数,见上
            self.lossfun.Weight_Adjust_levels(epoch)
            msg = 'lr: %.6f | weight of levels: %s' % (self.optim.param_groups[0]['lr'], str(self.lossfun.weight_levels))
            logging.info(msg)
    
            # train for one epoch
            mloss, mEPE, mD1 = self.train()
            losses.append(mloss)
            EPEs.append(mEPE)
            D1s.append(mD1)
    
            if(epoch % self.args.val_freq == 0) or (epoch == args.epochs-1):
                # evaluate on validation set
                mloss_val, mEPE_val, mD1_val = self.validate()
                epochs_val.append(epoch)
                losses_val.append(mloss_val)
                EPEs_val.append(mEPE_val)
                D1s_val.append(mD1_val)
        
                # remember best [email protected] and save checkpoint
                is_best = mD1_val < self.best_prec
                self.best_prec = min(mD1_val, self.best_prec)
                self.save_checkpoint(epoch, self.best_prec, is_best)
                torch.save([losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val], path_val)
                
                # plt
                m, n = 1, 3
                ax1 = plt.subplot(m, n, 1)
                ax2 = plt.subplot(m, n, 2)
                ax3 = plt.subplot(m, n, 3)
                plt.sca(ax1); plt.cla(); plt.xlabel("epoch"); plt.ylabel("Loss")
                plt.plot(np.array(losses), label='train'); plt.plot(np.array(epochs_val), np.array(losses_val), label='val'); plt.legend()
                plt.sca(ax2); plt.cla(); plt.xlabel("epoch"); plt.ylabel("EPE")
                plt.plot(np.array(EPEs), label='train'); plt.plot(np.array(epochs_val), np.array(EPEs_val), label='val'); plt.legend()
                plt.sca(ax3); plt.cla(); plt.xlabel("epoch"); plt.ylabel("D1")
                plt.plot(np.array(D1s), label='train'); plt.plot(np.array(epochs_val), np.array(D1s_val), label='val'); plt.legend()
                plt.savefig("check_%s_%s_%s_%s.png" % (args.mode, args.dataset, args.net, args.loss_name))
            
            time_curr = (time.time() - time_start)/3600.0
            time_all =  time_curr*(args.epochs - epoch0)/(epoch + 1 - epoch0)
            msg = 'Progress: %.2f | %.2f (hour)\n' % (time_curr, time_all)
            logging.info(msg) 
Example 46
Project: kappa   Author: ajkerr0   File: plot.py    MIT License 4 votes vote down vote up
def bondsax(molecule, ax, sites=False, indices=False, faces=False, order=False, 
          atomtypes=False, linewidth=4., size_scale=1.):
    """Draw a 2d 'overhead' view of a molecule."""
    
    plt.sca(ax)
    
    posList = molecule.posList
    length = len(molecule)
    
    for bond in molecule.bondList:
        i,j = bond
        plt.plot([posList[i][0],posList[j][0]],
                 [posList[i][1],posList[j][1]],
                 color='k', zorder=-1, linewidth=linewidth)
        
    cList = np.zeros([length,3])
    
    if sites:
        for count in range(len(molecule)):
            cList[count] = colors.hex2color(colors.cnames[atomColors[molecule.zList[count]]])
        plt.scatter(posList[:,0],posList[:,1],s=1.5*radList[molecule.zList]*size_scale,c=cList,
                    edgecolors='k')
        
    if indices:
        for index, pos in enumerate(molecule.posList):
            plt.annotate(index, (pos[0]+.1, pos[1]+.1), color='b', fontsize=10)
            
    if atomtypes:
        for atomtype, pos in zip(molecule.atomtypes, molecule.posList):
            plt.annotate(atomtype, (pos[0]-.5, pos[1]-.5), color='b', fontsize=10)
        
    if faces:
        for i,face in enumerate(molecule.faces):
            openAtoms = [x for x in face.atoms if x not in face.closed]
            plt.plot(face.pos[0],face.pos[1], 'rx', markersize=15., zorder=-2)
            plt.scatter(posList[openAtoms][:,0], posList[openAtoms][:,1], s=75., c='red')
            plt.scatter(posList[face.closed][:,0], posList[face.closed][:,1], s=40, c='purple')
            plt.annotate(i, (face.pos[0]-.35*face.norm[0], face.pos[1]-.35*face.norm[1]), 
                         color='r', fontsize=20)
            if np.linalg.norm(face.norm[:2]) > 0.0001:
                plt.quiver(face.pos[0]+.5*face.norm[0], face.pos[1]+.5*face.norm[1], 5.*face.norm[0], 5.*face.norm[1],
                color='r', headwidth=1, units='width', width=5e-3, headlength=2.5)
                
    if order:
        for index, bo in enumerate(molecule.bondorder):
            i,j = molecule.bondList[index]
            midpoint = (molecule.posList[i]+molecule.posList[j])/2.
            plt.annotate(bo, (midpoint[0], midpoint[1]), color='k', fontsize=20)
    
    plt.axis('equal')
    
    plt.show() 
Example 47
Project: monsoon-onset   Author: jenfly   File: utils.py    MIT License 4 votes vote down vote up
def plotyy(data1, data2=None, xname='dayrel', data1_styles=None,
           y2_opts={'color' : 'r', 'alpha' : 0.6, 'linewidth' : 2},
           xlims=None, xticks=None, ylims=None, yticks=None, y2_lims=None,
           xlabel='', y1_label='', y2_label='', legend=False,
           legend_kw={'fontsize' : 9, 'handlelength' : 2.5},
           x0_axvlines=None, grid=True):
    """Plot data1 and data2 together on different y-axes."""

    data1, data2 = atm.to_dataset(data1), atm.to_dataset(data2)

    for nm in data1.data_vars:
        if data1_styles is None:
            plt.plot(data1[xname], data1[nm], label=nm)
        elif isinstance(data1_styles[nm], dict):
            plt.plot(data1[xname], data1[nm], label=nm, **data1_styles[nm])
        else:
            plt.plot(data1[xname], data1[nm], data1_styles[nm], label=nm)
    atm.ax_lims_ticks(xlims, xticks, ylims, yticks)
    plt.grid(grid)
    if x0_axvlines is not None:
        for x0 in x0_axvlines:
            plt.axvline(x0, color='k')
    plt.xlabel(xlabel)
    plt.ylabel(y1_label)
    axes = [plt.gca()]

    if data2 is not None:
        plt.sca(plt.gca().twinx())
        for nm in data2.data_vars:
            plt.plot(data2[xname], data2[nm], label=nm, **y2_opts)
        if y2_lims is not None:
            plt.ylim(y2_lims)
        if 'linewidth' in y2_opts:
            y2_opts.pop('linewidth')
        atm.fmt_axlabels('y', y2_label, **y2_opts)
        atm.ax_lims_ticks(xlims, xticks)
    axes = axes + [plt.gca()]

    if legend:
        if data2 is None:
            plt.legend(**legend_kw)
        else:
            atm.legend_2ax(axes[0], axes[1], **legend_kw)

    return axes

# ---------------------------------------------------------------------- 
Example 48
Project: mlcv-tutorial   Author: johny-c   File: clustering.py    GNU General Public License v3.0 4 votes vote down vote up
def draw_ellipses_iteration(ax, data, covs, title='', colors='b', marker='o',
                            ellipses_to_remove=None, pause_time=0.3):
    """Draw ellipses as estimated in a single iteration of an algorithm.

    Parameters
    ----------
    ax : matplotlib.axes.Axes instance
        The axes to draw in.

    data : array, shape (n_components, n_features)
        The data to scatter plot.

    covs : array, shape (n_components, n_features, n_features)
        The covariance matrices of the components

    title : str (optional)
        Title of the plot.

    colors : array, shape (n_samples, 4) or str (optional)
        RGBA color per sample or single string

    marker : str, (optional)
        The representation of the points.

    pause_time : float (optional)
        How long to wait so the drawing can be rendered and observed.

    ellipses_to_remove : list
        List of ellipses from previous iteration(s) to be cleared.

    Returns
    -------
    ellipses : list[matplotlib.patches.Ellipse]
        The drawn ellipses objects.

    """

    if ellipses_to_remove is not None:
        [e.remove() for e in ellipses_to_remove]

    ellipses = []
    plt.sca(ax)
    plt.scatter(data[:, 0], data[:, 1], c=colors, marker=marker, lw=0, s=50)
    n_components, n_features = data.shape
    for k in range(n_components):
        ellipse = draw_ellipse(ax, data[k, :], covs[k, :, :])
        ellipses.append(ellipse)

    plt.title('{}'.format(title), fontweight='bold')
    plt.draw()
    plt.pause(pause_time)
    return ellipses


########################################################################### 
Example 49
Project: empymod   Author: empymod   File: fdesign.py    Apache License 2.0 4 votes vote down vote up
def _plot_transform_pairs(fCI, r, k, axes, tit):
    r"""Plot the input transform pairs."""

    # Plot lhs
    plt.sca(axes[0])
    plt.title('|' + tit + ' lhs|')
    for f in fCI:
        if f.name == 'j2':
            lhs = f.lhs(k)
            plt.loglog(k, np.abs(lhs[0]), lw=2, label='j0')
            plt.loglog(k, np.abs(lhs[1]), lw=2, label='j1')
        else:
            plt.loglog(k, np.abs(f.lhs(k)), lw=2, label=f.name)
    if tit != 'fC':
        plt.xlabel('l')
    plt.legend(loc='best')

    # Plot rhs
    plt.sca(axes[1])
    plt.title('|' + tit + ' rhs|')

    # Transform pair rhs
    for f in fCI:
        if tit == 'fC':
            plt.loglog(r, np.abs(f.rhs), lw=2, label=f.name)
        else:
            plt.loglog(r, np.abs(f.rhs(r)), lw=2, label=f.name)

    # Transform with Key in the case of Hankel or Fourier transform.
    for f in fCI:
        if f.name in ['j0', 'j1', 'j2', 'cos', 'sin']:
            if f.name[1] in ['0', '1', '2'] and f.name[0] == 'j':
                filt = j0j1filt()
            else:
                filt = sincosfilt()
            kk = filt.base/r[:, None]
            if f.name == 'j2':
                lhs = f.lhs(kk)
                kr0 = np.dot(lhs[0], getattr(filt, 'j0'))/r
                kr1 = np.dot(lhs[1], getattr(filt, 'j1'))/r**2
                kr = kr0+kr1
            else:
                kr = np.dot(f.lhs(kk), getattr(filt, f.name))/r

            plt.loglog(r, np.abs(kr), '-.', lw=2, label=filt.name)

    if tit != 'fC':
        plt.xlabel('r')

    plt.legend(loc='best') 
Example 50
Project: psychrometric-chart-makeover   Author: buds-lab   File: axisgrid.py    MIT License 4 votes vote down vote up
def map(self, func, **kwargs):
        """Plot with the same function in every subplot.

        Parameters
        ----------
        func : callable plotting function
            Must take x, y arrays as positional arguments and draw onto the
            "currently active" matplotlib Axes. Also needs to accept kwargs
            called ``color`` and  ``label``.

        """
        kw_color = kwargs.pop("color", None)
        for i, y_var in enumerate(self.y_vars):
            for j, x_var in enumerate(self.x_vars):
                hue_grouped = self.data.groupby(self.hue_vals)
                for k, label_k in enumerate(self.hue_names):

                    # Attempt to get data for this level, allowing for empty
                    try:
                        data_k = hue_grouped.get_group(label_k)
                    except KeyError:
                        data_k = pd.DataFrame(columns=self.data.columns,
                                              dtype=np.float)

                    ax = self.axes[i, j]
                    plt.sca(ax)

                    # Insert the other hue aesthetics if appropriate
                    for kw, val_list in self.hue_kws.items():
                        kwargs[kw] = val_list[k]

                    color = self.palette[k] if kw_color is None else kw_color
                    func(data_k[x_var], data_k[y_var],
                         label=label_k, color=color, **kwargs)

                self._clean_axis(ax)
                self._update_legend_data(ax)

        if kw_color is not None:
            kwargs["color"] = kw_color
        self._add_axis_labels()

        return self