Python matplotlib.pyplot.gca() Examples

The following are 30 code examples of matplotlib.pyplot.gca(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: __init__.py    From EDeN with MIT License 11 votes vote down vote up
def plot_confusion_matrix(y_true, y_pred, size=None, normalize=False):
    """plot_confusion_matrix."""
    cm = confusion_matrix(y_true, y_pred)
    fmt = "%d"
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = "%.2f"
    xticklabels = list(sorted(set(y_pred)))
    yticklabels = list(sorted(set(y_true)))
    if size is not None:
        plt.figure(figsize=(size, size))
    heatmap(cm, xlabel='Predicted label', ylabel='True label',
            xticklabels=xticklabels, yticklabels=yticklabels,
            cmap=plt.cm.Blues, fmt=fmt)
    if normalize:
        plt.title("Confusion matrix (norm.)")
    else:
        plt.title("Confusion matrix")
    plt.gca().invert_yaxis() 
Example #2
Source File: Micaps11Data.py    From PyMICAPS with GNU General Public License v2.0 10 votes vote down vote up
def ConvertPacth(self, ax, patch):
        path = patch.get_path()
        lon = []
        lat = []
        for points in path.vertices:
            x, y = points[0], points[1]
            xy_pixels = ax.transData.transform(np.vstack([x, y]).T)
            xpix, ypix = xy_pixels.T
            lon.append(xpix[0])
            lat.append(ypix[0])
        from matplotlib.path import Path
        apath = Path(list(zip(lon, lat)))
        from matplotlib import patches
        apatch = patches.PathPatch(apath, linewidth=1, facecolor='none', edgecolor='k')
        plt.gca().add_patch(apatch)
        return apatch 
Example #3
Source File: arts_lookup.py    From typhon with MIT License 10 votes vote down vote up
def _add_opacity_legend(ax=None):
    """Add legend to an opacity lookup table plot."""
    if ax is None:
        ax = plt.gca()

    blue_line = Line2D([], [], label='species opacity')
    black_line = Line2D([], [], color='k', linewidth=1.,
                        label='total opacity')
    dashed_line = Line2D([], [], color='k', linestyle='--',
                         linewidth=1., label='opacity=1')

    handles = [blue_line, black_line, dashed_line]
    labels = [h.get_label() for h in handles]

    ax.legend(handles=handles, labels=labels, fontsize='xx-small',
              loc='upper left', ncol=6) 
Example #4
Source File: SimplicialComplex.py    From OpenTDA with Apache License 2.0 8 votes vote down vote up
def drawComplex(origData, ripsComplex, axes=[-6,8,-6,6]):
  plt.clf()
  plt.axis(axes)
  plt.scatter(origData[:,0],origData[:,1]) #plotting just for clarity
  for i, txt in enumerate(origData):
      plt.annotate(i, (origData[i][0]+0.05, origData[i][1])) #add labels

  #add lines for edges
  for edge in [e for e in ripsComplex if len(e)==2]:
      #print(edge)
      pt1,pt2 = [origData[pt] for pt in [n for n in edge]]
      #plt.gca().add_line(plt.Line2D(pt1,pt2))
      line = plt.Polygon([pt1,pt2], closed=None, fill=None, edgecolor='r')
      plt.gca().add_line(line)

  #add triangles
  for triangle in [t for t in ripsComplex if len(t)==3]:
      pt1,pt2,pt3 = [origData[pt] for pt in [n for n in triangle]]
      line = plt.Polygon([pt1,pt2,pt3], closed=False, color="blue",alpha=0.3, fill=True, edgecolor=None)
      plt.gca().add_line(line)
  plt.show() 
Example #5
Source File: FilteredSimplicialComplex.py    From OpenTDA with Apache License 2.0 8 votes vote down vote up
def drawComplex(origData, ripsComplex, axes=[-6,8,-6,6]):
  plt.clf()
  plt.axis(axes)
  plt.scatter(origData[:,0],origData[:,1]) #plotting just for clarity
  for i, txt in enumerate(origData):
      plt.annotate(i, (origData[i][0]+0.05, origData[i][1])) #add labels

  #add lines for edges
  for edge in [e for e in ripsComplex if len(e)==2]:
      #print(edge)
      pt1,pt2 = [origData[pt] for pt in [n for n in edge]]
      #plt.gca().add_line(plt.Line2D(pt1,pt2))
      line = plt.Polygon([pt1,pt2], closed=None, fill=None, edgecolor='r')
      plt.gca().add_line(line)

  #add triangles
  for triangle in [t for t in ripsComplex if len(t)==3]:
      pt1,pt2,pt3 = [origData[pt] for pt in [n for n in triangle]]
      line = plt.Polygon([pt1,pt2,pt3], closed=False, color="blue",alpha=0.3, fill=True, edgecolor=None)
      plt.gca().add_line(line)
  plt.show() 
Example #6
Source File: plot_name.py    From pyhanlp with Apache License 2.0 8 votes vote down vote up
def newline(p1, p2, color=None, marker=None):
    """
    https://stackoverflow.com/questions/36470343/how-to-draw-a-line-with-matplotlib
    :param p1:
    :param p2:
    :return:
    """
    ax = plt.gca()
    xmin, xmax = ax.get_xbound()

    if (p2[0] == p1[0]):
        xmin = xmax = p1[0]
        ymin, ymax = ax.get_ybound()
    else:
        ymax = p1[1] + (p2[1] - p1[1]) / (p2[0] - p1[0]) * (xmax - p1[0])
        ymin = p1[1] + (p2[1] - p1[1]) / (p2[0] - p1[0]) * (xmin - p1[0])

    l = mlines.Line2D([xmin, xmax], [ymin, ymax], color=color, marker=marker)
    ax.add_line(l)
    return l 
Example #7
Source File: plotting.py    From kvae with MIT License 7 votes vote down vote up
def hinton(matrix, max_weight=None, ax=None):
    """Draw Hinton diagram for visualizing a weight matrix."""
    ax = ax if ax is not None else plt.gca()

    if not max_weight:
        max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))

    ax.patch.set_facecolor('gray')
    ax.set_aspect('equal', 'box')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(matrix):
        color = 'white' if w > 0 else 'black'
        size = np.sqrt(np.abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
                             facecolor=color, edgecolor=color)
        ax.add_patch(rect)

    ax.autoscale_view()
    ax.invert_yaxis() 
Example #8
Source File: figures.py    From pywr with GNU General Public License v3.0 7 votes vote down vote up
def plot_QQ(A, B, ax=None):
    if ax is None:
        ax = plt.gca()
    ax.scatter(B.values, A.values, color=c["Cfill"], edgecolor=c["Cedge"], clip_on=False)
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    limit = max(xlim[1], ylim[0])
    ax.plot([0, limit], [0, limit], '-k')
    ax.set_xlim(0, limit)
    ax.set_ylim(0, limit)
    ax.grid(True)
    set_000formatter(ax.get_xaxis())
    set_000formatter(ax.get_yaxis())
    ax.set_xlabel(B.name)
    ax.set_ylabel(A.name)
    ax.legend(["Equality"], loc="best")
    return ax 
Example #9
Source File: figures.py    From pywr with GNU General Public License v3.0 7 votes vote down vote up
def plot_percentiles(A, B, ax=None):
    if ax is None:
        ax = plt.gca()
    percentiles = np.linspace(0.001, 0.999, 1000) * 100
    A_pct = scipy.stats.scoreatpercentile(A.values, percentiles)
    B_pct = scipy.stats.scoreatpercentile(B.values, percentiles)
    percentiles = percentiles / 100.0
    ax.plot(percentiles, B_pct[::-1], color=c["Bfill"], clip_on=False, linewidth=2)
    ax.plot(percentiles, A_pct[::-1], color=c["Afill"], clip_on=False, linewidth=2)
    ax.set_xlabel("Cumulative frequency")
    ax.grid(True)
    ax.xaxis.grid(True, which="both")
    set_000formatter(ax.get_yaxis())
    ax.set_xscale("logit")
    xticks = ax.get_xticks()
    xticks_minr = ax.get_xticks(minor=True)
    ax.set_xticklabels([], minor=True)
    ax.set_xticks([0.01, 0.1, 0.5, 0.9, 0.99])
    ax.set_xticklabels(["1", "10", "50", "90", "99"])
    ax.set_xlim(0.001, 0.999)
    ax.legend([B.name, A.name], loc="best")
    return ax 
Example #10
Source File: quadPlot.py    From quadcopter-simulation with BSD 3-Clause "New" or "Revised" License 7 votes vote down vote up
def set_frame(frame):
    # convert 3x6 world_frame matrix into three line_data objects which is 3x2 (row:point index, column:x,y,z)
    lines_data = [frame[:,[0,2]], frame[:,[1,3]], frame[:,[4,5]]]
    ax = plt.gca()
    lines = ax.get_lines()
    for line, line_data in zip(lines[:3], lines_data):
        x, y, z = line_data
        line.set_data(x, y)
        line.set_3d_properties(z)

    global history, count
    # plot history trajectory
    history[count] = frame[:,4]
    if count < np.size(history, 0) - 1:
        count += 1
    zline = history[:count,-1]
    xline = history[:count,0]
    yline = history[:count,1]
    lines[-1].set_data(xline, yline)
    lines[-1].set_3d_properties(zline)
    # ax.plot3D(xline, yline, zline, 'blue') 
Example #11
Source File: plotting.py    From kvae with MIT License 7 votes vote down vote up
def plot_ball_trajectory(var, filename, idx=0, scale=30, cmap='Blues'):
    # Calc optimal radius of ball
    x_min, y_min = np.min(var[:, :, :2], axis=(0, 1))
    x_max, y_max = np.max(var[:, :, :2], axis=(0, 1))
    r = max((x_max - x_min), (y_max - y_min)) / scale

    fig = plt.figure(figsize=[4, 4])
    ax = fig.gca()
    collection = construct_ball_trajectory(var[idx], r=1, cmap=cmap)
    ax.add_collection(collection)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis("equal")
    ax.set_xlabel('$a_{t,1}$', fontsize=24)
    ax.set_ylabel('$a_{t,2}$', fontsize=24)

    plt.savefig(filename, format='png', bbox_inches='tight', dpi=80)
    plt.close() 
Example #12
Source File: display.py    From radiometric_normalization with Apache License 2.0 6 votes vote down vote up
def plot_pixels(file_name, candidate_data_single_band,
                reference_data_single_band, limits=None, fit_line=None):

    logging.info('Display: Creating pixel plot - {}'.format(file_name))
    fig = plt.figure()
    plt.hexbin(
        candidate_data_single_band, reference_data_single_band, mincnt=1)
    if not limits:
        min_value = 0
        _, ymax = plt.gca().get_ylim()
        _, xmax = plt.gca().get_xlim()
        max_value = max([ymax, xmax])
        limits = [min_value, max_value]
    plt.plot(limits, limits, 'k-')
    if fit_line:
        start = limits[0] * fit_line.gain + fit_line.offset
        end = limits[1] * fit_line.gain + fit_line.offset
        plt.plot(limits, [start, end], 'g-')
    plt.xlim(limits)
    plt.ylim(limits)
    plt.xlabel('Candidate DNs')
    plt.ylabel('Reference DNs')
    fig.savefig(file_name, bbox_inches='tight')
    plt.close(fig) 
Example #13
Source File: covariancematrix.py    From typhon with MIT License 6 votes vote down vote up
def plot_covariance_matrix(covariance_matrix, ax = None):
    """
    Plots a covariance matrix.

    Parameters:
        covariance_matrix(:class:`CovarianceMatrix`): The covariance matrix
            to plot
        ax(matplotlib.axes): An axes object into which to plot the
            covariance matrix.
    """

    if ax is None:
        ax = plt.gca()

    for b in covariance_matrix.blocks:
        y = np.arange(b.row_start, b.row_start + b.matrix.shape[0] + 1) - 0.5
        x = np.arange(b.column_start, b.column_start + b.matrix.shape[1] + 1) - 0.5
        ax.pcolormesh(x, y, np.array(b.matrix.toarray()))



    m = max([b.row_start + b.matrix.shape[0] for b in covariance_matrix.blocks])
    n = max([b.column_start + b.matrix.shape[1] for b in covariance_matrix.blocks])
    ax.set_xlim([-0.5, n + 0.5])
    ax.set_ylim([m + 0.5, -0.5]) 
Example #14
Source File: plots.py    From ibllib with MIT License 6 votes vote down vote up
def squares(tscale, polarity, ax=None, yrange=[-1, 1], **kwargs):
    """
    Matplotlib display of rising and falling fronts in a square-wave pattern

    :param tscale: time of indices of fronts
    :param polarity: polarity of front (1: rising, -1:falling)
    :param ax: matplotlib axes object
    :return: None
    """
    if not ax:
        ax = plt.gca()
    isort = np.argsort(tscale)
    tscale = tscale[isort]
    polarity = polarity[isort]
    f = np.tile(polarity, (2, 1))
    t = np.concatenate((tscale, np.r_[tscale[1:], tscale[-1]])).reshape(2, f.shape[1])
    ydata = f.transpose().ravel()
    ydata = (ydata + 1) / 2 * (yrange[1] - yrange[0]) + yrange[0]
    ax.plot(t.transpose().ravel(), ydata, **kwargs) 
Example #15
Source File: plots.py    From ibllib with MIT License 6 votes vote down vote up
def vertical_lines(x, ymin=0, ymax=1, ax=None, **kwargs):
    """
    From a x vector, draw separate vertical lines at each x location ranging from ymin to ymax

    :param x: numpy array vector of x values where to display lnes
    :param ymin: lower end of the lines (scalar)
    :param ymax: higher end of the lines (scalar)
    :param ax: (optional) matplotlib axis instance
    :return: None
    """
    x = np.tile(x, (3, 1))
    x[2, :] = np.nan
    y = np.zeros_like(x)
    y[0, :] = ymin
    y[1, :] = ymax
    y[2, :] = np.nan
    if not ax:
        ax = plt.gca()
    ax.plot(x.T.flatten(), y.T.flatten(), **kwargs) 
Example #16
Source File: plotting.py    From OpenTDA with Apache License 2.0 6 votes vote down vote up
def drawComplex(data, ph, axes=[-6, 8, -6, 6]):
    plt.clf()
    plt.axis(axes)  # axes = [x1, x2, y1, y2]
    plt.scatter(data[:, 0], data[:, 1])  # plotting just for clarity
    for i, txt in enumerate(data):
        plt.annotate(i, (data[i][0] + 0.05, data[i][1]))  # add labels

    # add lines for edges
    for edge in [e for e in ph.ripsComplex if len(e) == 2]:
        # print(edge)
        pt1, pt2 = [data[pt] for pt in [n for n in edge]]
        # plt.gca().add_line(plt.Line2D(pt1,pt2))
        line = plt.Polygon([pt1, pt2], closed=None, fill=None, edgecolor='r')
        plt.gca().add_line(line)

    # add triangles
    for triangle in [t for t in ph.ripsComplex if len(t) == 3]:
        pt1, pt2, pt3 = [data[pt] for pt in [n for n in triangle]]
        line = plt.Polygon([pt1, pt2, pt3], closed=False,
                           color="blue", alpha=0.3, fill=True, edgecolor=None)
        plt.gca().add_line(line)
    plt.show() 
Example #17
Source File: plot_images_grid.py    From Deep-SAD-PyTorch with MIT License 6 votes vote down vote up
def plot_images_grid(x: torch.tensor, export_img, title: str = '', nrow=8, padding=2, normalize=False, pad_value=0):
    """Plot 4D Tensor of images of shape (B x C x H x W) as a grid."""

    grid = make_grid(x, nrow=nrow, padding=padding, normalize=normalize, pad_value=pad_value)
    npgrid = grid.cpu().numpy()

    plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')

    ax = plt.gca()
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)

    if not (title == ''):
        plt.title(title)

    plt.savefig(export_img, bbox_inches='tight', pad_inches=0.1)
    plt.clf() 
Example #18
Source File: plt3d.py    From connecting_the_dots with MIT License 6 votes vote down vote up
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs):
  if ax is None:
    ax = plt.gca()
  C0 = geometry.translation_to_cameracenter(R, t).ravel()
  C1 = C0 + R.T.dot( np.array([[-size],[-size],[3*size]], dtype=np.float32) ).ravel()
  C2 = C0 + R.T.dot( np.array([[-size],[+size],[3*size]], dtype=np.float32) ).ravel()
  C3 = C0 + R.T.dot( np.array([[+size],[+size],[3*size]], dtype=np.float32) ).ravel()
  C4 = C0 + R.T.dot( np.array([[+size],[-size],[3*size]], dtype=np.float32) ).ravel()

  if marker_C != '':
    ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
  ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
  ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
  ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
  ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
  ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) 
Example #19
Source File: show_boxes.py    From Deep-Feature-Flow-Segmentation with MIT License 6 votes vote down vote up
def show_boxes(im, dets, classes, scale = 1.0):
    plt.cla()
    plt.axis("off")
    plt.imshow(im)
    for cls_idx, cls_name in enumerate(classes):
        cls_dets = dets[cls_idx]
        for det in cls_dets:
            bbox = det[:4] * scale
            color = (rand(), rand(), rand())
            rect = plt.Rectangle((bbox[0], bbox[1]),
                                  bbox[2] - bbox[0],
                                  bbox[3] - bbox[1], fill=False,
                                  edgecolor=color, linewidth=2.5)
            plt.gca().add_patch(rect)

            if cls_dets.shape[1] == 5:
                score = det[-1]
                plt.gca().text(bbox[0], bbox[1],
                               '{:s} {:.3f}'.format(cls_name, score),
                               bbox=dict(facecolor=color, alpha=0.5), fontsize=9, color='white')
    plt.show()
    return im 
Example #20
Source File: simulation.py    From striatum with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def plot_tuning_curve(tuning_region, ctr_tuning, label):
    """Draw the parameter tuning plot

    Parameters
    ----------
    tuning_region: array
        The region for tuning parameter.

    ctr_tuning: array
        The resulted ctrs for each number of the tuning parameter.

    label: string
        The name of label want to show.
    """

    plt.plot(tuning_region, ctr_tuning, 'ro-', label=label)
    plt.xlabel('parameter value')
    plt.ylabel('CTR')
    plt.legend()
    axes = plt.gca()
    axes.set_ylim([0, 1])
    plt.title("Parameter Tunning Curve")
    plt.show() 
Example #21
Source File: movielens_bandit.py    From striatum with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def main():
    streaming_batch, user_feature, actions, reward_list, action_context = get_data()
    streaming_batch_small = streaming_batch.iloc[0:10000]

    # conduct regret analyses
    experiment_bandit = ['LinUCB', 'LinThompSamp', 'Exp4P', 'UCB1', 'Exp3', 'random']
    regret = {}
    col = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
    i = 0
    for bandit in experiment_bandit:
        policy = policy_generation(bandit, actions)
        seq_error = policy_evaluation(policy, bandit, streaming_batch_small, user_feature, reward_list,
                                      actions, action_context)
        regret[bandit] = regret_calculation(seq_error)
        plt.plot(range(len(streaming_batch_small)), regret[bandit], c=col[i], ls='-', label=bandit)
        plt.xlabel('time')
        plt.ylabel('regret')
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        axes = plt.gca()
        axes.set_ylim([0, 1])
        plt.title("Regret Bound with respect to T")
        i += 1
    plt.show() 
Example #22
Source File: VisMPL.py    From NURBS-Python with MIT License 6 votes vote down vote up
def set_axes_equal(ax):
        """ Sets equal aspect ratio across the three axes of a 3D plot.

        Contributed by Xuefeng Zhao.

        :param ax: a Matplotlib axis, e.g., as output from plt.gca().
        """
        bounds = [ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()]
        ranges = [abs(bound[1] - bound[0]) for bound in bounds]
        centers = [np.mean(bound) for bound in bounds]
        radius = 0.5 * max(ranges)
        lower_limits = centers - radius
        upper_limits = centers + radius
        ax.set_xlim3d([lower_limits[0], upper_limits[0]])
        ax.set_ylim3d([lower_limits[1], upper_limits[1]])
        ax.set_zlim3d([lower_limits[2], upper_limits[2]]) 
Example #23
Source File: show_offset.py    From Deep-Feature-Flow-Segmentation with MIT License 5 votes vote down vote up
def show_boxes_simple(bbox, color='r', lw=2):
    rect = plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False, edgecolor=color, linewidth=lw)
    plt.gca().add_patch(rect) 
Example #24
Source File: DOTA.py    From AerialDetection with Apache License 2.0 5 votes vote down vote up
def showAnns(self, objects, imgId, range):
        """
        :param catNms: category names
        :param objects: objects to show
        :param imgId: img to show
        :param range: display range in the img
        :return:
        """
        img = self.loadImgs(imgId)[0]
        plt.imshow(img)
        plt.axis('off')

        ax = plt.gca()
        ax.set_autoscale_on(False)
        polygons = []
        color = []
        circles = []
        r = 5
        for obj in objects:
            c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
            poly = obj['poly']
            polygons.append(Polygon(poly))
            color.append(c)
            point = poly[0]
            circle = Circle((point[0], point[1]), r)
            circles.append(circle)
        p = PatchCollection(polygons, facecolors=color, linewidths=0, alpha=0.4)
        ax.add_collection(p)
        p = PatchCollection(polygons, facecolors='none', edgecolors=color, linewidths=2)
        ax.add_collection(p)
        p = PatchCollection(circles, facecolors='red')
        ax.add_collection(p) 
Example #25
Source File: Map.py    From PyMICAPS with GNU General Public License v2.0 5 votes vote down vote up
def DrawClipBorders(clipborders):

        # 绘制裁切区域边界并返回
        path = clipborders[0].path
        linewidth = clipborders[0].linewidth
        linecolor = clipborders[0].linecolor
        if path is not None:
            patch = patches.PathPatch(path, linewidth=linewidth, facecolor='none', edgecolor=linecolor)
            plt.gca().add_patch(patch)
        else:
            patch = None
        return patch 
Example #26
Source File: common.py    From typhon with MIT License 5 votes vote down vote up
def _plot_matrix(
            matrix, classes, normalize=False, ax=None, **kwargs
    ):
        """Plots the confusion matrix of
        Normalization can be applied by setting `normalize=True`.
        """
        if normalize:
            matrix = matrix.astype('float') / matrix.sum(axis=1)[:, np.newaxis]

        default_kwargs = {
            "cmap": "Blues",
            **kwargs
        }

        if ax is None:
            ax = plt.gca()

        img = ax.imshow(matrix, interpolation='nearest', **default_kwargs)
        tick_marks = np.arange(len(classes))
        ax.set_xticks(tick_marks)
        ax.set_xticklabels(classes, rotation=45)
        ax.set_yticks(tick_marks)
        ax.set_yticklabels(classes)

        fmt = '.2f' if normalize else 'd'
        thresh = matrix.max() / 2.
        for i, j in itertools.product(range(matrix.shape[0]),
                                      range(matrix.shape[1])):
            ax.text(j, i, format(matrix[i, j], fmt),
                    horizontalalignment="center",
                    color="white" if matrix[i, j] > thresh else "black")

        return img 
Example #27
Source File: Map.py    From PyMICAPS with GNU General Public License v2.0 5 votes vote down vote up
def DrawBorders(m, products):
        """
        画县市边界
        :param m: 画布对象(plt或投影后的plt)
        :param products: 产品参数
        :return: 
        """
        try:
            for area in products.map.borders:
                if not area.draw:
                    continue
                if area.filetype == 'SHP':  # shp文件
                    if m is plt:
                        # Map.DrawShapeFile(area)
                        Map.readshapefile(area.file.replace('.shp', ''),
                                          os.path.basename(area.file),
                                          color=area.linecolor,
                                          linewidth=area.linewidth)
                    else:
                        m.readshapefile(area.file.replace('.shp', ''),
                                        os.path.basename(area.file),
                                        color=area.linecolor)
                else:  # 文本文件 , 画之前 路径中的点已经被投影了
                    if area.path is None:
                        continue
                    if area.polygon == 'ON':
                        area_patch = patches.PathPatch(area.path,
                                                       linewidth=area.linewidth,
                                                       linestyle='solid',
                                                       facecolor='none',
                                                       edgecolor=area.linecolor)
                        plt.gca().add_patch(area_patch)
                    else:
                        x, y = list(zip(*area.path.vertices))
                        m.plot(x, y, 'k-', linewidth=area.linewidth, color=area.linecolor)
        except Exception as err:
            print(u'【{0}】{1}-{2}'.format(products.xmlfile, err, datetime.now())) 
Example #28
Source File: formatter.py    From typhon with MIT License 5 votes vote down vote up
def set_xaxis_formatter(formatter, ax=None):
    """Set given formatter for major and minor xticks."""
    if ax is None:
        ax = plt.gca()

    ax.xaxis.set_major_formatter(formatter)
    ax.xaxis.set_minor_formatter(formatter) 
Example #29
Source File: formatter.py    From typhon with MIT License 5 votes vote down vote up
def set_yaxis_formatter(formatter, ax=None):
    """Set given formatter for major and minor yticks."""
    if ax is None:
        ax = plt.gca()

    ax.yaxis.set_major_formatter(formatter)
    ax.yaxis.set_minor_formatter(formatter) 
Example #30
Source File: Map.py    From PyMICAPS with GNU General Public License v2.0 5 votes vote down vote up
def DrawGridLine(products, m):
        pj = products.map.projection
        if m is plt:
            # 坐标轴
            plt.axis(pj.axis)

            # 设置坐标轴刻度值显示格式
            if pj.axis == 'on':
                x_majorFormatter = FormatStrFormatter(pj.axisfmt[0])
                y_majorFormatter = FormatStrFormatter(pj.axisfmt[1])
                plt.gca().xaxis.set_major_formatter(x_majorFormatter)
                plt.gca().yaxis.set_major_formatter(y_majorFormatter)
                xaxis = plt.gca().xaxis
                for label in xaxis.get_ticklabels():
                    label.set_fontproperties('DejaVu Sans')
                    label.set_fontsize(10)
                yaxis = plt.gca().yaxis
                for label in yaxis.get_ticklabels():
                    label.set_fontproperties('DejaVu Sans')
                    label.set_fontsize(10)

                xaxis.set_visible(pj.lonlabels[3] == 1)
                yaxis.set_visible(pj.latlabels[0] == 1)

            return
        else:
            # draw parallels and meridians.
            if pj.axis == 'on':
                m.drawparallels(np.arange(-80., 81., 10.), labels=pj.latlabels, family='DejaVu Sans', fontsize=10)
                m.drawmeridians(np.arange(-180., 181., 10.), labels=pj.lonlabels, family='DejaVu Sans', fontsize=10)