Python matplotlib.use() Examples

The following are 30 code examples for showing how to use matplotlib.use(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module matplotlib , or try the search function .

Example 1
Project: nsf   Author: bayesiains   File: images.py    License: MIT License 7 votes vote down vote up
def set_device(use_gpu, multi_gpu, _log):
    # Decide which device to use.
    if use_gpu and not torch.cuda.is_available():
        raise RuntimeError('use_gpu is True but CUDA is not available')

    if use_gpu:
        device = torch.device('cuda')
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        device = torch.device('cpu')

    if multi_gpu and torch.cuda.device_count() == 1:
        raise RuntimeError('Multiple GPU training requested, but only one GPU is available.')

    if multi_gpu:
        _log.info('Using all {} GPUs available'.format(torch.cuda.device_count()))

    return device 
Example 2
Project: DeepLung   Author: uci-cbcl   File: utils.py    License: GNU General Public License v3.0 6 votes vote down vote up
def getFreeId():
    import pynvml 

    pynvml.nvmlInit()
    def getFreeRatio(id):
        handle = pynvml.nvmlDeviceGetHandleByIndex(id)
        use = pynvml.nvmlDeviceGetUtilizationRates(handle)
        ratio = 0.5*(float(use.gpu+float(use.memory)))
        return ratio

    deviceCount = pynvml.nvmlDeviceGetCount()
    available = []
    for i in range(deviceCount):
        if getFreeRatio(i)<70:
            available.append(i)
    gpus = ''
    for g in available:
        gpus = gpus+str(g)+','
    gpus = gpus[:-1]
    return gpus 
Example 3
Project: linguistic-style-transfer   Author: vineetjohn   File: tsne_visualizer.py    License: Apache License 2.0 6 votes vote down vote up
def plot_coordinates(coordinates, plot_path, markers, label_names, fig_num):
    matplotlib.use('svg')
    import matplotlib.pyplot as plt

    plt.figure(fig_num)
    for i in range(len(markers) - 1):
        plt.scatter(x=coordinates[markers[i]:markers[i + 1], 0],
                    y=coordinates[markers[i]:markers[i + 1], 1],
                    marker=plot_markers[i % len(plot_markers)],
                    c=colors[i % len(colors)],
                    label=label_names[i], alpha=0.75)

    plt.legend(loc='upper right', fontsize='x-large')
    plt.axis('off')
    plt.savefig(fname=plot_path, format="svg", bbox_inches='tight', transparent=True)
    plt.close() 
Example 4
Project: magpy   Author: geomagpy   File: mpplot.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def hzfunc(self,label):
        ax = self.hzdict[label]
        num = int(label.replace("plot ",""))
        #print "Selected axis number:", num
        #global mainnum
        self.mainnum = num
        # drawtype is 'box' or 'line' or 'none'
        toggle_selector.RS = RectangleSelector(ax, self.line_select_callback,
                                           drawtype='box', useblit=True,
                                           button=[1,3], # don't use middle button
                                           minspanx=5, minspany=5,
                                           spancoords='pixels',
                                           rectprops = dict(facecolor='red', edgecolor = 'black', alpha=0.2, fill=True))

        #plt.connect('key_press_event', toggle_selector)
        plt.draw() 
Example 5
Project: scGAN   Author: imsb-uke   File: scGAN.py    License: MIT License 6 votes vote down vote up
def validation(self, sess, cells_no, exp_folder, train_step):
        """
        Method that initiates some validation steps of the current model.

        Parameters
        ----------
        sess : Session
            The TF Session in use.
        cells_no : int
            Number of cells to use for the validation step.
        exp_folder : str
            Path to the job folder in which the outputs will be saved.
        train_step : int
            Index of the current training step.

        Returns
        -------
        """
        print("Find tSNE embedding for the generated and the validation cells")
        self.generate_tSNE_image(sess, cells_no, exp_folder, train_step) 
Example 6
Project: reinforcement-learning-an-introduction   Author: ShangtongZhang   File: trajectory_sampling.py    License: MIT License 6 votes vote down vote up
def evaluate_pi(q, task):
    # use Monte Carlo method to estimate the state value
    runs = 1000
    returns = []
    for r in range(runs):
        rewards = 0
        state = 0
        while state < task.n_states:
            action = argmax(q[state])
            state, r = task.step(state, action)
            rewards += r
        returns.append(rewards)
    return np.mean(returns)

# perform expected update from a uniform state-action distribution of the MDP @task
# evaluate the learned q value every @eval_interval steps 
Example 7
Project: stdpopsim   Author: popsim-consortium   File: validation.py    License: GNU General Public License v3.0 6 votes vote down vote up
def _twopop_IM(
        engine_id, out_dir, seed,
        NA=1000, N1=500, N2=5000, T=1000, M12=0, M21=0, pulse=None, samples=None,
        **sim_kwargs):
    species = stdpopsim.get_species("AraTha")
    contig = species.get_contig("chr5", length_multiplier=0.01)  # ~270 kb
    contig = irradiate(contig)
    model = stdpopsim.IsolationWithMigration(
            NA=NA, N1=N1, N2=N2, T=T, M12=M12, M21=M21)
    if pulse is not None:
        model.demographic_events.append(pulse)
        model.demographic_events.sort(key=lambda x: x.time)
    # XXX: AraTha has species.generation_time == 1, but there is the potential
    # for this to mask bugs related to generation_time scaling, so we use 3 here.
    model.generation_time = 3
    if samples is None:
        samples = model.get_samples(50, 50, 0)
    engine = stdpopsim.get_engine(engine_id)
    t0 = time.perf_counter()
    ts = engine.simulate(model, contig, samples, seed=seed, **sim_kwargs)
    t1 = time.perf_counter()
    out_file = out_dir / f"{seed}.trees"
    ts.dump(out_file)
    return out_file, t1 - t0 
Example 8
Project: assaytools   Author: choderalab   File: analysis.py    License: GNU Lesser General Public License v2.1 6 votes vote down vote up
def _create_solutions_model(self):
        """
        Create pymc model components for true concentrations of source receptor and ligand solutions.

        Populates the following fields:
        * parameter_names['concentrations'] : parameters associated with true concentrations of receptor and ligand solutions
        """
        # Determine solutions in use in plate
        solutions_in_use = set()
        for well in self.wells:
            for shortname in well.properties['contents']:
                solutions_in_use.add(shortname)
        print('Solutions in use: %s' % str(solutions_in_use))

        # Retain only solutions that appear in the plate
        self.solutions = { shortname : self.solutions[shortname] for shortname in solutions_in_use }

        self.parameter_names['solution concentrations'] = list()
        for solution in self.solutions.values():
            if solution.species is None:
                continue # skip buffers or pure solvents
            name = 'log concentration of %s' % solution.shortname
            self.model[name] = LogNormalWrapper(name, mean=solution.concentration.to_base_units().m, stddev=solution.uncertainty.to_base_units().m)
            self.parameter_names['solution concentrations'].append(name) 
Example 9
Project: LSDMappingTools   Author: LSDtopotools   File: LSDMap_HillslopeMorphology.py    License: MIT License 6 votes vote down vote up
def chunkIt(seq, num):
    """
    This comes from https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length
    I will use it to create a bunch of lists for sequential clustering

    Args:
        seq: The initial list for chunking
        num: The number of items in each chunk

    Return:
        A chunked list with roughly equal numbers of elements

    Author: Max Shawabkeh


    """
    avg = len(seq) / float(num)
    out = []
    last = 0.0

    while last < len(seq):
        out.append(seq[int(last):int(last + avg)])
        last += avg

    return out 
Example 10
Project: westpa   Author: westpa   File: plot.py    License: MIT License 6 votes vote down vote up
def __generic_histo__(self, vector, labels):
        # This function just calls the appropriate plot function for our available
        # interface.  Same thing as generic_ci, but for a histogram.
        if self.interface == 'text':
            self.__terminal_histo__(vector, labels)
        else:
            try:
                import matplotlib
                matplotlib.use('TkAgg')
                from matplotlib import pyplot as plt
                plt.bar(list(range(0, np.array(vector).shape[0])), vector, linewidth=0, align='center', color='gold', tick_label=labels)
                plt.show()
            except:
                print('Unable to import plotting interface.  An X server ($DISPLAY) is required.')
                self.__terminal_histo__(h5file, vector, labels)
                return 1 
Example 11
Project: pase   Author: santi-pdp   File: transforms.py    License: MIT License 6 votes vote down vote up
def __init__(self, kaldi_root, hop=160, win=400, sr=16000,
                    num_mel_bins=40, num_ceps=13, der_order=2,
                    name='kaldimfcc'):

        super(KaldiMFCC, self).__init__(kaldi_root=kaldi_root, 
                                        hop=hop, win=win, sr=sr)

        self.num_mel_bins = num_mel_bins
        self.num_ceps = num_ceps
        self.der_order=der_order

        cmd = "ark:| {}/src/featbin/compute-mfcc-feats --print-args=false "\
               "--use-energy=false --snip-edges=false --num-ceps={} "\
               "--frame-length={} --frame-shift={} "\
               "--num-mel-bins={} --sample-frequency={} ark:- ark:- |"\
               " {}/src/featbin/add-deltas --print-args=false "\
               "--delta-order={} ark:- ark:- |"

        self.cmd = cmd.format(self.kaldi_root, self.num_ceps,
                              self.frame_length, self.frame_shift,
                              self.num_mel_bins, self.sr, self.kaldi_root,
                              self.der_order)
        self.name = name 
Example 12
Project: pase   Author: santi-pdp   File: transforms.py    License: MIT License 6 votes vote down vote up
def __init__(self, kaldi_root, hop=160, win=400, sr=16000,
                 num_mel_bins=20, num_ceps=20, lpc_order=20,
                 name='kaldiplp'):

        super(KaldiPLP, self).__init__(kaldi_root=kaldi_root, 
                                        hop=hop, win=win, sr=sr)

        self.num_mel_bins = num_mel_bins
        self.num_ceps = num_ceps
        self.lpc_order = lpc_order

        cmd = "ark:| {}/src/featbin/compute-plp-feats "\
               "--print-args=false --snip-edges=false --use-energy=false "\
               "--num-ceps={} --lpc-order={} "\
               "--frame-length={} --frame-shift={} "\
               "--num-mel-bins={} --sample-frequency={} "\
               "ark:- ark:- |"

        self.cmd = cmd.format(self.kaldi_root, self.num_ceps, self.lpc_order, 
                              self.frame_length, self.frame_shift, 
                              self.num_mel_bins, self.sr)
        self.name = name 
Example 13
Project: Dropout_BBalpha   Author: YingzhenLi   File: loading_utils.py    License: MIT License 6 votes vote down vote up
def plot_images(ax, images, shape, color = False):
     # finally save to file
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    # flip 0 to 1
    images = 1.0 - images

    images = reshape_and_tile_images(images, shape, n_cols=len(images))
    if color:
        from matplotlib import cm
        plt.imshow(images, cmap=cm.Greys_r, interpolation='nearest')
    else:
        plt.imshow(images, cmap='Greys')
    ax.axis('off') 
Example 14
Project: PyGeoIpMap   Author: pieqq   File: pygeoipmap.py    License: MIT License 6 votes vote down vote up
def get_lat_lon_from_csv(csv_file, lats=[], lons=[]):
    """
    Retrieves the last two rows of a CSV formatted file to use as latitude
    and longitude.
    Returns two lists (latitudes and longitudes).

    Example CSV file:
    119.80.39.54, Beijing, China, 39.9289, 116.3883
    101.44.1.135, Shanghai, China, 31.0456, 121.3997
    219.144.17.74, Xian, China, 34.2583, 108.9286
    64.27.26.7, Los Angeles, United States, 34.053, -118.2642
    """
    with contextlib.closing(csv_file):
        reader = csv.reader(csv_file)
        for row in reader:
            lats.append(row[-2])
            lons.append(row[-1])

    return lats, lons 
Example 15
Project: pmdarima   Author: alkaline-ml   File: matplotlib.py    License: MIT License 6 votes vote down vote up
def mpl_hist_arg(value=True):
    """Find the appropriate `density` kwarg for our given matplotlib version.

    This will determine if we should use `normed` or `density`. Additionally,
    since this is a kwarg, the user can supply a value (True or False) that
    they would like in the output dictionary.

    Parameters
    ----------
    value : bool, optional (default=True)
        The boolean value of density/normed

    Returns
    -------
    density_kwarg : dict
        A dictionary containing the appropriate density kwarg for the
        installed  matplotlib version, mapped to the provided or default
        value
    """
    import matplotlib

    density_kwarg = 'density' if matplotlib.__version__ >= '2.1.0'\
        else 'normed'
    return {density_kwarg: value} 
Example 16
Project: neural-fingerprinting   Author: StephanZheng   File: utils.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def pair_visual(original, adversarial, figure=None):
    """
    This function displays two images: the original and the adversarial sample
    :param original: the original input
    :param adversarial: the input after perterbations have been applied
    :param figure: if we've already displayed images, use the same plot
    :return: the matplot figure to reuse for future samples
    """
    import matplotlib.pyplot as plt

    # Squeeze the image to remove single-dimensional entries from array shape
    original = np.squeeze(original)
    adversarial = np.squeeze(adversarial)

    # Ensure our inputs are of proper shape
    assert(len(original.shape) == 2 or len(original.shape) == 3)

    # To avoid creating figures per input sample, reuse the sample plot
    if figure is None:
        plt.ion()
        figure = plt.figure()
        figure.canvas.set_window_title('Cleverhans: Pair Visualization')

    # Add the images to the plot
    perterbations = adversarial - original
    for index, image in enumerate((original, perterbations, adversarial)):
        figure.add_subplot(1, 3, index + 1)
        plt.axis('off')

        # If the image is 2D, then we have 1 color channel
        if len(image.shape) == 2:
            plt.imshow(image, cmap='gray')
        else:
            plt.imshow(image)

        # Give the plot some time to update
        plt.pause(0.01)

    # Draw the plot and return
    plt.show()
    return figure 
Example 17
Project: tmhmm.py   Author: dansondergaard   File: cli.py    License: MIT License 5 votes vote down vote up
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--file', dest='sequence_file',
                        type=argparse.FileType('r'), required=True,
                        help='path to file in fasta format with sequences')
    parser.add_argument('-m', '--model', dest='model_file',
                        type=argparse.FileType('r'), default=DEFAULT_MODEL,
                        help='path to the model to use')
    if has_matplotlib:
        parser.add_argument('-p', '--plot', dest='plot_posterior',
                            action='store_true',
                            help='plot posterior probabilies')

    args = parser.parse_args()

    header, model = parse(args.model_file)
    for entry in load_fasta_file(args.sequence_file):
        path, posterior = predict(entry.sequence, header, model)

        with open(entry.id + '.summary', 'w') as summary_file:
            for start, end, state in summarize(path):
                print("{} {} {}".format(start, end, PRETTY_NAMES[state]),
                      file=summary_file)

        with open(entry.id + '.annotation', 'w') as ann_file:
            print('>', entry.id, ' ', entry.description, sep='', file=ann_file)
            for line in textwrap.wrap(path, 79):
                print(line, file=ann_file)

        plot_filename = entry.id + '.plot'
        with open(plot_filename, 'w') as plot_file:
            dump_posterior_file(plot_file, posterior)

        if hasattr(args, 'plot_posterior') and args.plot_posterior:
            with open(plot_filename, 'r') as fileobj:
                plot(fileobj, entry.id + '.pdf') 
Example 18
Project: bioservices   Author: cokelaer   File: conf.py    License: GNU General Public License v3.0 5 votes vote down vote up
def setup(app):
    app.add_javascript('copybutton.js')
    app.connect('autodoc-process-docstring', touch_example_backreferences)

# -- Options for HTML output ---------------------------------------------------

# The theme to use for HTML and HTML Help pages.  Major themes that come with
# Sphinx are currently 'default' and 'sphinxdoc'. 
Example 19
Project: EXOSIMS   Author: dsavransky   File: plotC0vsT0andCvsT.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, args=None):
        """
        Args:
            args (dict) - 'file' keyword specifies specific pkl file to use
        """
        self.args = args
        pass 
Example 20
def __init__(self, args=None):
        """
        Args:
            args (dict) - 'file' keyword specifies specific pkl file to use
        """
        self.args = args
        pass 
Example 21
Project: EXOSIMS   Author: dsavransky   File: kopparapuPlot.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def is_earthlike(self, specs, plan_id, star_ind):
        """Depricated Determine if this planet is Earth-Like or Not, given specs/star id/planet id
        """
        # extract planet and star properties
        Rp_plan = strip_units(specs['Rp'][plan_id])
        a_plan = strip_units(specs['a'][plan_id])
        L_star = specs['L'][star_ind]
        L_plan = L_star / (a_plan**2) # adjust star luminosity by distance^2 in AU
        # its radius (in earth radii) and solar-equivalent luminosity must be
        # between given bounds.  The lower Rp bound is not axis-parallel, but
        # the best axis-parallel bound is 0.90, so that's what we use.
        return (Rp_plan >= 0.90 and Rp_plan <= 1.4) and (L_plan >= 0.3586 and L_plan <= 1.1080) 
Example 22
Project: EXOSIMS   Author: dsavransky   File: kopparapuPlot.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def is_earthlike2(self, Rp, L_plan):
        """ Determine if this planet is Earth-Like or Not, given Rp & L_plan
        NEED CITATION ON THESE RANGES
        Args:
            Rp (float) - planet radius in Earth-Radii
            L_plan (float) - adjusted stellar flux on planet
        Returns:
            earthLike (boolean) - True if planet is earth-like, False o.w.
        """
        # its radius (in earth radii) and solar-equivalent luminosity must be
        # between given bounds.  The lower Rp bound is not axis-parallel, but
        # the best axis-parallel bound is 0.90, so that's what we use.
        return (Rp_plan >= 0.90 and Rp_plan <= 1.4) and (L_plan >= 0.3586 and L_plan <= 1.1080) 
Example 23
Project: sopt   Author: Lyrichu   File: SA.py    License: MIT License 5 votes vote down vote up
def run(self):
        '''
        run SA
        :return:
        '''
        T = self.T_start
        while(T > self.T_end):
            for i in range(self.L):
                init_pos_disturb = self.random_disturb()
                delta = self.func(init_pos_disturb)-self.func(self.init_pos)
                if (self.func_type == sa_config.func_type_min and delta<0) \
                        or (self.func_type == sa_config.func_type_max and delta > 0):
                    self.init_pos = init_pos_disturb
                else:
                    if self.func_type == sa_config.func_type_min:
                        sign = 1
                    else:
                        sign = -1
                    rnd = np.exp(-sign*delta/T)
                    if np.random.random() < rnd:
                        # use a small probability to accept the worse solution
                        self.init_pos = init_pos_disturb
                self.steps += 1
                self.generations_best_targets.append(self.func(self.init_pos))
                self.generations_best_points.append(self.init_pos)
            T *= self.q
        if self.func_type == sa_config.func_type_min:
            self.global_best_index = np.argmin(self.generations_best_targets)
            self.global_best_target = np.min(self.generations_best_targets)
            self.global_best_point = self.generations_best_points[int(np.argmin(np.array(self.generations_best_targets)))]
        else:
            self.global_best_index = np.argmax(self.generations_best_targets)
            self.global_best_target = np.max(self.generations_best_targets)
            self.global_best_point = self.generations_best_points[int(np.argmax(np.array(self.generations_best_targets)))] 
Example 24
Project: ciftify   Author: edickie   File: cifti_vis_PINT.py    License: MIT License 5 votes vote down vote up
def main():
    global DEBUG
    arguments  = docopt(__doc__)
    snaps      = arguments['subject'] or arguments['snaps']
    index      = arguments['index']
    verbose    = arguments['--verbose']
    DEBUG      = arguments['--debug']

    if arguments['snaps']:
        logger.warning("The 'snaps' argument has be deprecated. Please use 'subject' in the future.")

    if verbose:
        logger.setLevel(logging.INFO)
        # Also set level for all loggers in ciftify module (or else will be
        # logging.WARN by default)
        logging.getLogger('ciftify').setLevel(logging.INFO)
    if DEBUG:
        logger.setLevel(logging.DEBUG)
        logging.getLogger('ciftify').setLevel(logging.DEBUG)

    ciftify.utils.log_arguments(arguments)

    settings = UserSettings(arguments)
    qc_config = ciftify.qc_config.Config(settings.qc_mode)

    ## make pics and qcpage for each subject
    if snaps:
        with ciftify.utils.TempDir() as temp_dir:
            logger.debug('Created tempdir {} on host {}'.format(temp_dir,
                    os.uname()[1]))
            logger.info("Making snaps for subject: {}".format(
                    settings.subject))
            ret = run_snaps(settings, qc_config, temp_dir, temp_dir)
        return ret

    # Start the index html file
    if index:
        logger.info("Writing Index pages to: {}".format(settings.qc_dir))
        ret = write_all_index_pages(settings, qc_config)
        return ret 
Example 25
Project: garden.matplotlib   Author: kivy-garden   File: backend_kivy.py    License: MIT License 5 votes vote down vote up
def weight_as_number(self, weight):
        ''' Replaces the deprecated matplotlib function of the same name
        '''
        # Return if number
        if isinstance(weight, numbers.Number):
            return weight
        # else use the mapping of matplotlib 2.2
        elif weight == 'ultralight':
            return 100
        elif weight == 'light':
            return 200
        elif weight == 'normal':
            return 400
        elif weight == 'regular':
            return 400
        elif weight == 'book':
            return 500
        elif weight == 'medium':
            return 500
        elif weight == 'roman':
            return 500
        elif weight == 'semibold':
            return 600
        elif weight == 'demibold':
            return 600
        elif weight == 'demi':
            return 600
        elif weight == 'bold':
            return 700
        elif weight == 'heavy':
            return 800
        elif weight == 'extra bold':
            return 800
        elif weight == 'black':
            return 900
        else:
            raise ValueError('weight ' + weight + ' not valid') 
Example 26
Project: prefactor   Author: lofar-astron   File: getStructure_from_phases.py    License: GNU General Public License v3.0 5 votes vote down vote up
def mad(arr):
    """ Median Absolute Deviation: a "Robust" version of standard deviation.
        Indices variabililty of the sample.
        https://en.wikipedia.org/wiki/Median_absolute_deviation 
    """
    arr = np.ma.array(arr).compressed() # should be faster to not use masked arrays.
    med = np.median(arr)
    return np.median(np.abs(arr - med)) 
Example 27
Project: qb   Author: Pinafore   File: new_performance.py    License: MIT License 5 votes vote down vote up
def report(variables, save_dir, folds):
    # use this to have jinja skip non-existent features
    jinja_keys = ['his_lines', 'his_stacked', 'rush_late_plot', 'choice_plot',
            'hype_configs', 'protobowl_plot', 'protobowl_stats']
    _variables = {k: dict() for k in jinja_keys}
    _variables.update(variables)
    if len(folds) == 1:
        output = os.path.join(save_dir, 'report_{}.pdf'.format(folds[0]))
    else:
        output = os.path.join(save_dir, 'report_all.pdf')
    report_generator = ReportGenerator('new_performance.md')
    report_generator.create(_variables, output) 
Example 28
Project: qb   Author: Pinafore   File: abstract.py    License: MIT License 5 votes vote down vote up
def __init__(self, config_num: Optional[int]):
        """
        Abstract class representing a guesser. All abstract methods must be implemented. Class
        construction should be light and not load data since this is reserved for the
        AbstractGuesser.load method.

        :param config_num: Required parameter saying which configuration of the guesser to use or explicitly not
            requesting one by passing None. If it is None implementors should not read the guesser config, otherwise
            read the appropriate configuration. This is a positional argument to force all implementors to fail fast
            rather than implicitly
        """
        self.config_num = config_num 
Example 29
Project: scGAN   Author: imsb-uke   File: scGAN.py    License: MIT License 5 votes vote down vote up
def read_valid_cells(self, sess, cells_no):
        """
        Method that reads a given number of cells from the validation set.

        Parameters
        ----------
        sess : Session
            The TF Session in use.
        cells_no : int
            Number of validation cells to read.

        Returns
        -------
        real_cells : numpy array
            Matrix with the required amount of validation cells.
        """

        batches_no = int(np.ceil(cells_no // self.batch_size))
        real_cells = []
        for i_batch in range(batches_no):
            test_inputs = sess.run([self.test_cells])
            real_cells.append(test_inputs)

        real_cells = np.array(real_cells, dtype=np.float32)
        real_cells = real_cells.reshape((-1, self.test_cells.shape[1]))

        real_cells = rescale(real_cells,
                             scaling=self.scaling,
                             scale_value=self.scale_value)

        return real_cells 
Example 30
Project: scGAN   Author: imsb-uke   File: cscGAN.py    License: MIT License 5 votes vote down vote up
def read_valid_cells(self, sess, cells_no):
        """
        Method that reads a given number of cells from the validation set.

        Parameters
        ----------
        sess : Session
            The TF Session in use.
        cells_no : int
            Number of validation cells to read.

        Returns
        -------
        real_cells : numpy array
            Matrix with the required amount of validation cells.
        real_clusters : list
            List containing the corresponding cluster indexes.
        """

        batches_no = int(np.ceil(cells_no // self.batch_size))

        real_cells = []
        real_clusters = []
        for i_batch in range(batches_no):
            test_inputs, test_clusters = sess.run(
                [self.test_cells, self.test_cells_clusters])
            real_cells.append(test_inputs)
            real_clusters.append(test_clusters)

        real_cells = np.array(real_cells, dtype=np.float32)
        real_cells = real_cells.reshape((-1, self.test_cells.shape[1]))

        real_cells = rescale(real_cells, scaling=self.scaling,
                             scale_value=self.scale_value)

        return real_cells, real_clusters