Python matplotlib.pyplot.subplots() Examples

The following are 30 code examples for showing how to use matplotlib.pyplot.subplots(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .

Example 1
Project: dc_tts   Author: Kyubyong   File: utils.py    License: Apache License 2.0 6 votes vote down vote up
def plot_alignment(alignment, gs, dir=hp.logdir):
    """Plots the alignment.

    Args:
      alignment: A numpy array with shape of (encoder_steps, decoder_steps)
      gs: (int) global step.
      dir: Output path.
    """
    if not os.path.exists(dir): os.mkdir(dir)

    fig, ax = plt.subplots()
    im = ax.imshow(alignment)

    fig.colorbar(im)
    plt.title('{} Steps'.format(gs))
    plt.savefig('{}/alignment_{}.png'.format(dir, gs), format='png')
    plt.close(fig) 
Example 2
Project: deep-learning-note   Author: wdxtub   File: 8_kmeans_pca.py    License: MIT License 6 votes vote down vote up
def plot_n_image(X, n):
    """ plot first n images
    n has to be a square number
    """
    pic_size = int(np.sqrt(X.shape[1]))
    grid_size = int(np.sqrt(n))

    first_n_images = X[:n, :]

    fig, ax_array = plt.subplots(nrows=grid_size, ncols=grid_size,
                                    sharey=True, sharex=True, figsize=(8, 8))

    for r in range(grid_size):
        for c in range(grid_size):
            ax_array[r, c].imshow(first_n_images[grid_size * r + c].reshape((pic_size, pic_size)))
            plt.xticks(np.array([]))
            plt.yticks(np.array([])) 
Example 3
Project: keras-anomaly-detection   Author: chen0040   File: h2o_ecg_pulse_detection.py    License: MIT License 6 votes vote down vote up
def plot_bidimensional(model, test, recon_error, layer, title):
    bidimensional_data = model.deepfeatures(test, layer).cbind(recon_error).as_data_frame()

    cmap = cm.get_cmap('Spectral')

    fig, ax = plt.subplots()
    bidimensional_data.plot(kind='scatter',
                            x='DF.L{}.C1'.format(layer + 1),
                            y='DF.L{}.C2'.format(layer + 1),
                            s=500,
                            c='Reconstruction.MSE',
                            title=title,
                            ax=ax,
                            colormap=cmap)
    layer_column = 'DF.L{}.C'.format(layer + 1)
    columns = [layer_column + '1', layer_column + '2']
    for k, v in bidimensional_data[columns].iterrows():
        ax.annotate(k, v, size=20, verticalalignment='bottom', horizontalalignment='left')
    fig.canvas.draw()
    plt.show() 
Example 4
Project: keras-anomaly-detection   Author: chen0040   File: plot_utils.py    License: MIT License 6 votes vote down vote up
def visualize_anomaly(y_true, reconstruction_error, threshold):
    error_df = pd.DataFrame({'reconstruction_error': reconstruction_error,
                             'true_class': y_true})
    print(error_df.describe())

    groups = error_df.groupby('true_class')
    fig, ax = plt.subplots()

    for name, group in groups:
        ax.plot(group.index, group.reconstruction_error, marker='o', ms=3.5, linestyle='',
                label="Fraud" if name == 1 else "Normal")

    ax.hlines(threshold, ax.get_xlim()[0], ax.get_xlim()[1], colors="r", zorder=100, label='Threshold')
    ax.legend()
    plt.title("Reconstruction error for different classes")
    plt.ylabel("Reconstruction error")
    plt.xlabel("Data point index")
    plt.show() 
Example 5
Project: pruning_yolov3   Author: zbyuan   File: utils.py    License: GNU General Public License v3.0 6 votes vote down vote up
def plot_images(imgs, targets, paths=None, fname='images.jpg'):
    # Plots training images overlaid with targets
    imgs = imgs.cpu().numpy()
    targets = targets.cpu().numpy()
    # targets = targets[targets[:, 1] == 21]  # plot only one class

    fig = plt.figure(figsize=(10, 10))
    bs, _, h, w = imgs.shape  # batch size, _, height, width
    bs = min(bs, 16)  # limit plot to 16 images
    ns = np.ceil(bs ** 0.5)  # number of subplots

    for i in range(bs):
        boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).T
        boxes[[0, 2]] *= w
        boxes[[1, 3]] *= h
        plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0))
        plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
        plt.axis('off')
        if paths is not None:
            s = Path(paths[i]).name
            plt.title(s[:min(len(s), 40)], fontdict={'size': 8})  # limit to 40 characters
    fig.tight_layout()
    fig.savefig(fname, dpi=200)
    plt.close() 
Example 6
Project: pruning_yolov3   Author: zbyuan   File: utils.py    License: GNU General Public License v3.0 6 votes vote down vote up
def plot_test_txt():  # from utils.utils import *; plot_test()
    # Plot test.txt histograms
    x = np.loadtxt('test.txt', dtype=np.float32)
    box = xyxy2xywh(x[:, :4])
    cx, cy = box[:, 0], box[:, 1]

    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
    ax.set_aspect('equal')
    fig.tight_layout()
    plt.savefig('hist2d.jpg', dpi=300)

    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].hist(cx, bins=600)
    ax[1].hist(cy, bins=600)
    fig.tight_layout()
    plt.savefig('hist1d.jpg', dpi=200) 
Example 7
Project: pruning_yolov3   Author: zbyuan   File: utils.py    License: GNU General Public License v3.0 6 votes vote down vote up
def plot_results(start=0, stop=0):  # from utils.utils import *; plot_results()
    # Plot training results files 'results*.txt'
    fig, ax = plt.subplots(2, 5, figsize=(14, 7))
    ax = ax.ravel()
    s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
         'val GIoU', 'val Objectness', 'val Classification', 'mAP', 'F1']
    for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
        results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
        n = results.shape[1]  # number of rows
        x = range(start, min(stop, n) if stop else n)
        for i in range(10):
            y = results[i, x]
            if i in [0, 1, 2, 5, 6, 7]:
                y[y == 0] = np.nan  # dont show zero loss values
            ax[i].plot(x, y, marker='.', label=f.replace('.txt', ''))
            ax[i].set_title(s[i])
            if i in [5, 6, 7]:  # share train and val loss y axes
                ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])

    fig.tight_layout()
    ax[1].legend()
    fig.savefig('results.png', dpi=200) 
Example 8
Project: pruning_yolov3   Author: zbyuan   File: utils.py    License: GNU General Public License v3.0 6 votes vote down vote up
def plot_results_overlay(start=0, stop=0):  # from utils.utils import *; plot_results_overlay()
    # Plot training results files 'results*.txt', overlaying train and val losses
    s = ['train', 'train', 'train', 'Precision', 'mAP', 'val', 'val', 'val', 'Recall', 'F1']  # legends
    t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1']  # titles
    for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
        results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
        n = results.shape[1]  # number of rows
        x = range(start, min(stop, n) if stop else n)
        fig, ax = plt.subplots(1, 5, figsize=(14, 3.5))
        ax = ax.ravel()
        for i in range(5):
            for j in [i, i + 5]:
                y = results[j, x]
                if i in [0, 1, 2]:
                    y[y == 0] = np.nan  # dont show zero loss values
                ax[i].plot(x, y, marker='.', label=s[j])
            ax[i].set_title(t[i])
            ax[i].legend()
            ax[i].set_ylabel(f) if i == 0 else None  # add filename
        fig.tight_layout()
        fig.savefig(f.replace('.txt', '.png'), dpi=200) 
Example 9
Project: pywr   Author: pywr   File: thames.py    License: GNU General Public License v3.0 6 votes vote down vote up
def figures(ext, show):

    for name, df in TablesRecorder.generate_dataframes('thames_output.h5'):
        df.columns = ['Very low', 'Low', 'Central', 'High', 'Very high']

        fig, (ax1, ax2) = plt.subplots(figsize=(12, 4), ncols=2, sharey='row',
                                       gridspec_kw={'width_ratios': [3, 1]})
        df['2100':'2125'].plot(ax=ax1)
        df.quantile(np.linspace(0, 1)).plot(ax=ax2)

        if name.startswith('reservoir'):
            ax1.set_ylabel('Volume [$Mm^3$]')
        else:
            ax1.set_ylabel('Flow [$Mm^3/day$]')

        for ax in (ax1, ax2):
            ax.set_title(name)
            ax.grid(True)
        plt.tight_layout()

        if ext is not None:
            fig.savefig(f'{name}.{ext}', dpi=300)

    if show:
        plt.show() 
Example 10
Project: kss   Author: Kyubyong   File: utils.py    License: Apache License 2.0 6 votes vote down vote up
def plot_alignment(alignment, gs, dir=hp.logdir):
    """Plots the alignment.

    Args:
      alignment: A numpy array with shape of (encoder_steps, decoder_steps)
      gs: (int) global step.
      dir: Output path.
    """
    if not os.path.exists(dir): os.mkdir(dir)

    fig, ax = plt.subplots()
    im = ax.imshow(alignment)

    fig.colorbar(im)
    plt.title('{} Steps'.format(gs))
    plt.savefig('{}/alignment_{}.png'.format(dir, gs), format='png') 
Example 11
Project: TaskBot   Author: EvilPsyCHo   File: plot.py    License: GNU General Public License v3.0 6 votes vote down vote up
def plot_attention(sentences, attentions, labels, **kwargs):
    fig, ax = plt.subplots(**kwargs)
    im = ax.imshow(attentions, interpolation='nearest',
                   vmin=attentions.min(), vmax=attentions.max())
    plt.colorbar(im, shrink=0.5, ticks=[0, 1])
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontproperties=getChineseFont())
    # Loop over data dimensions and create text annotations.
    for i in range(attentions.shape[0]):
        for j in range(attentions.shape[1]):
            text = ax.text(j, i, sentences[i][j],
                           ha="center", va="center", color="b", size=10,
                           fontproperties=getChineseFont())

    ax.set_title("Attention Visual")
    fig.tight_layout()
    plt.show() 
Example 12
Project: fin   Author: vsmolyakov   File: olmar.py    License: MIT License 6 votes vote down vote up
def analyze(context=None, results=None):
        
    f, (ax1, ax2, ax3) = plt.subplots(3, sharex = True)        
    ax1.plot(results.portfolio_value, linewidth = 2.0, label = 'porfolio')
    ax1.set_title('On-Line Moving Average Reversion')
    ax1.set_ylabel('Portfolio value (USD)')
    ax1.legend(loc=0)
    ax1.grid(True)
            
    ax2.plot(results['AAPL'], color = 'b', linestyle = '-', linewidth = 2.0, label = 'AAPL')
    ax2.plot(results['MSFT'], color = 'r', linestyle = '-', linewidth = 2.0, label = 'MSFT')
    ax2.set_ylabel('stock price (USD)')
    ax2.legend(loc=0)
    ax2.grid(True)
    
    ax3.semilogy(results['step_size'], color = 'b', linestyle = '-', linewidth = 2.0, label = 'step-size')
    ax3.semilogy(results['variability'], color = 'r', linestyle = '-', linewidth = 2.0, label = 'variability')
    ax3.legend(loc=0)
    ax3.grid(True)
    
    plt.show() 
Example 13
Project: fin   Author: vsmolyakov   File: momentum.py    License: MIT License 6 votes vote down vote up
def analyze(context=None, results=None, benchmark=None):
    
    hist_size = 300
        
    f, (ax1, ax2) = plt.subplots(2, sharex = True)        
    ax1.plot(results.portfolio_value[hist_size:], linewidth = 2.0, label = 'porfolio')
    ax1.set_title('Dual Moving Average Strategy')
    ax1.set_ylabel('Portfolio value (USD)')
    ax1.legend(loc=0)
    ax1.grid(True)
    
    ax2.plot(results['AAPL'][hist_size:], linewidth = 2.0, label = 'AAPL')
    ax2.plot(results['short_mavg'][hist_size:], color = 'r', linestyle = '-', linewidth = 2.0, label = 'short mavg')
    ax2.plot(results['long_mavg'][hist_size:], color = 'g', linestyle = '-', linewidth = 2.0, label = 'long mavg')
    ax2.set_ylabel('AAPL price (USD)')
    ax2.legend(loc=0)
    ax2.grid(True)

    plt.show() 
Example 14
Project: pymoo   Author: msu-coinlab   File: traveling_salesman.py    License: Apache License 2.0 6 votes vote down vote up
def visualize(problem, x, fig=None, ax=None, show=True, label=True):
    with plt.style.context('ggplot'):

        if fig is None or ax is None:
            fig, ax = plt.subplots()

        # plot cities using scatter plot
        ax.scatter(problem.cities[:, 0], problem.cities[:, 1], s=250)
        if label:
            # annotate cities
            for i, c in enumerate(problem.cities):
                ax.annotate(str(i), xy=c, fontsize=10, ha="center", va="center", color="white")

        # plot the line on the path
        for i in range(len(x)):
            current = x[i]
            next_ = x[(i + 1) % len(x)]
            ax.plot(problem.cities[[current, next_], 0], problem.cities[[current, next_], 1], 'r--')

        fig.suptitle("Route length: %.4f" % problem.get_route_length(x))

        if show:
            fig.show() 
Example 15
Project: pyscf   Author: pyscf   File: mf.py    License: Apache License 2.0 6 votes vote down vote up
def plot_contour(self, w=0.0):
    """
      Plot contour with poles of Green's function in the self-energy 
      SelfEnergy(w) = G(w+w')W(w')
      with respect to w' = Re(w')+Im(w')
      Poles of G(w+w') are located: w+w'-(E_n-Fermi)+i*eps sign(E_n-Fermi)==0 ==> 
      w'= (E_n-Fermi) - w -i eps sign(E_n-Fermi)
    """
    try :
      import matplotlib.pyplot as plt
      from matplotlib.patches import Arc, Arrow 
    except:
      print('no matplotlib?')
      return

    fig,ax = plt.subplots()
    fe = self.fermi_energy
    ee = self.mo_energy
    iee = 0.5-np.array(ee>fe)
    eew = ee-fe-w
    ax.plot(eew, iee, 'r.', ms=10.0)
    pp = list()
    pp.append(Arc((0,0),4,4,angle=0, linewidth=2, theta1=0, theta2=90, zorder=2, color='b'))
    pp.append(Arc((0,0),4,4,angle=0, linewidth=2, theta1=180, theta2=270, zorder=2, color='b'))
    pp.append(Arrow(0,2,0,-4,width=0.2, color='b', hatch='o'))
    pp.append(Arrow(-2,0,4,0,width=0.2, color='b', hatch='o'))
    for p in pp: ax.add_patch(p)
    ax.set_aspect('equal')
    ax.grid(True, which='both')
    ax.axhline(y=0, color='k')
    ax.axvline(x=0, color='k')
    plt.ylim(-3.0,3.0)
    plt.show() 
Example 16
Project: neat-python   Author: CodeReclaimers   File: visualize.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_species(statistics, view=False, filename='speciation.svg'):
    """ Visualizes speciation throughout evolution. """
    if plt is None:
        warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
        return

    species_sizes = statistics.get_species_sizes()
    num_generations = len(species_sizes)
    curves = np.array(species_sizes).T

    fig, ax = plt.subplots()
    ax.stackplot(range(num_generations), *curves)

    plt.title("Speciation")
    plt.ylabel("Size per Species")
    plt.xlabel("Generations")

    plt.savefig(filename)

    if view:
        plt.show()

    plt.close() 
Example 17
Project: neat-python   Author: CodeReclaimers   File: visualize.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_species(statistics, view=False, filename='speciation.svg'):
    """ Visualizes speciation throughout evolution. """
    if plt is None:
        warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
        return

    species_sizes = statistics.get_species_sizes()
    num_generations = len(species_sizes)
    curves = np.array(species_sizes).T

    fig, ax = plt.subplots()
    ax.stackplot(range(num_generations), *curves)

    plt.title("Speciation")
    plt.ylabel("Size per Species")
    plt.xlabel("Generations")

    plt.savefig(filename)

    if view:
        plt.show()

    plt.close() 
Example 18
Project: neat-python   Author: CodeReclaimers   File: visualize.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_species(statistics, view=False, filename='speciation.svg'):
    """ Visualizes speciation throughout evolution. """
    if plt is None:
        warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
        return

    species_sizes = statistics.get_species_sizes()
    num_generations = len(species_sizes)
    curves = np.array(species_sizes).T

    fig, ax = plt.subplots()
    ax.stackplot(range(num_generations), *curves)

    plt.title("Speciation")
    plt.ylabel("Size per Species")
    plt.xlabel("Generations")

    plt.savefig(filename)

    if view:
        plt.show()

    plt.close() 
Example 19
Project: neat-python   Author: CodeReclaimers   File: visualize.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_species(statistics, view=False, filename='speciation.svg'):
    """ Visualizes speciation throughout evolution. """
    if plt is None:
        warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
        return

    species_sizes = statistics.get_species_sizes()
    num_generations = len(species_sizes)
    curves = np.array(species_sizes).T

    fig, ax = plt.subplots()
    ax.stackplot(range(num_generations), *curves)

    plt.title("Speciation")
    plt.ylabel("Size per Species")
    plt.xlabel("Generations")

    plt.savefig(filename)

    if view:
        plt.show()

    plt.close() 
Example 20
Project: neat-python   Author: CodeReclaimers   File: visualize.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_species(statistics, view=False, filename='speciation.svg'):
    """ Visualizes speciation throughout evolution. """
    if plt is None:
        warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
        return

    species_sizes = statistics.get_species_sizes()
    num_generations = len(species_sizes)
    curves = np.array(species_sizes).T

    fig, ax = plt.subplots()
    ax.stackplot(range(num_generations), *curves)

    plt.title("Speciation")
    plt.ylabel("Size per Species")
    plt.xlabel("Generations")

    plt.savefig(filename)

    if view:
        plt.show()

    plt.close() 
Example 21
Project: NeuroKit   Author: neuropsychology   File: complexity_dimension.py    License: MIT License 6 votes vote down vote up
def _embedding_dimension_plot(
    method, dimension_seq, min_dimension, E1=None, E2=None, f1=None, f2=None, f3=None, ax=None
):

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = None
    ax.set_title("Optimization of Dimension (d)")
    ax.set_xlabel("Embedding dimension $d$")
    ax.set_ylabel("$E_1(d)$ and $E_2(d)$")
    if method in ["afnn"]:
        ax.plot(dimension_seq[:-1], E1, "bo-", label="$E_1(d)$", color="#FF5722")
        ax.plot(dimension_seq[:-1], E2, "go-", label="$E_2(d)$", color="#f44336")

    if method in ["fnn"]:
        ax.plot(dimension_seq, 100 * f1, "bo--", label="Test I", color="#FF5722")
        ax.plot(dimension_seq, 100 * f2, "g^--", label="Test II", color="#f44336")
        ax.plot(dimension_seq, 100 * f3, "rs-", label="Test I + II", color="#852b01")

    ax.axvline(x=min_dimension, color="#E91E63", label="Optimal dimension: " + str(min_dimension))
    ax.legend(loc="upper right")

    return fig 
Example 22
Project: NeuroKit   Author: neuropsychology   File: tests_signal.py    License: MIT License 6 votes vote down vote up
def test_signal_filter():

    signal = np.cos(np.linspace(start=0, stop=10, num=1000))  # Low freq
    signal += np.cos(np.linspace(start=0, stop=100, num=1000))  # High freq
    filtered = nk.signal_filter(signal, highcut=10)
    assert np.std(signal) > np.std(filtered)

    # Generate 10 seconds of signal with 2 Hz oscillation and added 50Hz powerline-noise.
    sampling_rate = 250
    samples = np.arange(10 * sampling_rate)

    signal = np.sin(2 * np.pi * 2 * (samples / sampling_rate))
    powerline = np.sin(2 * np.pi * 50 * (samples / sampling_rate))

    signal_corrupted = signal + powerline
    signal_clean = nk.signal_filter(signal_corrupted, sampling_rate=sampling_rate, method="powerline")

    # import matplotlib.pyplot as plt
    # figure, (ax0, ax1, ax2) = plt.subplots(nrows=3, ncols=1, sharex=True)
    # ax0.plot(signal_corrupted)
    # ax1.plot(signal)
    # ax2.plot(signal_clean * 100)

    assert np.allclose(sum(signal_clean - signal), -2, atol=0.2) 
Example 23
Project: Keras-GAN   Author: eriklindernoren   File: sgan.py    License: MIT License 6 votes vote down vote up
def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close() 
Example 24
Project: Keras-GAN   Author: eriklindernoren   File: context_encoder.py    License: MIT License 6 votes vote down vote up
def sample_images(self, epoch, imgs):
        r, c = 3, 6

        masked_imgs, missing_parts, (y1, y2, x1, x2) = self.mask_randomly(imgs)
        gen_missing = self.generator.predict(masked_imgs)

        imgs = 0.5 * imgs + 0.5
        masked_imgs = 0.5 * masked_imgs + 0.5
        gen_missing = 0.5 * gen_missing + 0.5

        fig, axs = plt.subplots(r, c)
        for i in range(c):
            axs[0,i].imshow(imgs[i, :,:])
            axs[0,i].axis('off')
            axs[1,i].imshow(masked_imgs[i, :,:])
            axs[1,i].axis('off')
            filled_in = imgs[i].copy()
            filled_in[y1[i]:y2[i], x1[i]:x2[i], :] = gen_missing[i]
            axs[2,i].imshow(filled_in)
            axs[2,i].axis('off')
        fig.savefig("images/%d.png" % epoch)
        plt.close() 
Example 25
Project: Keras-GAN   Author: eriklindernoren   File: ccgan.py    License: MIT License 6 votes vote down vote up
def sample_images(self, epoch, imgs):
        r, c = 3, 6

        masked_imgs = self.mask_randomly(imgs)
        gen_imgs = self.generator.predict(masked_imgs)

        imgs = (imgs + 1.0) * 0.5
        masked_imgs = (masked_imgs + 1.0) * 0.5
        gen_imgs = (gen_imgs + 1.0) * 0.5

        gen_imgs = np.where(gen_imgs < 0, 0, gen_imgs)

        fig, axs = plt.subplots(r, c)
        for i in range(c):
            axs[0,i].imshow(imgs[i, :, :, 0], cmap='gray')
            axs[0,i].axis('off')
            axs[1,i].imshow(masked_imgs[i, :, :, 0], cmap='gray')
            axs[1,i].axis('off')
            axs[2,i].imshow(gen_imgs[i, :, :, 0], cmap='gray')
            axs[2,i].axis('off')
        fig.savefig("images/%d.png" % epoch)
        plt.close() 
Example 26
Project: Keras-GAN   Author: eriklindernoren   File: bigan.py    License: MIT License 6 votes vote down vote up
def sample_interval(self, epoch):
        r, c = 5, 5
        z = np.random.normal(size=(25, self.latent_dim))
        gen_imgs = self.generator.predict(z)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close() 
Example 27
Project: Keras-GAN   Author: eriklindernoren   File: cgan.py    License: MIT License 6 votes vote down vote up
def sample_images(self, epoch):
        r, c = 2, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.arange(0, 10).reshape(-1, 1)

        gen_imgs = self.generator.predict([noise, sampled_labels])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close() 
Example 28
Project: Keras-GAN   Author: eriklindernoren   File: pixelda.py    License: MIT License 6 votes vote down vote up
def sample_images(self, epoch):
        r, c = 2, 5

        imgs_A, _ = self.data_loader.load_data(domain="A", batch_size=5)

        # Translate images to the other domain
        fake_B = self.generator.predict(imgs_A)

        gen_imgs = np.concatenate([imgs_A, fake_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        #titles = ['Original', 'Translated']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                #axs[i, j].set_title(titles[i])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % (epoch))
        plt.close() 
Example 29
Project: Keras-GAN   Author: eriklindernoren   File: wgan.py    License: MIT License 6 votes vote down vote up
def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close() 
Example 30
Project: Keras-GAN   Author: eriklindernoren   File: wgan_gp.py    License: MIT License 6 votes vote down vote up
def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()