Python matplotlib.pyplot.axvspan() Examples

The following are 23 code examples of matplotlib.pyplot.axvspan(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: test_axes.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_axvspan_epoch():
    from datetime import datetime
    import matplotlib.testing.jpl_units as units
    units.register()

    # generate some data
    t0 = units.Epoch("ET", dt=datetime(2009, 1, 20))
    tf = units.Epoch("ET", dt=datetime(2009, 1, 21))

    dt = units.Duration("ET", units.day.convert("sec"))

    fig = plt.figure()

    plt.axvspan(t0, tf, facecolor="blue", alpha=0.25)

    ax = plt.gca()
    ax.set_xlim(t0 - 5.0*dt, tf + 5.0*dt) 
Example #2
Source File: test_axes.py    From twitter-stock-recommendation with MIT License 6 votes vote down vote up
def test_axvspan_epoch():
    from datetime import datetime
    import matplotlib.testing.jpl_units as units
    units.register()

    # generate some data
    t0 = units.Epoch("ET", dt=datetime(2009, 1, 20))
    tf = units.Epoch("ET", dt=datetime(2009, 1, 21))

    dt = units.Duration("ET", units.day.convert("sec"))

    fig = plt.figure()

    plt.axvspan(t0, tf, facecolor="blue", alpha=0.25)

    ax = plt.gca()
    ax.set_xlim(t0 - 5.0*dt, tf + 5.0*dt) 
Example #3
Source File: object_storage_timeline.py    From PerfKitBenchmarker with Apache License 2.0 6 votes vote down vote up
def on_motion(self, event):
    'on motion we will move the rect if the mouse is over us'
    if self.span is None:
      return

    self.span.remove()

    self.end = event.xdata
    self.span = plt.axvspan(self.start, self.end, color='blue', alpha=0.5)

    canvas = self.figure.canvas
    axes = self.span.axes
    # restore the background region
    canvas.restore_region(self.background)
    # Save the new background
    self.background = canvas.copy_from_bbox(self.span.axes.bbox)

    # redraw just the current rectangle
    axes.draw_artist(self.span)
    # blit just the redrawn area
    canvas.blit(axes.bbox)

    self.updater.update(self.start, self.end) 
Example #4
Source File: object_storage_timeline.py    From PerfKitBenchmarker with Apache License 2.0 6 votes vote down vote up
def on_press(self, event):
    'on button press we will see if the mouse is over us and store some data'
    if event.button != 3:
      # Only continue for right mouse button
      return
    if self.span is not None:
      return

    self.start = event.xdata
    self.end = event.xdata
    self.span = plt.axvspan(self.start, self.end, color='blue', alpha=0.5)

    # draw everything but the selected rectangle and store the pixel buffer
    canvas = self.figure.canvas
    axes = self.span.axes
    canvas.draw()
    self.background = canvas.copy_from_bbox(self.span.axes.bbox)

    # now redraw just the rectangle
    axes.draw_artist(self.span)
    # and blit just the redrawn area
    canvas.blit(axes.bbox)

    self.updater.update(self.start, self.end) 
Example #5
Source File: test_axes.py    From coffeegrindsize with MIT License 6 votes vote down vote up
def test_axvspan_epoch():
    from datetime import datetime
    import matplotlib.testing.jpl_units as units
    units.register()

    # generate some data
    t0 = units.Epoch("ET", dt=datetime(2009, 1, 20))
    tf = units.Epoch("ET", dt=datetime(2009, 1, 21))

    dt = units.Duration("ET", units.day.convert("sec"))

    fig = plt.figure()

    plt.axvspan(t0, tf, facecolor="blue", alpha=0.25)

    ax = plt.gca()
    ax.set_xlim(t0 - 5.0*dt, tf + 5.0*dt) 
Example #6
Source File: test_axes.py    From ImageFusion with MIT License 6 votes vote down vote up
def test_axvspan_epoch():
    from datetime import datetime
    import matplotlib.testing.jpl_units as units
    units.register()

    # generate some data
    t0 = units.Epoch("ET", dt=datetime(2009, 1, 20))
    tf = units.Epoch("ET", dt=datetime(2009, 1, 21))

    dt = units.Duration("ET", units.day.convert("sec"))

    fig = plt.figure()

    plt.axvspan(t0, tf, facecolor="blue", alpha=0.25)

    ax = plt.gca()
    ax.set_xlim(t0 - 5.0*dt, tf + 5.0*dt) 
Example #7
Source File: test_axes.py    From neural-network-animation with MIT License 6 votes vote down vote up
def test_axvspan_epoch():
    from datetime import datetime
    import matplotlib.testing.jpl_units as units
    units.register()

    # generate some data
    t0 = units.Epoch("ET", dt=datetime(2009, 1, 20))
    tf = units.Epoch("ET", dt=datetime(2009, 1, 21))

    dt = units.Duration("ET", units.day.convert("sec"))

    fig = plt.figure()

    plt.axvspan(t0, tf, facecolor="blue", alpha=0.25)

    ax = plt.gca()
    ax.set_xlim(t0 - 5.0*dt, tf + 5.0*dt) 
Example #8
Source File: test_axes.py    From twitter-stock-recommendation with MIT License 5 votes vote down vote up
def test_twinx_knows_limits():
    fig, ax = plt.subplots()

    ax.axvspan(1, 2)
    xtwin = ax.twinx()
    xtwin.plot([0, 0.5], [1, 2])
    # control axis
    fig2, ax2 = plt.subplots()

    ax2.axvspan(1, 2)
    ax2.plot([0, 0.5], [1, 2])

    assert_array_equal(xtwin.viewLim.intervalx, ax2.viewLim.intervalx) 
Example #9
Source File: initialize.py    From qkit with GNU General Public License v2.0 5 votes vote down vote up
def crop_recording_window(self):
        self._sample.mspec.spec_stop()
        self._sample.mspec.set_averages(1e4)
        self._sample.mspec.set_window(0,512)
        self._sample.mspec.set_segments(1)
        msp = self._sample.mspec.acquire()
        
        def pltfunc(start,end,done):
            if done:
                self._sample.acqu_window = [start,end]
                self._sample.mspec.set_window(start,end)
                self._sw.disabled = True
                self._ew.disabled = True
                self._dw.disabled = True
                self._dw.description = "acqu_window set to [{:d}:{:d}]".format(start,end)
            else:
                plt.figure(figsize=(15,5))
                plt.plot(msp)
                plt.axvspan(0,start,color='k',alpha=.2)
                plt.axvspan(end,len(msp),color='k',alpha=.2)
                plt.xlim(0,len(msp))
                plt.show()
        self._sw =  widgets.IntSlider(min=0,max=len(msp),step=1,value=self._sample.acqu_window[0],continuous_update=True)
        self._ew = widgets.IntSlider(min=0,max=len(msp),step=1,value=self._sample.acqu_window[1],continuous_update=True)
        self._dw = widgets.Checkbox(value=False,description="Done!",indent=True)
        self._wgt = widgets.interact(pltfunc,start=self._sw,end=self._ew,done=self._dw)
        self._sample.mspec.set_window(*self._sample.acqu_window) 
Example #10
Source File: analysis.py    From px4tools with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def background_flight_modes(data):
    """
    Overlays a background color for each flight mode. Can be called to style a graph.
    """
    import matplotlib.pyplot as plt
    modes = np.array(data.STAT_MainState.unique(), np.uint8)
    for m in modes:
        mode_data = data.STAT_MainState[data.STAT_MainState == m]
        mode_name = FLIGHT_MODES[m]
        mode_color = FLIGHT_MODE_COLOR[mode_name]
        t_min = mode_data.index[0]
        t_max = mode_data.index[mode_data.count() - 1]
        plt.axvspan(
            t_min, t_max, alpha=0.1, color=mode_color,
            label=mode_name, zorder=0) 
Example #11
Source File: example_rbasex_block.py    From PyAbel with MIT License 5 votes vote down vote up
def plots(row,
          im_title, im, im_mask,
          tr_title, tr, tr_mask,
          pr_title, r, P0, P2):
    # input image
    if im is not None:
        plt.subplot(3, 4, 4 * row + 1)
        plt.title(im_title)
        im_masked = np.ma.masked_where(im_mask == 0, im)
        plt.imshow(im_masked, cmap='hot')
        plt.axis('off')

    # transformed image
    plt.subplot(3, 4, 4 * row + 2)
    plt.title(tr_title)
    tr_masked = np.ma.masked_where(tr_mask == 0, tr)
    plt.imshow(tr_masked, vmin=-vlim, vmax=vlim, cmap='seismic')
    plt.axis('off')

    # profiles
    plt.subplot(3, 2, 2 * row + 2)
    plt.title(pr_title)
    plt.axvspan(0, mask_r, color='lightgray')  # shade region without valid data
    plt.plot(r_src, P0_src, 'C0--', lw=1)
    plt.plot(r_src, P2_src, 'C3--', lw=1)
    plt.plot(r, P0, 'C0', lw=1, label='$P_0(r)$')
    plt.plot(r, P2, 'C3', lw=1, label='$P_2(r)$')
    plt.xlim((0, R))
    plt.ylim(ylim)
    plt.legend() 
Example #12
Source File: IDAtropy.py    From IDAtropy with GNU General Public License v3.0 5 votes vote down vote up
def segment_changed(self, item):
      row = item.row()
      col = item.column()
      seg_name = item.text()

      if (item.checkState() == QtCore.Qt.Checked):
        start, end = self.segments[seg_name]['chart_offsets']
        aspan = plt.axvspan(start, end, color=self.colors[row % len(self.colors)], alpha=0.6)
        self.spans[seg_name] = aspan      
      else:
        if seg_name in self.spans.keys():
          self.spans[seg_name].remove()
          del self.spans[seg_name]      
      self.canvas.draw() 
Example #13
Source File: fig2.py    From OASIS with GNU General Public License v3.0 5 votes vote down vote up
def cb(y, P, counter, current):
    solution = np.empty(len(y))
    for v, w, f, l in P:
        solution[f:f + l] = max(v, 0) / w * g**np.arange(l)
    plt.figure(figsize=(3, 3))
    color = y.copy()
    plt.plot(solution, c='k', zorder=-11, lw=1.2)
    plt.scatter(np.arange(len(y)), solution, s=60, cmap=plt.cm.Spectral,
                c=color, clip_on=False, zorder=11)
    plt.scatter([np.arange(len(y))[current]], [solution[current]],
                s=200, lw=2.5, marker='+', color='b', clip_on=False, zorder=11)
    for a in P[::2]:
        plt.axvspan(a[2], a[2] + a[3], alpha=0.1, color='k', zorder=-11)
    for x in np.where(trueSpikes)[0]:
        plt.plot([x, x], [0, 1.65], lw=1.5, c='r', zorder=-12)
    plt.xlim((0, len(y) - .5))
    plt.ylim((0, 1.65))
    simpleaxis(plt.gca())
    plt.xticks([])
    plt.yticks([])
    if save_figs:
        plt.savefig('fig/%d.pdf' % counter)
    plt.show()


# generate data 
Example #14
Source File: backtest.py    From sanpy with MIT License 5 votes vote down vote up
def plot_backtest(self, viz=None):
        ''' param viz: None OR "trades" OR "hodl".
        '''
        plt.figure(figsize=(15, 8))
        plt.plot(self.performance, label="performance")
        plt.plot(self.benchmark, label="holding")

        if viz == 'trades':
            min_y = min(self.performance.min(), self.benchmark.min())
            max_y = max(self.performance.max(), self.benchmark.max())
            plt.vlines(self.nr_trades['sell'], min_y, max_y, color='red')
            plt.vlines(self.nr_trades['buy'], min_y, max_y, color='green')
        elif viz == 'hodl':
            hodl_periods = []
            for i in range(len(self.trades)):
                state = self.trades[i - 1] if i > 0 else self.trades[i]
                if self.trades[i] and not state:
                    start = self.strategy_returns.index[i]
                elif not self.trades[i] and state:
                    hodl_periods.append([start, self.strategy_returns.index[i]])
            if self.trades[-1]:
                hodl_periods.append([start, self.strategy_returns.index[i]])
            for hodl_period in hodl_periods:
                plt.axvspan(hodl_period[0], hodl_period[1], color='#aeffa8')

        plt.legend()
        plt.show() 
Example #15
Source File: application.py    From seasonal with MIT License 5 votes vote down vote up
def _periodogram_plot(title, column, data, trend, peaks):
    """display periodogram results using matplotlib"""

    periods, power = periodogram(data)
    plt.figure(1)
    plt.subplot(311)
    plt.title(title)
    plt.plot(data, label=column)
    if trend is not None:
        plt.plot(trend, linewidth=3, label="broad trend")
        plt.legend()
        plt.subplot(312)
        plt.title("detrended")
        plt.plot(data - trend)
    else:
        plt.legend()
        plt.subplot(312)
        plt.title("(no detrending specified)")
    plt.subplot(313)
    plt.title("periodogram")
    plt.stem(periods, power)
    for peak in peaks:
        period, score, pmin, pmax = peak
        plt.axvline(period, linestyle='dashed', linewidth=2)
        plt.axvspan(pmin, pmax, alpha=0.2, color='b')
        plt.annotate("{}".format(period), (period, score * 0.8))
        plt.annotate("{}...{}".format(pmin, pmax), (pmin, score * 0.5))
    plt.tight_layout()
    plt.show() 
Example #16
Source File: test_axes.py    From coffeegrindsize with MIT License 5 votes vote down vote up
def test_twinx_knows_limits():
    fig, ax = plt.subplots()

    ax.axvspan(1, 2)
    xtwin = ax.twinx()
    xtwin.plot([0, 0.5], [1, 2])
    # control axis
    fig2, ax2 = plt.subplots()

    ax2.axvspan(1, 2)
    ax2.plot([0, 0.5], [1, 2])

    assert_array_equal(xtwin.viewLim.intervalx, ax2.viewLim.intervalx) 
Example #17
Source File: gettingStarted.py    From pyABF with MIT License 5 votes vote down vote up
def advanced_10a_digital_output_shading(self):
        """
        ## Shading Epochs

        In this ABF digital output 4 is high during epoch C. Let's highlight
        this by plotting sweeps and shading that epoch.

        `print(abf.epochPoints)` yields `[0, 3125, 7125, 23125, 23145, 200000]`
        and I know the epoch I'm interested in is bound by index 3 and 4.
        """

        import pyabf
        abf = pyabf.ABF("data/abfs/17o05026_vc_stim.abf")

        plt.figure(figsize=self.figsize)
        for sweepNumber in abf.sweepList:
            abf.setSweep(sweepNumber)
            plt.plot(abf.sweepX, abf.sweepY, color='C0', alpha=.5, lw=.5)
        plt.ylabel(abf.sweepLabelY)
        plt.xlabel(abf.sweepLabelX)
        plt.title("Shade a Specific Epoch")
        plt.axis([1.10, 1.25, -150, 50])

        epochNumber = 3
        t1 = abf.sweepEpochs.p1s[epochNumber] * abf.dataSecPerPoint
        t2 = abf.sweepEpochs.p2s[epochNumber] * abf.dataSecPerPoint
        plt.axvspan(t1, t2, color='r', alpha=.3, lw=0)
        plt.grid(alpha=.2)
        self.saveAndClose() 
Example #18
Source File: fig2.py    From OASIS with GNU General Public License v3.0 4 votes vote down vote up
def cb(y, P, counter, current):
    solution = np.empty(len(y))
    for i, (v, w, f, l) in enumerate(P):
        solution[f:f + l] = (v if i else max(v, 0)) / w * g**np.arange(l)
    color = y.copy()
    ax1.plot(solution, c='k', zorder=-11, lw=1.3, clip_on=False)
    ax1.scatter(np.arange(len(y)), solution, s=40, cmap=plt.cm.Spectral,
                c=color, clip_on=False, zorder=11)
    ax1.scatter([np.arange(len(y))[current]], [solution[current]],
                s=120, lw=2.5, marker='+', color='b', clip_on=False, zorder=11)
    for a in P[::2]:
        ax1.axvspan(a[2], a[2] + a[3], alpha=0.1, color='k', zorder=-11)
    for x in np.where(trueSpikes)[0]:
        ax1.plot([x, x], [0, 2.3], lw=1.5, c='r', zorder=-12)
    ax1.set_xlim((0, len(y) - .5))
    ax1.set_ylim((0, 2.3))
    simpleaxis(ax1)
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_ylabel('Fluorescence')
    for i, s in enumerate(np.r_[[0], solution[1:] - g * solution[:-1]]):
        ax2.plot([i, i], [0, s], c='k', zorder=-11, lw=1.4, clip_on=False)
    ax2.scatter(np.arange(len(y)), np.r_[[0], solution[1:] - g * solution[:-1]],
                s=40, cmap=plt.cm.Spectral, c=color, clip_on=False, zorder=11)
    ax2.scatter([np.arange(len(y))[current]],
                [np.r_[[0], solution[1:] - g * solution[:-1]][current]],
                s=120, lw=2.5, marker='+', color='b', clip_on=False, zorder=11)
    for a in P[::2]:
        ax2.axvspan(a[2], a[2] + a[3], alpha=0.1, color='k', zorder=-11)
    for x in np.where(trueSpikes)[0]:
        ax2.plot([x, x], [0, 1.55], lw=1.5, c='r', zorder=-12)
    ax2.set_xlim((0, len(y) - .5))
    ax2.set_ylim((0, 1.55))
    simpleaxis(ax2)
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_xlabel('Time', labelpad=35, x=.5)
    ax2.set_ylabel('Spikes')
    plt.subplots_adjust(left=0.032, right=.995, top=.995, bottom=0.19, hspace=0.22)
    if save_figs:
        plt.savefig('video/%03d.pdf' % counter)
    plt.pause(1e-9)
    ax1.clear()
    ax2.clear()


# generate data 
Example #19
Source File: test_axes.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def test_twinx_knows_limits():
    fig, ax = plt.subplots()

    ax.axvspan(1, 2)
    xtwin = ax.twinx()
    xtwin.plot([0, 0.5], [1, 2])
    # control axis
    fig2, ax2 = plt.subplots()

    ax2.axvspan(1, 2)
    ax2.plot([0, 0.5], [1, 2])

    assert_array_equal(xtwin.viewLim.intervalx, ax2.viewLim.intervalx) 
Example #20
Source File: eeg.py    From EEG with MIT License 4 votes vote down vote up
def plotERPElectrodes(data, trialNumList, events, trialDur=None, fs=2048.,
    baselineDur=0.1, electrodes='Fp1', normalize=False, facet=False,
    startOffset=0):
    """
    Plot the ERP (average across trials time-locked to specific events) of
    each electrode as single lines on the same figure with or without facetting.

    Parameters
    ----------
    data : instance of pandas.core.DataFrame
        Data containing the time series to transform and plot. Each column is an
        electrode.
    trialNumList : array-like of int
        List of all trials to use to compute the FFT.
    events : instance of pandas.core.DataFrame
        Dataframe containing the list of events obtained with
        mne.find_events(raw).
    trialDur : float
        Trial duration in seconds.
    fs : float
        Sampling frequency of data in Hz.
    baselineDur : float, defaults to 0.1
        Duration of the baseline in seconds. If normalize is True, normalize
        each electrode with a baseline of duration `baselineDur`.
    electrodes : int | array-like of int, default to 'Fp1'
        List of electrodes to use to compute the FFT.
    normalize : bool, defaults to False
        If True data will be normalized.
    facet : bool, default to False
        If True, each electrode will be plotted on a different facet.

    Returns:

    fig : instance of matplotlib.figure.Figure
        The figure of the ERP.
    """

    print 'Average of %d trials' % len(trialNumList)
    meanTrials = pd.DataFrame()
    for electrode in electrodes:
        meanTrials[electrode], allTrials = getTrialsAverage(data=data[electrode],
            events=events, trialDur=trialDur, trialNumList=trialNumList,
            baselineDur=baselineDur, normalize=normalize, fs=fs, startOffset=startOffset)

    if (facet):
        print 'Faceting...'
        meanTrials.plot(subplots=True)
    else:
        plt.figure()
        plt.plot(meanTrials)
        plt.axvline(x=0, color='grey', linestyle='dotted')
        plt.axvspan(-baselineDur, 0, alpha=0.3, color='grey')
        plt.xlabel('Time (s)')
        # plt.legend(meanTrials.columns, bbox_to_anchor=(1, 1), ncol=4)
        plt.show() 
Example #21
Source File: plotting.py    From msaf with MIT License 4 votes vote down vote up
def plot_one_track(file_struct, est_times, est_labels, boundaries_id, labels_id,
                   title=None):
    """Plots the results of one track, with ground truth if it exists."""
    import matplotlib.pyplot as plt
    # Set up the boundaries id
    bid_lid = boundaries_id
    if labels_id is not None:
        bid_lid += " + " + labels_id
    try:
        # Read file
        jam = jams.load(file_struct.ref_file)
        ann = jam.search(namespace='segment_.*')[0]
        ref_inters, ref_labels = ann.to_interval_values()

        # To times
        ref_times = utils.intervals_to_times(ref_inters)
        all_boundaries = [ref_times, est_times]
        all_labels = [ref_labels, est_labels]
        algo_ids = ["GT", bid_lid]
    except:
        logging.warning("No references found in %s. Not plotting groundtruth"
                        % file_struct.ref_file)
        all_boundaries = [est_times]
        all_labels = [est_labels]
        algo_ids = [bid_lid]

    N = len(all_boundaries)

    # Index the labels to normalize them
    for i, labels in enumerate(all_labels):
        all_labels[i] = mir_eval.util.index_labels(labels)[0]

    # Get color map
    cm = plt.get_cmap('gist_rainbow')
    max_label = max(max(labels) for labels in all_labels)

    figsize = (8, 4)
    plt.figure(1, figsize=figsize, dpi=120, facecolor='w', edgecolor='k')
    for i, boundaries in enumerate(all_boundaries):
        color = "b"
        if i == 0:
            color = "g"
        for b in boundaries:
            plt.axvline(b, i / float(N), (i + 1) / float(N), color=color)
        if labels_id is not None:
            labels = all_labels[i]
            inters = utils.times_to_intervals(boundaries)
            for label, inter in zip(labels, inters):
                plt.axvspan(inter[0], inter[1], ymin=i / float(N),
                            ymax=(i + 1) / float(N), alpha=0.6,
                            color=cm(label / float(max_label)))
        plt.axhline(i / float(N), color="k", linewidth=1)

    # Format plot
    _plot_formatting(title, os.path.basename(file_struct.audio_file), algo_ids,
                     all_boundaries[0][-1], N, None) 
Example #22
Source File: plotting.py    From msaf with MIT License 4 votes vote down vote up
def plot_labels(all_labels, gt_times, est_file, algo_ids=None, title=None,
                output_file=None):
    """Plots all the labels.

    Parameters
    ----------
    all_labels: list
        A list of np.arrays containing the labels of the boundaries, one array
        for each algorithm.
    gt_times: np.array
        Array with the ground truth boundaries.
    est_file: str
        Path to the estimated file (JSON file)
    algo_ids : list
        List of algorithm ids to to read boundaries from.
        If None, all algorithm ids are read.
    title : str
        Title of the plot. If None, the name of the file is printed instead.
    """
    import matplotlib.pyplot as plt
    N = len(all_labels)  # Number of lists of labels
    if algo_ids is None:
        algo_ids = io.get_algo_ids(est_file)

    # Translate ids
    for i, algo_id in enumerate(algo_ids):
        algo_ids[i] = translate_ids[algo_id]
    algo_ids = ["GT"] + algo_ids

    # Index the labels to normalize them
    for i, labels in enumerate(all_labels):
        all_labels[i] = mir_eval.util.index_labels(labels)[0]

    # Get color map
    cm = plt.get_cmap('gist_rainbow')
    max_label = max(max(labels) for labels in all_labels)

    # To intervals
    gt_inters = utils.times_to_intervals(gt_times)

    # Plot labels
    figsize = (6, 4)
    plt.figure(1, figsize=figsize, dpi=120, facecolor='w', edgecolor='k')
    for i, labels in enumerate(all_labels):
        for label, inter in zip(labels, gt_inters):
            plt.axvspan(inter[0], inter[1], ymin=i / float(N),
                        ymax=(i + 1) / float(N), alpha=0.6,
                        color=cm(label / float(max_label)))
        plt.axhline(i / float(N), color="k", linewidth=1)

    # Draw the boundary lines
    for bound in gt_times:
        plt.axvline(bound, color="g")

    # Format plot
    _plot_formatting(title, est_file, algo_ids, gt_times[-1], N,
                     output_file) 
Example #23
Source File: plotting.py    From msaf with MIT License 4 votes vote down vote up
def plot_tree(T, res=None, title=None, cmap_id="Pastel2"):
    """Plots a given tree, containing hierarchical segmentation.

    Parameters
    ----------
    T: mir_eval.segment.tree
        A tree object containing the hierarchical segmentation.
    res: float
        Frame-rate resolution of the tree (None to use seconds).
    title: str
        Title for the plot. `None` for no title.
    cmap_id: str
        Color Map ID
    """
    import matplotlib.pyplot as plt
    def round_time(t, res=0.1):
        v = int(t / float(res)) * res
        return v

    # Get color map
    cmap = plt.get_cmap(cmap_id)

    # Get segments by level
    level_bounds = []
    for level in T.levels:
        if level == "root":
            continue
        segments = T.get_segments_in_level(level)
        level_bounds.append(segments)

    # Plot axvspans for each segment
    B = float(len(level_bounds))
    #plt.figure(figsize=figsize)
    for i, segments in enumerate(level_bounds):
        labels = utils.segment_labels_to_floats(segments)
        for segment, label in zip(segments, labels):
            #print i, label, cmap(label)
            if res is None:
                start = segment.start
                end = segment.end
                xlabel = "Time (seconds)"
            else:
                start = int(round_time(segment.start, res=res) / res)
                end = int(round_time(segment.end, res=res) / res)
                xlabel = "Time (frames)"
            plt.axvspan(start, end,
                        ymax=(len(level_bounds) - i) / B,
                        ymin=(len(level_bounds) - i - 1) / B,
                        facecolor=cmap(label))

    # Plot labels
    L = float(len(T.levels) - 1)
    plt.yticks(np.linspace(0, (L - 1) / L, num=L) + 1 / L / 2.,
               T.levels[1:][::-1])
    plt.xlabel(xlabel)
    if title is not None:
        plt.title(title)
    plt.gca().set_xlim([0, end])