Python matplotlib.pyplot.close() Examples

The following are code examples for showing how to use matplotlib.pyplot.close(). They are from open source Python projects. You can vote up the examples you like or vote down the ones you don't like.

Example 1
Project: dc_tts   Author: Kyubyong   File: utils.py    Apache License 2.0 6 votes vote down vote up
def plot_alignment(alignment, gs, dir=hp.logdir):
    """Plots the alignment.

    Args:
      alignment: A numpy array with shape of (encoder_steps, decoder_steps)
      gs: (int) global step.
      dir: Output path.
    """
    if not os.path.exists(dir): os.mkdir(dir)

    fig, ax = plt.subplots()
    im = ax.imshow(alignment)

    fig.colorbar(im)
    plt.title('{} Steps'.format(gs))
    plt.savefig('{}/alignment_{}.png'.format(dir, gs), format='png')
    plt.close(fig) 
Example 2
Project: Kaggle-Statoil-Challenge   Author: adodd202   File: utils.py    MIT License 6 votes vote down vote up
def __init__(self, fpath, title=None, resume=False):
        self.file = None
        self.resume = resume
        self.title = '' if title == None else title
        if fpath is not None:
            if resume:
                self.file = open(fpath, 'r')
                name = self.file.readline()
                self.names = name.rstrip().split('\t')
                self.numbers = {}
                for _, name in enumerate(self.names):
                    self.numbers[name] = []

                for numbers in self.file:
                    numbers = numbers.rstrip().split('\t')
                    for i in range(0, len(numbers)):
                        self.numbers[self.names[i]].append(numbers[i])
                self.file.close()
                self.file = open(fpath, 'a')
            else:
                self.file = open(fpath, 'w') 
Example 3
Project: tensorflow-DeepFM   Author: ChenglongChen   File: main.py    MIT License 6 votes vote down vote up
def _plot_fig(train_results, valid_results, model_name):
    colors = ["red", "blue", "green"]
    xs = np.arange(1, train_results.shape[1]+1)
    plt.figure()
    legends = []
    for i in range(train_results.shape[0]):
        plt.plot(xs, train_results[i], color=colors[i], linestyle="solid", marker="o")
        plt.plot(xs, valid_results[i], color=colors[i], linestyle="dashed", marker="o")
        legends.append("train-%d"%(i+1))
        legends.append("valid-%d"%(i+1))
    plt.xlabel("Epoch")
    plt.ylabel("Normalized Gini")
    plt.title("%s"%model_name)
    plt.legend(legends)
    plt.savefig("./fig/%s.png"%model_name)
    plt.close()


# load data 
Example 4
Project: synthetic-data-tutorial   Author: theodi   File: ModelInspector.py    MIT License 6 votes vote down vote up
def mutual_information_heatmap(self, figure_filepath, attributes: List = None):
        if attributes:
            private_df = self.private_df[attributes]
            synthetic_df = self.synthetic_df[attributes]
        else:
            private_df = self.private_df
            synthetic_df = self.synthetic_df

        private_mi = pairwise_attributes_mutual_information(private_df)
        synthetic_mi = pairwise_attributes_mutual_information(synthetic_df)

        fig = plt.figure(figsize=(15, 6), dpi=120)
        fig.suptitle('Pairwise Mutual Information Comparison (Private vs Synthetic)', fontsize=20)
        ax1 = fig.add_subplot(121)
        ax2 = fig.add_subplot(122)
        sns.heatmap(private_mi, ax=ax1, cmap="GnBu")
        sns.heatmap(synthetic_mi, ax=ax2, cmap="GnBu")
        ax1.set_title('Private, max=1', fontsize=15)
        ax2.set_title('Synthetic, max=1', fontsize=15)
        fig.autofmt_xdate()
        fig.tight_layout()
        plt.subplots_adjust(top=0.83)

        plt.savefig(figure_filepath, bbox_inches='tight')
        plt.close() 
Example 5
Project: DOTA_models   Author: ringringyi   File: nav_utils.py    Apache License 2.0 6 votes vote down vote up
def save_d_at_t(outputs, global_step, output_dir, metric_summary, N):
  """Save distance to goal at all time steps.
  
  Args:
    outputs        : [gt_dist_to_goal].
    global_step : number of iterations.
    output_dir     : output directory.
    metric_summary : to append scalars to summary.
    N              : number of outputs to process.

  """
  d_at_t = np.concatenate(map(lambda x: x[0][:,:,0]*1, outputs), axis=0)
  fig, axes = utils.subplot(plt, (1,1), (5,5))
  axes.plot(np.arange(d_at_t.shape[1]), np.mean(d_at_t, axis=0), 'r.')
  axes.set_xlabel('time step')
  axes.set_ylabel('dist to next goal')
  axes.grid('on')
  file_name = os.path.join(output_dir, 'dist_at_t_{:d}.png'.format(global_step))
  with fu.fopen(file_name, 'w') as f:
    fig.savefig(f, bbox_inches='tight', transparent=True, pad_inches=0)
  file_name = os.path.join(output_dir, 'dist_at_t_{:d}.pkl'.format(global_step))
  utils.save_variables(file_name, [d_at_t], ['d_at_t'], overwrite=True)
  plt.close(fig)
  return None 
Example 6
Project: smach_based_introspection_framework   Author: birlrobotics   File: visualize_dataset.py    BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_list_of_df(save_folder, list_of_df, list_of_labels):
    print 'plot_list_of_df with save_folder %s'%save_folder
    if len(list_of_df) == 0:
        return
    dims = list_of_df[0].columns
    for dim in dims:
        fig, ax = plt.subplots(nrows=1, ncols=1)
        lgds = []

        for idx, df in enumerate(list_of_df):
            label = list_of_labels[idx]
            ax.plot(df[dim], label=label)

        ax.set_title('...'+(save_folder+', '+dim)[-70:])

        handles, labels = ax.get_legend_handles_labels()
        lgd = ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5,-0.1))

        fig.savefig(os.path.join(save_folder, dim.replace('.', '>')), format='png', bbox_extra_artists=(lgd,), bbox_inches='tight')

        plt.close(fig)
    pass 
Example 7
Project: smach_based_introspection_framework   Author: birlrobotics   File: redis_based_anomaly_classification.py    BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def plot_resampled_anomaly_df(resampled_anomaly_df):
    import datetime
    realtime_anomaly_plot_dir = os.path.join(realtime_anomaly_plot_folder, str(datetime.datetime.now()))

    if not os.path.isdir(realtime_anomaly_plot_dir):
        os.makedirs(realtime_anomaly_plot_dir)

    for dim in resampled_anomaly_df.columns:
        rospy.loginfo("plotting %s"%dim)
        fig, ax = plt.subplots(nrows=1, ncols=1)
        time_x = resampled_anomaly_df.index-resampled_anomaly_df.index[0]
        ax.plot(
            time_x.tolist(),
            resampled_anomaly_df[dim].tolist(), 
        )
        ax.set_title(dim)
        fig.savefig(os.path.join(realtime_anomaly_plot_dir, (dim+'.png').strip('.')))
        plt.close(fig) 
Example 8
Project: beta3_IRT   Author: yc14600   File: plots.py    MIT License 6 votes vote down vote up
def vis_performance(gather_prec,gather_recal,path,asd='[email protected]',vtype='nfrac'):
    fig = plt.figure()      
    plt.plot(gather_recal.index, gather_recal.mean(axis=1),marker='o')
    plt.plot(gather_prec.index, gather_prec.mean(axis=1),marker='^')

    plt.errorbar(gather_recal.index, gather_recal.mean(axis=1), gather_recal.std(axis=1), linestyle='None')
    plt.errorbar(gather_prec.index, gather_prec.mean(axis=1), gather_prec.std(axis=1), linestyle='None')
    
    if vtype=='nfrac':
        plt.title('Precision and recall under different noise fractions')
        plt.xlabel('Noise fraction (percentile)')
        plt.ylim(-0.05,1.1)
        plt.yticks(np.arange(0,1.2,0.2))
        plt.legend(['Recall','Precision'],loc=0)
        plt.savefig(path+'gathered_dnoise_performance_nfrac_'+asd+'.pdf') 
    elif vtype=='astd':
        plt.title('Precision and recall under different prior SD')
        plt.xlabel('Prior standard deviation of discrimination')
        plt.xlim(0.5,3.25)
        plt.ylim(-0.05,1.1)
        plt.yticks(np.arange(0,1.2,0.2))
        plt.legend(['Recall','Precision'],loc=0)
        plt.savefig(path+'gathered_dnoise_performance_asd_nfrac20.pdf')
    plt.close(fig) 
Example 9
Project: multi-dimensional-topic-model   Author: LaoWang-Lab   File: plot.py    MIT License 6 votes vote down vote up
def plot_fig(json_file, save_fig=False):
    with open(json_file) as f:
        data = json.load(f)
    H, E, M, wot, iter = data['H'], data['E'], data['M'], data['wot'], data['iter']
    fig, axs = plt.subplots(E, H, figsize=(2.5 * E, 3.5 *H), sharex=True)
    fig.suptitle('H:%d E:%d M:%d wot:%d iter:%d' % (H, E, M, wot, iter), fontsize=20, fontweight='bold')
    n_het = np.array(data['topic'])
    y_limits_max = 1.05 * n_het.sum() / (E * H * wot)
    x = np.arange(data['T'])
    for e in range(data['E']):
        for h in range(data['H']):
            # print(len(x), np.shape(n_het), np.shape(axs))
            sns.barplot(x, n_het[h,e,:], palette="Set3", ax=axs[e][h])
            axs[e][h].set_ylabel("counts")
            axs[e][h].set_title("h:%d e:%d" % (h,e))
            axs[e][h].set_ylim([0, y_limits_max])
    plt.tight_layout()
    fig.subplots_adjust(top=0.9)
    if save_fig:
        i = data['iter']
        plt.savefig('%s/iter%03d.png' % (os.path.dirname(json_file), i), dpi=72, format='png')
    else:
        plt.show()
    plt.close() 
Example 10
Project: reportengine   Author: NNPDF   File: figure.py    GNU General Public License v2.0 6 votes vote down vote up
def savefig(fig, *, paths, output ,suffix=''):
    """Final action to save figures, with a nice filename"""
    #Import here to avoid problems with use()
    import matplotlib.pyplot as plt

    outpaths = []
    for path in paths:
        if suffix:
            suffix = normalize_name(suffix)
            path = path.with_name('_'.join((path.stem, suffix)) + path.suffix)
        log.debug("Writing figure file %s" % path)

        #Numpy can produce a lot of warnings while working on producing figures
        with np.errstate(invalid='ignore'):
            fig.savefig(str(path), bbox_inches='tight')
        outpaths.append(path.relative_to(output))
    plt.close(fig)
    return Figure(outpaths) 
Example 11
Project: trunklucator   Author: Dumbris   File: plot_perf.py    Apache License 2.0 6 votes vote down vote up
def plot_performance(performance_history):
    fig, ax = plt.subplots(figsize=(8.5, 6), dpi=130)

    ax.plot(performance_history)
    ax.scatter(range(len(performance_history)), performance_history, s=13)

    ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=5, integer=True))
    ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=10))
    ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))

    ax.set_ylim(bottom=0, top=1)
    ax.grid(True)

    ax.set_title('Incremental classification accuracy')
    ax.set_xlabel('Query iteration')
    ax.set_ylabel('Classification Accuracy')

    image = BytesIO()
    plt.plot()
    plt.savefig(image, format='png')
    plt.cla()
    plt.close(fig)
    return ''' <img src="data:image/png;base64,{}" border="0" /> '''.format(base64.encodebytes(image.getvalue()).decode()) 
Example 12
Project: nn_framework   Author: brohrer   File: framework.py    MIT License 6 votes vote down vote up
def report(self):
        n_bins = int(len(self.error_history) // self.reporting_bin_size)
        smoothed_history = []
        for i_bin in range(n_bins):
            smoothed_history.append(np.mean(self.error_history[
                i_bin * self.reporting_bin_size:
                (i_bin + 1) * self.reporting_bin_size
            ]))
        error_history = np.log10(np.array(smoothed_history) + 1e-10)
        ymin = np.minimum(self.report_min, np.min(error_history))
        ymax = np.maximum(self.report_max, np.max(error_history))
        fig = plt.figure()
        ax = plt.gca()
        ax.plot(error_history)
        ax.set_xlabel(f"x{self.reporting_bin_size} iterations")
        ax.set_ylabel("log error")
        ax.set_ylim(ymin, ymax)
        ax.grid()
        fig.savefig(os.path.join(self.reports_path, self.report_name))
        plt.close() 
Example 13
Project: nn_framework   Author: brohrer   File: autoencoder_viz.py    MIT License 6 votes vote down vote up
def render(self, nn, inputs, name=""):
        """
        Build a visualization of an image autoencoder neural network,
        piece by piece.
        """
        fig, ax_boss = self.create_background()
        self.find_nn_size(nn)
        self.find_node_image_size()
        self.find_between_layer_gap()
        self.find_between_node_gap()
        self.find_error_image_position()

        image_axes = []
        self.add_input_image(fig, image_axes, nn, inputs)
        for i_layer in range(self.n_layers):
            self.add_node_images(fig, i_layer, image_axes, nn, inputs)
        self.add_output_image(fig, image_axes, nn, inputs)
        self.add_error_image(fig, image_axes, nn, inputs)
        self.add_layer_connections(ax_boss, image_axes)
        self.save_nn_viz(fig, name)
        plt.close() 
Example 14
Project: Recipes   Author: Lasagne   File: massachusetts_road_segm.py    MIT License 6 votes vote down vote up
def plot_some_results(pred_fn, test_generator, n_images=10):
    fig_ctr = 0
    for data, seg in test_generator:
        res = pred_fn(data)
        for d, s, r in zip(data, seg, res):
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 3, 1)
            plt.imshow(d.transpose(1,2,0))
            plt.title("input patch")
            plt.subplot(1, 3, 2)
            plt.imshow(s[0])
            plt.title("ground truth")
            plt.subplot(1, 3, 3)
            plt.imshow(r)
            plt.title("segmentation")
            plt.savefig("road_segmentation_result_%03.0f.png"%fig_ctr)
            plt.close()
            fig_ctr += 1
            if fig_ctr > n_images:
                break 
Example 15
Project: Parallel.GAMIT   Author: demiangomez   File: pyETM.py    GNU General Public License v3.0 6 votes vote down vote up
def onpick(self, event):

        import dbConnection

        self.f.canvas.mpl_disconnect(self.cid)
        self.picking = False
        print 'Epoch: %s' % pyDate.Date(fyear=event.xdata).yyyyddd()
        jtype = int(input(' -- Enter type of jump (0 = mechanic; 1 = geophysical): '))
        if jtype == 1:
            relx = input(' -- Enter relaxation (e.g. 0.5, 0.5,0.01): ')
        operation = str(raw_input(' -- Enter operation (+, -): '))
        print ' >> Jump inserted'

        # now insert the jump into the db
        cnn = dbConnection.Cnn('gnss_data.cfg')

        self.plt.close()

        # reinitialize ETM

        # wait for 'keep' or 'undo' command 
Example 16
Project: PheKnowLator   Author: callahantiff   File: KGEmbeddingVisualizer.py    Apache License 2.0 6 votes vote down vote up
def plots_embeddings(colors, names, groups, legend_arg, label_size, tsne_size, title, title_size):

    # set up plot
    fig, ax = plt.subplots(figsize=(15, 10))
    ax.margins(0.05)

    # iterate through groups to layer the plot
    for name, group in groups:
        ax.plot(group.x, group.y, marker='o', linestyle='', ms=6, label=names[name],
                color=colors[name], mec='none', alpha=0.8)

    plt.legend(handles=legend_arg[0], fontsize=legend_arg[1], frameon=False, loc=legend_arg[2], ncol=legend_arg[3])

    ax.tick_params(labelsize=label_size)
    plt.ylim(-(tsne_size + 5), tsne_size)
    plt.xlim(-tsne_size, tsne_size)
    plt.title(title, fontsize=title_size)
    plt.show()
    plt.close() 
Example 17
Project: Deep_Neural_Networks   Author: sarthak268   File: GAN_.py    BSD 2-Clause "Simplified" License 6 votes vote down vote up
def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))

    y1 = hist['D_losses']
    y2 = hist['G_losses']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

# training parameters 
Example 18
Project: Deep_Neural_Networks   Author: sarthak268   File: GAN_cuda.py    BSD 2-Clause "Simplified" License 6 votes vote down vote up
def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))

    y1 = hist['D_losses']
    y2 = hist['G_losses']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

# training parameters 
Example 19
Project: rtreelib   Author: sergkr   File: diagram.py    MIT License 6 votes vote down vote up
def plot_rtree(tree: RTreeBase, filename=None, show=True, highlight_node=None, highlight_entry=None):
    """
    Create a cartesian plot (using matplotlib) of the R-Tree nodes/entries. Each node's bounding rectangle
    is plotted as a tan rectangle with dashed edges, and each leaf entry's bounding rectangle is plotted in
    blue. A particular node or entry may be highlighted in the plot by passing in highlight_node and/or
    highlight_entry.
    :param tree: R-Tree instance to plot
    :param filename: If passed in, the plot will be saved to a file
    :param show: If True, show the plot
    :param highlight_node: R-Tree node to highlight
    :param highlight_entry: R-Tree leaf entry to highlight
    """
    fig, ax = plt.subplots(1)
    bbox = tree.root.get_bounding_rect()
    padx, pady = (0.1 * bbox.width, 0.1 * bbox.height)
    ax.set_xlim(left=bbox.min_x - padx, right=bbox.max_x + padx)
    ax.set_ylim(bottom=bbox.min_y - pady, top=bbox.max_y + pady)
    _plot_rtree_leaves(ax, tree, highlight_entry)
    _plot_rtree_nodes(ax, tree, highlight_node)
    if filename:
        plt.savefig(filename, bbox_inches='tight')
    if show:
        plt.show()
    plt.close(fig) 
Example 20
Project: xia2   Author: xia2   File: plot_multiplicity.py    BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, scene, settings=None):
        import matplotlib

        matplotlib.use("Agg")
        from matplotlib import pyplot

        render_2d.__init__(self, scene, settings)

        self._open_circle_points = flex.vec2_double()
        self._open_circle_radii = []
        self._open_circle_colors = []
        self._filled_circle_points = flex.vec2_double()
        self._filled_circle_radii = []
        self._filled_circle_colors = []

        self.fig, self.ax = pyplot.subplots(figsize=self.settings.size_inches)
        self.render(self.ax)
        pyplot.close() 
Example 21
Project: SyNEThesia   Author: RunOrVeith   File: live_viewer.py    MIT License 5 votes vote down vote up
def handle_close(event):
        plt.close("all") 
Example 22
Project: SyNEThesia   Author: RunOrVeith   File: live_viewer.py    MIT License 5 votes vote down vote up
def __exit__(self, *args, **kwargs):
        if self.stream is not None:
            self.stream.stop_stream()
            self.stream.close()
        self.audio_controller.terminate() 
Example 23
Project: PEAKachu   Author: tbischler   File: window.py    ISC License 5 votes vote down vote up
def _plot_initial_windows(self, unsig_base_means, unsig_fcs,
                              sig_base_means, sig_fcs):
        # create plot folder if it does not exist
        plot_folder = "{}/plots".format(self._output_folder)
        if not exists(plot_folder):
            makedirs(plot_folder)
        # MA plot
        plt.plot(np.log10(unsig_base_means),
                 np.log2(unsig_fcs), ".",
                 markersize=2.0, alpha=0.3)
        plt.plot(np.log10(sig_base_means),
                 np.log2(sig_fcs), ".",
                 markersize=2.0, color="red", alpha=0.3)
        plt.axhline(y=np.median(np.log2(unsig_fcs.append(sig_fcs))))
        plt.axvline(x=np.median(np.log10(unsig_base_means.append(
                                         sig_base_means))))
        plt.title("Initial_windows_MA_plot")
        plt.xlabel("log10 base mean")
        plt.ylabel("log2 fold-change")
        plt.savefig("{}/Initial_windows_MA_plot.png".format(plot_folder),
                    dpi=600)
        plt.close()
        # HexBin plot
        df = pd.DataFrame({'log10 base mean': np.log10(unsig_base_means.append(
            sig_base_means)), 'log2 fold-change': np.log2(unsig_fcs.append(
                sig_fcs))})
        df.plot(kind='hexbin', x='log10 base mean',
                y='log2 fold-change', gridsize=50, bins='log')
        plt.axhline(y=np.median(np.log2(unsig_fcs.append(sig_fcs))))
        plt.axvline(x=np.median(np.log10(unsig_base_means.append(
                                         sig_base_means))))
        plt.title("Initial_windows_HexBin_plot")
        plt.savefig("{}/Initial_windows_HexBin_plot.pdf".format(plot_folder))
        plt.close() 
Example 24
Project: PEAKachu   Author: tbischler   File: adaptive.py    ISC License 5 votes vote down vote up
def _plot_initial_peaks(self, unsig_base_means, unsig_fcs,
                            sig_base_means, sig_fcs):
        # create plot folder if it does not exist
        plot_folder = "{}/plots".format(self._output_folder)
        if not exists(plot_folder):
            makedirs(plot_folder)
        # MA plot
        plt.plot(np.log10(unsig_base_means),
                 np.log2(unsig_fcs), ".",
                 markersize=2.0, alpha=0.3)
        plt.plot(np.log10(sig_base_means),
                 np.log2(sig_fcs), ".",
                 markersize=2.0, color="red", alpha=0.3)
        plt.axhline(y=np.median(np.log2(unsig_fcs.append(sig_fcs))))
        plt.axvline(x=np.median(np.log10(unsig_base_means.append(
                                         sig_base_means))))
        plt.title("Initial_peaks_MA_plot")
        plt.xlabel("log10 base mean")
        plt.ylabel("log2 fold-change")
        plt.savefig("{}/Initial_peaks_MA_plot.png".format(plot_folder),
                    dpi=600)
        plt.close()
        # HexBin plot
        df = pd.DataFrame({'log10 base mean': np.log10(unsig_base_means.append(
            sig_base_means)), 'log2 fold-change': np.log2(unsig_fcs.append(
                sig_fcs))})
        df.plot(kind='hexbin', x='log10 base mean',
                y='log2 fold-change', gridsize=50, bins='log')
        plt.axhline(y=np.median(np.log2(unsig_fcs.append(sig_fcs))))
        plt.axvline(x=np.median(np.log10(unsig_base_means.append(
                                         sig_base_means))))
        plt.title("Initial_peaks_HexBin_plot")
        plt.savefig("{}/Initial_peaks_HexBin_plot.pdf".format(plot_folder))
        plt.close() 
Example 25
Project: mmdetection   Author: open-mmlab   File: coco_error_analysis.py    Apache License 2.0 5 votes vote down vote up
def makeplot(rs, ps, outDir, class_name, iou_type):
    cs = np.vstack([
        np.ones((2, 3)),
        np.array([.31, .51, .74]),
        np.array([.75, .31, .30]),
        np.array([.36, .90, .38]),
        np.array([.50, .39, .64]),
        np.array([1, .6, 0])
    ])
    areaNames = ['allarea', 'small', 'medium', 'large']
    types = ['C75', 'C50', 'Loc', 'Sim', 'Oth', 'BG', 'FN']
    for i in range(len(areaNames)):
        area_ps = ps[..., i, 0]
        figure_tile = iou_type + '-' + class_name + '-' + areaNames[i]
        aps = [ps_.mean() for ps_ in area_ps]
        ps_curve = [
            ps_.mean(axis=1) if ps_.ndim > 1 else ps_ for ps_ in area_ps
        ]
        ps_curve.insert(0, np.zeros(ps_curve[0].shape))
        fig = plt.figure()
        ax = plt.subplot(111)
        for k in range(len(types)):
            ax.plot(rs, ps_curve[k + 1], color=[0, 0, 0], linewidth=0.5)
            ax.fill_between(
                rs,
                ps_curve[k],
                ps_curve[k + 1],
                color=cs[k],
                label=str('[{:.3f}'.format(aps[k]) + ']' + types[k]))
        plt.xlabel('recall')
        plt.ylabel('precision')
        plt.xlim(0, 1.)
        plt.ylim(0, 1.)
        plt.title(figure_tile)
        plt.legend()
        # plt.show()
        fig.savefig(outDir + '/{}.png'.format(figure_tile))
        plt.close(fig) 
Example 26
Project: Kaggle-Statoil-Challenge   Author: adodd202   File: utils.py    MIT License 5 votes vote down vote up
def close(self):
        if self.file is not None:
            self.file.close() 
Example 27
Project: neural-fingerprinting   Author: StephanZheng   File: fp_eval.py    BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_pr_auc(pr_results, args, plot=False, plot_name=""):
    xys = [(0.0, 1.0, 0.)]
    labels = []
    for tau, result in pr_results.items():
        xys += [(pr_results[tau]["recall"], pr_results[tau]["prec"], tau)]
        labels += [tau]

    xys.sort(key=lambda x: x[0])
    xs = [i[0] for i in xys]
    ys = [i[1] for i in xys]

    # print("pr")
    # for i in sorted(xys, key=lambda x: x[-1]): print(i)

    # print("recall", xs)
    # print("precis", ys)
    _auc = auc(xs, ys)

    if plot:
        fig, ax = plt.subplots(nrows=1, ncols=1)
        ax.plot(xs, ys, 'go-',)

        for label, x, y in zip(labels, xs, ys):
            ax.annotate(
                label,
                xy=(x, y), xytext=(-20, 20),
                textcoords='offset points', ha='right', va='bottom',
                bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'))

        path = os.path.join(args.log_dir, "pr-{}.svg".format(plot_name))
        print("Storing PR plot in", path)
        fig.savefig(path)
        plt.close(fig)

    return _auc 
Example 28
Project: neural-fingerprinting   Author: StephanZheng   File: fp_eval.py    BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_roc_auc(pr_results, args, plot=False, plot_name=""):
    xys = [(0.0, 0.0, 0.0)]
    labels = []
    for tau, result in pr_results.items():
        xys += [(pr_results[tau]["fpr"], pr_results[tau]["tpr"], tau)]
        labels += [tau]

    xys.sort(key=lambda x: x[0])
    xs = [i[0] for i in xys]
    ys = [i[1] for i in xys]

    # print("roc")
    # for i in sorted(xys, key=lambda x: x[-1]): print(i)

    # print("fpr", xs)
    # print("tpr", ys)
    _auc = auc(xs, ys)

    if plot:
        fig, ax = plt.subplots(nrows=1, ncols=1)
        ax.plot(xs, ys, 'go-',)

        for label, x, y in zip(labels, xs, ys):
            ax.annotate(
                label,
                xy=(x, y), xytext=(-20, 20),
                textcoords='offset points', ha='right', va='bottom',
                bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'))

        path = os.path.join(args.log_dir, "roc-{}.svg".format(plot_name))
        print("Storing ROC plot in", path)
        fig.savefig(path)
        plt.close(fig)

    return _auc 
Example 29
Project: programsynthesishunting   Author: flexgp   File: save_plots.py    GNU General Public License v3.0 5 votes vote down vote up
def save_plot_from_data(data, name):
    """
    Saves a plot of a given set of data.

    :param data: the data to be plotted
    :param name: the name of the data to be plotted.
    :return: Nothing.
    """

    from algorithm.parameters import params

    # Initialise up figure instance.
    fig = plt.figure()
    ax1 = fig.add_subplot(1, 1, 1)

    # Plot data.
    ax1.plot(data)

    # Set labels.
    ax1.set_ylabel(name, fontsize=14)
    ax1.set_xlabel('Generation', fontsize=14)

    # Plot title.
    plt.title(name)

    # Save plot and close.
    plt.savefig(path.join(params['FILE_PATH'], (name + '.pdf')))
    plt.close() 
Example 30
Project: programsynthesishunting   Author: flexgp   File: save_plots.py    GNU General Public License v3.0 5 votes vote down vote up
def save_plot_from_file(filename, stat_name):
    """
    Saves a plot of a given stat from the stats file.

    :param filename: a full specified path to a .csv stats file.
    :param stat_name: the stat of interest for plotting.
    :return: Nothing.
    """

    # Read in the data
    data = pd.read_csv(filename, sep="\t")
    try:
        stat = list(data[stat_name])
    except KeyError:
        s = "utilities.stats.save_plots.save_plot_from_file\n" \
            "Error: stat %s does not exist" % stat_name
        raise Exception(s)

        # Set up the figure.
    fig = plt.figure()
    ax1 = fig.add_subplot(1, 1, 1)

    # Plot the data.
    ax1.plot(stat)

    # Plot title.
    plt.title(stat_name)

    # Get save path
    save_path = pathsep.join(filename.split(pathsep)[:-1])

    # Save plot and close.
    plt.savefig(path.join(save_path, (stat_name + '.pdf')))
    plt.close() 
Example 31
Project: programsynthesishunting   Author: flexgp   File: save_plots.py    GNU General Public License v3.0 5 votes vote down vote up
def save_box_plot(data, names, title):
    """
    Given an array of some data, and a list of names of that data, generate
    and save a box plot of that data.

    :param data: An array of some data to be plotted.
    :param names: A list of names of that data.
    :param title: The title of the plot.
    :return: Nothing
    """

    from algorithm.parameters import params

    import matplotlib.pyplot as plt
    plt.rc('font', family='Times New Roman')

    # Set up the figure.
    fig = plt.figure()
    ax1 = fig.add_subplot(1, 1, 1)

    # Plot tight layout.
    plt.tight_layout()

    # Plot the data.
    ax1.boxplot(np.transpose(data), 1)

    # Plot title.
    plt.title(title)

    # Generate list of numbers for plotting names.
    nums = list(range(len(data))[1:]) + [len(data)]

    # Plot names for each data point.
    plt.xticks(nums, names, rotation='vertical', fontsize=8)

    # Save plot.
    plt.savefig(path.join(params['FILE_PATH'], (title + '.pdf')))

    # Close plot.
    plt.close() 
Example 32
Project: black-widow   Author: BLQ-Software   File: run_interactive.py    MIT License 5 votes vote down vote up
def do_close(self, line):
        """Closes the graph.

        Parameters
        ----------
        line : string
            A string containing command line arguments. Ignored.
        """
        plt.close() 
Example 33
Project: black-widow   Author: BLQ-Software   File: run_interactive.py    MIT License 5 votes vote down vote up
def help_close(self):
        """Prints help message for close command"""
        print "Close the graph" 
Example 34
Project: DOTA_models   Author: ringringyi   File: cmp_summary.py    Apache License 2.0 5 votes vote down vote up
def _vis_readout_maps(outputs, global_step, output_dir, metric_summary, N):
  # outputs is [gt_map, pred_map]:
  if N >= 0:
    outputs = outputs[:N]
  N = len(outputs)

  plt.set_cmap('jet')
  fig, axes = utils.subplot(plt, (N, outputs[0][0].shape[4]*2), (5,5))
  axes = axes.ravel()[::-1].tolist()
  for i in range(N):
    gt_map, pred_map = outputs[i]
    for j in [0]:
      for k in range(gt_map.shape[4]):
        # Display something like the midpoint of the trajectory.
        id = np.int(gt_map.shape[1]/2)

        ax = axes.pop();
        ax.imshow(gt_map[j,id,:,:,k], origin='lower', interpolation='none',
                  vmin=0., vmax=1.)
        ax.set_axis_off();
        if i == 0: ax.set_title('gt_map')

        ax = axes.pop();
        ax.imshow(pred_map[j,id,:,:,k], origin='lower', interpolation='none',
                  vmin=0., vmax=1.)
        ax.set_axis_off();
        if i == 0: ax.set_title('pred_map')

  file_name = os.path.join(output_dir, 'readout_map_{:d}.png'.format(global_step))
  with fu.fopen(file_name, 'w') as f:
    fig.savefig(f, bbox_inches='tight', transparent=True, pad_inches=0)
  plt.close(fig) 
Example 35
Project: DOTA_models   Author: ringringyi   File: plot_lfads.py    Apache License 2.0 5 votes vote down vote up
def plot_lfads(train_bxtxd, train_model_vals,
               train_ext_input_bxtxi=None, train_truth_bxtxd=None,
               valid_bxtxd=None, valid_model_vals=None,
               valid_ext_input_bxtxi=None, valid_truth_bxtxd=None,
               bidx=None, cf=1.0, output_dist='poisson'):

  # Plotting
  f = plt.figure(figsize=(18,20), tight_layout=True)
  plot_lfads_timeseries(train_bxtxd, train_model_vals,
                        train_ext_input_bxtxi,
                        truth_bxtxn=train_truth_bxtxd,
                        conversion_factor=cf, bidx=bidx,
                        output_dist=output_dist, col_title='Train')
  plot_lfads_timeseries(valid_bxtxd, valid_model_vals,
                        valid_ext_input_bxtxi,
                        truth_bxtxn=valid_truth_bxtxd,
                        conversion_factor=cf, bidx=bidx,
                        output_dist=output_dist,
                        subplot_cidx=1, col_title='Valid')

  # Convert from figure to an numpy array width x height x 3 (last for RGB)
  f.canvas.draw()
  data = np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8, sep='')
  data_wxhx3 = data.reshape(f.canvas.get_width_height()[::-1] + (3,))
  plt.close()

  return data_wxhx3 
Example 36
Project: ortholotree   Author: oxpeter   File: internal.py    GNU General Public License v2.0 5 votes vote down vote up
def make_phylip(fastaalignment, logfile):
    "Convert a fasta file alignment to phylip format"
    phylip_alignment = logfile[:-3] + 'phylip'

    input_handle = open(fastaalignment, 'rb')
    output_handle = open(phylip_alignment, 'w')

    alignment = AlignIO.read( input_handle, "fasta")
    AlignIO.write(alignment, output_handle, "phylip")

    input_handle.close()
    output_handle.close()

    return phylip_alignment 
Example 37
Project: ortholotree   Author: oxpeter   File: internal.py    GNU General Public License v2.0 5 votes vote down vote up
def count_genes(genes=[], fastafile=None):
    "evaluates the number of genes provided between a gene list and a fasta file"
    genenum = 0
    if not isinstance(genes,list):
        genes = [genes]
    genes = [ g for g in genes if g != '' ]

    if fastafile:
        # count number of genes provided:
        handle = os.popen("grep -c '^>' " + fastafile)
        result = re.search("(\d*)", handle.readline())
        handle.close()
        if result:
            try:
                genenum = int(result.group(1))
            except ValueError:
                genenum = 2
                print "ValueError calculating genenum"
                """
                putting one will ensure hmmer model is built if there is an error
                counting the number of genes in the fasta file
                """
        else:
            genenum = 2
            print "No result found for genenum"
    return len(genes), genenum 
Example 38
Project: ortholotree   Author: oxpeter   File: internal.py    GNU General Public License v2.0 5 votes vote down vote up
def rank_scores(homologlist, thresh1=0, thresh2=None, genename=None, outfile=None, showplot=False):
    yvalues = sorted([val[1] for val in homologlist.values()], reverse=True)
    plt.plot(yvalues)
    score_cutoff = thresh1 * max(yvalues)
    sample_cutoff = sum(1 for s in yvalues if s >= thresh1 * max(yvalues))
    plt.axhline( score_cutoff , color='r' )
    if thresh2:
        plt.axhline( thresh2 * max(yvalues) , color='r' )
    plt.axvline( sample_cutoff -1 , color='g' )
    plt.text(sample_cutoff + 1,score_cutoff + 10 , "(%d,%d)" % (sample_cutoff,score_cutoff) )
    plt.xlabel("Gene rank")
    plt.ylabel("Phmmer score")
    plt.title("Ranking of phmmer scores for alignment with %s" % genename)
    if outfile:
        plt.savefig(outfile, format='png')
    if showplot:
        plt.show()
    else:
        plt.close() 
Example 39
Project: ortholotree   Author: oxpeter   File: internal.py    GNU General Public License v2.0 5 votes vote down vote up
def display_alignment(fastafile, conversiondic={}, outfile=None, showplot=True,
                        gapthresh=0.05, domain_prb=None, domain_stats=None):
    fig = build_alignment(fastafile, conversiondic, gapthresh=gapthresh,
                          domain_prb=domain_prb, domain_stats=domain_stats)
    if outfile:
        fig.savefig(outfile, format='png')
    if showplot:
        fig.show()
    else:
        plt.close() 
Example 40
Project: ortholotree   Author: oxpeter   File: internal.py    GNU General Public License v2.0 5 votes vote down vote up
def parse_the_hmmer(handle):
    """
    ####### DEPRECATED ######
    parses the protein matches from a hmmer search and returns a dictionary of peptides
    and their associated score and p-value.
    """
    parse_dic = {}
    lcount = 0
    collected = 0
    for line in handle:
        lcount += 1
        if line[0] in ['#', 'Q', 'D', 'S']:
            continue
        elif len(line) < 2:
            continue
        elif line[0] == '>':
            break
        elif line.split()[1] in ['hits', 'inclusion', 'annotation']:
            break
        else:
            try:
                score = float(line.split()[1])
                pvalue = eval(line.split()[0])
            except ValueError:
                continue
            else:
                parse_dic[line.split()[8]] = (pvalue, score)

    handle.close()
    return parse_dic

####### Phylogeny creation/manipulation ######## 
Example 41
Project: CAFA_assessment_tool   Author: ashleyzhou972   File: PrettyIO.py    GNU General Public License v3.0 5 votes vote down vote up
def print_enrichment_chart(file_handle, vals, title):
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        print("Error while printing. To use this functionality you need to have matplotlib installed.", file=sys.stderr)
    else:
        fig, ax1 = plt.subplots()
        
        xs = list(range(len(vals)))
        ys =  vals
        
        ax1.plot(xs, ys)
        
        bar_ys = [int(ys[0] > 0)]
        for i in range(1, len(ys)):
            bar_ys.append(int(ys[i] > ys[i - 1]))
        bar_ys = [bar_ys]
        
        pos = ax1.axes.get_position()
        
        ax0 = fig.add_axes([pos.x0, pos.y1, pos.width, 0.1])
        
        ax0.imshow(bar_ys, cmap=plt.cm.Blues, interpolation='nearest')
        ax0.axes.get_yaxis().set_visible(False)
        ax0.axes.get_xaxis().set_visible(False)
        ax0.set_title(title)
        
        plt.savefig(file_handle, bbox_inches=0)
        plt.close() 
Example 42
Project: CAFA_assessment_tool   Author: ashleyzhou972   File: plot.py    GNU General Public License v3.0 5 votes vote down vote up
def plotMultiple(title,listofResults,smooth):
    '''
    supply lists of precision+recall+name lists
    '''
    fontP = FontProperties()
    fontP.set_size('small')
    num = len(listofResults)
    pal=sns.color_palette("Paired", num)
    colors=pal.as_hex()
    for j,i in enumerate(listofResults):
        linetype = '-'
        if smooth=='Y':
            ax = plt.subplot()
            precision = curveSmooth(i)[0][1:]
            recall = curveSmooth(i)[1][1:]
            ax.plot(recall,precision,linetype,color=colors[j],label=i.method+':\nF=%s C=%s'%(i.opt,i.coverage)) 
            ax.plot(i.recall[int(i.thres*100)],i.precision[int(i.thres*100)],'o',color=colors[j])
        elif smooth=='N':
            ax = plt.subplot()
            ax.plot(i.recall,i.precision,linetype,color=colors[j],label=i.method+':\nF=%s C=%s'%(i.opt,i.coverage))
            ax.plot(i.recall[int(i.thres*100)],i.precision[int(i.thres*100)],'o',color=colors[j])
    plt.axis([0,1,0,1])
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    plt.yticks(numpy.arange(0,1,0.1))
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.title(title)
    figurename = os.path.join('./plots/',title)       
    plt.savefig(figurename,dpi=200)
    plt.close() 
Example 43
Project: whatsapp-stats   Author: nielsrolf   File: analyze.py    Apache License 2.0 5 votes vote down vote up
def plot(self, stat_name, keys, values):
        plt.clf()
        fig = plt.figure()

        cats = list(dict(keys).keys()) # keys are cat, time pairs
        for cat in cats:
            filtered_values = [value for key, value in zip(keys, values) if key[0]==cat]
            time_labels = [key[1] for key in keys if key[0]==cat]
            self.time_group._plot(cat, stat_name, time_labels, filtered_values, fig)

        plt.legend()
        plt.title(stat_name + " " + self.name)
        plt.savefig(PLOT_PATH + "/" + self.name+" "+stat_name+".png")
        plt.close('all') 
Example 44
Project: RNASEqTool   Author: armell   File: content_representations.py    MIT License 5 votes vote down vote up
def to_html(self):
        import matplotlib.pyplot as plt
        html = fig_to_html(self.content, figid="generatedchart")
        #closes fig element (refresh)
        plt.close()

        return html 
Example 45
Project: RNASEqTool   Author: armell   File: content_representations.py    MIT License 5 votes vote down vote up
def to_html(self):
        script, div = components(self.content, CDN)
        #plt.close()
        return str(script) + str(div) 
Example 46
Project: cs294-112_hws   Author: xuwd11   File: pointmass.py    MIT License 5 votes vote down vote up
def reset(self):
        plt.close()
        self.state = np.array([self.goal_padding, self.goal_padding])
        state = self.state/self.scale
        return state 
Example 47
Project: cs294-112_hws   Author: xuwd11   File: pointmass.py    MIT License 5 votes vote down vote up
def visualize(self, states, itr, dirname):
        if states is None:
            states = np.load(os.path.join(dirname, '{}.npy'.format(itr)))
        indices = np.array([int(self.preprocess(s)) for s in states])
        a = np.zeros(int(self.grid_size))
        for i in indices:
            a[i] += 1
        max_freq = np.max(a)
        a/=float(max_freq)  # normalize
        a = np.reshape(a, (self.scale, self.scale))
        ax = sns.heatmap(a)
        plt.savefig(os.path.join(dirname, '{}.png'.format(itr)))
        plt.close() 
Example 48
Project: End-to-end-ASR-Pytorch   Author: Alexander-H-Liu   File: util.py    MIT License 5 votes vote down vote up
def _save_canvas(data, meta=None):
    fig, ax = plt.subplots(figsize=(16, 8))
    if meta is None:
        ax.imshow(data, aspect="auto", origin="lower")
    else:
        ax.bar(meta[0], data[0], tick_label=meta[1], fc=(0, 0, 1, 0.5))
        ax.bar(meta[0], data[1], tick_label=meta[1], fc=(1, 0, 0, 0.5))
    fig.canvas.draw()
    # Note : torch tb add_image takes color as [0,1]
    data = np.array(fig.canvas.renderer._renderer)[:, :, :-1]/255.0
    plt.close(fig)
    return data

# Reference : https://stackoverflow.com/questions/579310/formatting-long-numbers-as-strings-in-python 
Example 49
Project: relay-bench   Author: uwsampl   File: plot_util.py    Apache License 2.0 5 votes vote down vote up
def save(self, dirname, filename):
        outfile = prepare_out_file(dirname, filename)
        plt.savefig(outfile, dpi=500, bbox_inches='tight')
        plt.close() 
Example 50
Project: lirpg   Author: Hwhitetooth   File: mujoco_dset.py    MIT License 5 votes vote down vote up
def plot(self):
        import matplotlib.pyplot as plt
        plt.hist(self.rets)
        plt.savefig("histogram_rets.png")
        plt.close()