Python matplotlib.pyplot.figure() Examples

The following are 30 code examples of matplotlib.pyplot.figure(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: __init__.py    From EDeN with MIT License 11 votes vote down vote up
def plot_confusion_matrix(y_true, y_pred, size=None, normalize=False):
    """plot_confusion_matrix."""
    cm = confusion_matrix(y_true, y_pred)
    fmt = "%d"
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = "%.2f"
    xticklabels = list(sorted(set(y_pred)))
    yticklabels = list(sorted(set(y_true)))
    if size is not None:
        plt.figure(figsize=(size, size))
    heatmap(cm, xlabel='Predicted label', ylabel='True label',
            xticklabels=xticklabels, yticklabels=yticklabels,
            cmap=plt.cm.Blues, fmt=fmt)
    if normalize:
        plt.title("Confusion matrix (norm.)")
    else:
        plt.title("Confusion matrix")
    plt.gca().invert_yaxis() 
Example #2
Source File: dataset.py    From neural-combinatorial-optimization-rl-tensorflow with MIT License 8 votes vote down vote up
def visualize_2D_trip(self, trip):
        plt.figure(figsize=(30,30))
        rcParams.update({'font.size': 22})

        # Plot cities
        plt.scatter(trip[:,0], trip[:,1], s=200)

        # Plot tour
        tour=np.array(list(range(len(trip))) + [0])
        X = trip[tour, 0]
        Y = trip[tour, 1]
        plt.plot(X, Y,"--", markersize=100)

        # Annotate cities with order
        labels = range(len(trip))
        for i, (x, y) in zip(labels,(zip(X,Y))):
            plt.annotate(i,xy=(x, y))  

        plt.xlim(0,100)
        plt.ylim(0,100)
        plt.show()


    # Heatmap of permutations (x=cities; y=steps) 
Example #3
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 8 votes vote down vote up
def plot_wh_methods():  # from utils.utils import *; plot_wh_methods()
    # Compares the two methods for width-height anchor multiplication
    # https://github.com/ultralytics/yolov3/issues/168
    x = np.arange(-4.0, 4.0, .1)
    ya = np.exp(x)
    yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2

    fig = plt.figure(figsize=(6, 3), dpi=150)
    plt.plot(x, ya, '.-', label='yolo method')
    plt.plot(x, yb ** 2, '.-', label='^2 power method')
    plt.plot(x, yb ** 2.5, '.-', label='^2.5 power method')
    plt.xlim(left=-4, right=4)
    plt.ylim(bottom=0, top=6)
    plt.xlabel('input')
    plt.ylabel('output')
    plt.legend()
    fig.tight_layout()
    fig.savefig('comparison.png', dpi=200) 
Example #4
Source File: util.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 8 votes vote down vote up
def compute_roc(y_true, y_pred, plot=False):
    """
    TODO
    :param y_true: ground truth
    :param y_pred: predictions
    :param plot:
    :return:
    """
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    auc_score = auc(fpr, tpr)
    if plot:
        plt.figure(figsize=(7, 6))
        plt.plot(fpr, tpr, color='blue',
                 label='ROC (AUC = %0.4f)' % auc_score)
        plt.legend(loc='lower right')
        plt.title("ROC Curve")
        plt.xlabel("FPR")
        plt.ylabel("TPR")
        plt.show()

    return fpr, tpr, auc_score 
Example #5
Source File: dataset.py    From neural-combinatorial-optimization-rl-tensorflow with MIT License 7 votes vote down vote up
def visualize_2D_trip(self,trip,tw_open,tw_close):
        plt.figure(figsize=(30,30))
        rcParams.update({'font.size': 22})
        # Plot cities
        colors = ['red'] # Depot is first city
        for i in range(len(tw_open)-1):
            colors.append('blue')
        plt.scatter(trip[:,0], trip[:,1], color=colors, s=200)
        # Plot tour
        tour=np.array(list(range(len(trip))) + [0])
        X = trip[tour, 0]
        Y = trip[tour, 1]
        plt.plot(X, Y,"--", markersize=100)
        # Annotate cities with TW
        tw_open = np.rint(tw_open)
        tw_close = np.rint(tw_close)
        time_window = np.concatenate((tw_open,tw_close),axis=1)
        for tw, (x, y) in zip(time_window,(zip(X,Y))):
            plt.annotate(tw,xy=(x, y))  
        plt.xlim(0,60)
        plt.ylim(0,60)
        plt.show()


    # Heatmap of permutations (x=cities; y=steps) 
Example #6
Source File: __init__.py    From EDeN with MIT License 7 votes vote down vote up
def plot_roc_curve(y_true, y_score, size=None):
    """plot_roc_curve."""
    false_positive_rate, true_positive_rate, thresholds = roc_curve(
        y_true, y_score)
    if size is not None:
        plt.figure(figsize=(size, size))
        plt.axis('equal')
    plt.plot(false_positive_rate, true_positive_rate, lw=2, color='navy')
    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.ylim([-0.05, 1.05])
    plt.xlim([-0.05, 1.05])
    plt.grid()
    plt.title('Receiver operating characteristic AUC={0:0.2f}'.format(
        roc_auc_score(y_true, y_score))) 
Example #7
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 7 votes vote down vote up
def plot_evolution_results(hyp):  # from utils.utils import *; plot_evolution_results(hyp)
    # Plot hyperparameter evolution results in evolve.txt
    x = np.loadtxt('evolve.txt', ndmin=2)
    f = fitness(x)
    weights = (f - f.min()) ** 2  # for weighted results
    fig = plt.figure(figsize=(12, 10))
    matplotlib.rc('font', **{'size': 8})
    for i, (k, v) in enumerate(hyp.items()):
        y = x[:, i + 5]
        # mu = (y * weights).sum() / weights.sum()  # best weighted result
        mu = y[f.argmax()]  # best single result
        plt.subplot(4, 5, i + 1)
        plt.plot(mu, f.max(), 'o', markersize=10)
        plt.plot(y, f, '.')
        plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9})  # limit to 40 characters
        print('%15s: %.3g' % (k, mu))
    fig.tight_layout()
    plt.savefig('evolve.png', dpi=200) 
Example #8
Source File: test_bayestar.py    From dustmaps with GNU General Public License v2.0 6 votes vote down vote up
def atest_plot_samples(self):
        dm = np.linspace(4., 19., 1001)
        samples = []

        for dm_k in dm:
            d = 10.**(dm_k/5.-2.)
            samples.append(self._interp_ebv(self._test_data[0], d))

        samples = np.array(samples).T
        # print samples

        import matplotlib.pyplot as plt
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        for s in samples:
            ax.plot(dm, s, lw=2., alpha=0.5)

        plt.show() 
Example #9
Source File: plotting.py    From medicaldetectiontoolkit with Apache License 2.0 6 votes vote down vote up
def __init__(self, cf):

        self.file_name = cf.plot_dir + '/monitor_{}'.format(cf.fold)
        self.exp_name = cf.fold_dir
        self.do_validation = cf.do_validation
        self.separate_values_dict = cf.assign_values_to_extra_figure
        self.figure_list = []
        for n in range(cf.n_monitoring_figures):
            self.figure_list.append(plt.figure(figsize=(10, 6)))
            self.figure_list[-1].ax1 = plt.subplot(111)
            self.figure_list[-1].ax1.set_xlabel('epochs')
            self.figure_list[-1].ax1.set_ylabel('loss / metrics')
            self.figure_list[-1].ax1.set_xlim(0, cf.num_epochs)
            self.figure_list[-1].ax1.grid()

        self.figure_list[0].ax1.set_ylim(0, 1.5)
        self.color_palette = ['b', 'c', 'r', 'purple', 'm', 'y', 'k', 'tab:gray'] 
Example #10
Source File: zipf_law.py    From pyhanlp with Apache License 2.0 6 votes vote down vote up
def plot(token_counts, title='MSR语料库词频统计', ylabel='词频'):
    from matplotlib import pyplot as plt
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    fig = plt.figure(
        # figsize=(8, 6)
    )
    ax = fig.add_subplot(111)
    token_counts = list(zip(*token_counts))
    num_elements = np.arange(len(token_counts[0]))
    top_offset = max(token_counts[1]) + len(str(max(token_counts[1])))
    ax.set_title(title)
    ax.set_xlabel('词语')
    ax.set_ylabel(ylabel)
    ax.xaxis.set_label_coords(1.05, 0.015)
    ax.set_xticks(num_elements)
    ax.set_xticklabels(token_counts[0], rotation=55, verticalalignment='top')
    ax.set_ylim([0, top_offset])
    ax.set_xlim([-1, len(token_counts[0])])
    rects = ax.plot(num_elements, token_counts[1], linewidth=1.5)
    plt.show() 
Example #11
Source File: plot_part1.py    From cs294-112_hws with MIT License 6 votes vote down vote up
def plot_13(data):
    r1, r2, r3, r4 = data
    plt.figure()
    add_plot(r3, 'MeanReward100Episodes');
    add_plot(r3, 'BestMeanReward', 'gamma = 0.9');
    add_plot(r2, 'MeanReward100Episodes');
    add_plot(r2, 'BestMeanReward', 'gamma = 0.99');
    add_plot(r4, 'MeanReward100Episodes');
    add_plot(r4, 'BestMeanReward', 'gamma = 0.999');
    plt.legend();
    plt.xlabel('Time step');
    plt.ylabel('Reward');
    plt.savefig(
        os.path.join('results', 'p13.png'),
        bbox_inches='tight',
        transparent=True,
        pad_inches=0.1
    ) 
Example #12
Source File: massachusetts_road_segm.py    From Recipes with MIT License 6 votes vote down vote up
def plot_some_results(pred_fn, test_generator, n_images=10):
    fig_ctr = 0
    for data, seg in test_generator:
        res = pred_fn(data)
        for d, s, r in zip(data, seg, res):
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 3, 1)
            plt.imshow(d.transpose(1,2,0))
            plt.title("input patch")
            plt.subplot(1, 3, 2)
            plt.imshow(s[0])
            plt.title("ground truth")
            plt.subplot(1, 3, 3)
            plt.imshow(r)
            plt.title("segmentation")
            plt.savefig("road_segmentation_result_%03.0f.png"%fig_ctr)
            plt.close()
            fig_ctr += 1
            if fig_ctr > n_images:
                break 
Example #13
Source File: helper.py    From Stock-Price-Prediction with MIT License 6 votes vote down vote up
def plot_mul(Y_hat, Y, pred_len):
    """
    PLots the predicted data versus true data

    Input: Predicted data, True Data, Length of prediction
    Output: return plot

    Note: Run from timeSeriesPredict.py
    """
    fig = plt.figure(facecolor='white')
    ax = fig.add_subplot(111)
    ax.plot(Y, label='Y')
    # Print the predictions in its respective series-length
    for i, j in enumerate(Y_hat):
        shift = [None for p in range(i * pred_len)]
        plt.plot(shift + j, label='Y_hat')
        plt.legend()
    plt.show() 
Example #14
Source File: test.py    From MomentumContrast.pytorch with MIT License 6 votes vote down vote up
def show(mnist, targets, ret):
    target_ids = range(len(set(targets)))
    
    colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'violet', 'orange', 'purple']
    
    plt.figure(figsize=(12, 10))
    
    ax = plt.subplot(aspect='equal')
    for label in set(targets):
        idx = np.where(np.array(targets) == label)[0]
        plt.scatter(ret[idx, 0], ret[idx, 1], c=colors[label], label=label)
    
    for i in range(0, len(targets), 250):
        img = (mnist[i][0] * 0.3081 + 0.1307).numpy()[0]
        img = OffsetImage(img, cmap=plt.cm.gray_r, zoom=0.5) 
        ax.add_artist(AnnotationBbox(img, ret[i]))
    
    plt.legend()
    plt.show() 
Example #15
Source File: recall.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def plot_iou_recall(recalls, iou_thrs):
    """Plot IoU-Recalls curve.

    Args:
        recalls(ndarray or list): shape (k,)
        iou_thrs(ndarray or list): same shape as `recalls`
    """
    if isinstance(iou_thrs, np.ndarray):
        _iou_thrs = iou_thrs.tolist()
    else:
        _iou_thrs = iou_thrs
    if isinstance(recalls, np.ndarray):
        _recalls = recalls.tolist()
    else:
        _recalls = recalls

    import matplotlib.pyplot as plt
    f = plt.figure()
    plt.plot(_iou_thrs + [1.0], _recalls + [0.])
    plt.xlabel('IoU')
    plt.ylabel('Recall')
    plt.axis([iou_thrs.min(), 1, 0, 1])
    f.show() 
Example #16
Source File: inference.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def show_result_pyplot(model, img, result, score_thr=0.3, fig_size=(15, 10)):
    """Visualize the detection results on the image.

    Args:
        model (nn.Module): The loaded detector.
        img (str or np.ndarray): Image filename or loaded image.
        result (tuple[list] or list): The detection result, can be either
            (bbox, segm) or just bbox.
        score_thr (float): The threshold to visualize the bboxes and masks.
        fig_size (tuple): Figure size of the pyplot figure.
    """
    if hasattr(model, 'module'):
        model = model.module
    img = model.show_result(img, result, score_thr=score_thr, show=False)
    plt.figure(figsize=fig_size)
    plt.imshow(mmcv.bgr2rgb(img))
    plt.show() 
Example #17
Source File: recall.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def plot_num_recall(recalls, proposal_nums):
    """Plot Proposal_num-Recalls curve.

    Args:
        recalls(ndarray or list): shape (k,)
        proposal_nums(ndarray or list): same shape as `recalls`
    """
    if isinstance(proposal_nums, np.ndarray):
        _proposal_nums = proposal_nums.tolist()
    else:
        _proposal_nums = proposal_nums
    if isinstance(recalls, np.ndarray):
        _recalls = recalls.tolist()
    else:
        _recalls = recalls

    import matplotlib.pyplot as plt
    f = plt.figure()
    plt.plot([0] + _proposal_nums, [0] + _recalls)
    plt.xlabel('Proposal num')
    plt.ylabel('Recall')
    plt.axis([0, proposal_nums.max(), 0, 1])
    f.show() 
Example #18
Source File: plot_part1.py    From cs294-112_hws with MIT License 6 votes vote down vote up
def plot_12(data):
    r1, r2, r3, r4 = data
    plt.figure()
    add_plot(r1, 'MeanReward100Episodes');
    add_plot(r1, 'BestMeanReward', 'vanilla DQN');
    add_plot(r2, 'MeanReward100Episodes');
    add_plot(r2, 'BestMeanReward', 'double DQN');
    plt.xlabel('Time step');
    plt.ylabel('Reward');
    plt.legend();
    plt.savefig(
        os.path.join('results', 'p12.png'),
        bbox_inches='tight',
        transparent=True,
        pad_inches=0.1
    ) 
Example #19
Source File: util.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def compute_roc_rfeinman(probs_neg, probs_pos, plot=False):
    """
    TODO
    :param probs_neg:
    :param probs_pos:
    :param plot:
    :return:
    """
    probs = np.concatenate((probs_neg, probs_pos))
    labels = np.concatenate((np.zeros_like(probs_neg), np.ones_like(probs_pos)))
    fpr, tpr, _ = roc_curve(labels, probs)
    auc_score = auc(fpr, tpr)
    if plot:
        plt.figure(figsize=(7, 6))
        plt.plot(fpr, tpr, color='blue',
                 label='ROC (AUC = %0.4f)' % auc_score)
        plt.legend(loc='lower right')
        plt.title("ROC Curve")
        plt.xlabel("FPR")
        plt.ylabel("TPR")
        plt.show()

    return fpr, tpr, auc_score 
Example #20
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 6 votes vote down vote up
def plot_images(imgs, targets, paths=None, fname='images.jpg'):
    # Plots training images overlaid with targets
    imgs = imgs.cpu().numpy()
    targets = targets.cpu().numpy()
    # targets = targets[targets[:, 1] == 21]  # plot only one class

    fig = plt.figure(figsize=(10, 10))
    bs, _, h, w = imgs.shape  # batch size, _, height, width
    bs = min(bs, 16)  # limit plot to 16 images
    ns = np.ceil(bs ** 0.5)  # number of subplots

    for i in range(bs):
        boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).T
        boxes[[0, 2]] *= w
        boxes[[1, 3]] *= h
        plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0))
        plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
        plt.axis('off')
        if paths is not None:
            s = Path(paths[i]).name
            plt.title(s[:min(len(s), 40)], fontdict={'size': 8})  # limit to 40 characters
    fig.tight_layout()
    fig.savefig(fname, dpi=200)
    plt.close() 
Example #21
Source File: simulate_sin.py    From deep-learning-note with MIT License 6 votes vote down vote up
def run_eval(sess, test_X, test_y):
    ds = tf.data.Dataset.from_tensor_slices((test_X, test_y))
    ds = ds.batch(1)
    X, y = ds.make_one_shot_iterator().get_next()

    with tf.variable_scope("model", reuse=True):
        prediction, _, _ = lstm_model(X, [0.0], False)
        predictions = []
        labels = []
        for i in range(TESTING_EXAMPLES):
            p, l = sess.run([prediction, y])
            predictions.append(p)
            labels.append(l)

    predictions = np.array(predictions).squeeze()
    labels = np.array(labels).squeeze()
    rmse = np.sqrt(((predictions-labels) ** 2).mean(axis=0))
    print("Mean Square Error is: %f" % rmse)

    plt.figure()
    plt.plot(predictions, label='predictions')
    plt.plot(labels, label='real_sin')
    plt.legend()
    plt.show() 
Example #22
Source File: dataset.py    From neural-combinatorial-optimization-rl-tensorflow with MIT License 6 votes vote down vote up
def visualize_sampling(self, permutations):
        max_length = len(permutations[0])
        grid = np.zeros([max_length,max_length]) # initialize heatmap grid to 0

        transposed_permutations = np.transpose(permutations)
        for t, cities_t in enumerate(transposed_permutations): # step t, cities chosen at step t
            city_indices, counts = np.unique(cities_t,return_counts=True,axis=0)
            for u,v in zip(city_indices, counts):
                grid[t][u]+=v # update grid with counts from the batch of permutations

        # plot heatmap
        fig = plt.figure()
        rcParams.update({'font.size': 22})
        ax = fig.add_subplot(1,1,1)
        ax.set_aspect('equal')
        plt.imshow(grid, interpolation='nearest', cmap='gray')
        plt.colorbar()
        plt.title('Sampled permutations')
        plt.ylabel('Time t')
        plt.xlabel('City i')
        plt.show() 
Example #23
Source File: plot_lfads.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def plot_lfads(train_bxtxd, train_model_vals,
               train_ext_input_bxtxi=None, train_truth_bxtxd=None,
               valid_bxtxd=None, valid_model_vals=None,
               valid_ext_input_bxtxi=None, valid_truth_bxtxd=None,
               bidx=None, cf=1.0, output_dist='poisson'):

  # Plotting
  f = plt.figure(figsize=(18,20), tight_layout=True)
  plot_lfads_timeseries(train_bxtxd, train_model_vals,
                        train_ext_input_bxtxi,
                        truth_bxtxn=train_truth_bxtxd,
                        conversion_factor=cf, bidx=bidx,
                        output_dist=output_dist, col_title='Train')
  plot_lfads_timeseries(valid_bxtxd, valid_model_vals,
                        valid_ext_input_bxtxi,
                        truth_bxtxn=valid_truth_bxtxd,
                        conversion_factor=cf, bidx=bidx,
                        output_dist=output_dist,
                        subplot_cidx=1, col_title='Valid')

  # Convert from figure to an numpy array width x height x 3 (last for RGB)
  f.canvas.draw()
  data = np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8, sep='')
  data_wxhx3 = data.reshape(f.canvas.get_width_height()[::-1] + (3,))
  plt.close()

  return data_wxhx3 
Example #24
Source File: competition_model_class.py    From Deep_Learning_Weather_Forecasting with Apache License 2.0 5 votes vote down vote up
def plot_prediction(self, x, y_true, y_pred, input_ruitu=None):
        """Plots the predictions.

        Arguments
        ---------
        x: Input sequence of shape (input_sequence_length,
            dimension_of_signal)
        y_true: True output sequence of shape (input_sequence_length,
            dimension_of_signal)
        y_pred: Predicted output sequence (input_sequence_length,
            dimension_of_signal)
        input_ruitu: Ruitu output sequence 
        """

        plt.figure(figsize=(12, 3))

        output_dim = x.shape[-1]# feature dimension
        for j in range(output_dim):
            past = x[:, j] 
            true = y_true[:, j]
            pred = y_pred[:, j]
            if input_ruitu is not None:
                ruitu = input_ruitu[:, j]

            label1 = "Seen (past) values" if j==0 else "_nolegend_"
            label2 = "True future values" if j==0 else "_nolegend_"
            label3 = "Predictions" if j==0 else "_nolegend_"
            label4 = "Ruitu values" if j==0 else "_nolegend_"

            plt.plot(range(len(past)), past, "o-g",
                     label=label1)
            plt.plot(range(len(past),
                     len(true)+len(past)), true, "x--g", label=label2)
            plt.plot(range(len(past), len(pred)+len(past)), pred, "o--y",
                     label=label3)
            if input_ruitu is not None:
                plt.plot(range(len(past), len(ruitu)+len(past)), ruitu, "o--r",
                     label=label4)
        plt.legend(loc='best')
        plt.title("Predictions v.s. true values v.s. Ruitu")
        plt.show() 
Example #25
Source File: plot_lfads.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def plot_priors():
  g0s_prior_mean_bxn = train_modelvals['prior_g0_mean']
  g0s_prior_var_bxn = train_modelvals['prior_g0_var']
  g0s_post_mean_bxn = train_modelvals['posterior_g0_mean']
  g0s_post_var_bxn = train_modelvals['posterior_g0_var']

  plt.figure(figsize=(10,4), tight_layout=True);
  plt.subplot(1,2,1)
  plt.hist(g0s_post_mean_bxn.flatten(), bins=20, color='b');
  plt.hist(g0s_prior_mean_bxn.flatten(), bins=20, color='g');

  plt.title('Histogram of Prior/Posterior Mean Values')
  plt.subplot(1,2,2)
  plt.hist((g0s_post_var_bxn.flatten()), bins=20, color='b');
  plt.hist((g0s_prior_var_bxn.flatten()), bins=20, color='g');
  plt.title('Histogram of Prior/Posterior Log Variance Values')

  plt.figure(figsize=(10,10), tight_layout=True)
  plt.subplot(2,2,1)
  plt.imshow(g0s_prior_mean_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Prior g0 means')

  plt.subplot(2,2,2)
  plt.imshow(g0s_post_mean_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Posterior g0 means');

  plt.subplot(2,2,3)
  plt.imshow(g0s_prior_var_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Prior g0 variance Values')

  plt.subplot(2,2,4)
  plt.imshow(g0s_post_var_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Posterior g0 variance Values')

  plt.figure(figsize=(10,5))
  plt.stem(np.sort(np.log(g0s_post_mean_bxn.std(axis=0))));
  plt.title('Log standard deviation of h0 means'); 
Example #26
Source File: plot_utils.py    From keras-anomaly-detection with MIT License 5 votes vote down vote up
def plot_confusion_matrix(y_true, y_pred):
    conf_matrix = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(12, 12))
    sns.heatmap(conf_matrix, xticklabels=LABELS, yticklabels=LABELS, annot=True, fmt="d")
    plt.title("Confusion matrix")
    plt.ylabel('True class')
    plt.xlabel('Predicted class')
    plt.show() 
Example #27
Source File: plotting.py    From medicaldetectiontoolkit with Apache License 2.0 5 votes vote down vote up
def plot_stat_curves(stats, outfile):

    for c in ['roc', 'prc']:
        plt.figure()
        for s in stats:
            if s[c] is not None:
                plt.plot(s[c][0], s[c][1], label=s['name'] + '_' + c)
        plt.title(outfile.split('/')[-1] + '_' + c)
        plt.legend(loc=3 if c == 'prc' else 4)
        plt.xlabel('precision' if c == 'prc' else '1-spec.')
        plt.ylabel('recall')
        plt.savefig(outfile + '_' + c)
        plt.close() 
Example #28
Source File: plotting.py    From medicaldetectiontoolkit with Apache License 2.0 5 votes vote down vote up
def plot_prediction_hist(label_list, pred_list, type_list, outfile):
    """
    plot histogram of predictions for a specific class.
    :param label_list: list of 1s and 0s specifying whether prediction is a true positive match (1) or a false positive (0).
    False negatives (missed ground truth objects) are artificially added predictions with score 0 and label 1.
    :param pred_list: list of prediction-scores.
    :param type_list: list of prediction-types for stastic-info in title.
    """
    preds = np.array(pred_list)
    labels = np.array(label_list)
    title = outfile.split('/')[-1] + ' count:{}'.format(len(label_list))
    plt.figure()
    plt.yscale('log')
    if 0 in labels:
        plt.hist(preds[labels == 0], alpha=0.3, color='g', range=(0, 1), bins=50, label='false pos.')
    if 1 in labels:
        plt.hist(preds[labels == 1], alpha=0.3, color='b', range=(0, 1), bins=50, label='true pos. (false neg. @ score=0)')

    if type_list is not None:
        fp_count = type_list.count('det_fp')
        fn_count = type_list.count('det_fn')
        tp_count = type_list.count('det_tp')
        pos_count = fn_count + tp_count
        title += ' tp:{} fp:{} fn:{} pos:{}'. format(tp_count, fp_count, fn_count, pos_count)

    plt.legend()
    plt.title(title)
    plt.xlabel('confidence score')
    plt.ylabel('log n')
    plt.savefig(outfile)
    plt.close() 
Example #29
Source File: plot_lfads.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def _plot_item(W, name, full_name, nspaces):
  plt.figure()
  if W.shape == ():
    print(name, ": ", W)
  elif W.shape[0] == 1:
    plt.stem(W.T)
    plt.title(full_name)
  elif W.shape[1] == 1:
    plt.stem(W)
    plt.title(full_name)
  else:
    plt.imshow(np.abs(W), interpolation='nearest', cmap='jet');
    plt.colorbar()
    plt.title(full_name) 
Example #30
Source File: cli.py    From tmhmm.py with MIT License 5 votes vote down vote up
def plot(posterior_file, outputfile):
    inside, membrane, outside = load_posterior_file(posterior_file)

    plt.figure(figsize=(16, 8))
    plt.title('Posterior probabilities')
    plt.suptitle('tmhmm.py')
    plt.plot(inside, label='inside', color='blue')
    plt.plot(membrane, label='transmembrane', color='red')
    plt.fill_between(range(len(inside)), membrane, color='red')
    plt.plot(outside, label='outside', color='black')
    plt.legend(frameon=False, bbox_to_anchor=[0.5, 0],
               loc='upper center', ncol=3, borderaxespad=1.5)
    plt.tight_layout(pad=3)
    plt.savefig(outputfile)