Python matplotlib.use() Examples

The following are 30 code examples of matplotlib.use(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib , or try the search function .
Example #1
Source File: utils.py    From DeepLung with GNU General Public License v3.0 7 votes vote down vote up
def getFreeId():
    import pynvml 

    pynvml.nvmlInit()
    def getFreeRatio(id):
        handle = pynvml.nvmlDeviceGetHandleByIndex(id)
        use = pynvml.nvmlDeviceGetUtilizationRates(handle)
        ratio = 0.5*(float(use.gpu+float(use.memory)))
        return ratio

    deviceCount = pynvml.nvmlDeviceGetCount()
    available = []
    for i in range(deviceCount):
        if getFreeRatio(i)<70:
            available.append(i)
    gpus = ''
    for g in available:
        gpus = gpus+str(g)+','
    gpus = gpus[:-1]
    return gpus 
Example #2
Source File: trajectory_sampling.py    From reinforcement-learning-an-introduction with MIT License 6 votes vote down vote up
def evaluate_pi(q, task):
    # use Monte Carlo method to estimate the state value
    runs = 1000
    returns = []
    for r in range(runs):
        rewards = 0
        state = 0
        while state < task.n_states:
            action = argmax(q[state])
            state, r = task.step(state, action)
            rewards += r
        returns.append(rewards)
    return np.mean(returns)

# perform expected update from a uniform state-action distribution of the MDP @task
# evaluate the learned q value every @eval_interval steps 
Example #3
Source File: pygeoipmap.py    From PyGeoIpMap with MIT License 6 votes vote down vote up
def get_lat_lon_from_csv(csv_file, lats=[], lons=[]):
    """
    Retrieves the last two rows of a CSV formatted file to use as latitude
    and longitude.
    Returns two lists (latitudes and longitudes).

    Example CSV file:
    119.80.39.54, Beijing, China, 39.9289, 116.3883
    101.44.1.135, Shanghai, China, 31.0456, 121.3997
    219.144.17.74, Xian, China, 34.2583, 108.9286
    64.27.26.7, Los Angeles, United States, 34.053, -118.2642
    """
    with contextlib.closing(csv_file):
        reader = csv.reader(csv_file)
        for row in reader:
            lats.append(row[-2])
            lons.append(row[-1])

    return lats, lons 
Example #4
Source File: loading_utils.py    From Dropout_BBalpha with MIT License 6 votes vote down vote up
def plot_images(ax, images, shape, color = False):
     # finally save to file
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    # flip 0 to 1
    images = 1.0 - images

    images = reshape_and_tile_images(images, shape, n_cols=len(images))
    if color:
        from matplotlib import cm
        plt.imshow(images, cmap=cm.Greys_r, interpolation='nearest')
    else:
        plt.imshow(images, cmap='Greys')
    ax.axis('off') 
Example #5
Source File: matplotlib.py    From pmdarima with MIT License 6 votes vote down vote up
def mpl_hist_arg(value=True):
    """Find the appropriate `density` kwarg for our given matplotlib version.

    This will determine if we should use `normed` or `density`. Additionally,
    since this is a kwarg, the user can supply a value (True or False) that
    they would like in the output dictionary.

    Parameters
    ----------
    value : bool, optional (default=True)
        The boolean value of density/normed

    Returns
    -------
    density_kwarg : dict
        A dictionary containing the appropriate density kwarg for the
        installed  matplotlib version, mapped to the provided or default
        value
    """
    import matplotlib

    density_kwarg = 'density' if matplotlib.__version__ >= '2.1.0'\
        else 'normed'
    return {density_kwarg: value} 
Example #6
Source File: transforms.py    From pase with MIT License 6 votes vote down vote up
def __init__(self, kaldi_root, hop=160, win=400, sr=16000,
                 num_mel_bins=20, num_ceps=20, lpc_order=20,
                 name='kaldiplp'):

        super(KaldiPLP, self).__init__(kaldi_root=kaldi_root, 
                                        hop=hop, win=win, sr=sr)

        self.num_mel_bins = num_mel_bins
        self.num_ceps = num_ceps
        self.lpc_order = lpc_order

        cmd = "ark:| {}/src/featbin/compute-plp-feats "\
               "--print-args=false --snip-edges=false --use-energy=false "\
               "--num-ceps={} --lpc-order={} "\
               "--frame-length={} --frame-shift={} "\
               "--num-mel-bins={} --sample-frequency={} "\
               "ark:- ark:- |"

        self.cmd = cmd.format(self.kaldi_root, self.num_ceps, self.lpc_order, 
                              self.frame_length, self.frame_shift, 
                              self.num_mel_bins, self.sr)
        self.name = name 
Example #7
Source File: transforms.py    From pase with MIT License 6 votes vote down vote up
def __init__(self, kaldi_root, hop=160, win=400, sr=16000,
                    num_mel_bins=40, num_ceps=13, der_order=2,
                    name='kaldimfcc'):

        super(KaldiMFCC, self).__init__(kaldi_root=kaldi_root, 
                                        hop=hop, win=win, sr=sr)

        self.num_mel_bins = num_mel_bins
        self.num_ceps = num_ceps
        self.der_order=der_order

        cmd = "ark:| {}/src/featbin/compute-mfcc-feats --print-args=false "\
               "--use-energy=false --snip-edges=false --num-ceps={} "\
               "--frame-length={} --frame-shift={} "\
               "--num-mel-bins={} --sample-frequency={} ark:- ark:- |"\
               " {}/src/featbin/add-deltas --print-args=false "\
               "--delta-order={} ark:- ark:- |"

        self.cmd = cmd.format(self.kaldi_root, self.num_ceps,
                              self.frame_length, self.frame_shift,
                              self.num_mel_bins, self.sr, self.kaldi_root,
                              self.der_order)
        self.name = name 
Example #8
Source File: images.py    From nsf with MIT License 6 votes vote down vote up
def set_device(use_gpu, multi_gpu, _log):
    # Decide which device to use.
    if use_gpu and not torch.cuda.is_available():
        raise RuntimeError('use_gpu is True but CUDA is not available')

    if use_gpu:
        device = torch.device('cuda')
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        device = torch.device('cpu')

    if multi_gpu and torch.cuda.device_count() == 1:
        raise RuntimeError('Multiple GPU training requested, but only one GPU is available.')

    if multi_gpu:
        _log.info('Using all {} GPUs available'.format(torch.cuda.device_count()))

    return device 
Example #9
Source File: plot.py    From westpa with MIT License 6 votes vote down vote up
def __generic_histo__(self, vector, labels):
        # This function just calls the appropriate plot function for our available
        # interface.  Same thing as generic_ci, but for a histogram.
        if self.interface == 'text':
            self.__terminal_histo__(vector, labels)
        else:
            try:
                import matplotlib
                matplotlib.use('TkAgg')
                from matplotlib import pyplot as plt
                plt.bar(list(range(0, np.array(vector).shape[0])), vector, linewidth=0, align='center', color='gold', tick_label=labels)
                plt.show()
            except:
                print('Unable to import plotting interface.  An X server ($DISPLAY) is required.')
                self.__terminal_histo__(h5file, vector, labels)
                return 1 
Example #10
Source File: tsne_visualizer.py    From linguistic-style-transfer with Apache License 2.0 6 votes vote down vote up
def plot_coordinates(coordinates, plot_path, markers, label_names, fig_num):
    matplotlib.use('svg')
    import matplotlib.pyplot as plt

    plt.figure(fig_num)
    for i in range(len(markers) - 1):
        plt.scatter(x=coordinates[markers[i]:markers[i + 1], 0],
                    y=coordinates[markers[i]:markers[i + 1], 1],
                    marker=plot_markers[i % len(plot_markers)],
                    c=colors[i % len(colors)],
                    label=label_names[i], alpha=0.75)

    plt.legend(loc='upper right', fontsize='x-large')
    plt.axis('off')
    plt.savefig(fname=plot_path, format="svg", bbox_inches='tight', transparent=True)
    plt.close() 
Example #11
Source File: mpplot.py    From magpy with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def hzfunc(self,label):
        ax = self.hzdict[label]
        num = int(label.replace("plot ",""))
        #print "Selected axis number:", num
        #global mainnum
        self.mainnum = num
        # drawtype is 'box' or 'line' or 'none'
        toggle_selector.RS = RectangleSelector(ax, self.line_select_callback,
                                           drawtype='box', useblit=True,
                                           button=[1,3], # don't use middle button
                                           minspanx=5, minspany=5,
                                           spancoords='pixels',
                                           rectprops = dict(facecolor='red', edgecolor = 'black', alpha=0.2, fill=True))

        #plt.connect('key_press_event', toggle_selector)
        plt.draw() 
Example #12
Source File: LSDMap_HillslopeMorphology.py    From LSDMappingTools with MIT License 6 votes vote down vote up
def chunkIt(seq, num):
    """
    This comes from https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length
    I will use it to create a bunch of lists for sequential clustering

    Args:
        seq: The initial list for chunking
        num: The number of items in each chunk

    Return:
        A chunked list with roughly equal numbers of elements

    Author: Max Shawabkeh


    """
    avg = len(seq) / float(num)
    out = []
    last = 0.0

    while last < len(seq):
        out.append(seq[int(last):int(last + avg)])
        last += avg

    return out 
Example #13
Source File: analysis.py    From assaytools with GNU Lesser General Public License v2.1 6 votes vote down vote up
def _create_solutions_model(self):
        """
        Create pymc model components for true concentrations of source receptor and ligand solutions.

        Populates the following fields:
        * parameter_names['concentrations'] : parameters associated with true concentrations of receptor and ligand solutions
        """
        # Determine solutions in use in plate
        solutions_in_use = set()
        for well in self.wells:
            for shortname in well.properties['contents']:
                solutions_in_use.add(shortname)
        print('Solutions in use: %s' % str(solutions_in_use))

        # Retain only solutions that appear in the plate
        self.solutions = { shortname : self.solutions[shortname] for shortname in solutions_in_use }

        self.parameter_names['solution concentrations'] = list()
        for solution in self.solutions.values():
            if solution.species is None:
                continue # skip buffers or pure solvents
            name = 'log concentration of %s' % solution.shortname
            self.model[name] = LogNormalWrapper(name, mean=solution.concentration.to_base_units().m, stddev=solution.uncertainty.to_base_units().m)
            self.parameter_names['solution concentrations'].append(name) 
Example #14
Source File: scGAN.py    From scGAN with MIT License 6 votes vote down vote up
def validation(self, sess, cells_no, exp_folder, train_step):
        """
        Method that initiates some validation steps of the current model.

        Parameters
        ----------
        sess : Session
            The TF Session in use.
        cells_no : int
            Number of cells to use for the validation step.
        exp_folder : str
            Path to the job folder in which the outputs will be saved.
        train_step : int
            Index of the current training step.

        Returns
        -------
        """
        print("Find tSNE embedding for the generated and the validation cells")
        self.generate_tSNE_image(sess, cells_no, exp_folder, train_step) 
Example #15
Source File: validation.py    From stdpopsim with GNU General Public License v3.0 6 votes vote down vote up
def _twopop_IM(
        engine_id, out_dir, seed,
        NA=1000, N1=500, N2=5000, T=1000, M12=0, M21=0, pulse=None, samples=None,
        **sim_kwargs):
    species = stdpopsim.get_species("AraTha")
    contig = species.get_contig("chr5", length_multiplier=0.01)  # ~270 kb
    contig = irradiate(contig)
    model = stdpopsim.IsolationWithMigration(
            NA=NA, N1=N1, N2=N2, T=T, M12=M12, M21=M21)
    if pulse is not None:
        model.demographic_events.append(pulse)
        model.demographic_events.sort(key=lambda x: x.time)
    # XXX: AraTha has species.generation_time == 1, but there is the potential
    # for this to mask bugs related to generation_time scaling, so we use 3 here.
    model.generation_time = 3
    if samples is None:
        samples = model.get_samples(50, 50, 0)
    engine = stdpopsim.get_engine(engine_id)
    t0 = time.perf_counter()
    ts = engine.simulate(model, contig, samples, seed=seed, **sim_kwargs)
    t1 = time.perf_counter()
    out_file = out_dir / f"{seed}.trees"
    ts.dump(out_file)
    return out_file, t1 - t0 
Example #16
Source File: core.py    From ffn with MIT License 5 votes vote down vote up
def calc_max_drawdown(prices):
    """
    Calculates the max drawdown of a price series. If you want the
    actual drawdown series, please use to_drawdown_series.
    """
    return (prices / prices.expanding(min_periods=1).max()).min() - 1 
Example #17
Source File: graph.py    From PyEveLiveDPS with GNU General Public License v3.0 5 votes vote down vote up
def readjust(self, highestAverage):
        """
        This is for use during the animation cycle, or when a user resizes the window. 
        We must change how much room we have to draw numbers on the left-hand side,
          as well as adjust the y-axis values.
        Annoyingly, we have to use a %, not a number of pixels
        """
        self.windowWidth = self.winfo_width()
        if (highestAverage < 900):
            self.graphFigure.subplots_adjust(left=(33/self.windowWidth), top=(1-15/self.windowWidth), 
                                             bottom=(15/self.windowWidth), wspace=0, hspace=0)
        elif (highestAverage < 9000):
            self.graphFigure.subplots_adjust(left=(44/self.windowWidth), top=(1-15/self.windowWidth), 
                                             bottom=(15/self.windowWidth), wspace=0, hspace=0)
        elif (highestAverage < 90000):
            self.graphFigure.subplots_adjust(left=(55/self.windowWidth), top=(1-15/self.windowWidth), 
                                             bottom=(15/self.windowWidth), wspace=0, hspace=0)
        else:
            self.graphFigure.subplots_adjust(left=(66/self.windowWidth), top=(1-15/self.windowWidth), 
                                             bottom=(15/self.windowWidth), wspace=0, hspace=0)
        if (highestAverage < 100):
            self.graphFigure.axes[0].set_ylim(bottom=0, top=100)
        else:
            self.graphFigure.axes[0].set_ylim(bottom=0, top=(highestAverage+highestAverage*0.1))
        self.graphFigure.axes[0].get_yaxis().grid(True, linestyle="-", color="grey", alpha=0.2)
        self.canvas.draw() 
Example #18
Source File: check_dataset.py    From signaltrain with GNU General Public License v3.0 5 votes vote down vote up
def estimate_time_shift(x, y):
    """ Computes the cross-correlation between time series x and y, grabs the
        index of where it's a maximum.  This yields the time difference in
        samples between x and y.
    """
    if DEBUG: print("computing cross-correlation")
    corr = signal.correlate(y, x, mode='same', method='fft')
    if DEBUG: print("finished computing cross-correlation")

    nx, ny = len(x), len(y)
    t_samples = np.arange(nx)
    ct_samples = t_samples - nx//2  # try to center time shift (x axis) on zero
    cmax_ind = np.argmax(corr)      # where is the max of the cross-correlation?
    dt = ct_samples[cmax_ind]       # grab the time shift value corresponding to the max c-corr

    if DEBUG:
        print("cmax_ind, nx//2, ny//2, dt =",cmax_ind, nx//2, ny//2, dt)
        fig, (ax_x, ax_y, ax_corr) = plt.subplots(3, 1)
        ax_x.get_shared_x_axes().join(ax_x, ax_y)
        ax_x.plot(t_samples, x)
        ax_y.plot(t_samples, y)
        ax_corr.plot(ct_samples, corr)
        plt.show()
        
    return dt


#  for use in filtering filenames 
Example #19
Source File: read_log.py    From image-compression-cnn with MIT License 5 votes vote down vote up
def plot(values, metric_name):

    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import sys

    plt.style.use('ggplot')

    fig, ax = plt.subplots(1, 1, figsize=(25, 3))
    ax.margins(0)

    x = []
    y = []
    for index,v in enumerate( values ):
        # if not index: continue
        # plt.plot(x, new_recall, linewidth=2, label='Condensed Mem Network')
        x.append(index)
        y.append(v[1]['our']-v[1]['jpeg'])

    # plt.plot(x,y, 'o')
    # plt.semilogy(x,y)
    y_neg = [max(0,i) for i in y]
    y_pos = [min(0,i) for i in y]

    plt.bar(x,y_neg)
    plt.bar(x,y_pos, color='r')
    plt.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off')

    plt.title(metric_name.upper(), x=0.5, y=0.8, fontsize=14)
    plt.legend(loc='')
    ax.get_xaxis().set_visible(False)
    ax.xaxis.set_major_formatter(plt.NullFormatter())
    fig.tight_layout()
    # plt.savefig('plot_size_' + metric_name + '.png', bbox_inches='tight_layout', pad_inches=0)
    plt.savefig('plot_kodak_' + metric_name + '.png') 
Example #20
Source File: train_op.py    From RelativePose with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def import_matplotlib():
    if env_display():
        import matplotlib
    else:
        import matplotlib
        matplotlib.use('Agg')
    import matplotlib.pyplot as plt 
Example #21
Source File: plot.py    From westpa with MIT License 5 votes vote down vote up
def __generic_ci__(self, h5file, iteration, i, j, tau, h5key='rate_evolution', dim=0, interface=None):
        # This function just calls the appropriate plot function for our available
        # interface.
        if (interface == None and self.interface == 'text') or interface == 'text':
            if self.dim > 1:
                self.__terminal_ci__(h5file, iteration, i, j, tau, h5key)
            else:
                self.__terminal_expected__(h5file, iteration, i, j, tau, h5key, dim)
        else:
            try:
                import matplotlib
                matplotlib.use('TkAgg')
                from matplotlib import pyplot as plt
                if self.dim == 3:
                    plt.plot(h5file[h5key]['expected'][:iteration, i, j] / tau, color='black')
                    plt.plot(h5file[h5key]['ci_ubound'][:iteration, i, j] / tau, color='grey')
                    plt.plot(h5file[h5key]['ci_lbound'][:iteration, i, j] / tau, color='grey')
                else:
                    plt.plot(h5file[h5key]['expected'][:iteration, i] / tau, color='black')
                    plt.plot(h5file[h5key]['ci_ubound'][:iteration, i] / tau, color='grey')
                    plt.plot(h5file[h5key]['ci_lbound'][:iteration, i] / tau, color='grey')
                plt.show()
            except:
                print('Unable to import plotting interface.  An X server ($DISPLAY) is required.')
                if self.dim > 1:
                    self.__terminal_ci__(h5file, iteration, i, j, tau)
                else:
                    self.__terminal_expected__(h5file, iteration, i, j, tau, h5key, dim)
                return 1 
Example #22
Source File: test_image_regression.py    From pytest-regressions with MIT License 5 votes vote down vote up
def test_image_regression(image_regression, datadir):
    import matplotlib

    # this ensures matplot lib does not use a GUI backend (such as Tk)
    matplotlib.use("Agg")

    import matplotlib.pyplot as plt
    import numpy as np

    t = np.arange(0.0, 2.0, 0.01)
    s = 1 + np.sin(2 * np.pi * t)

    fig, ax = plt.subplots()
    ax.plot(t, s)

    ax.set(
        xlabel="time (s)",
        ylabel="voltage (mV)",
        title="About as simple as it gets, folks",
    )
    ax.grid()

    image_filename = datadir / "test.png"
    fig.savefig(str(image_filename))

    image_regression.check(image_filename.read_bytes(), diff_threshold=1.0) 
Example #23
Source File: bam_cov.py    From basenji with Apache License 2.0 5 votes vote down vote up
def set_clips(self, coverage):
    """ Hash indexes to clip at various thresholds.

        Must run this before running clip_multi, which will use
        self.multi_clip_indexes. The objective is to estimate
        coverage conservatively w/ clip_max and smoothing before
        asking whether the raw coverage count is compelling.

        In:
         coverage (np.array): Pre-clipped genome coverage.

        Out:
          self.adaptive_t (int->float): Clip values mapped to coverage thresholds
                                        above which to apply them.
          self.multi_clip_indexes (int->np.array): Clip values mapped to genomic
                                                   indexes to clip.
        """

    # choose clip thresholds
    if len(self.adaptive_t) == 0:
      for clip_value in range(2, self.clip_max + 1):
        # aiming for .01 cumulative density above the threshold.
        #  decreasing the density increases the thresholds.
        cdf_matcher = lambda u: (self.adaptive_cdf - (1-poisson.cdf(clip_value, u)))**2
        self.adaptive_t[clip_value] = minimize(cdf_matcher, clip_value)['x'][0]

    # take indexes with coverage between this clip threshold and the next
    self.multi_clip_indexes = {}
    for clip_value in range(2, self.clip_max):
      mci = np.where((coverage > self.adaptive_t[clip_value]) &
                     (coverage <= self.adaptive_t[clip_value + 1]))[0]
      if len(mci) > 0:
        self.multi_clip_indexes[clip_value] = mci
      print('Sites clipped to %d: %d' % (clip_value, len(mci)))

    # set the last clip_value
    mci = np.where(coverage > self.adaptive_t[self.clip_max])[0]
    if len(mci) > 0:
      self.multi_clip_indexes[self.clip_max] = mci
    print('Sites clipped to %d: %d' % (self.clip_max, len(mci))) 
Example #24
Source File: matplotlib.py    From pmdarima with MIT License 5 votes vote down vote up
def get_compatible_pyplot(backend=None, debug=True):
    """Make the backend of MPL compatible.

    In Travis Mac distributions, python is not installed as a framework. This
    means that using the TkAgg backend is the best solution (so it doesn't
    try to use the mac OS backend by default).

    Parameters
    ----------
    backend : str, optional (default="TkAgg")
        The backend to default to.

    debug : bool, optional (default=True)
        Whether to log the existing backend to stderr.
    """
    import matplotlib

    # If the backend provided is None, just default to
    # what's already being used.
    existing_backend = matplotlib.get_backend()
    if backend is not None:
        # Can this raise?...
        matplotlib.use(backend)

        # Print out the new backend
        if debug:
            sys.stderr.write("Currently using '%s' MPL backend, "
                             "switching to '%s' backend%s"
                             % (existing_backend, backend, os.linesep))

    # If backend is not set via env variable, but debug is
    elif debug:
        sys.stderr.write("Using '%s' MPL backend%s"
                         % (existing_backend, os.linesep))

    from matplotlib import pyplot as plt
    return plt 
Example #25
Source File: maze.py    From reinforcement-learning-an-introduction with MIT License 5 votes vote down vote up
def insert(self, priority, state, action):
        # note the priority queue is a minimum heap, so we use -priority
        self.priority_queue.add_item((tuple(state), action), -priority)

    # @return: whether the priority queue is empty 
Example #26
Source File: trajectory_sampling.py    From reinforcement-learning-an-introduction with MIT License 5 votes vote down vote up
def __init__(self, n_states, b):
        self.n_states = n_states
        self.b = b

        # transition matrix, each state-action pair leads to b possible states
        self.transition = np.random.randint(n_states, size=(n_states, len(ACTIONS), b))

        # it is not clear how to set the reward, I use a unit normal distribution here
        # reward is determined by (s, a, s')
        self.reward = np.random.randn(n_states, len(ACTIONS), b) 
Example #27
Source File: blackjack.py    From reinforcement-learning-an-introduction with MIT License 5 votes vote down vote up
def monte_carlo_es(episodes):
    # (playerSum, dealerCard, usableAce, action)
    state_action_values = np.zeros((10, 10, 2, 2))
    # initialze counts to 1 to avoid division by 0
    state_action_pair_count = np.ones((10, 10, 2, 2))

    # behavior policy is greedy
    def behavior_policy(usable_ace, player_sum, dealer_card):
        usable_ace = int(usable_ace)
        player_sum -= 12
        dealer_card -= 1
        # get argmax of the average returns(s, a)
        values_ = state_action_values[player_sum, dealer_card, usable_ace, :] / \
                  state_action_pair_count[player_sum, dealer_card, usable_ace, :]
        return np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)])

    # play for several episodes
    for episode in tqdm(range(episodes)):
        # for each episode, use a randomly initialized state and action
        initial_state = [bool(np.random.choice([0, 1])),
                       np.random.choice(range(12, 22)),
                       np.random.choice(range(1, 11))]
        initial_action = np.random.choice(ACTIONS)
        current_policy = behavior_policy if episode else target_policy_player
        _, reward, trajectory = play(current_policy, initial_state, initial_action)
        first_visit_check = set()
        for (usable_ace, player_sum, dealer_card), action in trajectory:
            usable_ace = int(usable_ace)
            player_sum -= 12
            dealer_card -= 1
            state_action = (usable_ace, player_sum, dealer_card, action)
            if state_action in first_visit_check:
                continue
            first_visit_check.add(state_action)
            # update values of state-action pairs
            state_action_values[player_sum, dealer_card, usable_ace, action] += reward
            state_action_pair_count[player_sum, dealer_card, usable_ace, action] += 1

    return state_action_values / state_action_pair_count

# Monte Carlo Sample with Off-Policy 
Example #28
Source File: cliff_walking.py    From reinforcement-learning-an-introduction with MIT License 5 votes vote down vote up
def choose_action(state, q_value):
    if np.random.binomial(1, EPSILON) == 1:
        return np.random.choice(ACTIONS)
    else:
        values_ = q_value[state[0], state[1], :]
        return np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)])

# an episode with Sarsa
# @q_value: values for state action pair, will be updated
# @expected: if True, will use expected Sarsa algorithm
# @step_size: step size for updating
# @return: total rewards within this episode 
Example #29
Source File: nucleus.py    From PanopticSegmentation with MIT License 5 votes vote down vote up
def load_nucleus(self, dataset_dir, subset):
        """Load a subset of the nuclei dataset.

        dataset_dir: Root directory of the dataset
        subset: Subset to load. Either the name of the sub-directory,
                such as stage1_train, stage1_test, ...etc. or, one of:
                * train: stage1_train excluding validation images
                * val: validation images from VAL_IMAGE_IDS
        """
        # Add classes. We have one class.
        # Naming the dataset nucleus, and the class nucleus
        self.add_class("nucleus", 1, "nucleus")

        # Which subset?
        # "val": use hard-coded list above
        # "train": use data from stage1_train minus the hard-coded list above
        # else: use the data from the specified sub-directory
        assert subset in ["train", "val", "stage1_train", "stage1_test", "stage2_test"]
        subset_dir = "stage1_train" if subset in ["train", "val"] else subset
        dataset_dir = os.path.join(dataset_dir, subset_dir)
        if subset == "val":
            image_ids = VAL_IMAGE_IDS
        else:
            # Get image ids from directory names
            image_ids = next(os.walk(dataset_dir))[1]
            if subset == "train":
                image_ids = list(set(image_ids) - set(VAL_IMAGE_IDS))

        # Add images
        for image_id in image_ids:
            self.add_image(
                "nucleus",
                image_id=image_id,
                path=os.path.join(dataset_dir, image_id, "images/{}.png".format(image_id))) 
Example #30
Source File: mountain_car.py    From reinforcement-learning-an-introduction with MIT License 5 votes vote down vote up
def figure_12_10():
    runs = 30
    episodes = 50
    alphas = np.arange(1, 8) / 4.0
    lams = [0.99, 0.95, 0.5, 0]

    steps = np.zeros((len(lams), len(alphas), runs, episodes))
    for lamInd, lam in enumerate(lams):
        for alphaInd, alpha in enumerate(alphas):
            for run in tqdm(range(runs)):
                evaluator = Sarsa(alpha, lam, replacing_trace)
                for ep in range(episodes):
                    step = play(evaluator)
                    steps[lamInd, alphaInd, run, ep] = step

    # average over episodes
    steps = np.mean(steps, axis=3)

    # average over runs
    steps = np.mean(steps, axis=2)

    for lamInd, lam in enumerate(lams):
        plt.plot(alphas, steps[lamInd, :], label='lambda = %s' % (str(lam)))
    plt.xlabel('alpha * # of tilings (8)')
    plt.ylabel('averaged steps per episode')
    plt.ylim([180, 300])
    plt.legend()

    plt.savefig('../images/figure_12_10.png')
    plt.close()

# figure 12.11, summary comparision of Sarsa(lambda) algorithms
# I use 8 tilings rather than 10 tilings