Python matplotlib.pyplot.axvline() Examples

The following are 30 code examples of matplotlib.pyplot.axvline(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: m_dos_pdos_eigenvalues.py    From pyscf with Apache License 2.0 6 votes vote down vote up
def dosplot (filename = None, data = None, fermi = None):
    if (filename is not None): data = np.loadtxt(filename)
    elif (data is not None): data = data

    import matplotlib.pyplot as plt
    from matplotlib import rc
    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    plt.plot(data.T[0], data.T[1], label='MF Spin-UP', linestyle=':',color='r')
    plt.fill_between(data.T[0], 0, data.T[1], facecolor='r',alpha=0.1, interpolate=True)
    plt.plot(data.T[0], data.T[2], label='QP Spin-UP',color='r')
    plt.fill_between(data.T[0], 0, data.T[2], facecolor='r',alpha=0.5, interpolate=True)
    plt.plot(data.T[0],-data.T[3], label='MF Spin-DN', linestyle=':',color='b')
    plt.fill_between(data.T[0], 0, -data.T[3], facecolor='b',alpha=0.1, interpolate=True)
    plt.plot(data.T[0],-data.T[4], label='QP Spin-DN',color='b')
    plt.fill_between(data.T[0], 0, -data.T[4], facecolor='b',alpha=0.5, interpolate=True)
    if (fermi!=None): plt.axvline(x=fermi ,color='k', linestyle='--') #label='Fermi Energy'
    plt.axhline(y=0,color='k')
    plt.title('Total DOS', fontsize=20)
    plt.xlabel('Energy (eV)', fontsize=15) 
    plt.ylabel('Density of States (electron/eV)', fontsize=15)
    plt.legend()
    plt.savefig("dos_eigen.svg", dpi=900)
    plt.show() 
Example #2
Source File: clustering.py    From malss with MIT License 6 votes vote down vote up
def plot_gap(cls, algorithm, dname):
        if dname is None:
            return
        if not os.path.exists(dname):
            os.mkdir(dname)

        plt.figure()
        plt.title(algorithm.estimator.__class__.__name__)
        plt.xlabel("Number of clusters")
        plt.ylabel("Gap statistic")

        plt.plot(range(algorithm.results['min_nc'], algorithm.results['max_nc'] + 1),
                    algorithm.results['gap'], 'o-', color='dodgerblue')
        plt.errorbar(range(algorithm.results['min_nc'], algorithm.results['max_nc'] + 1),
                        algorithm.results['gap'], algorithm.results['gap_sk'], capsize=3)
        plt.axvline(x=algorithm.results['gap_nc'], ls='--', C='gray', zorder=0)
        plt.savefig('%s/gap_%s.png' %
                    (dname, algorithm.estimator.__class__.__name__),
                    bbox_inches='tight', dpi=75)
        plt.close() 
Example #3
Source File: clustering.py    From malss with MIT License 6 votes vote down vote up
def plot_silhouette(cls, algorithm, dname):
        if dname is None:
            return
        if not os.path.exists(dname):
            os.mkdir(dname)

        plt.figure()
        plt.title(algorithm.estimator.__class__.__name__)
        plt.xlabel("Number of clusters")
        plt.ylabel("Silhouette score")

        plt.plot(range(algorithm.results['min_nc'], algorithm.results['max_nc'] + 1),
                    algorithm.results['silhouette'], 'o-', color='darkorange')
        plt.axvline(x=algorithm.results['silhouette_nc'], ls='--', C='gray', zorder=0)
        plt.savefig('%s/silhouette_%s.png' %
                    (dname, algorithm.estimator.__class__.__name__),
                    bbox_inches='tight', dpi=75)
        plt.close() 
Example #4
Source File: clustering.py    From malss with MIT License 6 votes vote down vote up
def plot_davies(cls, algorithm, dname):
        if dname is None:
            return
        if not os.path.exists(dname):
            os.mkdir(dname)

        plt.figure()
        plt.title(algorithm.estimator.__class__.__name__)
        plt.xlabel("Number of clusters")
        plt.ylabel("Davies-Bouldin score")

        plt.plot(range(algorithm.results['min_nc'], algorithm.results['max_nc'] + 1),
                    algorithm.results['davies'], 'o-', color='limegreen')
        plt.axvline(x=algorithm.results['davies_nc'], ls='--', C='gray', zorder=0)
        plt.savefig('%s/davies_%s.png' %
                    (dname, algorithm.estimator.__class__.__name__),
                    bbox_inches='tight', dpi=75)
        plt.close() 
Example #5
Source File: clustering.py    From malss with MIT License 6 votes vote down vote up
def plot_calinski(cls, algorithm, dname):
        if dname is None:
            return
        if not os.path.exists(dname):
            os.mkdir(dname)

        plt.figure()
        plt.title(algorithm.estimator.__class__.__name__)
        plt.xlabel("Number of clusters")
        plt.ylabel("Calinski and Harabasz score")

        plt.plot(range(algorithm.results['min_nc'], algorithm.results['max_nc'] + 1),
                    algorithm.results['calinski'], 'o-', color='crimson')
        plt.axvline(x=algorithm.results['calinski_nc'], ls='--', C='gray', zorder=0)
        plt.savefig('%s/calinski_%s.png' %
                    (dname, algorithm.estimator.__class__.__name__),
                    bbox_inches='tight', dpi=75)
        plt.close() 
Example #6
Source File: validation_plots.py    From TheCannon with MIT License 6 votes vote down vote up
def chisq_dist():
    fig = plt.figure(figsize=(6,4))
    ivar = np.load("%s/val_ivar_norm.npz" %DATA_DIR)['arr_0']
    npix = np.sum(ivar>0, axis=1)
    chisq = np.load("%s/val_chisq.npz" %DATA_DIR)['arr_0']
    redchisq = chisq/npix
    nbins = 25
    plt.hist(redchisq, bins=nbins, color='k', histtype="step",
            lw=2, normed=False, alpha=0.3, range=(0,3))
    plt.legend()
    plt.xlabel("Reduced $\chi^2$", fontsize=16)
    plt.tick_params(axis='both', labelsize=16)
    plt.ylabel("Count", fontsize=16)
    plt.axvline(x=1.0, linestyle='--', c='k')
    fig.tight_layout()
    #plt.show()
    plt.savefig("chisq_dist.png") 
Example #7
Source File: memory_cpu_profile.py    From mlens with MIT License 6 votes vote down vote up
def plot_rss(cm, t1, t2, t3):
    """Plot the memory profile."""
    f = plt.figure(figsize=(8, 6))
    plt.plot(range(cm.cpu.shape[0]), cm.rss / 1000000)
    plt.axvline(t1 - 3, color='darkcyan', linestyle='--', linewidth=1.0,
                label='load data')
    plt.axvline(t2, color='blue', linestyle='--', linewidth=1.0,
                label='fit start')
    plt.axvline(t3, color='blue', linestyle='-.', linewidth=1.0,
                label='fit end')
    plt.xticks([i for i in [0, 50, 100, 150, 200, 250]],
               [i for i in [0, 5, 10, 15, 20, 25]])
#    plt.ylim(120, 240)
    plt.title("ML-Ensemble memory profile (working set)")
    plt.ylabel("Working set memory (MB)")
    plt.xlabel("Time (s)")
    plt.legend()
    plt.show()

    if PRINT:
        try:
            f.savefig("dev/img/memory_profile.png", dpi=600)
        except:
            f.savefig("memory_profile.png", dpi=600) 
Example #8
Source File: memory_cpu_profile.py    From mlens with MIT License 6 votes vote down vote up
def plot_cpu(cm, t1, t2, t3):
    """Plot the CPU profile."""
    f = plt.figure()
    plt.plot(range(cm.cpu.shape[0]), cm.cpu)
    plt.axvline(t1 - 3, color='darkcyan', linestyle='--', linewidth=1.0,
                label='load data')
    plt.axvline(t2, color='blue', linestyle='--', linewidth=1.0,
                label='fit start')
    plt.axvline(t3, color='blue', linestyle='-.', linewidth=1.0,
                label='fit end')
    plt.xticks([i for i in [0, 50, 100, 150, 200, 250]],
               [i for i in [0, 5, 10, 15, 20, 25]])
    plt.title("ML-Ensemble CPU profile")
    plt.ylabel("CPU utilization (%)")
    plt.xlabel("Time (s)")
    plt.legend()

    if PRINT:
        try:
            f.savefig("dev/cpu_profile.png", dpi=600)
        except:
            f.savefig("cpu_profile.png", dpi=600) 
Example #9
Source File: measures.py    From nolds with MIT License 6 votes vote down vote up
def plot_histogram_matrix(data, name, fname=None):
  # local import to avoid dependency for non-debug use
  import matplotlib.pyplot as plt
  nhists = len(data[0])
  nbins = 25
  ylim = (0, 0.5)
  nrows = int(np.ceil(np.sqrt(nhists)))
  plt.figure(figsize=(nrows * 4, nrows * 4))
  for i in range(nhists):
    plt.subplot(nrows, nrows, i + 1)
    absmax = max(abs(np.max(data[:, i])), abs(np.min(data[:, i])))
    rng = (-absmax, absmax)
    h, bins = np.histogram(data[:, i], nbins, rng)
    bin_width = bins[1] - bins[0]
    h = h.astype("float32") / np.sum(h)
    plt.bar(bins[:-1], h, bin_width)
    plt.axvline(np.mean(data[:, i]), color="red")
    plt.ylim(ylim)
    plt.title("{:s}[{:d}]".format(name, i))
  if fname is None:
    plt.show()
  else:
    plt.savefig(fname)
  plt.close() 
Example #10
Source File: correlation_analysis.py    From copper_price_forecast with GNU General Public License v3.0 6 votes vote down vote up
def data_visualization(co_price, pcb_price):
    """
    原始数据可视化
    """
    x_co_values = co_price.index
    y_co_values = co_price.price / 100

    x_pcb_values = pcb_price.index
    y_pcb_values = pcb_price.price

    plt.figure(figsize=(10, 6))
    plt.title('copper price(100rmb/t) vs. pcb price(rmb/sq.m.)')
    plt.xlabel('date')
    plt.ylabel('history price')

    plt.plot(x_co_values, y_co_values, '-', label='co price')
    plt.plot(x_pcb_values, y_pcb_values, '-', label='pcb price')
    plt.axvline('2015-04-23', linewidth=1, color='r', linestyle='dashed')
    plt.axvline('2015-10-23', linewidth=1, color='r', linestyle='dashed')
    plt.axvline('2016-04-23', linewidth=1, color='r', linestyle='dashed')
    plt.axvline('2016-10-23', linewidth=1, color='r', linestyle='dashed')

    plt.legend(loc='upper right')

    plt.show() 
Example #11
Source File: electronic.py    From pyiron with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_fermi_dirac(self):
        """
        Plots the obtained eigenvalue vs occupation plot

        """
        try:
            import matplotlib.pylab as plt
        except ModuleNotFoundError:
            import matplotlib.pyplot as plt
        arg = np.argsort(self.eigenvalues)
        plt.plot(
            self.eigenvalues[arg], self.occupancies[arg], linewidth=2.0, color="blue"
        )
        plt.axvline(self.efermi, linewidth=2.0, linestyle="dashed", color="black")
        plt.xlabel("Energies (eV)")
        plt.ylabel("Occupancy")
        return plt 
Example #12
Source File: burst_plot.py    From FRETBursts with GNU General Public License v2.0 6 votes vote down vote up
def _hist_burst_taildist(data, bins, pdf, weights=None, yscale='log',
                         color=None, label=None, plot_style=None, vline=None):
    hist = HistData(*np.histogram(data[~np.isnan(data)],
                                  bins=_bins_array(bins), weights=weights))
    ydata = hist.pdf if pdf else hist.counts

    default_plot_style = dict(marker='o')
    if plot_style is None:
        plot_style = {}
    if color is not None:
        plot_style['color'] = color
    if label is not None:
        plot_style['label'] = label
    default_plot_style.update(_normalize_kwargs(plot_style, kind='line2d'))
    plt.plot(hist.bincenters, ydata, **default_plot_style)
    if vline is not None:
        plt.axvline(vline, ls='--')
    plt.yscale(yscale)
    if pdf:
        plt.ylabel('PDF')
    else:
        plt.ylabel('# Bursts') 
Example #13
Source File: hicCompartmentalization.py    From HiCExplorer with GNU General Public License v3.0 6 votes vote down vote up
def plot_polarization_ratio(polarization_ratio, plotName, labels,
                            number_of_quantiles):
    """
    Generate a plot to visualize the polarization ratio between A and B
    compartments. It presents how well 2 compartments are seperated.
    """

    for i, r in enumerate(polarization_ratio):
        plt.plot(r, marker="o", label=labels[i])
    plt.axhline(1, c='grey', ls='--', lw=1)
    plt.axvline(number_of_quantiles / 2, c='grey', ls='--', lw=1)
    plt.legend(loc='best')
    plt.xlabel('Quantiles')
    plt.ylabel('signal within comp. / signla between comp.')
    plt.title('compartment polarization ratio')
    plt.savefig(plotName) 
Example #14
Source File: acqzoo.py    From pyGPGO with MIT License 6 votes vote down vote up
def plotGPGO(gpgo, param, index, new=True):
    param_value = list(param.values())[0][1]
    x_test = np.linspace(param_value[0], param_value[1], 1000).reshape((1000, 1))
    y_hat, y_var = gpgo.GP.predict(x_test, return_std=True)
    std = np.sqrt(y_var)
    l, u = y_hat - 1.96 * std, y_hat + 1.96 * std
    if new:
        plt.figure()
        plt.subplot(5, 1, 1)
        plt.fill_between(x_test.flatten(), l, u, alpha=0.2)
        plt.plot(x_test.flatten(), y_hat)
    plt.subplot(5, 1, index)
    a = np.array([-gpgo._acqWrapper(np.atleast_1d(x)) for x in x_test]).flatten()
    plt.plot(x_test, a, color=colors[index - 2], label=acq_titles[index - 2])
    gpgo._optimizeAcq(method='L-BFGS-B', n_start=1000)
    plt.axvline(x=gpgo.best)
    plt.legend(loc=0) 
Example #15
Source File: example1d.py    From pyGPGO with MIT License 6 votes vote down vote up
def plotGPGO(gpgo, param):
    param_value = list(param.values())[0][1]
    x_test = np.linspace(param_value[0], param_value[1], 1000).reshape((1000, 1))
    hat = gpgo.GP.predict(x_test, return_std=True)
    y_hat, y_std = hat[0], np.sqrt(hat[1])
    l, u = y_hat - 1.96 * y_std, y_hat + 1.96 * y_std
    fig = plt.figure()
    r = fig.add_subplot(2, 1, 1)
    r.set_title('Fitted Gaussian process')
    plt.fill_between(x_test.flatten(), l, u, alpha=0.2)
    plt.plot(x_test.flatten(), y_hat, color='red', label='Posterior mean')
    plt.legend(loc=0)
    a = np.array([-gpgo._acqWrapper(np.atleast_1d(x)) for x in x_test]).flatten()
    r = fig.add_subplot(2, 1, 2)
    r.set_title('Acquisition function')
    plt.plot(x_test, a, color='green')
    gpgo._optimizeAcq(method='L-BFGS-B', n_start=1000)
    plt.axvline(x=gpgo.best, color='black', label='Found optima')
    plt.legend(loc=0)
    plt.tight_layout()
    plt.savefig(os.path.join(os.getcwd(), 'mthesis_text/figures/chapter3/sine/{}.pdf'.format(i)))
    plt.show() 
Example #16
Source File: loop.py    From SampleScanner with MIT License 6 votes vote down vote up
def process(aif, sample_rate=48000):
    file = read_wave_file(aif)

    # loop_start, loop_size = window_match(file)
    # loop_start, loop_size = zero_crossing_match(file)
    loop_start, loop_end = find_loop_points(file)
    loop_size = loop_end - loop_start

    file = file[0]

    print 'start, end', loop_start, loop_end

    plt.plot(file[loop_start:loop_end])
    plt.plot(file[loop_end:loop_start + (2 * loop_size)])
    plt.show()

    plt.plot(file[
        loop_start - (sample_rate * 2):
        loop_start + (sample_rate * 2)
    ])
    plt.axvline(sample_rate * 2)
    plt.axvline((sample_rate * 2) + loop_size)
    plt.show() 
Example #17
Source File: run_visual.py    From time-series-machine-learning with Apache License 2.0 6 votes vote down vote up
def main():
  train_date = None
  tickers, periods, targets = parse_command_line(default_tickers=['BTC_ETH', 'BTC_LTC'],
                                                 default_periods=['day'],
                                                 default_targets=['high'])

  for ticker in tickers:
    for period in periods:
      for target in targets:
        job = JobInfo('_data', '_zoo', name='%s_%s' % (ticker, period), target=target)
        result_df = predict_multiple(job, raw_df=read_df(job.get_source_name()), rows_to_predict=120)
        result_df.index.names = ['']
        result_df.plot(title=job.name)

        if train_date is not None:
          x = train_date
          y = result_df['True'].min()
          plt.axvline(x, color='k', linestyle='--')
          plt.annotate('Training stop', xy=(x, y), xytext=(result_df.index.min(), y), color='k',
                       arrowprops={'arrowstyle': '->', 'connectionstyle': 'arc3', 'color': 'k'})

  plt.show() 
Example #18
Source File: plotting.py    From privacy with Apache License 2.0 6 votes vote down vote up
def plot_histograms(train: Iterable[float],
                    test: Iterable[float],
                    xlabel: Text = 'x',
                    thresh: float = None) -> plt.Figure:
  """Plot histograms of training versus test metrics."""
  xmin = min(np.min(train), np.min(test))
  xmax = max(np.max(train), np.max(test))
  bins = np.linspace(xmin, xmax, 100)
  fig = plt.figure()
  plt.hist(test, bins=bins, density=True, alpha=0.5, label='test', log='y')
  plt.hist(train, bins=bins, density=True, alpha=0.5, label='train', log='y')
  if thresh is not None:
    plt.axvline(thresh, c='r', label=f'threshold = {thresh:.3f}')
  plt.xlabel(xlabel)
  plt.ylabel('normalized counts (density)')
  plt.legend()
  return fig 
Example #19
Source File: Agent.py    From Deep-RL-agents with MIT License 6 votes vote down vote up
def predict_action(self, s, plot_distrib):
        if plot_distrib:
            action, distrib, value = self.sess.run([self.network.actions,
                                                    self.network.Q_distrib_suggested_actions,
                                                    self.network.Q_values_suggested_actions],
                                                    feed_dict={self.network.state_ph: s[None]})
            action, distrib, value = action[0], distrib[0], value[0]
            fig = plt.figure(2)
            fig.clf()
            plt.bar(self.z, distrib, self.delta_z)
            plt.axvline(value, color='red', linewidth=0.7)
            plt.show(block=False)
            plt.pause(0.001)
            return action

        return self.sess.run(self.network.actions,
                             feed_dict={self.network.state_ph: s[None]})[0] 
Example #20
Source File: Agent.py    From Deep-RL-agents with MIT License 6 votes vote down vote up
def predict_action(self, s, plot_distrib):
        if plot_distrib:
            action, distrib, value = self.sess.run([self.network.actions,
                                                    self.network.Q_distrib_suggested_actions,
                                                    self.network.Q_values_suggested_actions],
                                                    feed_dict={self.network.state_ph: s[None]})
            action, distrib, value = action[0], distrib[0], value[0]
            fig = plt.figure(2)
            fig.clf()
            plt.bar(self.z, distrib, self.delta_z)
            plt.axvline(value, color='red', linewidth=0.7)
            plt.show(block=False)
            plt.pause(0.001)
            return action

        return self.sess.run(self.network.actions,
                             feed_dict={self.network.state_ph: s[None]})[0] 
Example #21
Source File: m_dos_pdos_eigenvalues.py    From pyscf with Apache License 2.0 5 votes vote down vote up
def pdosplot (filename = None, data = None, size = None,  fermi = None):
    if (filename is not None): data = np.loadtxt(filename).T
    elif (data is not None): data = data
    if (size is None): print('Please give number of resolved angular momentum!')
    if (fermi is None): print ('Please give fermi energy')


    import matplotlib.pyplot as plt
    from matplotlib import rc
    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    orb_name = ['$s$','$p$','$d$','$f$','$g$','$h$','$i$','$k$']
    orb_colo = ['r','g','b','y','k','m','c','w']
    for i, (n,c) in enumerate(zip(orb_name[0:size],orb_colo[0:size])):
        #GW_spin_UP
        plt.plot(data[0], data[i+1], label='QP- '+n,color=c)
        plt.fill_between(data[0], 0, data[i+1], facecolor=c, alpha=0.5, interpolate=True)
        #MF_spin_UP
        plt.plot(data[0], data[i+size+1], label='MF- '+n, linestyle=':',color=c)
        plt.fill_between(data[0], 0, data[i+size+1], facecolor=c, alpha=0.1, interpolate=True)
        #GW_spin_DN
        plt.plot(data[0], -data[i+2*size+1], label='_nolegend_',color=c)
        plt.fill_between(data[0], 0, -data[i+2*size+1], facecolor=c, alpha=0.5, interpolate=True)
        #MF_spin_DN
        plt.plot(data[0], -data[i+3*size+1], label='_nolegend_', linestyle=':',color=c)
        plt.fill_between(data[0], 0, -data[i+3*size+1], facecolor=c, alpha=0.1, interpolate=True)
    plt.axvline(x=fermi, color='k', linestyle='--') #label='Fermi Energy'
    plt.axhline(y=0,color='k')
    plt.title('PDOS', fontsize=20)
    plt.xlabel('Energy (eV)', fontsize=15) 
    plt.ylabel('Projected Density of States (electron/eV)', fontsize=15)
    plt.legend()
    plt.savefig("pdos.svg", dpi=900)
    plt.show() 
Example #22
Source File: ecg_findpeaks.py    From NeuroKit with MIT License 5 votes vote down vote up
def _ecg_findpeaks_promac(signal, sampling_rate=1000, threshold=0.33, show=False, **kwargs):

    x = np.zeros(len(signal))

    x = _ecg_findpeaks_promac_addmethod(signal, sampling_rate, x, _ecg_findpeaks_neurokit, **kwargs)
    x = _ecg_findpeaks_promac_addmethod(signal, sampling_rate, x, _ecg_findpeaks_gamboa, **kwargs)
    x = _ecg_findpeaks_promac_addmethod(signal, sampling_rate, x, _ecg_findpeaks_ssf, **kwargs)
    x = _ecg_findpeaks_promac_addmethod(signal, sampling_rate, x, _ecg_findpeaks_engzee, **kwargs)
    x = _ecg_findpeaks_promac_addmethod(signal, sampling_rate, x, _ecg_findpeaks_elgendi, **kwargs)
    x = _ecg_findpeaks_promac_addmethod(signal, sampling_rate, x, _ecg_findpeaks_kalidas, **kwargs)
    x = _ecg_findpeaks_promac_addmethod(signal, sampling_rate, x, _ecg_findpeaks_WT, **kwargs)
    x = _ecg_findpeaks_promac_addmethod(signal, sampling_rate, x, _ecg_findpeaks_rodrigues, **kwargs)

    # Rescale
    x = x / np.max(x)
    convoluted = x.copy()

    # Remove below threshold
    x[x < threshold] = 0
    # Find peaks
    peaks = signal_findpeaks(x, height_min=threshold)["Peaks"]

    if show is True:
        signal_plot([signal, convoluted], standardize=True)
        [plt.axvline(x=peak, color="red", linestyle="--") for peak in peaks]  # pylint: disable=W0106

    return peaks 
Example #23
Source File: events_plot.py    From NeuroKit with MIT License 5 votes vote down vote up
def _events_plot(events, color="red", linestyle="--"):
    # Check if events is list of lists
    try:
        len(events[0])
        is_listoflists = True
    except TypeError:
        is_listoflists = False

    if is_listoflists is False:
        # Loop through sublists
        for event in events:
            plt.axvline(event, color=color, linestyle=linestyle)

    else:
        # Convert color and style to list
        if isinstance(color, str):
            color_map = matplotlib.cm.get_cmap("rainbow")
            color = color_map(np.linspace(0, 1, num=len(events)))
        if isinstance(linestyle, str):
            linestyle = np.full(len(events), linestyle)

        # Loop through sublists
        for i, event in enumerate(events):
            for j in events[i]:
                plt.axvline(j, color=color[i], linestyle=linestyle[i], label=str(i))

        # Display only one legend per event type
        handles, labels = plt.gca().get_legend_handles_labels()
        newLabels, newHandles = [], []
        for handle, label in zip(handles, labels):
            if label not in newLabels:
                newLabels.append(label)
                newHandles.append(handle)
        plt.legend(newHandles, newLabels) 
Example #24
Source File: generate_loss_plots.py    From BMSG-GAN with MIT License 5 votes vote down vote up
def plot_loss(*loss_vals, plot_name="Loss plot",
              fig_size=(17, 7), save_path=None,
              legends=("discriminator", "generator")):
    """
    plot the discriminator loss values and save the plot if required
    :param loss_vals: (Variable Arg) numpy array or Sequence like for plotting values
    :param plot_name: Name of the plot
    :param fig_size: size of the generated figure (column_width, row_width)
    :param save_path: path to save the figure
    :param legends: list containing labels for loss plots' legends
                    len(legends) == len(loss_vals)
    :return:
    """
    assert len(loss_vals) == len(legends), "Not enough labels for legends"

    plt.figure(figsize=fig_size).suptitle(plot_name)
    plt.grid(True, which="both")
    plt.ylabel("loss value")
    plt.xlabel("spaced iterations")

    plt.axhline(y=0, color='k')
    plt.axvline(x=0, color='k')

    # plot all the provided loss values in a single plot
    plts = []
    for loss_val in loss_vals:
        plts.append(plt.plot(loss_val)[0])

    plt.legend(plts, legends, loc="upper right", fontsize=16)

    if save_path is not None:
        plt.savefig(save_path) 
Example #25
Source File: DLC_pupil_event.py    From ibllib with MIT License 5 votes vote down vote up
def plot_mean_std_around_event(event, diameter, times, eid):
    '''
     
    event in {'stimOn_times', 'feedback_times', 'stimOff_times'}
     
    '''
    event_times = trials[event]

    window_size = 70

    segments = []
    # skip first and last trials to get same window length
    for t in event_times[5:-5]:
        idx = find_nearest(times, t)
        segments.append(diameter[idx - window_size: idx + window_size])

    M = np.nanmean(np.array(segments), axis=0)
    E = np.nanstd(np.array(segments), axis=0)

    fig, ax = plt.subplots()
    ax.fill_between(
        range(
            len(M)),
        M - E,
        M + E,
        alpha=0.5,
        edgecolor='#CC4F1B',
        facecolor='#FF9848')
    plt.plot(range(len(M)), M, color='k', linewidth=3)
    plt.axvline(x=window_size, color='r', linewidth=1, label=event)
    plt.legend()
    plt.ylabel('pupil diameter [px]')
    plt.xlabel('frames')
    plt.title(eid)
    plt.tight_layout() 
Example #26
Source File: example_graphs.py    From SparseSC with MIT License 5 votes vote down vote up
def raw(Y, treated_units_idx, control_units_idx, treatment_period):
    N1 = len(treated_units_idx)
    fig, ax = plt.subplots(num="raw")
    # Individual controls & treated
    if len(treated_units_idx) > 1:
        lbl_t = "Treateds"
        lbl_mt = "Mean Treated"
    else:
        lbl_t = "Treated"
        lbl_mt = "Treated"
        
    if isinstance(Y, pd.DataFrame):
        plt.plot(np.transpose(Y.iloc[control_units_idx, :]), color="lightgray")
        plt.plot(Y.iloc[control_units_idx[0], :], color="lightgray", label="Controls")
        plt.plot(np.mean(Y.iloc[control_units_idx, :], axis=0), "kx--", color="dimgray", label="Mean Control")
        plt.axvline(x=treatment_period, linestyle="--")
        if N1>0:
            plt.plot(np.transpose(Y.iloc[treated_units_idx, :]), color="black")
            plt.plot(Y.iloc[treated_units_idx[0], :], color="black", label=lbl_t)
            if N1 > 1:
                plt.plot(np.mean(Y.iloc[treated_units_idx, :], axis=0), color="black", label=lbl_mt)
    else:
        plt.plot(np.transpose(Y[control_units_idx, :]), color="lightgray")
        plt.plot(Y[control_units_idx[0], :], color="lightgray", label="Controls")
        plt.plot(np.mean(Y[control_units_idx, :], axis=0), "kx--", color="dimgray", label="Mean Control")
        plt.axvline(x=treatment_period, linestyle="--")
        if N1>0:
            plt.plot(np.transpose(Y[treated_units_idx, :]), color="black")
            plt.plot(Y[treated_units_idx[0], :], color="black", label=lbl_t)
            if N1> 1:
                plt.plot(np.mean(Y[treated_units_idx, :], axis=0), "kx--", color="black", label=lbl_mt)
    plt.xlabel("Time")
    plt.ylabel("Outcome")
    plt.legend(loc=1)
    return fig, ax 
Example #27
Source File: example_graphs.py    From SparseSC with MIT License 5 votes vote down vote up
def sc_diff(est_ret, treatment_date, unit_idx, treatment_date_fit=None):
    fig, ax = plt.subplots(num="sc_diff")
    if isinstance(est_ret.Y, pd.DataFrame):
        Y_target = est_ret.Y.iloc[unit_idx,:]
        Y_target_sc = est_ret.get_sc(treatment_date).iloc[unit_idx,:]
    else:
        Y_target = est_ret.Y[unit_idx,:]
        Y_target_sc = est_ret.get_sc(treatment_date)[unit_idx,:]

    diff = Y_target - Y_target_sc
    if est_ret.ind_CI is not None:
        ind_ci = est_ret.ind_CI[treatment_date]
        if isinstance(est_ret.Y, pd.DataFrame):
            fb_index = Y_target.index
        else:
            fb_index = range(len(ind_ci.ci_low))
        plt.fill_between(
            fb_index,
            diff + ind_ci.ci_low,
            diff + ind_ci.ci_high,
            facecolor="gray",
            label="CI",
        )
    plt.axhline(y=0, linestyle="--")
    plt.plot(diff, "kx--", label="Unit Diff")
    if treatment_date_fit is not None:
        plt.axvline(x=treatment_date, linestyle="--", label="Treatment")
        plt.axvline(x=treatment_date_fit, linestyle=":", label="End Fit Window")
    else:
        plt.axvline(x=treatment_date, linestyle="--")
    plt.xlabel("Time")
    plt.ylabel("Real-SC Outcome Difference")
    plt.legend(loc=1)
    return fig, ax 
Example #28
Source File: example_graphs.py    From SparseSC with MIT License 5 votes vote down vote up
def sc_raw(est_ret, treatment_date, unit_idx, treatment_date_fit=None):
    fig, ax = plt.subplots(num="sc_raw")
    if isinstance(est_ret.Y, pd.DataFrame):
        Y_target = est_ret.Y.iloc[unit_idx,:]
        Y_target_sc = est_ret.get_sc(treatment_date).iloc[unit_idx,:]
    else:
        Y_target = est_ret.Y[unit_idx,:]
        Y_target_sc = est_ret.get_sc(treatment_date)[unit_idx,:]

    if est_ret.ind_CI is not None:
        ind_ci = est_ret.ind_CI[treatment_date]
        if isinstance(est_ret.Y, pd.DataFrame):
            fb_index = Y_target.index
        else:
            fb_index = range(len(Y_target_sc))
        plt.fill_between(
            fb_index,
            Y_target_sc + ind_ci.ci_low,
            Y_target_sc + ind_ci.ci_high,
            facecolor="gray",
            label="CI",
        )
    if treatment_date_fit is not None:
        plt.axvline(x=treatment_date, linestyle="--", label="Treatment")
        plt.axvline(x=treatment_date_fit, linestyle=":", label="End Fit Window")
    else:
        plt.axvline(x=treatment_date, linestyle="--")
    plt.plot(Y_target, "bx-", label="Unit")
    plt.plot(Y_target_sc, "gx--", label="SC")
    plt.xlabel("Time")
    plt.ylabel("Outcome")
    plt.legend(loc=1)
    return fig, ax 
Example #29
Source File: example_graphs.py    From SparseSC with MIT License 5 votes vote down vote up
def te_plot_aa(est_ret, treatment_date):
    fig, ax = plt.subplots(num="te_plot_aa")
    ci0 = (pl_res_pre.effect_vec.ci.ci_low)
    ci1 = (pl_res_pre.effect_vec.ci.ci_high)
    base = ci0.index if isinstance(pl_res_pre.effect_vec.ci.ci_low, pd.Series) else range(len(ci0))
    plt.fill_between(base, ci0, ci1, facecolor="gray", label="CI")

    plt.axvline(x=treatment_date, linestyle="--")
    plt.axhline(y=0, linestyle="--")
    plt.xlabel("Time")
    plt.ylabel("Real-SC Outcome Difference")
    plt.legend(loc=1)
    return fig, ax 
Example #30
Source File: example_graphs.py    From SparseSC with MIT License 5 votes vote down vote up
def te_plot(est_ret, treatment_date, treatment_date_fit=None):
    fig, ax = plt.subplots(num="te_plot")
    if isinstance(est_ret.pl_res_pre.effect_vec.effect, pd.Series):
        effect_vec = pd.concat((est_ret.pl_res_pre.effect_vec.effect, 
                                est_ret.pl_res_post.effect_vec.effect))
    else:
        effect_vec = np.concatenate((est_ret.pl_res_pre.effect_vec.effect, 
                                     est_ret.pl_res_post.effect_vec.effect))
    if est_ret.pl_res_pre.effect_vec.ci is not None:
        if isinstance(est_ret.pl_res_pre.effect_vec.ci.ci_low, pd.Series):
            ci0 = pd.concat((est_ret.pl_res_pre.effect_vec.ci.ci_low, 
                             est_ret.pl_res_post.effect_vec.ci.ci_low))
            ci1 = pd.concat((est_ret.pl_res_pre.effect_vec.ci.ci_high,
                             est_ret.pl_res_post.effect_vec.ci.ci_high))
            plt.fill_between(ci0.index, ci0, ci1, facecolor="gray", label="CI")
        else:
            ci0 = np.concatenate((est_ret.pl_res_pre.effect_vec.ci.ci_low, 
                                  est_ret.pl_res_post.effect_vec.ci.ci_low))
            ci1 = np.concatenate((est_ret.pl_res_pre.effect_vec.ci.ci_high,
                                  est_ret.pl_res_post.effect_vec.ci.ci_high))
            plt.fill_between(range(len(ci0)), ci0, ci1, facecolor="gray", label="CI")

    plt.plot(effect_vec, "kx--", label="Treated Diff")
    if treatment_date_fit is not None:
        plt.axvline(x=treatment_date, linestyle="--", label="Treatment")
        plt.axvline(x=treatment_date_fit, linestyle=":", label="End Fit Window")
    else:
        plt.axvline(x=treatment_date, linestyle="--")
    plt.axhline(y=0, linestyle="--")
    plt.xlabel("Time")
    plt.ylabel("Real-SC Outcome Difference")
    plt.legend(loc=1)
    return fig, ax