Python matplotlib.pyplot.Axes() Examples

The following are 30 code examples of matplotlib.pyplot.Axes(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: plot.py    From umap with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def show(plot_to_show):
    """Display a plot, either interactive or static.

    Parameters
    ----------
    plot_to_show: Output of a plotting command (matplotlib axis or bokeh figure)
        The plot to show

    Returns
    -------
    None
    """
    if isinstance(plot_to_show, plt.Axes):
        show_static()
    elif isinstance(plot_to_show, bpl.Figure):
        show_interactive(plot_to_show)
    else:
        raise ValueError(
            "The type of ``plot_to_show`` was not valid, or not understood."
        ) 
Example #2
Source File: validate.py    From mpl-probscale with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def axes_object(ax):
    """ Checks if a value if an Axes. If None, a new one is created.
    Both the figure and axes are returned (in that order).

    """

    if ax is None:
        ax = pyplot.gca()
        fig = ax.figure
    elif isinstance(ax, pyplot.Axes):
        fig = ax.figure
    else:
        msg = "`ax` must be a matplotlib Axes instance or None"
        raise ValueError(msg)

    return fig, ax 
Example #3
Source File: visualiser.py    From gym-jsbsim with MIT License 6 votes vote down vote up
def _prepare_state_printing(self, ax: plt.Axes):
        ys = [self.TEXT_Y_POSN_INITIAL + i * self.TEXT_Y_INCREMENT
              for i in range(len(self.print_props))]

        for prop, y in zip(self.print_props, ys):
            label = str(prop.name)
            ax.text(self.TEXT_X_POSN_LABEL, y, label, transform=ax.transAxes, **(self.LABEL_TEXT_KWARGS))

        # print and store empty Text objects which we will rewrite each plot call
        value_texts = []
        dummy_msg = ''
        for y in ys:
            text = ax.text(self.TEXT_X_POSN_VALUE, y, dummy_msg, transform=ax.transAxes,
                           **(self.VALUE_TEXT_KWARGS))
            value_texts.append(text)
        self.value_texts = tuple(value_texts) 
Example #4
Source File: plot2d.py    From kite with GNU General Public License v3.0 6 votes vote down vote up
def plot(self, **kwargs):
        """Plot current quadtree

        :param axes: Axes instance to plot in, defaults to None
        :type axes: [:py:class:`matplotlib.Axes`], optional
        :param figure: Figure instance to plot in, defaults to None
        :type figure: [:py:class:`matplotlib.Figure`], optional
        :param **kwargs: kwargs are passed into `plt.imshow`
        :type **kwargs: dict
        """
        self._initImagePlot(**kwargs)
        self.data = self._quadtree.leaf_matrix_means
        self.title = 'Quadtree Means'

        self._addInfoText()

        if self._show_plt:
            plt.show() 
Example #5
Source File: qubit_characterizations.py    From Cirq with Apache License 2.0 6 votes vote down vote up
def plot(self, ax: Optional[plt.Axes] = None,
             **plot_kwargs: Any) -> plt.Axes:
        """Plots the average ground state probability vs the number of
        Cliffords in the RB study.

        Args:
            ax: the plt.Axes to plot on. If not given, a new figure is created,
                plotted on, and shown.
            **plot_kwargs: Arguments to be passed to 'plt.Axes.plot'.
        Returns:
            The plt.Axes containing the plot.
        """
        show_plot = not ax
        if not ax:
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        ax.set_ylim([0, 1])
        ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', **plot_kwargs)
        ax.set_xlabel(r"Number of Cliffords")
        ax.set_ylabel('Ground State Probability')
        if show_plot:
            fig.show()
        return ax 
Example #6
Source File: cross_entropy_benchmarking.py    From Cirq with Apache License 2.0 6 votes vote down vote up
def plot(self, ax: Optional[plt.Axes] = None,
             **plot_kwargs: Any) -> plt.Axes:
        """Plots the average XEB fidelity vs the number of cycles.

        Args:
            ax: the plt.Axes to plot on. If not given, a new figure is created,
                plotted on, and shown.
            **plot_kwargs: Arguments to be passed to 'plt.Axes.plot'.
        Returns:
            The plt.Axes containing the plot.
        """
        show_plot = not ax
        if not ax:
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        num_cycles = [d.num_cycle for d in self.data]
        fidelities = [d.xeb_fidelity for d in self.data]
        ax.set_ylim([0, 1.1])
        ax.plot(num_cycles, fidelities, 'ro-', **plot_kwargs)
        ax.set_xlabel('Number of Cycles')
        ax.set_ylabel('XEB Fidelity')
        if show_plot:
            fig.show()
        return ax 
Example #7
Source File: heatmap.py    From Cirq with Apache License 2.0 6 votes vote down vote up
def _write_annotations(self, mesh: mpl_collections.Collection,
                           ax: plt.Axes) -> None:
        """Writes annotations to the center of cells. Internal."""
        for path, facecolor in zip(mesh.get_paths(), mesh.get_facecolors()):
            # Calculate the center of the cell, assuming that it is a square
            # centered at (x=col, y=row).
            vertices = path.vertices[:4]
            row = int(round(np.mean([v[1] for v in vertices])))
            col = int(round(np.mean([v[0] for v in vertices])))
            annotation = self.annot_map.get((row, col), '')
            if not annotation:
                continue
            face_luminance = relative_luminance(facecolor)
            text_color = 'black' if face_luminance > 0.4 else 'white'
            text_kwargs = dict(color=text_color, ha="center", va="center")
            text_kwargs.update(self.annot_kwargs)
            ax.text(col, row, annotation, **text_kwargs) 
Example #8
Source File: test_plot.py    From mars with Apache License 2.0 6 votes vote down vote up
def assert_is_valid_plot_return_object(objs):  # pragma: no cover
    import matplotlib.pyplot as plt

    if isinstance(objs, (pd.Series, np.ndarray)):
        for el in objs.ravel():
            msg = (
                "one of 'objs' is not a matplotlib Axes instance, "
                "type encountered {}".format(repr(type(el).__name__))
            )
            assert isinstance(el, (plt.Axes, dict)), msg
    else:
        msg = (
            "objs is neither an ndarray of Artist instances nor a single "
            "ArtistArtist instance, tuple, or dict, 'objs' is a {}".format(
                repr(type(objs).__name__))
        )
        assert isinstance(objs, (plt.Artist, tuple, dict)), msg 
Example #9
Source File: cnn_main.py    From Convolutional-Networks-for-Stock-Predicting with MIT License 6 votes vote down vote up
def plot_data(data):
    t = np.arange(0, 29, 1)
    file_name_number = 0
    fig = plt.figure(frameon=False, figsize=(width, height))
    for group in data:
        count = 30
        while count <= (len(group)-5):
            high = []
            low = []
            for item in group[count-30:count]:
                high.append(item[0])
                low.append(item[1])
            file_name = r'\fig_' + str(file_name_number)
            ax = plt.Axes(fig, [0., 0., 1., 1.])
            ax.set_axis_off()
            fig.add_axes(ax)
            ax.plot(t, high[0:-1], 'b', t, low[0:-1], 'g')
            fig.savefig(r'\figures_v2' + file_name, dpi=100)
            fig.clf()
            file_name_number += 1
            count += 1
    print('Created %d files!' % file_name_number) 
Example #10
Source File: main.py    From Convolutional-Networks-for-Stock-Predicting with MIT License 6 votes vote down vote up
def plot_data(data):
    t = np.arange(0, 29, 1)
    file_name_number = 0
    fig = plt.figure(frameon=False)
    for group in data:
        count = 30
        while count <= (len(group)-5):
            high = []
            low = []
            for item in group[count-30:count]:
                high.append(item[0])
                low.append(item[1])
            file_name = r'\fig_' + str(file_name_number)
            ax = plt.Axes(fig, [0., 0., 1., 1.])
            ax.set_axis_off()
            fig.add_axes(ax)
            ax.plot(t, high[0:-1], 'b', t, low[0:-1], 'g')
            fig.savefig(r'\figures' + file_name)
            fig.clf()
            file_name_number += 1
            count += 1
    print('Created %d files!' % file_name_number) 
Example #11
Source File: fid_pics.py    From adagan with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def save_pic(pic, path, exp):
    if len(pic.shape) == 4:
        pic = pic[0]
    height = pic.shape[0]
    width = pic.shape[1]
    fig = plt.figure(frameon=False, figsize=(width, height))#, dpi=1)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    if exp.symmetrize:
        pic = (pic + 1.) / 2.
    if exp.dataset == 'mnist':
        pic = pic[:, :, 0]
        pic = 1. - pic
    if exp.dataset == 'mnist':
        ax.imshow(pic, cmap='Greys', interpolation='none')
    else:
        ax.imshow(pic, interpolation='none')
    fig.savefig(path, dpi=1, format='png')
    plt.close()
    # if exp.dataset == 'mnist':
    #     pic = pic[:, :, 0]
    #     pic = 1. - pic
    #     ax = plt.imshow(pic, cmap='Greys', interpolation='none')
    # else:
    #     ax = plt.imshow(pic, interpolation='none')
    # ax.axes.get_xaxis().set_ticks([])
    # ax.axes.get_yaxis().set_ticks([])
    # ax.axes.set_xlim([0, width])
    # ax.axes.set_ylim([height, 0])
    # ax.axes.set_aspect(1)
    # fig.savefig(path, format='png')
    # plt.close() 
Example #12
Source File: testing.py    From elasticintel with GNU General Public License v3.0 5 votes vote down vote up
def assert_is_valid_plot_return_object(objs):
    import matplotlib.pyplot as plt
    if isinstance(objs, (pd.Series, np.ndarray)):
        for el in objs.ravel():
            msg = ('one of \'objs\' is not a matplotlib Axes instance, type '
                   'encountered {name!r}').format(name=el.__class__.__name__)
            assert isinstance(el, (plt.Axes, dict)), msg
    else:
        assert isinstance(objs, (plt.Artist, tuple, dict)), \
            ('objs is neither an ndarray of Artist instances nor a '
             'single Artist instance, tuple, or dict, "objs" is a {name!r}'
             ).format(name=objs.__class__.__name__) 
Example #13
Source File: draw_cad.py    From ezdxf with MIT License 5 votes vote down vote up
def _main():
    parser = argparse.ArgumentParser(description='draw the given CAD file and save it to a file or view it')
    parser.add_argument('cad_file', nargs='?')
    parser.add_argument('--supported_formats', action='store_true')
    parser.add_argument('--layout', default='Model')
    parser.add_argument('--out', required=False)
    parser.add_argument('--dpi', type=int, default=300)
    args = parser.parse_args()

    if args.supported_formats:
        fig = plt.figure()
        for extension, description in fig.canvas.get_supported_filetypes().items():
            print(f'{extension}: {description}')
        sys.exit()

    if args.cad_file is None:
        print('no CAD file specified')
        sys.exit(1)

    doc = ezdxf.readfile(args.cad_file)
    try:
        layout = doc.layouts.get(args.layout)
    except KeyError:
        print(f'could not find layout "{args.layout}". Valid layouts: {[l.name for l in doc.layouts]}')
        sys.exit(1)

    fig: plt.Figure = plt.figure()
    ax: plt.Axes = fig.add_axes([0, 0, 1, 1])
    ctx = RenderContext(doc)
    out = MatplotlibBackend(ax)
    Frontend(ctx, out).draw_layout(layout, finalize=True)
    if args.out is not None:
        print(f'saving to "{args.out}"')
        fig.savefig(args.out, dpi=args.dpi)
        plt.close(fig)
    else:
        plt.show() 
Example #14
Source File: plots.py    From causallib with Apache License 2.0 5 votes vote down vote up
def plot_counterfactual_common_support_folds(predictions, hue_by, cv, alpha_by_density=True, ax=None):
    """Plot the scatter plot of y0 vs. y1 for multiple scoring results, colored by the treatment

    Args:
        predictions (list[pd.Series]): List, the size of number of folds, of outcome prediction values.
        hue_by (pd.Series): Group assignment (as in treatment assignment) of the entire dataset.
                            (indices from `cv` will be used to slice this vector)
        cv (list[np.array]): List, the size of number of folds, of row indices (as in iloc locations) - the indices
                             of samples participating the fold.
        alpha_by_density (bool): Whether to calculate points alpha value (transparent-opaque) with density estimation.
                                 This can take some time to compute for large number of points.
                                 If False, alpha calculation will be a simple fast heuristic.
        ax (plt.Axes): The axes on which the plot will be displayed. Optional.

    """
    effect_folds = [(prediction.iloc[:, 1] - prediction.iloc[:, 0]).mean() for prediction in predictions]
    predictions = pd.concat(predictions)  # type: pd.DataFrame
    treatment = pd.concat([hue_by.iloc[fold_idx] for fold_idx in cv])  # type: pd.Series

    ax = _scatter_hue(predictions.iloc[:, 0], predictions.iloc[:, 1], treatment, alpha_by_density, ax=ax)

    effect_label = r"mean effect={:.2g}".format(np.mean(effect_folds))
    effect_label += r"$\pm${:.2g}".format(np.std(effect_folds)) if len(effect_folds) > 1 else ""
    ax.plot([], [], color=ax.get_facecolor(),  # Use background color
            label=effect_label)
    _add_diagonal(ax)
    ax.legend(loc="best")
    ax.set_xlabel(r"Predicted $Y^0$")
    ax.set_ylabel(r"Predicted $Y^1$")
    ax.set_title("Predicted Common Support")
    return ax 
Example #15
Source File: ale.py    From alibi with Apache License 2.0 5 votes vote down vote up
def _plot_one_ale_num(exp: Explanation,
                      feature: int,
                      targets: List[int],
                      constant: bool = False,
                      ax: 'plt.Axes' = None,
                      legend: bool = True,
                      line_kw: dict = None) -> 'plt.Axes':
    """
    Plots the ALE of exactly one feature on one axes.
    """
    import matplotlib.pyplot as plt
    from matplotlib import transforms

    if ax is None:
        ax = plt.gca()

    # add zero baseline
    ax.axhline(0, color='grey')

    lines = ax.plot(
        exp.feature_values[feature],
        exp.ale_values[feature][:, targets] + constant * exp.constant_value,
        **line_kw
    )

    # add decile markers to the bottom of the plot
    trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
    ax.vlines(exp.feature_deciles[feature][1:], 0, 0.05, transform=trans)

    ax.set_xlabel(exp.feature_names[feature])
    ax.set_ylabel('ALE')

    if legend:
        # if no explicit labels passed, just use target names
        if line_kw['label'] is None:
            ax.legend(lines, exp.target_names[targets])

    return ax 
Example #16
Source File: rscls.py    From Remote-Sensing-Image-Classification with MIT License 5 votes vote down vote up
def save_cmap(img, cmap, fname):
   
    sizes = np.shape(img)
    height = float(sizes[0])
    width = float(sizes[1])
     
    fig = plt.figure()
    fig.set_size_inches(width/height, 1, forward=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
 
    ax.imshow(img, cmap=cmap)
    plt.savefig(fname, dpi = height) 
    plt.close() 
Example #17
Source File: rscls.py    From Remote-Sensing-Image-Classification with MIT License 5 votes vote down vote up
def save_cmap(img, cmap, fname):
   
    sizes = np.shape(img)
    height = float(sizes[0])
    width = float(sizes[1])
     
    fig = plt.figure()
    fig.set_size_inches(width/height, 1, forward=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
 
    ax.imshow(img, cmap=cmap)
    plt.savefig(fname, dpi = height) 
    plt.close() 
Example #18
Source File: plots.py    From causallib with Apache License 2.0 5 votes vote down vote up
def plot_calibration_folds(predictions, targets, cv, n_bins=10, plot_se=True,
                           plot_rug=False, plot_histogram=False, quantile=False, ax=None):
    """Plot calibration curves for multiple models (presumably in folds)

    Args:
        predictions (list[pd.Series]): list (each entry of a fold) of arrays - probability ("scores") predictions.
        targets (pd.Series): true labels to calibrate against on the overall data (not divided to folds).
        cv (list[np.array]):
        n_bins (int): number of bins to evaluate in the plot
        plot_se (bool): Whether to plot standard errors around the mean bin-probability estimation.
        plot_rug:
        plot_histogram:
        quantile (bool): If true, the binning of the calibration curve is by quantiles. Default is false
        ax (plt.Axes): Optional

    Note:
        One of plot_propensity or plot_model must be True.

    Returns:

    """
    for i, idx_fold in enumerate(cv):
        predictions_fold = predictions[i]
        target_fold = targets.iloc[idx_fold]

        ax = _plot_calibration_single(y_true=target_fold, y_prob=predictions_fold, n_bins=n_bins, plot_diagonal=False,
                                      plot_se=plot_se, plot_rug=plot_rug, plot_histogram=plot_histogram,
                                      quantile=quantile, label="fold {}".format(i), ax=ax)
    _add_diagonal(ax)
    ax.legend(loc="best")
    # ax.set_title("{} Calibration".format("Propensity" if y is None else "Outcome"))
    ax.set_title("Calibration")
    return ax 
Example #19
Source File: post.py    From lcnn with MIT License 5 votes vote down vote up
def imshow(im):
    plt.close()
    sizes = im.shape
    height = float(sizes[0])
    width = float(sizes[1])

    fig = plt.figure()
    fig.set_size_inches(width / height, 1, forward=False)
    ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
    ax.set_axis_off()
    fig.add_axes(ax)
    plt.xlim([-0.5, sizes[1] - 0.5])
    plt.ylim([sizes[0] - 0.5, -0.5])
    plt.imshow(im) 
Example #20
Source File: testing.py    From Splunking-Crime with GNU Affero General Public License v3.0 5 votes vote down vote up
def assert_is_valid_plot_return_object(objs):
    import matplotlib.pyplot as plt
    if isinstance(objs, (pd.Series, np.ndarray)):
        for el in objs.ravel():
            msg = ('one of \'objs\' is not a matplotlib Axes instance, type '
                   'encountered {name!r}').format(name=el.__class__.__name__)
            assert isinstance(el, (plt.Axes, dict)), msg
    else:
        assert isinstance(objs, (plt.Artist, tuple, dict)), \
            ('objs is neither an ndarray of Artist instances nor a '
             'single Artist instance, tuple, or dict, "objs" is a {name!r}'
             ).format(name=objs.__class__.__name__) 
Example #21
Source File: draw-wireframe.py    From lcnn with MIT License 5 votes vote down vote up
def imshow(im):
    sizes = im.shape
    height = float(sizes[0])
    width = float(sizes[1])

    fig = plt.figure()
    fig.set_size_inches(width / height, 1, forward=False)
    ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
    ax.set_axis_off()
    fig.add_axes(ax)
    plt.xlim([-0.5, sizes[1] - 0.5])
    plt.ylim([sizes[0] - 0.5, -0.5])
    plt.imshow(im) 
Example #22
Source File: matplotlib_backend.py    From ezdxf with MIT License 5 votes vote down vote up
def __init__(self, ax: plt.Axes,
                 *,
                 adjust_figure: bool = True,
                 line_width: float = 0.5,
                 point_size: float = 2.0,
                 point_size_relative: bool = True,
                 font: FontProperties = FontProperties(),
                 ):
        super().__init__()
        self.ax = ax
        self._adjust_figure = adjust_figure

        # like set_axis_off, except that the face_color can still be set
        self.ax.xaxis.set_visible(False)
        self.ax.yaxis.set_visible(False)
        for s in self.ax.spines.values():
            s.set_visible(False)

        self.ax.autoscale(False)
        self.ax.set_aspect('equal', 'datalim')
        self._current_z = 0
        self.line_width = line_width
        self.point_size = point_size
        self.point_size_relative = point_size_relative
        self.font = font
        self._font_measurements = _get_font_measurements(font) 
Example #23
Source File: generate_pic.py    From Double-Branch-Dual-Attention-Mechanism-Network with GNU Affero General Public License v3.0 5 votes vote down vote up
def classification_map(map, ground_truth, dpi, save_path):
    fig = plt.figure(frameon=False)
    fig.set_size_inches(ground_truth.shape[1] * 2.0 / dpi, ground_truth.shape[0] * 2.0 / dpi)

    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    fig.add_axes(ax)

    ax.imshow(map)
    fig.savefig(save_path, dpi=dpi)

    return 0 
Example #24
Source File: summary_analysis.py    From bsuite with Apache License 2.0 5 votes vote down vote up
def _radar(
    df: pd.DataFrame, ax: plt.Axes, label: str, all_tags: Sequence[str],
    color: str, alpha: float = 0.2, edge_alpha: float = 0.85, zorder: int = 2,
    edge_style: str = '-'):
  """Plot utility for generating the underlying radar plot."""
  tmp = df.groupby('tag').mean().reset_index()

  values = []
  for curr_tag in all_tags:
    score = 0.
    selected = tmp[tmp['tag'] == curr_tag]
    if len(selected) == 1:
      score = float(selected['score'])
    else:
      print('{} bsuite scores found for tag {!r} with setting {!r}. '
            'Replacing with zero.'.format(len(selected), curr_tag, label))
    values.append(score)
  values = np.maximum(values, 0.05)  # don't let radar collapse to 0.
  values = np.concatenate((values, [values[0]]))

  angles = np.linspace(0, 2*np.pi, len(all_tags), endpoint=False)
  angles = np.concatenate((angles, [angles[0]]))

  ax.plot(angles, values, '-', linewidth=5, label=label,
          c=color, alpha=edge_alpha, zorder=zorder, linestyle=edge_style)
  ax.fill(angles, values, alpha=alpha, color=color, zorder=zorder)
  ax.set_thetagrids(
      angles * 180/np.pi, map(_tag_pretify, all_tags), fontsize=18)

  # To avoid text on top of gridlines, we flip horizontalalignment
  # based on label location
  text_angles = np.rad2deg(angles)
  for label, angle in zip(ax.get_xticklabels()[:-1], text_angles[:-1]):
    if 90 <= angle <= 270:
      label.set_horizontalalignment('right')
    else:
      label.set_horizontalalignment('left') 
Example #25
Source File: t2_decay_experiment.py    From Cirq with Apache License 2.0 5 votes vote down vote up
def plot_bloch_vector(self,
                          ax: Optional[plt.Axes] = None,
                          **plot_kwargs: Any) -> plt.Axes:
        """Plots the estimated length of the Bloch vector versus time.

        This plot estimates the Bloch Vector by squaring the Pauli expectation
        value of X and adding it to the square of the Pauli expectation value of
        Y.  This essentially projects the state into the XY plane.

        Note that Z expectation is not considered, since T1 related amplitude
        damping will generally push this value towards |0>
        (expectation <Z> = -1) which will significantly distort the T2 numbers.

        Args:
            ax: the plt.Axes to plot on. If not given, a new figure is created,
                plotted on, and shown.
            **plot_kwargs: Arguments to be passed to 'plt.Axes.plot'.

        Returns:
            The plt.Axes containing the plot.
         """
        show_plot = not ax
        if show_plot:
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        assert ax is not None
        ax.set_ylim(ymin=0, ymax=1)

        # Estimate length of Bloch vector (projected to xy plane)
        # by squaring <X> and <Y> expectation values
        bloch_vector = (self._expectation_pauli_x**2 +
                        self._expectation_pauli_y**2)

        ax.plot(self._expectation_pauli_x['delay_ns'], bloch_vector, 'r+-',
                **plot_kwargs)
        ax.set_xlabel(
            r"Delay between initialization and measurement (nanoseconds)")
        ax.set_ylabel('Bloch Vector X-Y Projection Squared')
        ax.set_title('T2 Decay Experiment Data')
        if show_plot:
            fig.show()
        return ax 
Example #26
Source File: t2_decay_experiment.py    From Cirq with Apache License 2.0 5 votes vote down vote up
def plot_expectations(self,
                          ax: Optional[plt.Axes] = None,
                          **plot_kwargs: Any) -> plt.Axes:
        """Plots the expectation values of Pauli operators versus delay time.

        Args:
            ax: the plt.Axes to plot on. If not given, a new figure is created,
                plotted on, and shown.
            **plot_kwargs: Arguments to be passed to 'plt.Axes.plot'.

        Returns:
            The plt.Axes containing the plot.
        """
        show_plot = not ax
        if show_plot:
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        assert ax is not None
        ax.set_ylim(ymin=-2, ymax=2)

        # Plot different expectation values in different colors.
        ax.plot(self._expectation_pauli_x['delay_ns'],
                self._expectation_pauli_x['value'],
                'bo-',
                label='<X>',
                **plot_kwargs)
        ax.plot(self._expectation_pauli_y['delay_ns'],
                self._expectation_pauli_y['value'],
                'go-',
                label='<Y>',
                **plot_kwargs)

        ax.set_xlabel(
            r"Delay between initialization and measurement (nanoseconds)")
        ax.set_ylabel('Pauli Operator Expectation')
        ax.set_title('T2 Decay Pauli Expectations')
        ax.legend()
        if show_plot:
            fig.show()
        return ax 
Example #27
Source File: qubit_characterizations.py    From Cirq with Apache License 2.0 5 votes vote down vote up
def _matrix_bar_plot(mat: np.ndarray,
                     z_label: str,
                     ax: plt.Axes,
                     kets: Sequence[str] = None,
                     title: str = None,
                     ylim: Tuple[int, int] = (-1, 1),
                     **bar3d_kwargs: Any) -> None:
    num_rows, num_cols = mat.shape
    indices = np.meshgrid(range(num_cols), range(num_rows))
    x_indices = np.array(indices[1]).flatten()
    y_indices = np.array(indices[0]).flatten()
    z_indices = np.zeros(mat.size)

    dx = np.ones(mat.size) * 0.3
    dy = np.ones(mat.size) * 0.3
    dz = mat.flatten()
    ax.bar3d(x_indices,
             y_indices,
             z_indices,
             dx,
             dy,
             dz,
             color='#ff0080',
             alpha=1.0,
             **bar3d_kwargs)

    ax.set_zlabel(z_label)
    ax.set_zlim3d(ylim[0], ylim[1])

    if kets is not None:
        ax.set_xticks(np.arange(num_cols) + 0.15)
        ax.set_yticks(np.arange(num_rows) + 0.15)
        ax.set_xticklabels(kets)
        ax.set_yticklabels(kets)

    if title is not None:
        ax.set_title(title) 
Example #28
Source File: t1_decay_experiment.py    From Cirq with Apache License 2.0 5 votes vote down vote up
def plot(self, ax: Optional[plt.Axes] = None,
             **plot_kwargs: Any) -> plt.Axes:
        """Plots the excited state probability vs the amount of delay.

        Args:
            ax: the plt.Axes to plot on. If not given, a new figure is created,
                plotted on, and shown.
            **plot_kwargs: Arguments to be passed to 'plt.Axes.plot'.

        Returns:
            The plt.Axes containing the plot.
        """
        show_plot = not ax
        if show_plot:
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        assert ax is not None
        ax.set_ylim(ymin=0, ymax=1)

        xs = self._data['delay_ns']
        ts = self._data['true_count']
        fs = self._data['false_count']

        ax.plot(xs, ts / (fs + ts), 'ro-', **plot_kwargs)
        ax.set_xlabel(
            r"Delay between initialization and measurement (nanoseconds)")
        ax.set_ylabel('Excited State Probability')
        ax.set_title('T1 Decay Experiment Data')
        if show_plot:
            fig.show()
        return ax 
Example #29
Source File: qubit_characterizations_test.py    From Cirq with Apache License 2.0 5 votes vote down vote up
def test_tomography_plot_raises_for_incorrect_number_of_axes():
    simulator = sim.Simulator()
    qubit = GridQubit(0, 0)
    circuit = circuits.Circuit(ops.X(qubit)**0.5)
    result = single_qubit_state_tomography(simulator, qubit, circuit, 1000)
    with pytest.raises(TypeError):  # ax is not a List[plt.Axes]
        ax = plt.subplot()
        result.plot(ax)
    with pytest.raises(ValueError):
        _, axes = plt.subplots(1, 3)
        result.plot(axes) 
Example #30
Source File: heatmap.py    From Cirq with Apache License 2.0 5 votes vote down vote up
def _plot_colorbar(self, mappable: mpl.cm.ScalarMappable,
                       ax: plt.Axes) -> mpl.colorbar.Colorbar:
        """Plots the colorbar. Internal."""
        colorbar_ax = axes_grid1.make_axes_locatable(ax).append_axes(
            **self.colorbar_location_options)
        position = self.colorbar_location_options.get('position', 'right')
        orien = 'vertical' if position in ('left', 'right') else 'horizontal'
        colorbar = ax.figure.colorbar(mappable,
                                      colorbar_ax,
                                      ax,
                                      orientation=orien,
                                      **self.colorbar_options)
        colorbar_ax.tick_params(axis='y', direction='out')
        return colorbar