Python matplotlib.pyplot.tight_layout() Examples

The following are 30 code examples of matplotlib.pyplot.tight_layout(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: data_augmentation.py    From Sound-Recognition-Tutorial with Apache License 2.0 10 votes vote down vote up
def demo_plot():
    audio = './data/esc10/audio/Dog/1-30226-A.ogg'
    y, sr = librosa.load(audio, sr=44100)
    y_ps = librosa.effects.pitch_shift(y, sr, n_steps=6)   # n_steps控制音调变化尺度
    y_ts = librosa.effects.time_stretch(y, rate=1.2)   # rate控制时间维度的变换尺度
    plt.subplot(311)
    plt.plot(y)
    plt.title('Original waveform')
    plt.axis([0, 200000, -0.4, 0.4])
    # plt.axis([88000, 94000, -0.4, 0.4])
    plt.subplot(312)
    plt.plot(y_ts)
    plt.title('Time Stretch transformed waveform')
    plt.axis([0, 200000, -0.4, 0.4])
    plt.subplot(313)
    plt.plot(y_ps)
    plt.title('Pitch Shift transformed waveform')
    plt.axis([0, 200000, -0.4, 0.4])
    # plt.axis([88000, 94000, -0.4, 0.4])
    plt.tight_layout()
    plt.show() 
Example #2
Source File: plot_utils.py    From celer with BSD 3-Clause "New" or "Revised" License 7 votes vote down vote up
def plot_path_hist(results, labels, tols, figsize, ylim=None):
    configure_plt()
    sns.set_palette('colorblind')
    n_competitors = len(results)
    fig, ax = plt.subplots(figsize=figsize)
    width = 1. / (n_competitors + 1)
    ind = np.arange(len(tols))
    b = (1 - n_competitors) / 2.
    for i in range(n_competitors):
        plt.bar(ind + (i + b) * width, results[i], width,
                label=labels[i])
    ax.set_ylabel('path computation time (s)')
    ax.set_xticks(ind + width / 2)
    plt.xticks(range(len(tols)), ["%.0e" % tol for tol in tols])
    if ylim is not None:
        plt.ylim(ylim)

    ax.set_xlabel(r"$\epsilon$")
    plt.legend(loc='upper left')
    plt.tight_layout()
    plt.show(block=False)
    return fig 
Example #3
Source File: pearsons_filtering.py    From simba with GNU Lesser General Public License v3.0 7 votes vote down vote up
def pearson_filter(projectPath, featuresDf, del_corr_status, del_corr_threshold, del_corr_plot_status):
    print('Reducing features. Correlation threshold: ' + str(del_corr_threshold))
    col_corr = set()
    corr_matrix = featuresDf.corr()
    for i in range(len(corr_matrix.columns)):
        for j in range(i):
            if (corr_matrix.iloc[i, j] >= del_corr_threshold) and (corr_matrix.columns[j] not in col_corr):
                colname = corr_matrix.columns[i]
                col_corr.add(colname)
                if colname in featuresDf.columns:
                    del featuresDf[colname]
    if del_corr_plot_status == 'yes':
        print('Creating feature correlation heatmap...')
        dateTime = datetime.now().strftime('%Y%m%d%H%M%S')
        plt.matshow(featuresDf.corr())
        plt.tight_layout()
        plt.savefig(os.path.join(projectPath, 'logs', 'Feature_correlations_' + dateTime + '.png'), dpi=300)
        plt.close('all')
        print('Feature correlation heatmap .png saved in project_folder/logs directory')

    return featuresDf 
Example #4
Source File: visualise_att_maps_epoch.py    From Attention-Gated-Networks with MIT License 7 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Epochs 
Example #5
Source File: thames.py    From pywr with GNU General Public License v3.0 6 votes vote down vote up
def figures(ext, show):

    for name, df in TablesRecorder.generate_dataframes('thames_output.h5'):
        df.columns = ['Very low', 'Low', 'Central', 'High', 'Very high']

        fig, (ax1, ax2) = plt.subplots(figsize=(12, 4), ncols=2, sharey='row',
                                       gridspec_kw={'width_ratios': [3, 1]})
        df['2100':'2125'].plot(ax=ax1)
        df.quantile(np.linspace(0, 1)).plot(ax=ax2)

        if name.startswith('reservoir'):
            ax1.set_ylabel('Volume [$Mm^3$]')
        else:
            ax1.set_ylabel('Flow [$Mm^3/day$]')

        for ax in (ax1, ax2):
            ax.set_title(name)
            ax.grid(True)
        plt.tight_layout()

        if ext is not None:
            fig.savefig(f'{name}.{ext}', dpi=300)

    if show:
        plt.show() 
Example #6
Source File: test_frame.py    From recruit with Apache License 2.0 6 votes vote down vote up
def test_if_scatterplot_colorbars_are_next_to_parent_axes(self):
        import matplotlib.pyplot as plt
        random_array = np.random.random((1000, 3))
        df = pd.DataFrame(random_array,
                          columns=['A label', 'B label', 'C label'])

        fig, axes = plt.subplots(1, 2)
        df.plot.scatter('A label', 'B label', c='C label', ax=axes[0])
        df.plot.scatter('A label', 'B label', c='C label', ax=axes[1])
        plt.tight_layout()

        points = np.array([ax.get_position().get_points()
                           for ax in fig.axes])
        axes_x_coords = points[:, :, 0]
        parent_distance = axes_x_coords[1, :] - axes_x_coords[0, :]
        colorbar_distance = axes_x_coords[3, :] - axes_x_coords[2, :]
        assert np.isclose(parent_distance,
                          colorbar_distance, atol=1e-7).all() 
Example #7
Source File: data_provider.py    From ICDAR-2019-SROIE with MIT License 6 votes vote down vote up
def generator(vis=False):
    image_list = np.array(get_training_data())
    print('{} training images in {}'.format(image_list.shape[0], DATA_FOLDER))
    index = np.arange(0, image_list.shape[0])
    while True:
        np.random.shuffle(index)
        for i in index:
            try:
                im_fn = image_list[i]
                im = cv2.imread(im_fn)
                h, w, c = im.shape
                im_info = np.array([h, w, c]).reshape([1, 3])

                _, fn = os.path.split(im_fn)
                fn, _ = os.path.splitext(fn)
                txt_fn = os.path.join(DATA_FOLDER, "label", fn + '.txt')
                if not os.path.exists(txt_fn):
                    print("Ground truth for image {} not exist!".format(im_fn))
                    continue
                bbox = load_annoataion(txt_fn)
                if len(bbox) == 0:
                    print("Ground truth for image {} empty!".format(im_fn))
                    continue

                if vis:
                    for p in bbox:
                        cv2.rectangle(im, (p[0], p[1]), (p[2], p[3]), color=(0, 0, 255), thickness=1)
                    fig, axs = plt.subplots(1, 1, figsize=(30, 30))
                    axs.imshow(im[:, :, ::-1])
                    axs.set_xticks([])
                    axs.set_yticks([])
                    plt.tight_layout()
                    plt.show()
                    plt.close()
                yield [im], bbox, im_info

            except Exception as e:
                print(e)
                continue 
Example #8
Source File: plot.py    From TaskBot with GNU General Public License v3.0 6 votes vote down vote up
def plot_attention(sentences, attentions, labels, **kwargs):
    fig, ax = plt.subplots(**kwargs)
    im = ax.imshow(attentions, interpolation='nearest',
                   vmin=attentions.min(), vmax=attentions.max())
    plt.colorbar(im, shrink=0.5, ticks=[0, 1])
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontproperties=getChineseFont())
    # Loop over data dimensions and create text annotations.
    for i in range(attentions.shape[0]):
        for j in range(attentions.shape[1]):
            text = ax.text(j, i, sentences[i][j],
                           ha="center", va="center", color="b", size=10,
                           fontproperties=getChineseFont())

    ax.set_title("Attention Visual")
    fig.tight_layout()
    plt.show() 
Example #9
Source File: simplest_raster_plot.py    From ibllib with MIT License 6 votes vote down vote up
def raster_complete(R, times, Clusters):
    '''
    Plot a rasterplot for the complete recording
    (might be slow, restrict R if so),
    ordered by insertion depth
    '''

    plt.imshow(R, aspect='auto', cmap='binary', vmax=T_BIN / 0.001 / 4,
               origin='lower', extent=np.r_[times[[0, -1]], Clusters[[0, -1]]])

    plt.xlabel('Time (s)')
    plt.ylabel('Cluster #; ordered by depth')
    plt.show()

    # plt.savefig('/home/mic/Rasters/%s.svg' %(trial_number))
    # plt.close('all')
    plt.tight_layout() 
Example #10
Source File: time_bench.py    From astroalign with MIT License 6 votes vote down vote up
def plot_command(self, ns):
        import matplotlib.pyplot as plt

        results = pd.read_csv(ns.file)

        orientation = COLSROWS[ns.orientation]
        size = ns.size if ns.size else DEFAULT_SIZES[ns.orientation]

        fig, axes = plt.subplots(**orientation)
        fig.set_size_inches(*size)

        plot(results, *axes)

        fig.suptitle("")
        plt.tight_layout()
        if ns.out is None:
            print(f"Showing plot for data stored in '{ns.file.name}'...")
            fig.canvas.set_window_title(f"{self.parser.prog} - {ns.file.name}")
            plt.show()
        else:
            print(
                f"Storing plot for data in '{ns.file.name}' -> '{ns.out}'...")
            plt.savefig(ns.out)
            print("DONE!") 
Example #11
Source File: flux_bench.py    From astroalign with MIT License 6 votes vote down vote up
def plot_command(self, ns):
        import matplotlib.pyplot as plt

        results = pd.read_csv(ns.file)

        size = ns.size if ns.size else DEFAULT_SIZE

        fig, ax = plt.subplots()
        fig.set_size_inches(*size)

        plot(results, ax)

        fig.suptitle("")
        plt.tight_layout()
        if ns.out is None:
            print(f"Showing plot for data stored in '{ns.file.name}'...")
            fig.canvas.set_window_title(f"{self.parser.prog} - {ns.file.name}")
            plt.show()
        else:
            print(
                f"Storing plot for data in '{ns.file.name}' -> '{ns.out}'...")
            plt.savefig(ns.out)
            print("DONE!") 
Example #12
Source File: time_regression.py    From astroalign with MIT License 6 votes vote down vote up
def plot_command(self, ns):
        import matplotlib.pyplot as plt

        results = pd.read_csv(ns.file)

        size = ns.size if ns.size else DEFAULT_SIZE

        fig, ax = plt.subplots()
        fig.set_size_inches(*size)

        plot(results, ax)

        fig.suptitle("")
        plt.tight_layout()
        if ns.out is None:
            print(f"Showing plot for data stored in '{ns.file.name}'...")
            fig.canvas.set_window_title(f"{self.parser.prog} - {ns.file.name}")
            plt.show()
        else:
            print(
                f"Storing plot for data in '{ns.file.name}' -> '{ns.out}'...")
            plt.savefig(ns.out)
            print("DONE!") 
Example #13
Source File: visualise_fmaps.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Load options 
Example #14
Source File: visualise_attention.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()
    plt.suptitle(title) 
Example #15
Source File: visualise_attention.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear',
                        colormap=cm.jet, colormap_lim=None, title='', alpha=0.8):
    plt.ion()
    filters = units.shape[2]
    fig = plt.figure(figure_id, figsize=(5,5))
    fig.clf()

    for i in range(filters):
        plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray')
        plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha)
        plt.axis('off')
        plt.colorbar()
        plt.title(title, fontsize='small')
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

    # plt.savefig('{}/{}.png'.format(dir_name,time.time()))




## Load options 
Example #16
Source File: core.py    From prickle with MIT License 6 votes vote down vote up
def imshow(data, which, levels):
    """
        Display order book data as an image, where order book data is either of
        `df_price` or `df_volume` returned by `load_hdf5` or `load_postgres`.
    """

    if which == 'prices':
        idx = ['askprc.' + str(i) for i in range(levels, 0, -1)]
        idx.extend(['bidprc.' + str(i) for i in range(1, levels + 1, 1)])
    elif which == 'volumes':
        idx = ['askvol.' + str(i) for i in range(levels, 0, -1)]
        idx.extend(['bidvol.' + str(i) for i in range(1, levels + 1, 1)])
    plt.imshow(data.loc[:, idx].T, interpolation='nearest', aspect='auto')
    plt.yticks(range(0, levels * 2, 1), idx)
    plt.colorbar()
    plt.tight_layout()
    plt.show() 
Example #17
Source File: demo.py    From TFFRCNN with MIT License 5 votes vote down vote up
def vis_detections(im, class_name, dets, ax, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
        )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                 fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw() 
Example #18
Source File: visualization_utils.py    From ludwig with Apache License 2.0 5 votes vote down vote up
def predictions_distribution_plot(
        probabilities,
        algorithm_names=None,
        filename=None
):
    sns.set_style('whitegrid')

    colors = plt.get_cmap('tab10').colors

    num_algorithms = len(probabilities)

    plt.figure(figsize=(9, 9))
    plt.grid(which='both')
    plt.grid(which='minor', alpha=0.5)
    plt.grid(which='major', alpha=0.75)

    for i in range(num_algorithms):
        plt.hist(probabilities[i], range=(0, 1), bins=41, color=colors[i],
                 label=algorithm_names[
                     i] if algorithm_names is not None and i < len(
                     algorithm_names) else '',
                 histtype='stepfilled', alpha=0.5, lw=2)

    plt.xlabel('Mean predicted value')
    plt.xlim([0, 1])
    plt.xticks(np.linspace(0.0, 1.0, num=21))
    plt.ylabel('Count')
    plt.legend(loc='upper center', ncol=2)

    plt.tight_layout()
    ludwig.contrib.contrib_command("visualize_figure", plt.gcf())
    if filename:
        plt.savefig(filename)
    else:
        plt.show() 
Example #19
Source File: visualization_utils.py    From ludwig with Apache License 2.0 5 votes vote down vote up
def compare_classifiers_multiclass_multimetric_plot(
        scores,
        metrics,
        labels=None,
        title=None,
        filename=None
):
    assert len(scores) > 0

    sns.set_style('whitegrid')

    fig, ax = plt.subplots()

    if title is not None:
        ax.set_title(title)

    width = 0.9 / len(scores)
    ticks = np.arange(len(scores[0]))

    colors = plt.get_cmap('tab10').colors
    ax.set_xlabel('class')
    ax.set_xticks(ticks + width)
    if labels is not None:
        ax.set_xticklabels(labels, rotation=90)
    else:
        ax.set_xticklabels(ticks, rotation=90)

    for i, score in enumerate(scores):
        ax.bar(ticks + i * width, score, width, label=metrics[i],
               color=colors[i])

    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
    ludwig.contrib.contrib_command("visualize_figure", plt.gcf())
    if filename:
        plt.savefig(filename)
    else:
        plt.show() 
Example #20
Source File: generate.py    From TFFRCNN with MIT License 5 votes vote down vote up
def _vis_proposals(im, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    class_name = 'obj'
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw() 
Example #21
Source File: random_walk.py    From reinforcement-learning-an-introduction with MIT License 5 votes vote down vote up
def example_6_2():
    plt.figure(figsize=(10, 20))
    plt.subplot(2, 1, 1)
    compute_state_value()

    plt.subplot(2, 1, 2)
    rms_error()
    plt.tight_layout()

    plt.savefig('../images/example_6_2.png')
    plt.close() 
Example #22
Source File: visualization_utils.py    From ludwig with Apache License 2.0 5 votes vote down vote up
def compare_classifiers_line_plot(
        xs,
        scores,
        metric,
        algorithm_names=None,
        title=None,
        filename=None
):
    sns.set_style('whitegrid')
    colors = plt.get_cmap('tab10').colors

    fig, ax = plt.subplots()

    ax.grid(which='both')
    ax.grid(which='minor', alpha=0.5)
    ax.grid(which='major', alpha=0.75)

    if title is not None:
        ax.set_title(title)

    ax.set_xticks(xs)
    ax.set_xticklabels(xs)
    ax.set_xlabel('k')
    ax.set_ylabel(metric)

    for i, score in enumerate(scores):
        ax.plot(xs, score,
                label=algorithm_names[
                    i] if algorithm_names is not None and i < len(
                    algorithm_names) else 'Algorithm {}'.format(i),
                color=colors[i], linewidth=3, marker='o')

    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
    ludwig.contrib.contrib_command("visualize_figure", plt.gcf())
    if filename:
        plt.savefig(filename)
    else:
        plt.show() 
Example #23
Source File: DLC_pupil_event.py    From ibllib with MIT License 5 votes vote down vote up
def plot_mean_std_around_event(event, diameter, times, eid):
    '''
     
    event in {'stimOn_times', 'feedback_times', 'stimOff_times'}
     
    '''
    event_times = trials[event]

    window_size = 70

    segments = []
    # skip first and last trials to get same window length
    for t in event_times[5:-5]:
        idx = find_nearest(times, t)
        segments.append(diameter[idx - window_size: idx + window_size])

    M = np.nanmean(np.array(segments), axis=0)
    E = np.nanstd(np.array(segments), axis=0)

    fig, ax = plt.subplots()
    ax.fill_between(
        range(
            len(M)),
        M - E,
        M + E,
        alpha=0.5,
        edgecolor='#CC4F1B',
        facecolor='#FF9848')
    plt.plot(range(len(M)), M, color='k', linewidth=3)
    plt.axvline(x=window_size, color='r', linewidth=1, label=event)
    plt.legend()
    plt.ylabel('pupil diameter [px]')
    plt.xlabel('frames')
    plt.title(eid)
    plt.tight_layout() 
Example #24
Source File: plots.py    From yatsm with MIT License 5 votes vote down vote up
def plot_crossvalidation_scores(kfold_scores, test_labels):
    """ Plots KFold test summary statistics

    Args:
      kfold_scores (np.ndarray): n by 2 shaped array of mean and standard
        deviation of KFold scores
      test_labels (list): n length list of KFold label names

    """
    return
    ind = np.arange(kfold_scores.shape[0])
    width = 0.5

    fig, ax = plt.subplots()
    bars = ax.bar(ind, kfold_scores[:, 0], width)
    _, caplines, _ = ax.errorbar(ind + width / 2.0, kfold_scores[:, 0],
                                 fmt='none',
                                 yerr=kfold_scores[:, 1],
                                 capsize=10, elinewidth=3)
    for capline in caplines:
        capline.set_linewidth(10)
        capline.set_markeredgewidth(3)
        capline.set_color('red')

    for i, bar in enumerate(bars):
        txt = r'%.3f $\pm$ %.3f' % (kfold_scores[i, 0], kfold_scores[i, 1])
        ax.text(ind[i] + width / 2.0,
                kfold_scores[i, 0] / 2.0,
                txt,
                ha='center', va='bottom', size='large')

    ax.set_xticks(ind + width / 2.0)
    ax.set_xticklabels(test_labels, ha='center')
    # plt.ylim((0, 1.0))

    plt.title('KFold Cross Validation Summary Statistics')
    plt.xlabel('Test')
    plt.ylabel(r'Accuracy ($\pm$ standard deviation)')

    plt.tight_layout()
    plt.show() 
Example #25
Source File: bsds300.py    From nsf with MIT License 5 votes vote down vote up
def main():
    dataset = BSDS300Dataset(split='train')
    print(type(dataset.data))
    print(dataset.data.shape)
    print(dataset.data.min(), dataset.data.max())
    fig, axs = plt.subplots(8, 8, figsize=(10, 10), sharex=True, sharey=True)
    axs = axs.reshape(-1)
    for i, dimension in enumerate(dataset.data.T):
        axs[i].hist(dimension, bins=100)
    # plt.hist(dataset.data.reshape(-1), bins=250)
    plt.tight_layout()
    plt.show()
    print(len(dataset))
    loader = data.DataLoader(dataset, batch_size=128, drop_last=True)
    print(len(loader)) 
Example #26
Source File: gas.py    From nsf with MIT License 5 votes vote down vote up
def main():
    dataset = GasDataset(split='train')
    print(type(dataset.data))
    print(dataset.data.shape)
    print(dataset.data.min(), dataset.data.max())
    print(np.where(dataset.data == dataset.data.max()))
    fig, axs = plt.subplots(3, 3, figsize=(10, 10), sharex=True, sharey=True)
    axs = axs.reshape(-1)
    for i, dimension in enumerate(dataset.data.T):
        print(i)
        axs[i].hist(dimension, bins=100)
    plt.tight_layout()
    plt.show() 
Example #27
Source File: visualize.py    From dataiku-contrib with Apache License 2.0 5 votes vote down vote up
def plot_overlaps(gt_class_ids, pred_class_ids, pred_scores,
                  overlaps, class_names, threshold=0.5):
    """Draw a grid showing how ground truth objects are classified.
    gt_class_ids: [N] int. Ground truth class IDs
    pred_class_id: [N] int. Predicted class IDs
    pred_scores: [N] float. The probability scores of predicted classes
    overlaps: [pred_boxes, gt_boxes] IoU overlaps of predictions and GT boxes.
    class_names: list of all class names in the dataset
    threshold: Float. The prediction probability required to predict a class
    """
    gt_class_ids = gt_class_ids[gt_class_ids != 0]
    pred_class_ids = pred_class_ids[pred_class_ids != 0]

    plt.figure(figsize=(12, 10))
    plt.imshow(overlaps, interpolation='nearest', cmap=plt.cm.Blues)
    plt.yticks(np.arange(len(pred_class_ids)),
               ["{} ({:.2f})".format(class_names[int(id)], pred_scores[i])
                for i, id in enumerate(pred_class_ids)])
    plt.xticks(np.arange(len(gt_class_ids)),
               [class_names[int(id)] for id in gt_class_ids], rotation=90)

    thresh = overlaps.max() / 2.
    for i, j in itertools.product(range(overlaps.shape[0]),
                                  range(overlaps.shape[1])):
        text = ""
        if overlaps[i, j] > threshold:
            text = "match" if gt_class_ids[j] == pred_class_ids[i] else "wrong"
        color = ("white" if overlaps[i, j] > thresh
                 else "black" if overlaps[i, j] > 0
                 else "grey")
        plt.text(j, i, "{:.3f}\n{}".format(overlaps[i, j], text),
                 horizontalalignment="center", verticalalignment="center",
                 fontsize=9, color=color)

    plt.tight_layout()
    plt.xlabel("Ground Truth")
    plt.ylabel("Predictions") 
Example #28
Source File: cluster.py    From 2D-Motion-Retargeting with MIT License 5 votes vote down vote up
def cluster_motion(net, cluster_data, device, save_path, nr_anims=15, mode='both'):
    data, animations = cluster_data[0], cluster_data[1]
    idx = np.linspace(0, data.shape[0] - 1, nr_anims, dtype=int).tolist()
    data = data[idx]
    animations = animations[idx]
    if mode == 'body':
        data = data[:, :, 0, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4])
    elif mode == 'view':
        data = data[:, 3, :, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4])
    else:
        data = data[:, :4, ::2, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4])

    nr_anims, nr_cv = data.shape[:2]
    labels = np.arange(0, nr_anims).reshape(-1, 1)
    labels = np.tile(labels, (1, nr_cv)).reshape(-1)
    
    features = net.mot_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3]).to(device))
    features = features.detach().cpu().numpy().reshape(features.shape[0], -1)

    features_2d = tsne_on_pca(features)
    features_2d = features_2d.reshape(nr_anims, nr_cv, -1)
    if features_2d.shape[1] < 5:
        features_2d = np.tile(features_2d, (1, 2, 1))

    plt.figure(figsize=(8, 4))
    colors = cm.rainbow(np.linspace(0, 1, nr_anims))
    for i in range(nr_anims):
        x = features_2d[i, :, 0]
        y = features_2d[i, :, 1]
        plt.scatter(x, y, c=colors[i], label=animations[i])

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
    plt.tight_layout(rect=[0,0,0.8,1])
    plt.savefig(save_path) 
Example #29
Source File: cluster.py    From 2D-Motion-Retargeting with MIT License 5 votes vote down vote up
def cluster_view(net, cluster_data, device, save_path):
    data, views = cluster_data[0], cluster_data[3]
    idx = np.random.randint(data.shape[1] - 1)  # np.linspace(0, data.shape[1] - 1, 4, dtype=int).tolist()
    data = data[:, idx, :, :, :]

    nr_mc, nr_view = data.shape[0], data.shape[1]
    labels = np.arange(0, nr_view).reshape(1, -1)
    labels = np.tile(labels, (nr_mc, 1)).reshape(-1)
    
    if hasattr(net, 'static_encoder'):
        features = net.static_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    else:
        features = net.view_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    features = features.detach().cpu().numpy().reshape(features.shape[0], -1)

    features_2d = tsne_on_pca(features, is_PCA=False)
    features_2d = features_2d.reshape(nr_mc, nr_view, -1)

    plt.figure(figsize=(7, 4))
    colors = cm.rainbow(np.linspace(0, 1, nr_view))
    for i in range(nr_view):
        x = features_2d[:, i, 0]
        y = features_2d[:, i, 1]
        plt.scatter(x, y, c=colors[i], label=views[i])

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
    plt.tight_layout(rect=[0, 0, 0.75, 1])
    plt.savefig(save_path) 
Example #30
Source File: cluster.py    From 2D-Motion-Retargeting with MIT License 5 votes vote down vote up
def cluster_body(net, cluster_data, device, save_path):
    data, characters = cluster_data[0], cluster_data[2]
    data = data[:, :, 0, :, :]
    # data = data.reshape(-1, data.shape[2], data.shape[3], data.shape[4])

    nr_mv, nr_char = data.shape[0], data.shape[1]
    labels = np.arange(0, nr_char).reshape(1, -1)
    labels = np.tile(labels, (nr_mv, 1)).reshape(-1)
    
    if hasattr(net, 'static_encoder'):
        features = net.static_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    else:
        features = net.body_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    features = features.detach().cpu().numpy().reshape(features.shape[0], -1)

    features_2d = tsne_on_pca(features, is_PCA=False)
    features_2d = features_2d.reshape(nr_mv, nr_char, -1)

    plt.figure(figsize=(7, 4))
    colors = cm.rainbow(np.linspace(0, 1, nr_char))
    for i in range(nr_char):
        x = features_2d[:, i, 0]
        y = features_2d[:, i, 1]
        plt.scatter(x, y, c=colors[i], label=characters[i])

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
    plt.tight_layout(rect=[0,0,0.75,1])
    plt.savefig(save_path)