Python matplotlib.pylab.close() Examples

The following are 30 code examples of matplotlib.pylab.close(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pylab , or try the search function .
Example #1
Source File: utils.py    From hands-detection with MIT License 7 votes vote down vote up
def visualize_voxel_spectral(points, vis_size=128):
  """Function to visualize voxel (spectral)."""
  points = np.rint(points)
  points = np.swapaxes(points, 0, 2)
  fig = p.figure(figsize=(1, 1), dpi=vis_size)
  verts, faces = measure.marching_cubes(points, 0, spacing=(0.1, 0.1, 0.1))
  ax = fig.add_subplot(111, projection='3d')
  ax.plot_trisurf(
      verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap='Spectral_r', lw=0.1)
  ax.set_axis_off()
  fig.tight_layout(pad=0)
  fig.canvas.draw()
  data = np.fromstring(
      fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
          vis_size, vis_size, 3)
  p.close('all')
  return data 
Example #2
Source File: utils.py    From models with Apache License 2.0 6 votes vote down vote up
def visualize_voxel_spectral(points, vis_size=128):
  """Function to visualize voxel (spectral)."""
  points = np.rint(points)
  points = np.swapaxes(points, 0, 2)
  fig = p.figure(figsize=(1, 1), dpi=vis_size)
  verts, faces = measure.marching_cubes_classic(points, 0, spacing=(0.1, 0.1, 0.1))
  ax = fig.add_subplot(111, projection='3d')
  ax.plot_trisurf(
      verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap='Spectral_r', lw=0.1)
  ax.set_axis_off()
  fig.tight_layout(pad=0)
  fig.canvas.draw()
  data = np.fromstring(
      fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
          vis_size, vis_size, 3)
  p.close('all')
  return data 
Example #3
Source File: utils.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def visualize_voxel_spectral(points, vis_size=128):
  """Function to visualize voxel (spectral)."""
  points = np.rint(points)
  points = np.swapaxes(points, 0, 2)
  fig = p.figure(figsize=(1, 1), dpi=vis_size)
  verts, faces = measure.marching_cubes_classic(points, 0, spacing=(0.1, 0.1, 0.1))
  ax = fig.add_subplot(111, projection='3d')
  ax.plot_trisurf(
      verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap='Spectral_r', lw=0.1)
  ax.set_axis_off()
  fig.tight_layout(pad=0)
  fig.canvas.draw()
  data = np.fromstring(
      fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
          vis_size, vis_size, 3)
  p.close('all')
  return data 
Example #4
Source File: utils.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def visualize_voxel_spectral(points, vis_size=128):
  """Function to visualize voxel (spectral)."""
  points = np.rint(points)
  points = np.swapaxes(points, 0, 2)
  fig = p.figure(figsize=(1, 1), dpi=vis_size)
  verts, faces = measure.marching_cubes_classic(points, 0, spacing=(0.1, 0.1, 0.1))
  ax = fig.add_subplot(111, projection='3d')
  ax.plot_trisurf(
      verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap='Spectral_r', lw=0.1)
  ax.set_axis_off()
  fig.tight_layout(pad=0)
  fig.canvas.draw()
  data = np.fromstring(
      fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
          vis_size, vis_size, 3)
  p.close('all')
  return data 
Example #5
Source File: utils.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def visualize_voxel_spectral(points, vis_size=128):
  """Function to visualize voxel (spectral)."""
  points = np.rint(points)
  points = np.swapaxes(points, 0, 2)
  fig = p.figure(figsize=(1, 1), dpi=vis_size)
  verts, faces = measure.marching_cubes_classic(points, 0, spacing=(0.1, 0.1, 0.1))
  ax = fig.add_subplot(111, projection='3d')
  ax.plot_trisurf(
      verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap='Spectral_r', lw=0.1)
  ax.set_axis_off()
  fig.tight_layout(pad=0)
  fig.canvas.draw()
  data = np.fromstring(
      fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
          vis_size, vis_size, 3)
  p.close('all')
  return data 
Example #6
Source File: demo_ui.py    From spriteworld with Apache License 2.0 6 votes vote down vote up
def _setup_callbacks(self):
    """Default callbacks for the UI."""

    # Pressing escape should stop the UI
    def _onkeypress(event):
      if event.key == 'escape':
        # Stop UI
        logging.info('Pressed escape, stopping UI.')
        plt.close(self._fig)
        sys.exit()

    self._fig.canvas.mpl_connect('key_release_event', _onkeypress)

    # Disable default keyboard shortcuts
    for key in ('keymap.fullscreen', 'keymap.home', 'keymap.back',
                'keymap.forward', 'keymap.pan', 'keymap.zoom', 'keymap.save',
                'keymap.quit', 'keymap.grid', 'keymap.yscale', 'keymap.xscale',
                'keymap.all_axes'):
      plt.rcParams[key] = ''

    # Disable logging of some matplotlib events
    log.getLogger('matplotlib').setLevel('WARNING') 
Example #7
Source File: drawing.py    From BIRL with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def export_figure(path_fig, fig):
    """ export the figure and close it afterwords

    :param str path_fig: path to the new figure image
    :param fig: object

    >>> path_fig = './sample_figure.jpg'
    >>> export_figure(path_fig, plt.figure())
    >>> os.remove(path_fig)
    """
    assert os.path.isdir(os.path.dirname(path_fig)), \
        'missing folder "%s"' % os.path.dirname(path_fig)
    fig.subplots_adjust(left=0., right=1., top=1., bottom=0.)
    logging.debug('exporting Figure: %s', path_fig)
    fig.savefig(path_fig)
    plt.close(fig) 
Example #8
Source File: prod_basis.py    From pyscf with Apache License 2.0 6 votes vote down vote up
def generate_png_chess_dp_vertex(self):
    """Produces pictures of the dominant product vertex a chessboard convention"""
    import matplotlib.pylab as plt
    plt.ioff()
    dab2v = self.get_dp_vertex_doubly_sparse()
    for i, ab in enumerate(dab2v): 
        fname = "chess-v-{:06d}.png".format(i)
        print('Matrix No.#{}, Size: {}, Type: {}'.format(i+1, ab.shape, type(ab)), fname)
        if type(ab) != 'numpy.ndarray': ab = ab.toarray()
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        ax.set_aspect('equal')
        plt.imshow(ab, interpolation='nearest', cmap=plt.cm.ocean)
        plt.colorbar()
        plt.savefig(fname)
        plt.close(fig) 
Example #9
Source File: plotting_utils.py    From fac-via-ppg with Apache License 2.0 6 votes vote down vote up
def plot_alignment_to_numpy(alignment, info=None):
    fig, ax = plt.subplots(figsize=(6, 4))
    im = ax.imshow(alignment, aspect='auto', origin='lower',
                   interpolation='none')
    fig.colorbar(im, ax=ax)
    xlabel = 'Decoder timestep'
    if info is not None:
        xlabel += '\n\n' + info
    plt.xlabel(xlabel)
    plt.ylabel('Encoder timestep')
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data 
Example #10
Source File: utils.py    From object_detection_with_tensorflow with MIT License 6 votes vote down vote up
def visualize_voxel_spectral(points, vis_size=128):
  """Function to visualize voxel (spectral)."""
  points = np.rint(points)
  points = np.swapaxes(points, 0, 2)
  fig = p.figure(figsize=(1, 1), dpi=vis_size)
  verts, faces = measure.marching_cubes_classic(points, 0, spacing=(0.1, 0.1, 0.1))
  ax = fig.add_subplot(111, projection='3d')
  ax.plot_trisurf(
      verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap='Spectral_r', lw=0.1)
  ax.set_axis_off()
  fig.tight_layout(pad=0)
  fig.canvas.draw()
  data = np.fromstring(
      fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
          vis_size, vis_size, 3)
  p.close('all')
  return data 
Example #11
Source File: camera_test.py    From camera.py with MIT License 6 votes vote down vote up
def calibrate_division_model_test():
    img = rgb2gray(plt.imread('test/kamera2.png'))
    y0 = np.array(img.shape)[::-1][np.newaxis].T / 2.
    z_n = np.linalg.norm(np.array(img.shape) / 2.)
    points = pilab_annotate_load('test/kamera2_lines.xml')
    points_per_line = 5
    num_lines = points.shape[0] / points_per_line
    lines_coords = np.array([points[i * points_per_line:i * points_per_line + points_per_line] for i in xrange(num_lines)])
    c = camera.calibrate_division_model(lines_coords, y0, z_n)

    import matplotlib.cm as cm
    plt.figure()
    plt.imshow(img, cmap=cm.gray)
    for line in xrange(num_lines):
        x = lines_coords[line, :, 0]
        plt.plot(x, lines_coords[line, :, 1], 'g')
        mc = camera.fit_line(lines_coords[line].T)
        plt.plot(x, mc[0] * x + mc[1], 'y')
        xy = c.undistort(lines_coords[line].T)
        plt.plot(xy[0, :], xy[1, :], 'r')
    plt.show()
    plt.close() 
Example #12
Source File: plotting_utils.py    From nonparaSeq2seqVC_code with MIT License 6 votes vote down vote up
def plot_alignment_to_numpy(alignment, info=None):
    fig, ax = plt.subplots(figsize=(6, 4))
    im = ax.imshow(alignment, aspect='auto', origin='lower',
                   interpolation='none')
    fig.colorbar(im, ax=ax)
    xlabel = 'Decoder timestep'
    if info is not None:
        xlabel += '\n\n' + info
    plt.xlabel(xlabel)
    plt.ylabel('Encoder timestep')
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data 
Example #13
Source File: resnet_wgan_gp_cifar10_train.py    From Hands-On-Generative-Adversarial-Networks-with-Keras with MIT License 6 votes vote down vote up
def plot_losses(losses_d, losses_g, filename):
    losses_d = np.array(losses_d)
    fig, axes = plt.subplots(3, 2, figsize=(8, 8))
    axes = axes.flatten()
    axes[0].plot(losses_d[:, 0])
    axes[1].plot(losses_d[:, 1])
    axes[2].plot(losses_d[:, 2])
    axes[3].plot(losses_d[:, 3])
    axes[4].plot(losses_g)
    axes[0].set_title("losses_d")
    axes[1].set_title("losses_d_real")
    axes[2].set_title("losses_d_fake")
    axes[3].set_title("losses_d_gp")
    axes[4].set_title("losses_g")
    plt.tight_layout()
    plt.savefig(filename)
    plt.close() 
Example #14
Source File: metrics.py    From tacotron2 with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_alignment(alignments, text, _id, global_step, path):
    num_alignment = len(alignments)
    fig = plt.figure(figsize=(12, 16))
    for i, alignment in enumerate(alignments):
        ax = fig.add_subplot(num_alignment, 1, i + 1)
        im = ax.imshow(
            alignment,
            aspect='auto',
            origin='lower',
            interpolation='none')
        fig.colorbar(im, ax=ax)
        xlabel = 'Decoder timestep'
        ax.set_xlabel(xlabel)
        ax.set_ylabel('Encoder timestep')
        ax.set_title("layer {}".format(i + 1))
    fig.subplots_adjust(wspace=0.4, hspace=0.6)
    fig.suptitle(f"record ID: {_id}\nglobal step: {global_step}\ninput text: {str(text)}")
    fig.savefig(path, format='png')
    plt.close() 
Example #15
Source File: plot.py    From Tacotron2-PyTorch with MIT License 6 votes vote down vote up
def plot_alignment_to_numpy(alignment, info=None):
    fig, ax = plt.subplots(figsize=(6, 4))
    im = ax.imshow(alignment, aspect='auto', origin='lower',
                   interpolation='none')
    fig.colorbar(im, ax=ax)
    xlabel = 'Decoder timestep'
    if info is not None:
        xlabel += '\n\n' + info
    plt.xlabel(xlabel)
    plt.ylabel('Encoder timestep')
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data 
Example #16
Source File: utils.py    From object_detection_kitti with Apache License 2.0 6 votes vote down vote up
def visualize_voxel_spectral(points, vis_size=128):
  """Function to visualize voxel (spectral)."""
  points = np.rint(points)
  points = np.swapaxes(points, 0, 2)
  fig = p.figure(figsize=(1, 1), dpi=vis_size)
  verts, faces = measure.marching_cubes_classic(points, 0, spacing=(0.1, 0.1, 0.1))
  ax = fig.add_subplot(111, projection='3d')
  ax.plot_trisurf(
      verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap='Spectral_r', lw=0.1)
  ax.set_axis_off()
  fig.tight_layout(pad=0)
  fig.canvas.draw()
  data = np.fromstring(
      fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
          vis_size, vis_size, 3)
  p.close('all')
  return data 
Example #17
Source File: variable_describe.py    From convis with GNU General Public License v3.0 6 votes vote down vote up
def _plot_to_string():
    try:
        from StringIO import StringIO
        make_bytes = lambda x: x.buf
    except ImportError:
        from io import BytesIO as StringIO
        make_bytes = lambda x: x.getbuffer()
    try:
        from urllib import quote
    except:
        from urllib.parse import quote
    import base64
    import matplotlib.pylab as plt
    imgdata = StringIO()
    plt.savefig(imgdata)
    plt.close()
    imgdata.seek(0) 
    image = base64.encodestring(make_bytes(imgdata))
    return str(quote(image)) 
Example #18
Source File: data_augmentation.py    From ConvNetQuake with MIT License 6 votes vote down vote up
def plot_true_and_augmented_data(sample,noised_sample,label,n_examples):
    output_dir = os.path.split(FLAGS.output)[0]
    # Save augmented data
    plt.clf()
    fig, ax = plt.subplots(3,1)
    for t in range(noised_sample.shape[1]):
        ax[t].plot(noised_sample[:,t])
        ax[t].set_xlabel('time (samples)')
        ax[t].set_ylabel('amplitude')
    ax[0].set_title('window {:03d}, cluster_id: {}'.format(n_examples,label))
    plt.savefig(os.path.join(output_dir, "augmented_data",
                            'augmented_{:03d}.pdf'.format(n_examples)))
    plt.close()

    # Save true data
    plt.clf()
    fig, ax = plt.subplots(3,1)
    for t in range(sample.shape[1]):
        ax[t].plot(sample[:,t])
        ax[t].set_xlabel('time (samples)')
        ax[t].set_ylabel('amplitude')
    ax[0].set_title('window {:03d}, cluster_id: {}'.format(n_examples,label))
    plt.savefig(os.path.join(output_dir, "true_data",
                            'true__{:03d}.pdf'.format(n_examples)))
    plt.close() 
Example #19
Source File: PlotComps.py    From refinery with MIT License 6 votes vote down vote up
def plotModelInNewFigure(jobpath, hmodel, args):
  figHandle = pylab.figure()
  if args.doPlotData:
    Data = loadData(jobpath)
    plotData(Data)

  if hmodel.getObsModelName().count('ZMGauss') and hmodel.obsModel.D > 2:
    bnpy.viz.GaussViz.plotCovMatFromHModel(hmodel)
  elif hmodel.getObsModelName().count('Gauss'):
    bnpy.viz.GaussViz.plotGauss2DFromHModel(hmodel)
  elif args.dataName.lower().count('bars') > 0:
    pylab.close(figHandle)
    if args.doPlotTruth:
      Data = loadData(jobpath)
    else:
      Data = None
    bnpy.viz.BarsViz.plotBarsFromHModel(hmodel, Data=Data, 
                                        sortBySize=args.doSort, doShowNow=False)
  else:
    raise NotImplementedError('Unrecognized data/obsmodel combo') 
Example #20
Source File: diagnostics.py    From photometrypipeline with GNU General Public License v3.0 6 votes vote down vote up
def append_website(self, filename, content,
                       replace_from='X?!do not replace anything!?X',
                       keep_at='</BODY>',):
        """append content to an existing website: replace content starting
        at line containing `replace_from` until line containin `keep_at`;
        by default, all content following `replace_from` is
        replaced
        """
        # read existing code
        existing_html = open(filename, 'r').readlines()

        # insert content into existing html
        outf = open(filename, 'w')
        delete = False
        for line in existing_html:
            if replace_from in line:
                delete = True
                continue
            if keep_at in line:
                outf.writelines(content)
                delete = False
            if delete:
                continue
            outf.writelines(line)
        outf.close() 
Example #21
Source File: helpers.py    From NeMo with Apache License 2.0 6 votes vote down vote up
def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
    fig, ax = plt.subplots(figsize=(12, 3))
    ax.scatter(
        range(len(gate_targets)), gate_targets, alpha=0.5, color='green', marker='+', s=1, label='target',
    )
    ax.scatter(
        range(len(gate_outputs)), gate_outputs, alpha=0.5, color='red', marker='.', s=1, label='predicted',
    )

    plt.xlabel("Frames (Green target, Red predicted)")
    plt.ylabel("Gate State")
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data 
Example #22
Source File: helpers.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def plot_spectrogram_to_numpy(spectrogram):
    fig, ax = plt.subplots(figsize=(12, 3))
    im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation='none')
    plt.colorbar(im, ax=ax)
    plt.xlabel("Frames")
    plt.ylabel("Channels")
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data 
Example #23
Source File: utils.py    From models with Apache License 2.0 5 votes vote down vote up
def visualize_voxel_scatter(points, vis_size=128):
  """Function to visualize voxel (scatter)."""
  points = np.rint(points)
  points = np.swapaxes(points, 0, 2)
  fig = p.figure(figsize=(1, 1), dpi=vis_size)
  ax = fig.add_subplot(111, projection='3d')
  x = []
  y = []
  z = []
  (x_dimension, y_dimension, z_dimension) = points.shape
  for i in range(x_dimension):
    for j in range(y_dimension):
      for k in range(z_dimension):
        if points[i, j, k]:
          x.append(i)
          y.append(j)
          z.append(k)
  ax.scatter3D(x, y, z)
  ax.set_axis_off()
  fig.tight_layout(pad=0)
  fig.canvas.draw()
  data = np.fromstring(
      fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
          vis_size, vis_size, 3)
  p.close('all')
  return data 
Example #24
Source File: prod_basis.py    From pyscf with Apache License 2.0 5 votes vote down vote up
def generate_png_spy_dp_vertex(self):
    """Produces pictures of the dominant product vertex in a common black-and-white way"""
    import matplotlib.pyplot as plt
    plt.ioff()
    dab2v = self.get_dp_vertex_doubly_sparse()
    for i,ab2v in enumerate(dab2v): 
      plt.spy(ab2v.toarray())
      fname = "spy-v-{:06d}.png".format(i)
      print(fname)
      plt.savefig(fname, bbox_inches='tight')
      plt.close()
    return 0 
Example #25
Source File: utils.py    From Hands-On-Generative-Adversarial-Networks-with-Keras with MIT License 5 votes vote down vote up
def plot_losses(losses_d, losses_g, filename):
    losses_d = np.array(losses_d)
    fig, axes = plt.subplots(2, 2, figsize=(8, 8))
    axes = axes.flatten()
    axes[0].plot(losses_d[:, 0])
    axes[1].plot(losses_d[:, 1])
    axes[2].plot(losses_d[:, 2])
    axes[3].plot(losses_g)
    axes[0].set_title("losses_d")
    axes[1].set_title("losses_d_real")
    axes[2].set_title("losses_d_fake")
    axes[3].set_title("losses_g")
    plt.tight_layout()
    plt.savefig(filename)
    plt.close() 
Example #26
Source File: utils.py    From object_detection_kitti with Apache License 2.0 5 votes vote down vote up
def visualize_voxel_scatter(points, vis_size=128):
  """Function to visualize voxel (scatter)."""
  points = np.rint(points)
  points = np.swapaxes(points, 0, 2)
  fig = p.figure(figsize=(1, 1), dpi=vis_size)
  ax = fig.add_subplot(111, projection='3d')
  x = []
  y = []
  z = []
  (x_dimension, y_dimension, z_dimension) = points.shape
  for i in range(x_dimension):
    for j in range(y_dimension):
      for k in range(z_dimension):
        if points[i, j, k]:
          x.append(i)
          y.append(j)
          z.append(k)
  ax.scatter3D(x, y, z)
  ax.set_axis_off()
  fig.tight_layout(pad=0)
  fig.canvas.draw()
  data = np.fromstring(
      fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
          vis_size, vis_size, 3)
  p.close('all')
  return data 
Example #27
Source File: speech_utils.py    From python-dlpy with Apache License 2.0 5 votes vote down vote up
def convert_one_audio_file_to_specgram(local_audio_file, converted_local_png_file):
    '''
    Convert a local audio file into a png format with spectrogram.

    Parameters
    ----------
    local_audio_file : string
        Local location to the audio file to be converted.

    converted_local_png_file : string
        Local location to store the converted audio file

    Returns
    -------
    None

    Raises
    ------
    DLPyError
        If anything goes wrong, it complains and prints the appropriate message.

    '''

    try:
        import soundfile as sf
        import matplotlib.pylab as plt
    except (ModuleNotFoundError, ImportError):
        raise DLPyError('cannot import soundfile')

    data, sampling_rate = sf.read(local_audio_file)

    fig, ax = plt.subplots(1)
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
    ax.axis('off')
    ax.specgram(x=data, Fs=sampling_rate)
    ax.axis('off')
    fig.savefig(converted_local_png_file, dpi=300, frameon='false')
    # this is the key to avoid mem leaking in notebook
    plt.ioff()
    plt.close(fig) 
Example #28
Source File: plot.py    From Tacotron2-PyTorch with MIT License 5 votes vote down vote up
def plot_spectrogram_to_numpy(spectrogram):
    fig, ax = plt.subplots(figsize=(12, 3))
    im = ax.imshow(spectrogram, aspect="auto", origin="lower",
                   interpolation='none')
    plt.colorbar(im, ax=ax)
    plt.xlabel("Frames")
    plt.ylabel("Channels")
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data 
Example #29
Source File: helpers.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def plot_alignment_to_numpy(alignment, info=None):
    fig, ax = plt.subplots(figsize=(6, 4))
    im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none')
    fig.colorbar(im, ax=ax)
    xlabel = 'Decoder timestep'
    if info is not None:
        xlabel += '\n\n' + info
    plt.xlabel(xlabel)
    plt.ylabel('Encoder timestep')
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data 
Example #30
Source File: show_leaned_filters.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def show():
    args = get_args()

    # Load model
    nn.load_parameters(args.model_load_path)
    params = nn.get_parameters()

    # Show heatmap
    for name, param in params.items():
        # SSL only on convolution weights
        if "conv/W" not in name:
            continue
        print(name)
        n, m, k0, k1 = param.d.shape
        w_matrix = param.d.reshape((n, m * k0 * k1))
        # Filter x Channel heatmap

        fig, ax = plt.subplots()
        ax.set_title("{} with shape {} \n Filter x (Channel x Heigh x Width)".format(
            name, (n, m, k0, k1)))
        heatmap = ax.pcolor(w_matrix)
        fig.colorbar(heatmap)

        plt.pause(0.5)
        raw_input("Press Key")
        plt.close()