Python matplotlib.pyplot.axis() Examples

The following are 30 code examples of matplotlib.pyplot.axis(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: __init__.py    From EDeN with MIT License 11 votes vote down vote up
def plot_confusion_matrix(y_true, y_pred, size=None, normalize=False):
    """plot_confusion_matrix."""
    cm = confusion_matrix(y_true, y_pred)
    fmt = "%d"
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = "%.2f"
    xticklabels = list(sorted(set(y_pred)))
    yticklabels = list(sorted(set(y_true)))
    if size is not None:
        plt.figure(figsize=(size, size))
    heatmap(cm, xlabel='Predicted label', ylabel='True label',
            xticklabels=xticklabels, yticklabels=yticklabels,
            cmap=plt.cm.Blues, fmt=fmt)
    if normalize:
        plt.title("Confusion matrix (norm.)")
    else:
        plt.title("Confusion matrix")
    plt.gca().invert_yaxis() 
Example #2
Source File: movie.py    From kvae with MIT License 10 votes vote down vote up
def save_frames(images, filename):
    num_sequences, n_steps, w, h = images.shape

    fig = plt.figure()
    im = plt.imshow(combine_multiple_img(images[:, 0]), cmap=plt.cm.get_cmap('Greys'), interpolation='none')
    plt.axis('image')

    def updatefig(*args):
        im.set_array(combine_multiple_img(images[:, args[0]]))
        return im,

    ani = animation.FuncAnimation(fig, updatefig, interval=500, frames=n_steps)

    # Either avconv or ffmpeg need to be installed in the system to produce the videos!
    try:
        writer = animation.writers['avconv']
    except KeyError:
        writer = animation.writers['ffmpeg']
    writer = writer(fps=3)
    ani.save(filename, writer=writer)
    plt.close(fig) 
Example #3
Source File: data_augmentation.py    From Sound-Recognition-Tutorial with Apache License 2.0 10 votes vote down vote up
def demo_plot():
    audio = './data/esc10/audio/Dog/1-30226-A.ogg'
    y, sr = librosa.load(audio, sr=44100)
    y_ps = librosa.effects.pitch_shift(y, sr, n_steps=6)   # n_steps控制音调变化尺度
    y_ts = librosa.effects.time_stretch(y, rate=1.2)   # rate控制时间维度的变换尺度
    plt.subplot(311)
    plt.plot(y)
    plt.title('Original waveform')
    plt.axis([0, 200000, -0.4, 0.4])
    # plt.axis([88000, 94000, -0.4, 0.4])
    plt.subplot(312)
    plt.plot(y_ts)
    plt.title('Time Stretch transformed waveform')
    plt.axis([0, 200000, -0.4, 0.4])
    plt.subplot(313)
    plt.plot(y_ps)
    plt.title('Pitch Shift transformed waveform')
    plt.axis([0, 200000, -0.4, 0.4])
    # plt.axis([88000, 94000, -0.4, 0.4])
    plt.tight_layout()
    plt.show() 
Example #4
Source File: __init__.py    From EDeN with MIT License 7 votes vote down vote up
def plot_roc_curve(y_true, y_score, size=None):
    """plot_roc_curve."""
    false_positive_rate, true_positive_rate, thresholds = roc_curve(
        y_true, y_score)
    if size is not None:
        plt.figure(figsize=(size, size))
        plt.axis('equal')
    plt.plot(false_positive_rate, true_positive_rate, lw=2, color='navy')
    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.ylim([-0.05, 1.05])
    plt.xlim([-0.05, 1.05])
    plt.grid()
    plt.title('Receiver operating characteristic AUC={0:0.2f}'.format(
        roc_auc_score(y_true, y_score))) 
Example #5
Source File: movie.py    From kvae with MIT License 7 votes vote down vote up
def save_movie_to_frame(images, filename, idx=0, cmap='Blues'):
    # Collect to single image
    image = movie_to_frame(images[idx])

    # Flip it
    # image = np.fliplr(image)
    # image = np.flipud(image)

    f = plt.figure(figsize=[12, 12])
    plt.imshow(image, cmap=plt.cm.get_cmap(cmap), interpolation='none', vmin=0, vmax=1)

    plt.axis('image')
    plt.xticks([])
    plt.yticks([])
    plt.savefig(filename, format='png', bbox_inches='tight', dpi=80)
    plt.close(f) 
Example #6
Source File: visualise_att_maps_epoch.py    From Attention-Gated-Networks with MIT License 7 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Epochs 
Example #7
Source File: recall.py    From AerialDetection with Apache License 2.0 6 votes vote down vote up
def plot_num_recall(recalls, proposal_nums):
    """Plot Proposal_num-Recalls curve.

    Args:
        recalls(ndarray or list): shape (k,)
        proposal_nums(ndarray or list): same shape as `recalls`
    """
    if isinstance(proposal_nums, np.ndarray):
        _proposal_nums = proposal_nums.tolist()
    else:
        _proposal_nums = proposal_nums
    if isinstance(recalls, np.ndarray):
        _recalls = recalls.tolist()
    else:
        _recalls = recalls

    import matplotlib.pyplot as plt
    f = plt.figure()
    plt.plot([0] + _proposal_nums, [0] + _recalls)
    plt.xlabel('Proposal num')
    plt.ylabel('Recall')
    plt.axis([0, proposal_nums.max(), 0, 1])
    f.show() 
Example #8
Source File: visualise_attention.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear',
                        colormap=cm.jet, colormap_lim=None, title='', alpha=0.8):
    plt.ion()
    filters = units.shape[2]
    fig = plt.figure(figure_id, figsize=(5,5))
    fig.clf()

    for i in range(filters):
        plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray')
        plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha)
        plt.axis('off')
        plt.colorbar()
        plt.title(title, fontsize='small')
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

    # plt.savefig('{}/{}.png'.format(dir_name,time.time()))




## Load options 
Example #9
Source File: recall.py    From AerialDetection with Apache License 2.0 6 votes vote down vote up
def plot_iou_recall(recalls, iou_thrs):
    """Plot IoU-Recalls curve.

    Args:
        recalls(ndarray or list): shape (k,)
        iou_thrs(ndarray or list): same shape as `recalls`
    """
    if isinstance(iou_thrs, np.ndarray):
        _iou_thrs = iou_thrs.tolist()
    else:
        _iou_thrs = iou_thrs
    if isinstance(recalls, np.ndarray):
        _recalls = recalls.tolist()
    else:
        _recalls = recalls

    import matplotlib.pyplot as plt
    f = plt.figure()
    plt.plot(_iou_thrs + [1.0], _recalls + [0.])
    plt.xlabel('IoU')
    plt.ylabel('Recall')
    plt.axis([iou_thrs.min(), 1, 0, 1])
    f.show() 
Example #10
Source File: visualize.py    From dataiku-contrib with Apache License 2.0 6 votes vote down vote up
def display_images(images, titles=None, cols=4, cmap=None, norm=None,
                   interpolation=None):
    """Display the given set of images, optionally with titles.
    images: list or array of image tensors in HWC format.
    titles: optional. A list of titles to display with each image.
    cols: number of images per row
    cmap: Optional. Color map to use. For example, "Blues".
    norm: Optional. A Normalize instance to map values to colors.
    interpolation: Optional. Image interpolation to use for display.
    """
    titles = titles if titles is not None else [""] * len(images)
    rows = len(images) // cols + 1
    plt.figure(figsize=(14, 14 * rows // cols))
    i = 1
    for image, title in zip(images, titles):
        plt.subplot(rows, cols, i)
        plt.title(title, fontsize=9)
        plt.axis('off')
        plt.imshow(image.astype(np.uint8), cmap=cmap,
                   norm=norm, interpolation=interpolation)
        i += 1
    plt.show() 
Example #11
Source File: recall.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def plot_num_recall(recalls, proposal_nums):
    """Plot Proposal_num-Recalls curve.

    Args:
        recalls(ndarray or list): shape (k,)
        proposal_nums(ndarray or list): same shape as `recalls`
    """
    if isinstance(proposal_nums, np.ndarray):
        _proposal_nums = proposal_nums.tolist()
    else:
        _proposal_nums = proposal_nums
    if isinstance(recalls, np.ndarray):
        _recalls = recalls.tolist()
    else:
        _recalls = recalls

    import matplotlib.pyplot as plt
    f = plt.figure()
    plt.plot([0] + _proposal_nums, [0] + _recalls)
    plt.xlabel('Proposal num')
    plt.ylabel('Recall')
    plt.axis([0, proposal_nums.max(), 0, 1])
    f.show() 
Example #12
Source File: tsne_visualizer.py    From linguistic-style-transfer with Apache License 2.0 6 votes vote down vote up
def plot_coordinates(coordinates, plot_path, markers, label_names, fig_num):
    matplotlib.use('svg')
    import matplotlib.pyplot as plt

    plt.figure(fig_num)
    for i in range(len(markers) - 1):
        plt.scatter(x=coordinates[markers[i]:markers[i + 1], 0],
                    y=coordinates[markers[i]:markers[i + 1], 1],
                    marker=plot_markers[i % len(plot_markers)],
                    c=colors[i % len(colors)],
                    label=label_names[i], alpha=0.75)

    plt.legend(loc='upper right', fontsize='x-large')
    plt.axis('off')
    plt.savefig(fname=plot_path, format="svg", bbox_inches='tight', transparent=True)
    plt.close() 
Example #13
Source File: plot_alert_pattern_subgraphs.py    From AMLSim with Apache License 2.0 6 votes vote down vote up
def plot_alerts(_g, _bank_accts, _output_png):
    bank_ids = _bank_accts.keys()
    cmap = plt.get_cmap("tab10")
    pos = nx.nx_agraph.graphviz_layout(_g)

    plt.figure(figsize=(12.0, 8.0))
    plt.axis('off')

    for i, bank_id in enumerate(bank_ids):
        color = cmap(i)
        members = _bank_accts[bank_id]
        nx.draw_networkx_nodes(_g, pos, members, node_size=300, node_color=color, label=bank_id)
        nx.draw_networkx_labels(_g, pos, {n: n for n in members}, font_size=10)

    edge_labels = nx.get_edge_attributes(_g, "label")
    nx.draw_networkx_edges(_g, pos)
    nx.draw_networkx_edge_labels(_g, pos, edge_labels, font_size=6)

    plt.legend(numpoints=1)
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    plt.savefig(_output_png, dpi=120) 
Example #14
Source File: VisMPL.py    From NURBS-Python with MIT License 6 votes vote down vote up
def set_axes_equal(ax):
        """ Sets equal aspect ratio across the three axes of a 3D plot.

        Contributed by Xuefeng Zhao.

        :param ax: a Matplotlib axis, e.g., as output from plt.gca().
        """
        bounds = [ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()]
        ranges = [abs(bound[1] - bound[0]) for bound in bounds]
        centers = [np.mean(bound) for bound in bounds]
        radius = 0.5 * max(ranges)
        lower_limits = centers - radius
        upper_limits = centers + radius
        ax.set_xlim3d([lower_limits[0], upper_limits[0]])
        ax.set_ylim3d([lower_limits[1], upper_limits[1]])
        ax.set_zlim3d([lower_limits[2], upper_limits[2]]) 
Example #15
Source File: show_boxes.py    From Deep-Feature-Flow-Segmentation with MIT License 6 votes vote down vote up
def show_boxes(im, dets, classes, scale = 1.0):
    plt.cla()
    plt.axis("off")
    plt.imshow(im)
    for cls_idx, cls_name in enumerate(classes):
        cls_dets = dets[cls_idx]
        for det in cls_dets:
            bbox = det[:4] * scale
            color = (rand(), rand(), rand())
            rect = plt.Rectangle((bbox[0], bbox[1]),
                                  bbox[2] - bbox[0],
                                  bbox[3] - bbox[1], fill=False,
                                  edgecolor=color, linewidth=2.5)
            plt.gca().add_patch(rect)

            if cls_dets.shape[1] == 5:
                score = det[-1]
                plt.gca().text(bbox[0], bbox[1],
                               '{:s} {:.3f}'.format(cls_name, score),
                               bbox=dict(facecolor=color, alpha=0.5), fontsize=9, color='white')
    plt.show()
    return im 
Example #16
Source File: visualization.py    From mac-network with Apache License 2.0 6 votes vote down vote up
def showImgAtt(img, instance, step, ax):
    dx, dy = 0.05, 0.05
    x = np.arange(-1.5, 1.5, dx)
    y = np.arange(-1.0, 1.0, dy)
    X, Y = np.meshgrid(x, y)
    extent = np.min(x), np.max(x), np.min(y), np.max(y)

    ax.cla()

    img1 = ax.imshow(img, interpolation = "nearest", extent = extent)
    ax.imshow(np.array(instance["attentions"]["kb"][step]).reshape(imageDims), cmap = plt.get_cmap(args.cmap), 
        interpolation = "bicubic", extent = extent)

    ax.set_axis_off()
    plt.axis("off")

    ax.set_aspect("auto") 
Example #17
Source File: recall.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def plot_iou_recall(recalls, iou_thrs):
    """Plot IoU-Recalls curve.

    Args:
        recalls(ndarray or list): shape (k,)
        iou_thrs(ndarray or list): same shape as `recalls`
    """
    if isinstance(iou_thrs, np.ndarray):
        _iou_thrs = iou_thrs.tolist()
    else:
        _iou_thrs = iou_thrs
    if isinstance(recalls, np.ndarray):
        _recalls = recalls.tolist()
    else:
        _recalls = recalls

    import matplotlib.pyplot as plt
    f = plt.figure()
    plt.plot(_iou_thrs + [1.0], _recalls + [0.])
    plt.xlabel('IoU')
    plt.ylabel('Recall')
    plt.axis([iou_thrs.min(), 1, 0, 1])
    f.show() 
Example #18
Source File: graphTools.py    From graph-neural-networks with GNU General Public License v3.0 6 votes vote down vote up
def adjacencyToLaplacian(W):
    """
    adjacencyToLaplacian: Computes the Laplacian from an Adjacency matrix

    Input:

        W (np.array): adjacency matrix

    Output:

        L (np.array): Laplacian matrix
    """
    # Check that the matrix is square
    assert W.shape[0] == W.shape[1]
    # Compute the degree vector
    d = np.sum(W, axis = 1)
    # And build the degree matrix
    D = np.diag(d)
    # Return the Laplacian
    return D - W 
Example #19
Source File: graphTools.py    From graph-neural-networks with GNU General Public License v3.0 6 votes vote down vote up
def normalizeAdjacency(W):
    """
    NormalizeAdjacency: Computes the degree-normalized adjacency matrix

    Input:

        W (np.array): adjacency matrix

    Output:

        A (np.array): degree-normalized adjacency matrix
    """
    # Check that the matrix is square
    assert W.shape[0] == W.shape[1]
    # Compute the degree vector
    d = np.sum(W, axis = 1)
    # Invert the square root of the degree
    d = 1/np.sqrt(d)
    # And build the square root inverse degree matrix
    D = np.diag(d)
    # Return the Normalized Adjacency
    return D @ W @ D 
Example #20
Source File: plot_lfads.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def plot_time_series(vals_bxtxn, bidx=None, n_to_plot=np.inf, scale=1.0,
                     color='r', title=None):

  if bidx is None:
    vals_txn = np.mean(vals_bxtxn, axis=0)
  else:
    vals_txn = vals_bxtxn[bidx,:,:]

  T, N = vals_txn.shape
  if n_to_plot > N:
    n_to_plot = N

  plt.plot(vals_txn[:,0:n_to_plot] + scale*np.array(range(n_to_plot)),
           color=color, lw=1.0)
  plt.axis('tight')
  if title:
    plt.title(title) 
Example #21
Source File: visualise_attention.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()
    plt.suptitle(title) 
Example #22
Source File: visualise_fmaps.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Load options 
Example #23
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 6 votes vote down vote up
def print_mutation(hyp, results, bucket=''):
    # Print mutation results to evolve.txt (for use with train.py --evolve)
    a = '%10s' * len(hyp) % tuple(hyp.keys())  # hyperparam keys
    b = '%10.3g' * len(hyp) % tuple(hyp.values())  # hyperparam values
    c = '%10.3g' * len(results) % results  # results (P, R, mAP, F1, test_loss)
    print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))

    if bucket:
        os.system('gsutil cp gs://%s/evolve.txt .' % bucket)  # download evolve.txt

    with open('evolve.txt', 'a') as f:  # append result
        f.write(c + b + '\n')
    x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0)  # load unique rows
    np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g')  # save sort by fitness

    if bucket:
        os.system('gsutil cp evolve.txt gs://%s' % bucket)  # upload evolve.txt 
Example #24
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 6 votes vote down vote up
def plot_images(imgs, targets, paths=None, fname='images.jpg'):
    # Plots training images overlaid with targets
    imgs = imgs.cpu().numpy()
    targets = targets.cpu().numpy()
    # targets = targets[targets[:, 1] == 21]  # plot only one class

    fig = plt.figure(figsize=(10, 10))
    bs, _, h, w = imgs.shape  # batch size, _, height, width
    bs = min(bs, 16)  # limit plot to 16 images
    ns = np.ceil(bs ** 0.5)  # number of subplots

    for i in range(bs):
        boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).T
        boxes[[0, 2]] *= w
        boxes[[1, 3]] *= h
        plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0))
        plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
        plt.axis('off')
        if paths is not None:
            s = Path(paths[i]).name
            plt.title(s[:min(len(s), 40)], fontdict={'size': 8})  # limit to 40 characters
    fig.tight_layout()
    fig.savefig(fname, dpi=200)
    plt.close() 
Example #25
Source File: movie.py    From kvae with MIT License 6 votes vote down vote up
def save_movies_to_frame(images, filename, cmap='Blues'):
    # Binarize images
    # images[images > 0] = 1.

    # Grid images
    images = np.swapaxes(images, 1, 0)
    images = np.array([combine_multiple_img(image) for image in images])

    # Collect to single image
    image = movie_to_frame(images)

    f = plt.figure(figsize=[12, 12])
    plt.imshow(image, cmap=plt.cm.get_cmap(cmap), interpolation='none', vmin=0, vmax=1)
    plt.axis('image')
    plt.savefig(filename, format='png', bbox_inches='tight', dpi=80)
    plt.close(f) 
Example #26
Source File: mnist_gan.py    From gandlf with MIT License 6 votes vote down vote up
def get_mnist_data(binarize=False):
    """Puts the MNIST data in the right format."""

    (X_train, y_train), (X_test, y_test) = mnist.load_data()

    if binarize:
        X_test = np.where(X_test >= 10, 1, -1)
        X_train = np.where(X_train >= 10, 1, -1)
    else:
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_test = (X_test.astype(np.float32) - 127.5) / 127.5

    X_train = np.expand_dims(X_train, axis=-1)
    X_test = np.expand_dims(X_test, axis=-1)

    y_train = np.expand_dims(y_train, axis=-1)
    y_test = np.expand_dims(y_test, axis=-1)

    return (X_train, y_train), (X_test, y_test) 
Example #27
Source File: test.py    From Chinese-Character-and-Calligraphic-Image-Processing with MIT License 6 votes vote down vote up
def test(self):

        list_ = os.listdir("./maps/val/")
        nums_file = list_.__len__()
        saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "generator"))
        saver.restore(self.sess, "./save_para/model.ckpt")
        rand_select = np.random.randint(0, nums_file)
        INPUTS_CONDITION = np.zeros([1, self.img_h, self.img_w, 3])
        INPUTS = np.zeros([1, self.img_h, self.img_w, 3])
        img = np.array(Image.open(self.path + list_[rand_select]))
        img_h, img_w = img.shape[0], img.shape[1]
        INPUTS_CONDITION[0] = misc.imresize(img[:, img_w//2:], [self.img_h, self.img_w]) / 127.5 - 1.0
        INPUTS[0] = misc.imresize(img[:, :img_w//2], [self.img_h, self.img_w]) / 127.5 - 1.0
        [fake_img] = self.sess.run([self.inputs_fake], feed_dict={self.inputs_condition: INPUTS_CONDITION})
        out_img = np.concatenate((INPUTS_CONDITION[0], fake_img[0], INPUTS[0]), axis=1)
        Image.fromarray(np.uint8((out_img + 1.0)*127.5)).save("./results/1.jpg")
        plt.imshow(np.uint8((out_img + 1.0)*127.5))
        plt.grid("off")
        plt.axis("off")
        plt.show() 
Example #28
Source File: test.py    From Chinese-Character-and-Calligraphic-Image-Processing with MIT License 6 votes vote down vote up
def generator(self, inputs_condition):
        inputs = inputs_condition
        with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
            inputs1 = leaky_relu(conv2d("conv1", inputs, 64, 5, 2))#128x128x128
            inputs2 = leaky_relu(instanceNorm("in1", conv2d("conv2", inputs1, 128, 5, 2)))#64x64x256
            inputs3 = leaky_relu(instanceNorm("in2", conv2d("conv3", inputs2, 256, 5, 2)))#32x32x512
            inputs4 = leaky_relu(instanceNorm("in3", conv2d("conv4", inputs3, 512, 5, 2)))#16x16x512
            inputs5 = leaky_relu(instanceNorm("in4", conv2d("conv5", inputs4, 512, 5, 2)))#8x8x512
            inputs6 = leaky_relu(instanceNorm("in5", conv2d("conv6", inputs5, 512, 5, 2)))#4x4x512
            inputs7 = leaky_relu(instanceNorm("in6", conv2d("conv7", inputs6, 512, 5, 2)))#2x2x512
            inputs8 = leaky_relu(instanceNorm("in7", conv2d("conv8", inputs7, 512, 5, 2)))#1x1x512
            outputs1 = tf.nn.relu(tf.concat([tf.nn.dropout(instanceNorm("in9", deconv2d("dconv1", inputs8, 512, 5, 2)), 0.5), inputs7], axis=3))  # 2x2x512
            outputs2 = tf.nn.relu(tf.concat([tf.nn.dropout(instanceNorm("in10", deconv2d("dconv2", outputs1, 512, 5, 2)), 0.5), inputs6], axis=3))  # 4x4x512
            outputs3 = tf.nn.relu(tf.concat([tf.nn.dropout(instanceNorm("in11", deconv2d("dconv3", outputs2, 512, 5, 2)), 0.5), inputs5], axis=3))#8x8x512
            outputs4 = tf.nn.relu(tf.concat([instanceNorm("in12", deconv2d("dconv4", outputs3, 512, 5, 2)), inputs4], axis=3))#16x16x512
            outputs5 = tf.nn.relu(tf.concat([instanceNorm("in13", deconv2d("dconv5", outputs4, 256, 5, 2)), inputs3], axis=3))#32x32x256
            outputs6 = tf.nn.relu(tf.concat([instanceNorm("in14", deconv2d("dconv6", outputs5, 128, 5, 2)), inputs2], axis=3))#64x64x128
            outputs7 = tf.nn.relu(tf.concat([instanceNorm("in15", deconv2d("dconv7", outputs6, 64, 5, 2)), inputs1], axis=3))#128x128x64
            outputs8 = tf.nn.tanh((deconv2d("dconv8", outputs7, 3, 5, 2)))#256x256x3
            return outputs8 
Example #29
Source File: demo.py    From razzy-spinner with GNU General Public License v3.0 6 votes vote down vote up
def _demo_plot(learning_curve_output, teststats, trainstats=None, take=None):
   testcurve = [teststats['initialerrors']]
   for rulescore in teststats['rulescores']:
       testcurve.append(testcurve[-1] - rulescore)
   testcurve = [1 - x/teststats['tokencount'] for x in testcurve[:take]]

   traincurve = [trainstats['initialerrors']]
   for rulescore in trainstats['rulescores']:
       traincurve.append(traincurve[-1] - rulescore)
   traincurve = [1 - x/trainstats['tokencount'] for x in traincurve[:take]]

   import matplotlib.pyplot as plt
   r = list(range(len(testcurve)))
   plt.plot(r, testcurve, r, traincurve)
   plt.axis([None, None, None, 1.0])
   plt.savefig(learning_curve_output) 
Example #30
Source File: grid.py    From python-control with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _final_setup(ax):
    ax.set_xlabel('Real')
    ax.set_ylabel('Imaginary')
    ax.axhline(y=0, color='black', lw=1)
    ax.axvline(x=0, color='black', lw=1)
    plt.axis('equal')