Python matplotlib.cm.rainbow() Examples

The following are code examples for showing how to use matplotlib.cm.rainbow(). They are from open source Python projects. You can vote up the examples you like or vote down the ones you don't like.

Example 1
Project: DiCoNet   Author: alexnowakvila   File: kmeans.py    MIT License 8 votes vote down vote up
def plot_clusters(num, e, centers, points, fig, model):
    plt.figure(0)
    plt.clf()
    plt.gca().set_xlim([-0.05,1.05])
    plt.gca().set_ylim([-0.05,1.05])
    clusters = e[fig].max()+1
    colors = cm.rainbow(np.linspace(0,1,clusters))
    for i in range(clusters):
        c = colors[i][:-1]
        mask = e[fig] == i
        x = torch.masked_select(points[fig,:,0], mask)
        y = torch.masked_select(points[fig,:,1], mask)
        plt.plot(x.cpu().numpy(), y.cpu().numpy(), 'o', c=rgb2hex(c))
        if centers is not None:
            center = centers[i]
            plt.plot([center.data[0]], [center.data[1]], '*', c=rgb2hex(c))
    plt.title('clustering')
    plt.savefig('./plots/clustering_it_{}_{}.png'.format(num, model)) 
Example 2
Project: wassersteinms   Author: mciach   File: spectrum.py    MIT License 6 votes vote down vote up
def plot_all(spectra, show=True, profile=False, cmap=None):
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        import numpy as np
        if not cmap:
            colors = cm.rainbow(np.linspace(0, 1, len(spectra)))
            colors =  [[0, 0, 0, 0.8]] + [list(x[:3]) + [0.6] for x in colors]
        else:
            try:
                colors = [[0, 0, 0, 0.8]] + [cmap(x, alpha=1) for x in range(len(spectra))]
            except:
                colors = cmap
        if show:
            plt.clf()
        i = 0
        for spectre in spectra:
            spectre.plot(show = False, profile=profile, color = colors[i])
            i += 1
        #plt.legend(loc=9, bbox_to_anchor=(0.5, -0.1), ncol=len(spectra))  # legend below plot
        plt.legend(loc=0, ncol=1)
        if show: plt.show() 
Example 3
Project: motion-classification   Author: matthiasplappert   File: evaluate.py    MIT License 6 votes vote down vote up
def _plot_proto_symbol_space(coordinates, target_names, name, args):
    # Reduce to 2D so that we can plot it.
    coordinates_2d = TSNE().fit_transform(coordinates)

    n_samples = coordinates_2d.shape[0]
    x = coordinates_2d[:, 0]
    y = coordinates_2d[:, 1]
    colors = cm.rainbow(np.linspace(0, 1, n_samples))

    fig = plt.figure(1)
    plt.clf()
    ax = fig.add_subplot(111)
    dots = []
    for idx in xrange(n_samples):
        dots.append(ax.plot(x[idx], y[idx], "o", c=colors[idx], markersize=15)[0])
        ax.annotate(target_names[idx],  xy=(x[idx], y[idx]))
    lgd = ax.legend(dots, target_names, ncol=4, numpoints=1, loc='upper center', bbox_to_anchor=(0.5,-0.1))
    ax.grid('on')

    if args.output_dir is not None:
        path = os.path.join(args.output_dir, name + '.pdf')
        print('Saved plot to file "%s"' % path)
        fig.savefig(path, bbox_extra_artists=(lgd,), bbox_inches='tight')
    else:
        plt.show() 
Example 4
Project: simulator   Author: P2PSP   File: play.py    GNU General Public License v3.0 6 votes vote down vote up
def draw_buffer(self):
        self.buffer_figure, self.buffer_ax = plt.subplots()
        self.lineIN, = self.buffer_ax.plot([1] * 2, [1] * 2, color='#000000', ls="None", label="IN", marker='o',
                                           animated=True)
        self.lineOUT, = self.buffer_ax.plot([1] * 2, [1] * 2, color='#CCCCCC', ls="None", label="OUT", marker='o',
                                            animated=True)
        self.buffer_figure.suptitle("Buffer Status", size=16)
        plt.legend(loc=2, numpoints=1)
        total_peers = self.number_of_monitors + self.number_of_peers + self.number_of_malicious
        self.buffer_colors = cm.rainbow(np.linspace(0, 1, total_peers))
        plt.axis([0, total_peers + 1, 0, self.get_buffer_size()])
        plt.xticks(range(0, total_peers + 1, 1))
        self.buffer_order = {}
        self.buffer_index = 1
        self.buffer_labels = self.buffer_ax.get_xticks().tolist()
        plt.grid()
        self.buffer_figure.canvas.draw() 
Example 5
Project: DiCoNet   Author: alexnowakvila   File: Logger.py    MIT License 6 votes vote down vote up
def plot_accuracies(self, accuracies, scales=[], mode='train', fig=0):
        plt.figure(fig)
        plt.clf()
        colors = cm.rainbow(np.linspace(0, 1, len(scales)))
        l = []
        names = [str(sc) for sc in scales]
        for i, acc in enumerate(accuracies):
            ll, = plt.plot(range(len(acc)), acc, color=colors[i])
            l.append(ll)
        plt.ylabel('accuracy')
        plt.legend(l, names, loc=2, prop={'size': 6})
        if mode == 'train':
            plt.xlabel('iterations')
        else:
            plt.xlabel('iterations x 1000')
        path = os.path.join(self.path, 'accuracies_{}.png'.format(mode))
        plt.savefig(path) 
Example 6
Project: DiCoNet   Author: alexnowakvila   File: Logger.py    MIT License 6 votes vote down vote up
def plot_norm_points(self, Inputs_N, e, Perms, scales, fig=1):
        input = Inputs_N[0][0].data.cpu().numpy()
        e = torch.sort(e, 1)[0][0].data.cpu().numpy()
        Perms = [perm[0].data.cpu().numpy() for perm in Perms]
        plt.figure(fig)
        plt.clf()
        ee = e.copy()
        for i, perm in enumerate(Perms):
            plt.subplot(1, len(Perms), i + 1)
            colors = cm.rainbow(np.linspace(0, 1, 2 ** (scales - i)))
            perm = perm[np.where(perm > 0)[0]] - 1
            points = input[perm]
            e_scale = ee[perm]
            for node in range(2 ** (scales - i)):
                ind = np.where(e_scale == node)[0]
                pts = points[ind]
                plt.scatter(pts[:, 0], pts[:, 1], c=colors[node])
            ee //= 2
        path = os.path.join(self.path, 'visualize_example.png')
        plt.savefig(path) 
Example 7
Project: DiCoNet   Author: alexnowakvila   File: Logger.py    MIT License 6 votes vote down vote up
def plot_rates(self, discard_rates, fig=1):
        plt.figure(fig)
        plt.clf()
        plt.title('Discard Rates')
        for i, rate in enumerate(discard_rates):
            plt.subplot(len(discard_rates), 1, i + 1)
            colors = cm.rainbow(np.linspace(0, 1, len(rate)))
            l = [[] for ii in rate]
            names = ['scale {}'.format(ii) for ii in range(len(rate))]
            for j, sc in enumerate(rate):
                l[j], = plt.plot(sc, c=colors[j])
            plt.legend(l, names, loc=2, prop={'size': 6})
        plt.xlabel('Iterations')
        plt.ylabel('discard rates')
        path = os.path.join(self.path, 'rates.png')
        plt.savefig(path) 
Example 8
Project: tools   Author: sgs-us   File: lbPlot.py    GNU General Public License v3.0 6 votes vote down vote up
def lb2D():
    # define some colors
    colors = iter(cm.rainbow(np.linspace(0, 1, 9)))
    for offset in list(product([-1, 0, 1], repeat=2)):
        origin = [0.5, 0.5]
        rect = Rectangle ((0 + offset[0], 0 + offset[1]), # xy
                          1,                              # width
                          1,                              # height
                          ec = "black",                   # edge color
                          fc = [0, 0, 0, 0],              # face color = transparent
                          zorder=1)                       # specify stacking
        ax2.add_patch(rect)
        if offset[0] != 0 or offset[1] != 0:
            arrow = Arrow(origin[0], origin[1], offset[0], offset[1], width = 0.5, color=next(colors),zorder=0)
            ax2.add_patch(arrow)
    # draw c_0
    circ = Circle((origin[0], origin[1]), 0.125, color="black", zorder=2)
    ax2.add_patch(circ)
    fig2.savefig("d2q9.png") 
Example 9
Project: tools   Author: sgs-us   File: lbPlot.py    GNU General Public License v3.0 6 votes vote down vote up
def lb3D(q=19):
    assert (q == 15 or q == 19 or q == 27)
    colors = iter(cm.rainbow(np.linspace(0, 1, q)))
    if q == 15:
        allowedL = [1, sqrt(3)]
    elif q == 19:
        allowedL = [1, sqrt(2)]
    elif q == 27:
        allowedL = [1, sqrt(2), sqrt(3)]
    for offset in list(product([-1, 0, 1], repeat=3)):
        o = np.array(offset)
        l = LA.norm(o)
        if l in allowedL:
            arrow = Arrow3D([0.5, 0.5 + offset[0]],
                            [0.5, 0.5 + offset[1]],
                            [0.5, 0.5 + offset[2]],
                            mutation_scale=20, lw=7, arrowstyle="-|>",
                            color=next(colors))
            ax.add_artist(arrow)
        drawCube([0, 1], np.array(offset))
    fig.savefig("d3q19.png") 
Example 10
Project: GridCell-3D   Author: jianwen-xie   File: utils.py    MIT License 6 votes vote down vote up
def draw_heatmap(data, save_path, xlabels=None, ylabels=None):
    # data = np.clip(data, -0.05, 0.05)
    cmap = cm.get_cmap('rainbow', 1000)
    figure = plt.figure(facecolor='w')
    ax = figure.add_subplot(1, 1, 1, position=[0.1, 0.15, 0.8, 0.8])
    if xlabels is not None:
        ax.set_xticks(range(len(xlabels)))
        ax.set_xticklabels(xlabels)
    if ylabels is not None:
        ax.set_yticks(range(len(ylabels)))
        ax.set_yticklabels(ylabels)

    vmax = data[0][0]
    vmin = data[0][0]
    for i in data:
        for j in i:
            if j > vmax:
                vmax = j
            if j < vmin:
                vmin = j
    map = ax.imshow(data, interpolation='nearest', cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
    cb = plt.colorbar(mappable=map, cax=None, ax=None, shrink=0.5)
    plt.savefig(save_path)
    plt.close() 
Example 11
Project: GridCell-3D   Author: jianwen-xie   File: utils.py    MIT License 6 votes vote down vote up
def draw_heatmap(data, save_path, xlabels=None, ylabels=None):
    # data = np.clip(data, -0.05, 0.05)
    cmap = cm.get_cmap('rainbow', 1000)
    figure = plt.figure(facecolor='w')
    ax = figure.add_subplot(1, 1, 1, position=[0.1, 0.15, 0.8, 0.8])
    if xlabels is not None:
        ax.set_xticks(range(len(xlabels)))
        ax.set_xticklabels(xlabels)
    if ylabels is not None:
        ax.set_yticks(range(len(ylabels)))
        ax.set_yticklabels(ylabels)

    vmax = data[0][0]
    vmin = data[0][0]
    for i in data:
        for j in i:
            if j > vmax:
                vmax = j
            if j < vmin:
                vmin = j
    map = ax.imshow(data, interpolation='nearest', cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
    cb = plt.colorbar(mappable=map, cax=None, ax=None, shrink=0.5)
    plt.savefig(save_path)
    plt.close() 
Example 12
Project: GridCell-3D   Author: jianwen-xie   File: utils.py    MIT License 6 votes vote down vote up
def draw_path_to_target_gif(file_name, place_len, place_seq, target, col=(255, 0, 0)):
    cmap = cm.get_cmap('rainbow', 1000)
    canvas = np.ones((place_len, place_len, 3), dtype="uint8") * 255
    cv2.circle(canvas, tuple(target), 2, (0, 0, 255), -1)
    cv2.circle(canvas, tuple(place_seq[0]), 2, col, -1)

    canvas_list = []
    canvas_list.append(canvas)
    for i in range(1, len(place_seq)):
        canvas = np.ones((place_len, place_len, 3), dtype="uint8") * 255
        cv2.circle(canvas, tuple(target), 2, (0, 0, 255), -1)
        cv2.circle(canvas, tuple(place_seq[0]), 2, col, -1)
        for j in range(i):
            cv2.line(canvas, tuple(place_seq[j]), tuple(place_seq[j+1]), col, 1)
        canvas_list.append(canvas)

    imageio.mimsave(file_name, canvas_list, 'GIF', duration=0.3) 
Example 13
Project: GoogleAi   Author: nattimmis   File: plot.py    Apache License 2.0 6 votes vote down vote up
def plot_multi_obj_opt(smiles, target_mol, idx=0):
  with open('all_molecules_with_id.json') as f:
    molid = json.load(f)
  colors = iter(cm.rainbow(np.linspace(0, 1, 6)))
  plt.figure()
  for i in range(6):
    ssl = smiles['weight_%i' % i]
    sim, qed = zip(
        *[get_properties(ss, target_molecule=target_mol) for ss in ssl])
    plt.scatter(sim, qed, label='w=%.1f' % (i * 0.2), color=next(colors))
  target_sim, target_qed = get_properties(target_mol, target_mol)
  plt.axvline(x=target_sim, ls='dashed', color='grey')
  plt.axhline(y=target_qed, ls='dashed', color='grey')
  leg = plt.legend()
  leg.get_frame().set_alpha(0.95)
  plt.ylim((-0.2, 1))
  plt.xlabel('Similarity')
  plt.ylabel('QED')
  plt.title(molid[target_mol])
  plt.subplots_adjust(left=0.16, bottom=0.16, right=0.92, top=0.88)
  plt.savefig('batch/mult_obj_gen_{}.pdf'.format(idx))
  #plt.show() 
Example 14
Project: airfoil-opt-gan   Author: IDEALLab   File: shape_plot.py    MIT License 6 votes vote down vote up
def plot_shape(xys, z1, z2, ax, scale, scatter, symm_axis, **kwargs):
#    mx = max([y for (x, y) in m])
#    mn = min([y for (x, y) in m])
    xscl = scale# / (mx - mn)
    yscl = scale# / (mx - mn)
#    ax.scatter(z1, z2)
    if scatter:
        if 'c' not in kwargs:
            kwargs['c'] = cm.rainbow(np.linspace(0,1,xys.shape[0]))
#        ax.plot( *zip(*[(x * xscl + z1, y * yscl + z2) for (x, y) in xys]), lw=.2, c='b')
        ax.scatter( *zip(*[(x * xscl + z1, y * yscl + z2) for (x, y) in xys]), edgecolors='none', **kwargs)
    else:
        ax.plot( *zip(*[(x * xscl + z1, y * yscl + z2) for (x, y) in xys]), **kwargs)
        
    if symm_axis == 'y':
#        ax.plot( *zip(*[(-x * xscl + z1, y * yscl + z2) for (x, y) in xys]), lw=.2, c='b')
        plt.fill_betweenx( *zip(*[(y * yscl + z2, -x * xscl + z1, x * xscl + z1)
                          for (x, y) in xys]), color='gray', alpha=.2)
    elif symm_axis == 'x':
#        ax.plot( *zip(*[(x * xscl + z1, -y * yscl + z2) for (x, y) in xys]), lw=.2, c='b')
        plt.fill_between( *zip(*[(x * xscl + z1, -y * yscl + z2, y * yscl + z2)
                          for (x, y) in xys]), color='gray', alpha=.2) 
Example 15
Project: WebAppEx   Author: karlafej   File: compute.py    MIT License 6 votes vote down vote up
def get_plot(x, y, k, iris=iris):
    k_means = KMeans(n_clusters= k)
    k_means.fit(iris.data) 
    colormap = rainbow(np.linspace(0, 1, k))
    fig = plt.figure()
    splt = fig.add_subplot(1, 1, 1)
    splt.scatter(iris.data[:,x], iris.data[:,y], c = colormap[k_means.labels_], s=40)
    splt.scatter(k_means.cluster_centers_[:,x], k_means.cluster_centers_[:,y], c = 'black', marker='x')
    splt.set_xlabel(iris.feature_names[x])
    splt.set_ylabel(iris.feature_names[y])
    
    figfile = BytesIO()
    plt.savefig(figfile, format='png')
    figfile.seek(0) 
    figdata_png = base64.b64encode(figfile.getvalue()).decode()
    return figdata_png 
Example 16
Project: adni_rs_fmri_analysis   Author: mrahim   File: canica_extract_regions.py    GNU General Public License v2.0 6 votes vote down vote up
def extract_region_i(maps, i):
    """ Extract ROIs and plot
    """
    m = maps[i, ...]
    th_value = np.percentile(m, 100.-(100./42.))
    data = np.absolute(array_to_nii(m, mask).get_data())
    data[data <= th_value] = 0
    data[data > th_value] = 1
    data_lab = label(data)[0]
    
    for v in np.unique(data_lab):
        if len(np.where(data_lab == v)[0]) < 1000:
            data_lab[data_lab == v] = 0
        
    img_l = nib.Nifti1Image(data_lab, mask_affine)
    plot_roi(img_l, title=map_title + '_roi_' + str(i), cmap=cm.rainbow)
    plot_stat_map(index_img(img, i), title=map_title + '_' + str(i),
                  threshold=0) 
Example 17
Project: pohmm-keystroke   Author: vmonaco   File: plotting.py    BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def plot_continuous_identification_example(n_users=3, n_events=10, seed=2015):
    import matplotlib.cm as cm
    from matplotlib import gridspec

    np.random.seed(seed)
    T = n_events * 10
    time = np.random.randint(0, n_events * 10, n_events)
    user = np.random.randint(1, n_users + 1, n_events)
    zero = np.zeros(n_events)

    fig = plt.figure(figsize=(6, 4))
    gs = gridspec.GridSpec(2, 1, height_ratios=[n_users, 1])
    ax0 = plt.subplot(gs[0])

    colors = cm.rainbow(np.linspace(0, 1, n_users))
    for i in range(1, n_users + 1):
        ax0.scatter(time[user == i], user[user == i], color=colors[i - 1])
        ax0.axhline(i, color='k', alpha=0.1)

    ax0.xaxis.set_ticklabels([])
    ax0.set_yticks(np.arange(1, n_users + 1))
    ax0.set_ylabel('User')

    ax1 = plt.subplot(gs[1])
    ax1.axhline(0, color='k', alpha=0.1)
    for i in range(1, n_users + 1):
        ax1.scatter(time[user == i], zero[user == i], color=colors[i - 1])
    ax1.set_yticks([])
    ax1.set_ylabel('Global')
    plt.xlabel('Time')

    plt.tight_layout()
    return 
Example 18
Project: tf-example-models   Author: aakhundov   File: tf_kmeans.py    Apache License 2.0 5 votes vote down vote up
def plot_clustered_data(points, c_means, c_assignments):
    """Plots the cluster-colored data and the cluster means"""
    colors = cm.rainbow(np.linspace(0, 1, CLUSTERS))

    for cluster, color in zip(range(CLUSTERS), colors):
        c_points = points[c_assignments == cluster]
        plt.plot(c_points[:, 0], c_points[:, 1], ".", color=color, zorder=0)
        plt.plot(c_means[cluster, 0], c_means[cluster, 1], ".", color="black", zorder=1)

    plt.show()


# PREPARING DATA

# generating DATA_POINTS points from a GMM with CLUSTERS components 
Example 19
Project: 2D-Motion-Retargeting   Author: ChrisWu1997   File: cluster.py    MIT License 5 votes vote down vote up
def cluster_body(net, cluster_data, device, save_path):
    data, characters = cluster_data[0], cluster_data[2]
    data = data[:, :, 0, :, :]
    # data = data.reshape(-1, data.shape[2], data.shape[3], data.shape[4])

    nr_mv, nr_char = data.shape[0], data.shape[1]
    labels = np.arange(0, nr_char).reshape(1, -1)
    labels = np.tile(labels, (nr_mv, 1)).reshape(-1)
    
    if hasattr(net, 'static_encoder'):
        features = net.static_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    else:
        features = net.body_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    features = features.detach().cpu().numpy().reshape(features.shape[0], -1)

    features_2d = tsne_on_pca(features, is_PCA=False)
    features_2d = features_2d.reshape(nr_mv, nr_char, -1)

    plt.figure(figsize=(7, 4))
    colors = cm.rainbow(np.linspace(0, 1, nr_char))
    for i in range(nr_char):
        x = features_2d[:, i, 0]
        y = features_2d[:, i, 1]
        plt.scatter(x, y, c=colors[i], label=characters[i])

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
    plt.tight_layout(rect=[0,0,0.75,1])
    plt.savefig(save_path) 
Example 20
Project: 2D-Motion-Retargeting   Author: ChrisWu1997   File: cluster.py    MIT License 5 votes vote down vote up
def cluster_view(net, cluster_data, device, save_path):
    data, views = cluster_data[0], cluster_data[3]
    idx = np.random.randint(data.shape[1] - 1)  # np.linspace(0, data.shape[1] - 1, 4, dtype=int).tolist()
    data = data[:, idx, :, :, :]

    nr_mc, nr_view = data.shape[0], data.shape[1]
    labels = np.arange(0, nr_view).reshape(1, -1)
    labels = np.tile(labels, (nr_mc, 1)).reshape(-1)
    
    if hasattr(net, 'static_encoder'):
        features = net.static_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    else:
        features = net.view_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    features = features.detach().cpu().numpy().reshape(features.shape[0], -1)

    features_2d = tsne_on_pca(features, is_PCA=False)
    features_2d = features_2d.reshape(nr_mc, nr_view, -1)

    plt.figure(figsize=(7, 4))
    colors = cm.rainbow(np.linspace(0, 1, nr_view))
    for i in range(nr_view):
        x = features_2d[:, i, 0]
        y = features_2d[:, i, 1]
        plt.scatter(x, y, c=colors[i], label=views[i])

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
    plt.tight_layout(rect=[0, 0, 0.75, 1])
    plt.savefig(save_path) 
Example 21
Project: 2D-Motion-Retargeting   Author: ChrisWu1997   File: cluster.py    MIT License 5 votes vote down vote up
def cluster_motion(net, cluster_data, device, save_path, nr_anims=15, mode='both'):
    data, animations = cluster_data[0], cluster_data[1]
    idx = np.linspace(0, data.shape[0] - 1, nr_anims, dtype=int).tolist()
    data = data[idx]
    animations = animations[idx]
    if mode == 'body':
        data = data[:, :, 0, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4])
    elif mode == 'view':
        data = data[:, 3, :, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4])
    else:
        data = data[:, :4, ::2, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4])

    nr_anims, nr_cv = data.shape[:2]
    labels = np.arange(0, nr_anims).reshape(-1, 1)
    labels = np.tile(labels, (1, nr_cv)).reshape(-1)
    
    features = net.mot_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3]).to(device))
    features = features.detach().cpu().numpy().reshape(features.shape[0], -1)

    features_2d = tsne_on_pca(features)
    features_2d = features_2d.reshape(nr_anims, nr_cv, -1)
    if features_2d.shape[1] < 5:
        features_2d = np.tile(features_2d, (1, 2, 1))

    plt.figure(figsize=(8, 4))
    colors = cm.rainbow(np.linspace(0, 1, nr_anims))
    for i in range(nr_anims):
        x = features_2d[i, :, 0]
        y = features_2d[i, :, 1]
        plt.scatter(x, y, c=colors[i], label=animations[i])

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
    plt.tight_layout(rect=[0,0,0.8,1])
    plt.savefig(save_path) 
Example 22
Project: ml-deepranking   Author: urakozz   File: visualsearch_train.py    MIT License 5 votes vote down vote up
def plot_with_labels(low_dim_embs, labels, filename='tsne_c5s5r5.png'):
    assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings"
    colors = cm.rainbow(np.linspace(0, 1, len(np.unique(np.array(labels)))))
    plt.figure(figsize=(18, 18))  #in inches
    for i, label in enumerate(labels):
        x, y = low_dim_embs[i,:]
        plt.scatter(x, y, color=colors[label])
        plt.annotate(label,
                     xy=(x, y),
                     xytext=(5, 2),
                     textcoords='offset points',
                     ha='right',
                     va='bottom')

    plt.savefig(filename) 
Example 23
Project: Robotics-EIE3   Author: martinferianc   File: drawing_v.py    MIT License 5 votes vote down vote up
def __init__(self,map_size=210, virtual=False):
        self.map_size    = map_size    # in cm
        self.canvas_size = 768         # in pixels
        self.margin      = 0.05*map_size
        self.scale       = self.canvas_size/(map_size+2*self.margin)
        self.virtual = virtual
        x = np.arange(10)
        ys = [i+x+(i*x)**2 for i in range(20)]
        self.colors = cm.rainbow(np.linspace(0, 1, len(ys)))
        self.counter = 0 
Example 24
Project: Robotics-EIE3   Author: martinferianc   File: drawing_v.py    MIT License 5 votes vote down vote up
def __init__(self,map_size=210, virtual=False):
        self.map_size    = map_size    # in cm
        self.canvas_size = 768         # in pixels
        self.margin      = 0.05*map_size
        self.scale       = self.canvas_size/(map_size+2*self.margin)
        self.virtual = virtual
        x = np.arange(10)
        ys = [i+x+(i*x)**2 for i in range(20)]
        self.colors = cm.rainbow(np.linspace(0, 1, len(ys)))
        self.counter = 0 
Example 25
Project: Robotics-EIE3   Author: martinferianc   File: drawing_v.py    MIT License 5 votes vote down vote up
def __init__(self,map_size=210, virtual=False):
        self.map_size    = map_size    # in cm
        self.canvas_size = 768         # in pixels
        self.margin      = 0.05*map_size
        self.scale       = self.canvas_size/(map_size+2*self.margin)
        self.virtual = virtual
        x = np.arange(10)
        ys = [i+x+(i*x)**2 for i in range(20)]
        self.colors = cm.rainbow(np.linspace(0, 1, len(ys)))
        self.counter = 0 
Example 26
Project: Robotics-EIE3   Author: martinferianc   File: drawing_v.py    MIT License 5 votes vote down vote up
def __init__(self,map_size=210, virtual=False):
        self.map_size    = map_size    # in cm
        self.canvas_size = 768         # in pixels
        self.margin      = 0.05*map_size
        self.scale       = self.canvas_size/(map_size+2*self.margin)
        self.virtual = virtual
        x = np.arange(10)
        ys = [i+x+(i*x)**2 for i in range(20)]
        self.colors = cm.rainbow(np.linspace(0, 1, len(ys)))
        self.counter = 0 
Example 27
Project: Robotics-EIE3   Author: martinferianc   File: drawing_v.py    MIT License 5 votes vote down vote up
def __init__(self,map_size=210, virtual=False):
        self.map_size    = map_size    # in cm
        self.canvas_size = 768         # in pixels
        self.margin      = 0.05*map_size
        self.scale       = self.canvas_size/(map_size+2*self.margin)
        self.virtual = virtual
        x = np.arange(10)
        ys = [i+x+(i*x)**2 for i in range(20)]
        self.colors = cm.rainbow(np.linspace(0, 1, len(ys)))
        self.counter = 0 
Example 28
Project: Robotics-EIE3   Author: martinferianc   File: drawing_v.py    MIT License 5 votes vote down vote up
def __init__(self,map_size=210, virtual=False):
        self.map_size    = map_size    # in cm
        self.canvas_size = 768         # in pixels
        self.margin      = 0.05*map_size
        self.scale       = self.canvas_size/(map_size+2*self.margin)
        self.virtual = virtual
        x = np.arange(10)
        ys = [i+x+(i*x)**2 for i in range(20)]
        self.colors = cm.rainbow(np.linspace(0, 1, len(ys)))
        self.counter = 0 
Example 29
Project: Robotics-EIE3   Author: martinferianc   File: drawing_v.py    MIT License 5 votes vote down vote up
def __init__(self,map_size=210, virtual=False):
        self.map_size    = map_size    # in cm
        self.canvas_size = 768         # in pixels
        self.margin      = 0.05*map_size
        self.scale       = self.canvas_size/(map_size+2*self.margin)
        self.virtual = virtual
        x = np.arange(10)
        ys = [i+x+(i*x)**2 for i in range(20)]
        self.colors = cm.rainbow(np.linspace(0, 1, len(ys)))
        self.counter = 0 
Example 30
Project: Robotics-EIE3   Author: martinferianc   File: drawing_v.py    MIT License 5 votes vote down vote up
def __init__(self,map_size=210, virtual=False):
        self.map_size    = map_size    # in cm
        self.canvas_size = 768         # in pixels
        self.margin      = 0.05*map_size
        self.scale       = self.canvas_size/(map_size+2*self.margin)
        self.virtual = virtual
        x = np.arange(10)
        ys = [i+x+(i*x)**2 for i in range(20)]
        self.colors = cm.rainbow(np.linspace(0, 1, len(ys)))
        self.counter = 0 
Example 31
Project: tredify   Author: gsalvatori   File: Bar.py    GNU General Public License v3.0 5 votes vote down vote up
def get_colors(self):
        return cm.rainbow(np.linspace(0, 1, len(self.dict))) 
Example 32
Project: tredify   Author: gsalvatori   File: Bar.py    GNU General Public License v3.0 5 votes vote down vote up
def get_colors(self):
        return cm.rainbow(np.linspace(0, 1, len(self.dict))) 
Example 33
Project: tredify   Author: gsalvatori   File: Scatter.py    GNU General Public License v3.0 5 votes vote down vote up
def get_colors(self):
        return cm.rainbow(np.linspace(0, 1, len(self.dict))) 
Example 34
Project: tredify   Author: gsalvatori   File: Scatter.py    GNU General Public License v3.0 5 votes vote down vote up
def get_colors(self):
        return cm.rainbow(np.linspace(0, 1, len(self.dict))) 
Example 35
Project: simulator   Author: P2PSP   File: play.py    GNU General Public License v3.0 5 votes vote down vote up
def draw_buffer(self):
        self.buff_win = pg.GraphicsLayoutWidget()
        self.buff_win.setWindowTitle('Buffer Status')
        self.buff_win.resize(800, 700)

        self.total_peers = self.number_of_monitors + self.number_of_peers + self.number_of_malicious
        self.p4 = self.buff_win.addPlot()
        self.p4.showGrid(x=True, y=True, alpha=100)   # To show grid lines across x axis and y axis
        leftaxis = self.p4.getAxis('left')  # get left axis i.e y axis
        leftaxis.setTickSpacing(5, 1)    # to set ticks at a interval of 5 and grid lines at 1 space

        # Get different colors using matplotlib library
        if self.total_peers < 8:
            colors = cm.Set2(np.linspace(0, 1, 8))
        elif self.total_peers < 12:
            colors = cm.Set3(np.linspace(0, 1, 12))
        else:
            colors = cm.rainbow(np.linspace(0, 1, self.total_peers+1))
        self.QColors = [pg.hsvColor(color[0], color[1], color[2], color[3])
                        for color in colors]   # Create QtColors, each color would represent a peer

        self.Data = []  # To represent buffer out  i.e outgoing data from buffer
        self.OutData = []   # To represent buffer in i.e incoming data in buffer

        # a single line would reperesent a single color or peer, hence we would not need to pass a list of brushes
        self.lineIN = [None]*self.total_peers
        for ix in range(self.total_peers):
            self.lineIN[ix] = self.p4.plot(pen=(None), symbolBrush=self.QColors[ix], name='IN', symbol='o', clear=False)
            self.Data.append(set())
            self.OutData.append(set())

        # similiarly one line per peer to represent outgoinf data from buffer
        self.lineOUT = self.p4.plot(pen=(None), symbolBrush=mkColor('#CCCCCC'), name='OUT', symbol='o', clear=False)
        self.p4.setRange(xRange=[0, self.total_peers], yRange=[0, self.get_buffer_size()])
        self.buff_win.show()    # To actually show create window

        self.buffer_order = {}
        self.buffer_index = 0
        self.buffer_labels = []
        self.lastUpdate = pg.ptime.time()
        self.avgFps = 0.0 
Example 36
Project: diemaschine   Author: mowoe   File: websocket-server.py    GNU General Public License v3.0 5 votes vote down vote up
def sendTSNE(self, people):
        d = self.getData()
        if d is None:
            return
        else:
            (X, y) = d

        X_pca = PCA(n_components=50).fit_transform(X, X)
        tsne = TSNE(n_components=2, init='random', random_state=0)
        X_r = tsne.fit_transform(X_pca)

        yVals = list(np.unique(y))
        colors = cm.rainbow(np.linspace(0, 1, len(yVals)))

        # print(yVals)

        plt.figure()
        for c, i in zip(colors, yVals):
            name = "Unknown" if i == -1 else people[i]
            plt.scatter(X_r[y == i, 0], X_r[y == i, 1], c=c, label=name)
            plt.legend()

        imgdata = StringIO.StringIO()
        plt.savefig(imgdata, format='png')
        imgdata.seek(0)

        content = 'data:image/png;base64,' + \
                  urllib.quote(base64.b64encode(imgdata.buf))
        msg = {
            "type": "TSNE_DATA",
            "content": content
        }
        self.sendMessage(json.dumps(msg)) 
Example 37
Project: Python-Data-Analysis-Learning-Notes   Author: Asurada2015   File: 401_CNN.py    MIT License 5 votes vote down vote up
def plot_with_labels(lowDWeights, labels):
    plt.cla(); X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
    for x, y, s in zip(X, Y, labels):
        c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
    plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01) 
Example 38
Project: MOSCA   Author: iquasere   File: binning.py    GNU General Public License v3.0 5 votes vote down vote up
def plot_clusters(self, data, taxa_level, best_clusters, points, output, 
                      label = True, subtitle_size = 20):
        points = pd.read_csv(points, header = None)
        points.columns = ['lat','lon']
        best_clusters = pd.read_csv(best_clusters, sep = '\t')
        points = pd.concat([points, best_clusters], axis = 1)
        numeration = {'superkingdom':'1', 'phylum':'2', 'class':'3', 'order':'4',
                      'family':'5', 'genus':'6', 'species':'7'}
        column = numeration[taxa_level] + '.Taxonomic lineage (' + taxa_level.upper() + ')'
        data = pd.read_excel(data, index_col = [0,1])
        partial = data.xs(column, level = 1)
        points = pd.merge(points, partial, left_on = 'cluster', right_index = True)
        points = points[['lat','lon','Taxa']]
        points = points.fillna(value = 'Not identified')
        
        taxa = list(set(points['Taxa']))
        colors = iter(cm.rainbow(np.linspace(0, 1, len(taxa))))
        plt.gcf().clear()
        for i in range(len(taxa)):
            partial_points = points[points['Taxa'] == taxa[i]]
            plt.scatter(partial_points['lat'], partial_points['lon'], 0.1,
                        color = next(colors), label = taxa[i], marker = 'o')
        if label: 
            label = plt.legend(loc='best')
            for i in range(len(taxa)):
                label.legendHandles[i]._sizes = [subtitle_size]
        plt.savefig(output, bbox_inches='tight') 
Example 39
Project: apachecn_ml   Author: ys1305   File: CNN-DigitRecognizer.py    GNU General Public License v3.0 5 votes vote down vote up
def plot_with_labels(lowDWeights, labels):
    plt.cla()
    X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
    for x, y, s in zip(X, Y, labels):
        from matplotlib import cm
        c = cm.rainbow(int(255 * s / 9))
        plt.text(x, y, s, backgroundcolor=c, fontsize=9)
    plt.xlim(X.min(), X.max())
    plt.ylim(Y.min(), Y.max())
    plt.title('Visualize last layer')
    plt.show()
    plt.pause(0.01) 
Example 40
Project: python-cope   Author: dinhhuy2109   File: SE3lib.py    GNU General Public License v3.0 5 votes vote down vote up
def Visualize(Tlist,sigmalist, nsamples = 100):
  """
  Visualize an estimation (a point will be used to represent the translation position of a transformation)
  @param Tlist:     a list of Transformations
  @param sigmalist: a list of corresponding sigmas
  @param nsamples:  the number of samples generated for each (T,sigma)
  """
  import matplotlib.cm as cm
  fig = plt.figure()
  ax = fig.add_subplot(111, projection='3d')
  cholsigmalist = []
  colors = iter(cm.rainbow(np.linspace(0, 1, len(Tlist))))
  for i in range(len(sigmalist)):
    color = next(colors)
    cholsigma = np.linalg.cholesky(sigmalist[i]).T
    Tsample = []
    for k in range(nsamples):
      vecsample = np.dot(cholsigma,np.random.randn(6,1))
      #vecsample = np.dot(cholsigma, np.random.uniform(-1,1,size = 6))
      vecsample.resize(6)
      Tsample = np.dot(VecToTran(vecsample), Tlist[i])
      ax.scatter(Tsample[0,3],Tsample[1,3],Tsample[2,3], c = color)

  ax.set_autoscaley_on(False)
  ax.set_xlim([-0.5, 0.5])
  ax.set_ylim([-0.5, 0.5])
  ax.set_zlim([-0.5, 0.5])
  ax.set_xlabel('X Label')
  ax.set_ylabel('Y Label')
  ax.set_zlabel('Z Label')
  plt.show(False)
  return True 
Example 41
Project: Facial-Recognition-with-DNN   Author: eashanadhikarla   File: websocket-server.py    GNU General Public License v3.0 5 votes vote down vote up
def sendTSNE(self, people):
        d = self.getData()
        if d is None:
            return
        else:
            (X, y) = d

        X_pca = PCA(n_components=50).fit_transform(X, X)
        tsne = TSNE(n_components=2, init='random', random_state=0)
        X_r = tsne.fit_transform(X_pca)

        yVals = list(np.unique(y))
        colors = cm.rainbow(np.linspace(0, 1, len(yVals)))

        # print(yVals)

        plt.figure()
        for c, i in zip(colors, yVals):
            name = "Unknown" if i == -1 else people[i]
            plt.scatter(X_r[y == i, 0], X_r[y == i, 1], c=c, label=name)
            plt.legend()

        imgdata = StringIO.StringIO()
        plt.savefig(imgdata, format='png')
        imgdata.seek(0)

        content = 'data:image/png;base64,' + \
                  urllib.quote(base64.b64encode(imgdata.buf))
        msg = {
            "type": "TSNE_DATA",
            "content": content
        }
        self.sendMessage(json.dumps(msg)) 
Example 42
Project: tools   Author: sgs-us   File: lbPlot.py    GNU General Public License v3.0 5 votes vote down vote up
def lbD3Q19EdgeStream(i):
    allowedO = [[[ 0, -1,  0], [ 0,  0, -1], [ 0, -1, -1]], #e0
                [[ 0,  1,  0], [ 0,  0, -1], [ 0,  1, -1]], #e1
                [[ 0, -1,  0], [ 0,  0,  1], [ 0, -1,  1]], #e2
                [[ 0,  1,  0], [ 0,  0,  1], [ 0,  1,  1]], #e3
                [[-1,  0,  0], [ 0,  0, -1], [-1,  0, -1]], #e4
                [[ 1,  0,  0], [ 0,  0, -1], [ 1,  0, -1]], #e5
                [[-1,  0,  0], [ 0,  0,  1], [-1,  0,  1]], #e6
                [[ 1,  0,  0], [ 0,  0,  1], [ 1,  0,  1]], #e7
                [[-1,  0,  0], [ 0, -1,  0], [-1, -1,  0]], #e8
                [[ 1,  0,  0], [ 0, -1,  0], [ 1, -1,  0]], #e9
                [[-1,  0,  0], [ 0,  1,  0], [-1,  1,  0]], #e10
                [[ 1,  0,  0], [ 0,  1,  0], [ 1,  1,  0]]] #e11
    ax.cla()
    ax.view_init(elev=10.)

    colors = iter(cm.rainbow(np.linspace(i/12., (i+1)/12., 3)))
    for offset in list(product([-1, 0, 1], repeat=3)):
        drawCube([0, 1], np.array(offset))
        if list(offset) in allowedO[i]:
            arrow = Arrow3D([0.5, 0.5 + offset[0]],
                            [0.5, 0.5 + offset[1]],
                            [0.5, 0.5 + offset[2]],
                            mutation_scale=20, lw=7, arrowstyle="-|>",
                            color=next(colors))
            ax.add_artist(arrow)
    fig.savefig("d3q19Stream_"+str(i).zfill(2)+".png") 
Example 43
Project: SceneChangeDet   Author: gmayday1997   File: tsne_visual.py    MIT License 5 votes vote down vote up
def plot_with_labels(lowDWeights, labels,sz):
    plt.cla()
    X_t0,Y_t0 = lowDWeights[0][:,0],lowDWeights[0][:,1]
    X_t1,Y_t1 = lowDWeights[1][:,0],lowDWeights[1][:,1]
    for idx,(x_t0,y_t0,x_t1,y_t1,lab) in enumerate(zip(X_t0,Y_t0,X_t1,Y_t1,labels)):
        c = cm.rainbow(int(255 * idx/sz))
        plt.text(x_t0,y_t0,lab,backgroundcolor=c,fontsize=9)
        plt.text(x_t1,y_t1,lab,backgroundcolor=c,fontsize=9)
    plt.xlim(X_t0.min(), X_t0.max());plt.ylim(Y_t0.min(), Y_t1.max());
    plt.title('Visualize last layer');plt.show();plt.pause(0.01)
        #for x, y, s in zip(X, Y, labels):
        #c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9) 
Example 44
Project: SceneChangeDet   Author: gmayday1997   File: tsne_visual.py    MIT License 5 votes vote down vote up
def plot_with_labels_feat_cat(lowDWeights, labels,save_dir,title):
    plt.cla()
    X,Y = lowDWeights[:,0],lowDWeights[:,1]
    #plt.scatter(X,Y)
    for idx,(x,y,lab) in enumerate(zip(X,Y,labels)):
        color = cm.rainbow(int(255 * lab/2))
        #plt.scatter(x,y,color)
        plt.text(x,y,lab,backgroundcolor=color,fontsize=0)
    plt.xlim(X.min() *2 , X.max() *2);plt.ylim(Y.min()*2, Y.max()*2)
    plt.title(title)
    #plt.show();plt.pause(0.01)
    plt.savefig(save_dir)
    print save_dir
    #for x, y, s in zip(X, Y, labels):
    #c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9) 
Example 45
Project: SceneChangeDet   Author: gmayday1997   File: tsne_visual.py    MIT License 5 votes vote down vote up
def plot_with_labels_feat_cat_without_text(lowDWeights, labels,save_dir):
    plt.cla()
    X,Y = lowDWeights[:,0],lowDWeights[:,1]
    for idx,(x,y,lab) in enumerate(zip(X,Y,labels)):
        #c = cm.rainbow(int(255 * lab/2))
        if lab == 0:
           plt.plot(x,y,'b')
        if lab == 1:
           plt.plot(x,y,'r')
        #plt.text(x,y,lab,backgroundcolor=c,fontsize=9)
    plt.xlim(X.min() *2 , X.max() *2);plt.ylim(Y.min()*2, Y.max()*2)
    plt.title('Visualize last layer')
    #plt.show();plt.pause(0.01)
    plt.savefig(save_dir)
    print save_dir 
Example 46
Project: nrg_mapping   Author: GiggleLiu   File: utils.py    MIT License 5 votes vote down vote up
def plot_pauli_components(x,y,method='plot',ax=None,label=r'\sigma',**kwargs):
    '''
    Plot data by pauli components.

    Parameters
        :x,y: Datas.
        :ax: The axis to plot, will use gca() to get one if None.
        :label: The legend of plots.
        :method: `plot` or `scatter`
        :kwargs: The key word arguments for plot/scatter.
        
    Return:
        A list of plot instances.
    '''
    if ax is None: ax=gca()
    assert(ndim(x)==1 and ndim(y)==3 and y.shape[2]==2 and y.shape[1]==2)
    assert(method=='plot' or method=='scatter')
    subscripts=['0','x','y','z']

    yv=array([s2vec(yi) for yi in y]).real
    colormap=cm.rainbow(linspace(0,0.8,4))
    plts=[]
    for i in range(4):
        if method=='plot':
            plts+=plot(x,yv[:,i],color=colormap[i],**kwargs)
        else:
            plts.append(scatter(x,yv[:,i],edgecolors=colormap[i],facecolors='none',**kwargs))
    legend(plts,[r'$%s_%s$'%(label,sub) for sub in subscripts])
    return plts 
Example 47
Project: nrg_mapping   Author: GiggleLiu   File: discretization.py    MIT License 5 votes vote down vote up
def check_disc(rhofunc,discmodel,wlist,smearing=0.02,mode='eval'):
    '''
    check the discretization quality by eigenvalues - the multiple-band Green's function version.

    Parameters:
        :rhofunc: function, The original hybridization function.
        :discmodel: <DiscModel>, The discretized model.
        :wlist: 1D array, the frenquncy space.
        :smearing: float, smearing constant.
    '''
    print('Start checking the mapping of discretized model!')
    t0=time.time()
    nband=discmodel.nband
    N=discmodel.N_pos
    Elist,Tlist=discmodel.Elist,discmodel.Tlist
    zlist=discmodel.z
    nz=len(zlist)
    assert(mode=='pauli' or mode=='eval')
    #calculate GL
    AL=hybri_sun(tlist=Tlist,elist=Elist,wlist=wlist,smearing=smearing*1./nz)
    odatas=array([rhofunc(w) for w in wlist])
    if mode=='pauli':
        if nband!=2:
            raise Exception('Check pauli is for 2 band system, not %s!'%nband)
        plot_pauli_components(wlist,odatas,label='rho')
        plot_pauli_components(wlist,AL,method='scatter',label=r"\rho'")
    else:
        AV=array([eigvalsh(A) for A in AL]) if nband>1 else AL.reshape(AL.shape[:2])
        odatas=array([eigvalsh(o) for o in odatas]) if nband>1 else odatas.reshape(AL.shape[:2])
        colormap=cm.rainbow(linspace(0,0.8,4))
        plts=[]
        for i in range(nband):
            plts+=plot(wlist,odatas[:,i],lw=3,color=colormap[i])
        for i in range(nband):
            sct=scatter(wlist,AV[:,i],s=30,edgecolors=colormap[i],facecolors='none')
            plts.append(sct)
        legend(plts,[r'$\rho_%s$'%i for i in range(nband)]+[r"$\rho'_%s$"%i for i in range(nband)],ncol=2)
    xlabel('$\\omega$',fontsize=16)
    print('Time Elapsed: %s s'%(time.time()-t0)) 
Example 48
Project: dex   Author: Innixma   File: graph_helper.py    MIT License 5 votes vote down vote up
def graph_simple(x_list, y_list, names_list, title, y_label, x_label, savefig_name=""):
    colors = iter(cm.rainbow(np.linspace(0, 1, len(y_list))))
    plt.figure()
    max_y = 0
    max_x = 0
    min_y = 99999999
    min_x = 99999999
    for i in range(len(y_list)):
        color = next(colors)
        name = names_list[i]
        y = y_list[i]
        x = x_list[i]

        max_y = max(max_y, np.max(y))
        min_y = min(min_y, np.min(y))
        max_x = max(max_x, np.max(x))
        min_x = min(min_x, np.min(x))

        plt.plot(x, y, c=color, label=name)

    plt.legend(loc=2)
    plt.xlim(min_x, max_x)
    plt.ylim(min_y, max_y)
    plt.grid(True)
    plt.ylabel(y_label)
    plt.xlabel(x_label)
    plt.title(title)
    if savefig_name != "":
        plt.savefig(savefig_name)
    plt.close() 
Example 49
Project: GridCell-3D   Author: jianwen-xie   File: utils.py    MIT License 5 votes vote down vote up
def draw_heatmap_2D(data, vmin=None, vmax=None):
    cmap = cm.get_cmap('rainbow', 1000)

    if vmin is None:
        vmax = data[0][0]
        vmin = data[0][0]
        for i in data:
            for j in i:
                if j > vmax:
                    vmax = j
                if j < vmin:
                    vmin = j
    plt.imshow(data, interpolation='nearest', cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
    plt.axis('off') 
Example 50
Project: GridCell-3D   Author: jianwen-xie   File: utils.py    MIT License 5 votes vote down vote up
def draw_path_integral(place_len, place_seq, col=(255, 0, 0)):
    place_seq = np.round(place_seq).astype(int)
    cmap = cm.get_cmap('rainbow', 1000)

    canvas = np.ones((place_len, place_len, 3), dtype="uint8") * 255
    if target is not None:
        cv2.circle(canvas, tuple(target), 2, (0, 0, 255), -1)
        cv2.circle(canvas, tuple(place_seq[0]), 2, col, -1)
    else:
        cv2.circle(canvas, tuple(place_seq[-1]), 2, col, -1)
    for i in range(len(place_seq) - 1):
        cv2.line(canvas, tuple(place_seq[i]), tuple(place_seq[i + 1]), col, 1)

    plt.imshow(np.swapaxes(canvas, 0, 1), interpolation='nearest', cmap=cmap, aspect='auto')
    return canvas