Python matplotlib.pyplot.colorbar() Examples

The following are 30 code examples of matplotlib.pyplot.colorbar(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: visualise_att_maps_epoch.py    From Attention-Gated-Networks with MIT License 7 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Epochs 
Example #2
Source File: NavierStokes.py    From PINNs with MIT License 7 votes vote down vote up
def plot_solution(X_star, u_star, index):
    
    lb = X_star.min(0)
    ub = X_star.max(0)
    nn = 200
    x = np.linspace(lb[0], ub[0], nn)
    y = np.linspace(lb[1], ub[1], nn)
    X, Y = np.meshgrid(x,y)
    
    U_star = griddata(X_star, u_star.flatten(), (X, Y), method='cubic')
    
    plt.figure(index)
    plt.pcolor(X,Y,U_star, cmap = 'jet')
    plt.colorbar() 
Example #3
Source File: dataset.py    From neural-combinatorial-optimization-rl-tensorflow with MIT License 6 votes vote down vote up
def visualize_sampling(self, permutations):
        max_length = len(permutations[0])
        grid = np.zeros([max_length,max_length]) # initialize heatmap grid to 0

        transposed_permutations = np.transpose(permutations)
        for t, cities_t in enumerate(transposed_permutations): # step t, cities chosen at step t
            city_indices, counts = np.unique(cities_t,return_counts=True,axis=0)
            for u,v in zip(city_indices, counts):
                grid[t][u]+=v # update grid with counts from the batch of permutations

        # plot heatmap
        fig = plt.figure()
        rcParams.update({'font.size': 22})
        ax = fig.add_subplot(1,1,1)
        ax.set_aspect('equal')
        plt.imshow(grid, interpolation='nearest', cmap='gray')
        plt.colorbar()
        plt.title('Sampled permutations')
        plt.ylabel('Time t')
        plt.xlabel('City i')
        plt.show() 
Example #4
Source File: plot.py    From TaskBot with GNU General Public License v3.0 6 votes vote down vote up
def plot_attention(sentences, attentions, labels, **kwargs):
    fig, ax = plt.subplots(**kwargs)
    im = ax.imshow(attentions, interpolation='nearest',
                   vmin=attentions.min(), vmax=attentions.max())
    plt.colorbar(im, shrink=0.5, ticks=[0, 1])
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontproperties=getChineseFont())
    # Loop over data dimensions and create text annotations.
    for i in range(attentions.shape[0]):
        for j in range(attentions.shape[1]):
            text = ax.text(j, i, sentences[i][j],
                           ha="center", va="center", color="b", size=10,
                           fontproperties=getChineseFont())

    ax.set_title("Attention Visual")
    fig.tight_layout()
    plt.show() 
Example #5
Source File: prod_basis.py    From pyscf with Apache License 2.0 6 votes vote down vote up
def generate_png_chess_dp_vertex(self):
    """Produces pictures of the dominant product vertex a chessboard convention"""
    import matplotlib.pylab as plt
    plt.ioff()
    dab2v = self.get_dp_vertex_doubly_sparse()
    for i, ab in enumerate(dab2v): 
        fname = "chess-v-{:06d}.png".format(i)
        print('Matrix No.#{}, Size: {}, Type: {}'.format(i+1, ab.shape, type(ab)), fname)
        if type(ab) != 'numpy.ndarray': ab = ab.toarray()
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        ax.set_aspect('equal')
        plt.imshow(ab, interpolation='nearest', cmap=plt.cm.ocean)
        plt.colorbar()
        plt.savefig(fname)
        plt.close(fig) 
Example #6
Source File: test_mesh_io.py    From simnibs with GNU General Public License v3.0 6 votes vote down vote up
def test_interpolate_grid_const_nn(self, sphere3_msh):
        data = sphere3_msh.elm.tag1
        f = mesh_io.ElementData(data, mesh=sphere3_msh)
        n = (200, 10, 1)
        affine = np.array([[1, 0, 0, -100.5],
                           [0, 1, 0, -5],
                           [0, 0, 1, 0],
                           [0, 0, 0, 1]], dtype=float)
        interp = f.interpolate_to_grid(n, affine, method='assign')
        '''
        import matplotlib.pyplot as plt
        plt.imshow(np.squeeze(interp))
        plt.colorbar()
        plt.show()
        assert False
        '''
        assert np.isclose(interp[100, 5, 0], 3)
        assert np.isclose(interp[187, 5, 0], 4)
        assert np.isclose(interp[193, 5, 0], 5)
        assert np.isclose(interp[198, 5, 0], 0) 
Example #7
Source File: test_mesh_io.py    From simnibs with GNU General Public License v3.0 6 votes vote down vote up
def test_interpolate_grid_rotate_nn(self, sphere3_msh):
        data = np.zeros(sphere3_msh.elm.nr)
        b = sphere3_msh.elements_baricenters().value
        f = mesh_io.ElementData(data, mesh=sphere3_msh)
        # Assign quadrant numbers
        f.value[(b[:, 0] > 0) * (b[:, 1] > 0)] = 1.
        f.value[(b[:, 0] < 0) * (b[:, 1] > 0)] = 2.
        f.value[(b[:, 0] < 0) * (b[:, 1] < 0)] = 3.
        f.value[(b[:, 0] > 0) * (b[:, 1] < 0)] = 4.
        n = (200, 200, 1)
        affine = np.array([[np.cos(np.pi/4.), np.sin(np.pi/4.), 0, -141],
                           [-np.sin(np.pi/4.), np.cos(np.pi/4.), 0, 0],
                           [0, 0, 1, .5],
                           [0, 0, 0, 1]], dtype=float)
        interp = f.interpolate_to_grid(n, affine, method='assign')
        '''
        import matplotlib.pyplot as plt
        plt.imshow(np.squeeze(interp))
        plt.colorbar()
        plt.show()
        '''
        assert np.isclose(interp[190, 100, 0], 4)
        assert np.isclose(interp[100, 190, 0], 1)
        assert np.isclose(interp[10, 100, 0], 2)
        assert np.isclose(interp[100, 10, 0], 3) 
Example #8
Source File: test_mesh_io.py    From simnibs with GNU General Public License v3.0 6 votes vote down vote up
def test_interpolate_grid_rotate_nodedata(self, sphere3_msh):
        data = np.zeros(sphere3_msh.nodes.nr)
        b = sphere3_msh.nodes.node_coord.copy()
        f = mesh_io.NodeData(data, mesh=sphere3_msh)
        # Assign quadrant numbers
        f.value[(b[:, 0] >= 0) * (b[:, 1] >= 0)] = 1.
        f.value[(b[:, 0] <= 0) * (b[:, 1] >= 0)] = 2.
        f.value[(b[:, 0] <= 0) * (b[:, 1] <= 0)] = 3.
        f.value[(b[:, 0] >= 0) * (b[:, 1] <= 0)] = 4.
        n = (200, 200, 1)
        affine = np.array([[np.cos(np.pi/4.), np.sin(np.pi/4.), 0, -141],
                           [-np.sin(np.pi/4.), np.cos(np.pi/4.), 0, 0],
                           [0, 0, 1, .5],
                           [0, 0, 0, 1]], dtype=float)
        interp = f.interpolate_to_grid(n, affine)
        '''
        import matplotlib.pyplot as plt
        plt.imshow(np.squeeze(interp), interpolation='nearest')
        plt.colorbar()
        plt.show()
        '''
        assert np.isclose(interp[190, 100, 0], 4)
        assert np.isclose(interp[100, 190, 0], 1)
        assert np.isclose(interp[10, 100, 0], 2)
        assert np.isclose(interp[100, 10, 0], 3) 
Example #9
Source File: test_mesh_io.py    From simnibs with GNU General Public License v3.0 6 votes vote down vote up
def test_interpolate_grid_elmdata_linear(self, sphere3_msh):
        data = sphere3_msh.elements_baricenters().value[:, 0]
        f = mesh_io.ElementData(data, mesh=sphere3_msh)
        n = (130, 130, 1)
        affine = np.array([[1, 0, 0, -65],
                           [0, 1, 0, -65],
                           [0, 0, 1, 0],
                           [0, 0, 0, 1]], dtype=float)
        X, _ = np.meshgrid(np.arange(130), np.arange(130), indexing='ij')
        interp = f.interpolate_to_grid(n, affine, method='linear', continuous=True)
        '''
        import matplotlib.pyplot as plt
        plt.figure()
        plt.imshow(np.squeeze(interp))
        plt.colorbar()
        plt.show()
        '''
        assert np.allclose(interp[:, :, 0], X - 64.5, atol=1) 
Example #10
Source File: test_mesh_io.py    From simnibs with GNU General Public License v3.0 6 votes vote down vote up
def test_interpolate_grid_elmdata_dicontinuous(self, sphere3_msh):
        data = sphere3_msh.elm.tag1
        f = mesh_io.ElementData(data, mesh=sphere3_msh)
        n = (200, 130, 1)
        affine = np.array([[1, 0, 0, -100.1],
                           [0,-1, 0, 65.1],
                           [0, 0, 1, 0],
                           [0, 0, 0, 1]], dtype=float)
        interp = f.interpolate_to_grid(n, affine, method='linear', continuous=False)
        '''
        import matplotlib.pyplot as plt
        plt.figure()
        plt.imshow(np.squeeze(interp))
        plt.colorbar()
        plt.show()
        '''
        assert np.allclose(interp[6:10, 65, 0], 5, atol=1e-1)
        assert np.allclose(interp[11:15, 65, 0], 4, atol=1e-1)
        assert np.allclose(interp[16:100, 65, 0], 3, atol=1e-1) 
Example #11
Source File: visualise_fmaps.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Load options 
Example #12
Source File: visualise_attention.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()
    plt.suptitle(title) 
Example #13
Source File: visualise_attention.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear',
                        colormap=cm.jet, colormap_lim=None, title='', alpha=0.8):
    plt.ion()
    filters = units.shape[2]
    fig = plt.figure(figure_id, figsize=(5,5))
    fig.clf()

    for i in range(filters):
        plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray')
        plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha)
        plt.axis('off')
        plt.colorbar()
        plt.title(title, fontsize='small')
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

    # plt.savefig('{}/{}.png'.format(dir_name,time.time()))




## Load options 
Example #14
Source File: core.py    From prickle with MIT License 6 votes vote down vote up
def imshow(data, which, levels):
    """
        Display order book data as an image, where order book data is either of
        `df_price` or `df_volume` returned by `load_hdf5` or `load_postgres`.
    """

    if which == 'prices':
        idx = ['askprc.' + str(i) for i in range(levels, 0, -1)]
        idx.extend(['bidprc.' + str(i) for i in range(1, levels + 1, 1)])
    elif which == 'volumes':
        idx = ['askvol.' + str(i) for i in range(levels, 0, -1)]
        idx.extend(['bidvol.' + str(i) for i in range(1, levels + 1, 1)])
    plt.imshow(data.loc[:, idx].T, interpolation='nearest', aspect='auto')
    plt.yticks(range(0, levels * 2, 1), idx)
    plt.colorbar()
    plt.tight_layout()
    plt.show() 
Example #15
Source File: pixel.py    From yatsm with MIT License 6 votes vote down vote up
def plot_DOY(dates, y, mpl_cmap):
    """ Create a DOY plot

    Args:
        dates (iterable): sequence of datetime
        y (np.ndarray): variable to plot
        mpl_cmap (colormap): matplotlib colormap
    """
    doy = np.array([d.timetuple().tm_yday for d in dates])
    year = np.array([d.year for d in dates])

    sp = plt.scatter(doy, y, c=year, cmap=mpl_cmap,
                     marker='o', edgecolors='none', s=35)
    plt.colorbar(sp)

    months = mpl.dates.MonthLocator()  # every month
    months_fmrt = mpl.dates.DateFormatter('%b')

    plt.tick_params(axis='x', which='minor', direction='in', pad=-10)
    plt.axes().xaxis.set_minor_locator(months)
    plt.axes().xaxis.set_minor_formatter(months_fmrt)

    plt.xlim(1, 366)
    plt.xlabel('Day of Year') 
Example #16
Source File: pixel.py    From yatsm with MIT License 6 votes vote down vote up
def plot_VAL(dates, y, mpl_cmap, reps=2):
    """ Create a "Valerie Pasquarella" plot (repeated DOY plot)

    Args:
        dates (iterable): sequence of datetime
        y (np.ndarray): variable to plot
        mpl_cmap (colormap): matplotlib colormap
        reps (int, optional): number of additional repetitions
    """
    doy = np.array([d.timetuple().tm_yday for d in dates])
    year = np.array([d.year for d in dates])

    # Replicate `reps` times
    _doy = doy.copy()
    for r in range(1, reps + 1):
        _doy = np.concatenate((_doy, doy + r * 366))
    _year = np.tile(year, reps + 1)
    _y = np.tile(y, reps + 1)

    sp = plt.scatter(_doy, _y, c=_year, cmap=mpl_cmap,
                     marker='o', edgecolors='none', s=35)
    plt.colorbar(sp)
    plt.xlabel('Day of Year') 
Example #17
Source File: plotting.py    From qb with MIT License 6 votes vote down vote up
def plot_confusion(title, true_labels, predicted_labels, normalized=True):
    labels = list(set(true_labels) | set(predicted_labels))

    if normalized:
        cm = confusion_matrix(true_labels, predicted_labels, labels=labels)
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        cm = confusion_matrix(true_labels, predicted_labels, labels=labels)

    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.set_title(title)
    # plt.colorbar()
    tick_marks = np.arange(len(labels))
    ax.set_xticks(tick_marks)
    ax.set_xticklabels(labels, rotation=90)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(labels)
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    ax.grid(False)
    return fig, ax 
Example #18
Source File: ephys_qc_raw.py    From ibllib with MIT License 6 votes vote down vote up
def _plot_rmsmap(outfil, typ, savefig=True):
    rmsmap = alf.io.load_object(outpath, '_iblqc_ephysTimeRms' + typ.upper())
    plt.figure(figsize=[12, 4.5])
    axim = plt.axes([0.2, 0.1, 0.7, 0.8])
    axrms = plt.axes([0.05, 0.1, 0.15, 0.8])
    axcb = plt.axes([0.92, 0.1, 0.02, 0.8])

    axrms.plot(np.median(rmsmap['rms'], axis=0)[:-1] * 1e6, np.arange(1, rmsmap['rms'].shape[1]))
    axrms.set_ylim(0, rmsmap['rms'].shape[1])

    im = axim.imshow(20 * np.log10(rmsmap['rms'].T + 1e-15), aspect='auto', origin='lower',
                     extent=[rmsmap['timestamps'][0], rmsmap['timestamps'][-1],
                             0, rmsmap['rms'].shape[1]])
    axim.set_xlabel(r'Time (s)')
    axim.set_ylabel(r'Channel Number')
    plt.colorbar(im, cax=axcb)
    if typ == 'ap':
        im.set_clim(-110, -90)
        axrms.set_xlim(100, 0)
    elif typ == 'lf':
        im.set_clim(-100, -60)
        axrms.set_xlim(500, 0)
    axim.set_xlim(0, 4000)
    if savefig:
        plt.savefig(outpath / (typ + '_rms.png'), dpi=150) 
Example #19
Source File: preprocessing.py    From Geocoding-with-Map-Vector with GNU General Public License v3.0 6 votes vote down vote up
def visualise_2D_grid(x, title, log=False):
    """
    Display 2D array data with a title. Optional: log for better visualisation of small values.
    :param x: 2D numpy array you want to visualise
    :param title: of the chart because it's nice to have one :-)
    :param log: True in order to log the values and make for better visualisation, False for raw numbers
    """
    if log:
        x = np.log10(x)
    cmap = colors.LinearSegmentedColormap.from_list('my_colormap', ['lightgrey', 'darkgrey', 'dimgrey', 'black'])
    cmap.set_bad(color='white')
    img = pyplot.imshow(x, cmap=cmap, interpolation='nearest')
    pyplot.colorbar(img, cmap=cmap)
    plt.title(title)
    # plt.savefig(title + u".png", dpi=200, transparent=True)  # Uncomment to save to file
    plt.show() 
Example #20
Source File: Plotter.py    From nmp_qc with MIT License 5 votes vote down vote up
def plot_graph(self, am, position=None, cls=None, fig_name='graph.png'):

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")

            g = nx.from_numpy_matrix(am)

            if position is None:
                position=nx.drawing.circular_layout(g)

            fig = plt.figure()

            if cls is None:
                cls='r'
            else:
                # Make a user-defined colormap.
                cm1 = mcol.LinearSegmentedColormap.from_list("MyCmapName", ["r", "b"])

                # Make a normalizer that will map the time values from
                # [start_time,end_time+1] -> [0,1].
                cnorm = mcol.Normalize(vmin=0, vmax=1)

                # Turn these into an object that can be used to map time values to colors and
                # can be passed to plt.colorbar().
                cpick = cm.ScalarMappable(norm=cnorm, cmap=cm1)
                cpick.set_array([])
                cls = cpick.to_rgba(cls)
                plt.colorbar(cpick, ax=fig.add_subplot(111))


            nx.draw(g, pos=position, node_color=cls, ax=fig.add_subplot(111))

            fig.savefig(os.path.join(self.plotdir, fig_name)) 
Example #21
Source File: utils.py    From sklearn-audio-transfer-learning with ISC License 5 votes vote down vote up
def matrix_visualization(matrix,title=None):
    """ Visualize 2D matrices like spectrograms or feature maps.
    """
    plt.figure()
    plt.imshow(np.flipud(matrix.T),interpolation=None)
    plt.colorbar()
    if title!=None:
        plt.title(title)
    plt.show() 
Example #22
Source File: dataset.py    From neural-combinatorial-optimization-rl-tensorflow with MIT License 5 votes vote down vote up
def visualize_sampling(self,permutations):
        max_length = len(permutations[0])
        grid = np.zeros([max_length,max_length]) # initialize heatmap grid to 0
        transposed_permutations = np.transpose(permutations)
        for t, cities_t in enumerate(transposed_permutations): # step t, cities chosen at step t
            city_indices, counts = np.unique(cities_t,return_counts=True,axis=0)
            for u,v in zip(city_indices, counts):
                grid[t][u]+=v # update grid with counts from the batch of permutations
        # plot heatmap
        fig = plt.figure()
        rcParams.update({'font.size': 22})
        ax = fig.add_subplot(1,1,1)
        ax.set_aspect('equal')
        plt.imshow(grid, interpolation='nearest', cmap='gray')
        plt.colorbar()
        plt.title('Sampled permutations')
        plt.ylabel('Time t')
        plt.xlabel('City i')
        plt.show()

    # Heatmap of attention (x=cities; y=steps) 
Example #23
Source File: dataset.py    From neural-combinatorial-optimization-rl-tensorflow with MIT License 5 votes vote down vote up
def visualize_attention(self,attention):
        # plot heatmap
        fig = plt.figure()
        rcParams.update({'font.size': 22})
        ax = fig.add_subplot(1,1,1)
        ax.set_aspect('equal')
        plt.imshow(attention, interpolation='nearest', cmap='hot')
        plt.colorbar()
        plt.title('Attention distribution')
        plt.ylabel('Step t')
        plt.xlabel('Attention_t')
        plt.show() 
Example #24
Source File: bdk_demo.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def run_synthetic_SGLD():
    theta1 = 0
    theta2 = 1
    sigma1 = numpy.sqrt(10)
    sigma2 = 1
    sigmax = numpy.sqrt(2)
    X = load_synthetic(theta1=theta1, theta2=theta2, sigmax=sigmax, num=100)
    minibatch_size = 1
    total_iter_num = 1000000
    lr_scheduler = SGLDScheduler(begin_rate=0.01, end_rate=0.0001, total_iter_num=total_iter_num,
                                 factor=0.55)
    optimizer = mx.optimizer.create('sgld',
                                    learning_rate=None,
                                    rescale_grad=1.0,
                                    lr_scheduler=lr_scheduler,
                                    wd=0)
    updater = mx.optimizer.get_updater(optimizer)
    theta = mx.random.normal(0, 1, (2,), mx.cpu())
    grad = nd.empty((2,), mx.cpu())
    samples = numpy.zeros((2, total_iter_num))
    start = time.time()
    for i in xrange(total_iter_num):
        if (i + 1) % 100000 == 0:
            end = time.time()
            print("Iter:%d, Time spent: %f" % (i + 1, end - start))
            start = time.time()
        ind = numpy.random.randint(0, X.shape[0])
        synthetic_grad(X[ind], theta, sigma1, sigma2, sigmax, rescale_grad=
        X.shape[0] / float(minibatch_size), grad=grad)
        updater('theta', grad, theta)
        samples[:, i] = theta.asnumpy()
    plt.hist2d(samples[0, :], samples[1, :], (200, 200), cmap=plt.cm.jet)
    plt.colorbar()
    plt.show() 
Example #25
Source File: plot_lfads.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def _plot_item(W, name, full_name, nspaces):
  plt.figure()
  if W.shape == ():
    print(name, ": ", W)
  elif W.shape[0] == 1:
    plt.stem(W.T)
    plt.title(full_name)
  elif W.shape[1] == 1:
    plt.stem(W)
    plt.title(full_name)
  else:
    plt.imshow(np.abs(W), interpolation='nearest', cmap='jet');
    plt.colorbar()
    plt.title(full_name) 
Example #26
Source File: plot_lfads.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def plot_priors():
  g0s_prior_mean_bxn = train_modelvals['prior_g0_mean']
  g0s_prior_var_bxn = train_modelvals['prior_g0_var']
  g0s_post_mean_bxn = train_modelvals['posterior_g0_mean']
  g0s_post_var_bxn = train_modelvals['posterior_g0_var']

  plt.figure(figsize=(10,4), tight_layout=True);
  plt.subplot(1,2,1)
  plt.hist(g0s_post_mean_bxn.flatten(), bins=20, color='b');
  plt.hist(g0s_prior_mean_bxn.flatten(), bins=20, color='g');

  plt.title('Histogram of Prior/Posterior Mean Values')
  plt.subplot(1,2,2)
  plt.hist((g0s_post_var_bxn.flatten()), bins=20, color='b');
  plt.hist((g0s_prior_var_bxn.flatten()), bins=20, color='g');
  plt.title('Histogram of Prior/Posterior Log Variance Values')

  plt.figure(figsize=(10,10), tight_layout=True)
  plt.subplot(2,2,1)
  plt.imshow(g0s_prior_mean_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Prior g0 means')

  plt.subplot(2,2,2)
  plt.imshow(g0s_post_mean_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Posterior g0 means');

  plt.subplot(2,2,3)
  plt.imshow(g0s_prior_var_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Prior g0 variance Values')

  plt.subplot(2,2,4)
  plt.imshow(g0s_post_var_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Posterior g0 variance Values')

  plt.figure(figsize=(10,5))
  plt.stem(np.sort(np.log(g0s_post_mean_bxn.std(axis=0))));
  plt.title('Log standard deviation of h0 means'); 
Example #27
Source File: plot.py    From TaskBot with GNU General Public License v3.0 5 votes vote down vote up
def plot_confusion_matrix(y_true, y_test, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    cm = confusion_matrix(y_true, y_test)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label') 
Example #28
Source File: audio.py    From argus-freesound with MIT License 5 votes vote down vote up
def show_melspectrogram(mels, title='Log-frequency power spectrogram'):
    import matplotlib.pyplot as plt

    librosa.display.specshow(mels, x_axis='time', y_axis='mel',
                             sr=config.sampling_rate, hop_length=config.hop_length,
                             fmin=config.fmin, fmax=config.fmax)
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.show() 
Example #29
Source File: plotGraphs.py    From peershark with MIT License 5 votes vote down vote up
def plotGraph(x, y, z, filename):
	v = [0,5000,0,200]
	plt.axis(v)
	plt.scatter(x, y, alpha = 0.10, cmap=plt.cm.cool, edgecolors='None')
	# plt.colorbar()
	pylab.savefig(filename, bbox_inches = 0)
	plt.clf()

#scale input data and plot graphs 
Example #30
Source File: eval.py    From tartarus with MIT License 5 votes vote down vote up
def plot_confusion_matrix(cm, labels, title='Confusion matrix', cmap=plt.cm.Blues):
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        #plt.title(title, **csfont)
        plt.colorbar()
        tick_marks = np.arange(len(labels))
        plt.xticks(tick_marks, labels, rotation=90)
        plt.yticks(tick_marks, labels)
        #plt.xticks(tick_marks, rotation=90)
        #plt.yticks(tick_marks) 
        csfont = {'fontname':'Times', 'fontsize':'17'}
        plt.tight_layout()
        plt.ylabel('True label', **csfont)
        plt.xlabel('Predicted label', **csfont)