Python matplotlib.use() Examples
The following are 30
code examples of matplotlib.use().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
matplotlib
, or try the search function
.
Example #1
Source File: utils.py From DeepLung with GNU General Public License v3.0 | 7 votes |
def getFreeId(): import pynvml pynvml.nvmlInit() def getFreeRatio(id): handle = pynvml.nvmlDeviceGetHandleByIndex(id) use = pynvml.nvmlDeviceGetUtilizationRates(handle) ratio = 0.5*(float(use.gpu+float(use.memory))) return ratio deviceCount = pynvml.nvmlDeviceGetCount() available = [] for i in range(deviceCount): if getFreeRatio(i)<70: available.append(i) gpus = '' for g in available: gpus = gpus+str(g)+',' gpus = gpus[:-1] return gpus
Example #2
Source File: trajectory_sampling.py From reinforcement-learning-an-introduction with MIT License | 6 votes |
def evaluate_pi(q, task): # use Monte Carlo method to estimate the state value runs = 1000 returns = [] for r in range(runs): rewards = 0 state = 0 while state < task.n_states: action = argmax(q[state]) state, r = task.step(state, action) rewards += r returns.append(rewards) return np.mean(returns) # perform expected update from a uniform state-action distribution of the MDP @task # evaluate the learned q value every @eval_interval steps
Example #3
Source File: pygeoipmap.py From PyGeoIpMap with MIT License | 6 votes |
def get_lat_lon_from_csv(csv_file, lats=[], lons=[]): """ Retrieves the last two rows of a CSV formatted file to use as latitude and longitude. Returns two lists (latitudes and longitudes). Example CSV file: 119.80.39.54, Beijing, China, 39.9289, 116.3883 101.44.1.135, Shanghai, China, 31.0456, 121.3997 219.144.17.74, Xian, China, 34.2583, 108.9286 64.27.26.7, Los Angeles, United States, 34.053, -118.2642 """ with contextlib.closing(csv_file): reader = csv.reader(csv_file) for row in reader: lats.append(row[-2]) lons.append(row[-1]) return lats, lons
Example #4
Source File: loading_utils.py From Dropout_BBalpha with MIT License | 6 votes |
def plot_images(ax, images, shape, color = False): # finally save to file import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt # flip 0 to 1 images = 1.0 - images images = reshape_and_tile_images(images, shape, n_cols=len(images)) if color: from matplotlib import cm plt.imshow(images, cmap=cm.Greys_r, interpolation='nearest') else: plt.imshow(images, cmap='Greys') ax.axis('off')
Example #5
Source File: matplotlib.py From pmdarima with MIT License | 6 votes |
def mpl_hist_arg(value=True): """Find the appropriate `density` kwarg for our given matplotlib version. This will determine if we should use `normed` or `density`. Additionally, since this is a kwarg, the user can supply a value (True or False) that they would like in the output dictionary. Parameters ---------- value : bool, optional (default=True) The boolean value of density/normed Returns ------- density_kwarg : dict A dictionary containing the appropriate density kwarg for the installed matplotlib version, mapped to the provided or default value """ import matplotlib density_kwarg = 'density' if matplotlib.__version__ >= '2.1.0'\ else 'normed' return {density_kwarg: value}
Example #6
Source File: transforms.py From pase with MIT License | 6 votes |
def __init__(self, kaldi_root, hop=160, win=400, sr=16000, num_mel_bins=20, num_ceps=20, lpc_order=20, name='kaldiplp'): super(KaldiPLP, self).__init__(kaldi_root=kaldi_root, hop=hop, win=win, sr=sr) self.num_mel_bins = num_mel_bins self.num_ceps = num_ceps self.lpc_order = lpc_order cmd = "ark:| {}/src/featbin/compute-plp-feats "\ "--print-args=false --snip-edges=false --use-energy=false "\ "--num-ceps={} --lpc-order={} "\ "--frame-length={} --frame-shift={} "\ "--num-mel-bins={} --sample-frequency={} "\ "ark:- ark:- |" self.cmd = cmd.format(self.kaldi_root, self.num_ceps, self.lpc_order, self.frame_length, self.frame_shift, self.num_mel_bins, self.sr) self.name = name
Example #7
Source File: transforms.py From pase with MIT License | 6 votes |
def __init__(self, kaldi_root, hop=160, win=400, sr=16000, num_mel_bins=40, num_ceps=13, der_order=2, name='kaldimfcc'): super(KaldiMFCC, self).__init__(kaldi_root=kaldi_root, hop=hop, win=win, sr=sr) self.num_mel_bins = num_mel_bins self.num_ceps = num_ceps self.der_order=der_order cmd = "ark:| {}/src/featbin/compute-mfcc-feats --print-args=false "\ "--use-energy=false --snip-edges=false --num-ceps={} "\ "--frame-length={} --frame-shift={} "\ "--num-mel-bins={} --sample-frequency={} ark:- ark:- |"\ " {}/src/featbin/add-deltas --print-args=false "\ "--delta-order={} ark:- ark:- |" self.cmd = cmd.format(self.kaldi_root, self.num_ceps, self.frame_length, self.frame_shift, self.num_mel_bins, self.sr, self.kaldi_root, self.der_order) self.name = name
Example #8
Source File: images.py From nsf with MIT License | 6 votes |
def set_device(use_gpu, multi_gpu, _log): # Decide which device to use. if use_gpu and not torch.cuda.is_available(): raise RuntimeError('use_gpu is True but CUDA is not available') if use_gpu: device = torch.device('cuda') torch.set_default_tensor_type('torch.cuda.FloatTensor') else: device = torch.device('cpu') if multi_gpu and torch.cuda.device_count() == 1: raise RuntimeError('Multiple GPU training requested, but only one GPU is available.') if multi_gpu: _log.info('Using all {} GPUs available'.format(torch.cuda.device_count())) return device
Example #9
Source File: plot.py From westpa with MIT License | 6 votes |
def __generic_histo__(self, vector, labels): # This function just calls the appropriate plot function for our available # interface. Same thing as generic_ci, but for a histogram. if self.interface == 'text': self.__terminal_histo__(vector, labels) else: try: import matplotlib matplotlib.use('TkAgg') from matplotlib import pyplot as plt plt.bar(list(range(0, np.array(vector).shape[0])), vector, linewidth=0, align='center', color='gold', tick_label=labels) plt.show() except: print('Unable to import plotting interface. An X server ($DISPLAY) is required.') self.__terminal_histo__(h5file, vector, labels) return 1
Example #10
Source File: tsne_visualizer.py From linguistic-style-transfer with Apache License 2.0 | 6 votes |
def plot_coordinates(coordinates, plot_path, markers, label_names, fig_num): matplotlib.use('svg') import matplotlib.pyplot as plt plt.figure(fig_num) for i in range(len(markers) - 1): plt.scatter(x=coordinates[markers[i]:markers[i + 1], 0], y=coordinates[markers[i]:markers[i + 1], 1], marker=plot_markers[i % len(plot_markers)], c=colors[i % len(colors)], label=label_names[i], alpha=0.75) plt.legend(loc='upper right', fontsize='x-large') plt.axis('off') plt.savefig(fname=plot_path, format="svg", bbox_inches='tight', transparent=True) plt.close()
Example #11
Source File: mpplot.py From magpy with BSD 3-Clause "New" or "Revised" License | 6 votes |
def hzfunc(self,label): ax = self.hzdict[label] num = int(label.replace("plot ","")) #print "Selected axis number:", num #global mainnum self.mainnum = num # drawtype is 'box' or 'line' or 'none' toggle_selector.RS = RectangleSelector(ax, self.line_select_callback, drawtype='box', useblit=True, button=[1,3], # don't use middle button minspanx=5, minspany=5, spancoords='pixels', rectprops = dict(facecolor='red', edgecolor = 'black', alpha=0.2, fill=True)) #plt.connect('key_press_event', toggle_selector) plt.draw()
Example #12
Source File: LSDMap_HillslopeMorphology.py From LSDMappingTools with MIT License | 6 votes |
def chunkIt(seq, num): """ This comes from https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length I will use it to create a bunch of lists for sequential clustering Args: seq: The initial list for chunking num: The number of items in each chunk Return: A chunked list with roughly equal numbers of elements Author: Max Shawabkeh """ avg = len(seq) / float(num) out = [] last = 0.0 while last < len(seq): out.append(seq[int(last):int(last + avg)]) last += avg return out
Example #13
Source File: analysis.py From assaytools with GNU Lesser General Public License v2.1 | 6 votes |
def _create_solutions_model(self): """ Create pymc model components for true concentrations of source receptor and ligand solutions. Populates the following fields: * parameter_names['concentrations'] : parameters associated with true concentrations of receptor and ligand solutions """ # Determine solutions in use in plate solutions_in_use = set() for well in self.wells: for shortname in well.properties['contents']: solutions_in_use.add(shortname) print('Solutions in use: %s' % str(solutions_in_use)) # Retain only solutions that appear in the plate self.solutions = { shortname : self.solutions[shortname] for shortname in solutions_in_use } self.parameter_names['solution concentrations'] = list() for solution in self.solutions.values(): if solution.species is None: continue # skip buffers or pure solvents name = 'log concentration of %s' % solution.shortname self.model[name] = LogNormalWrapper(name, mean=solution.concentration.to_base_units().m, stddev=solution.uncertainty.to_base_units().m) self.parameter_names['solution concentrations'].append(name)
Example #14
Source File: scGAN.py From scGAN with MIT License | 6 votes |
def validation(self, sess, cells_no, exp_folder, train_step): """ Method that initiates some validation steps of the current model. Parameters ---------- sess : Session The TF Session in use. cells_no : int Number of cells to use for the validation step. exp_folder : str Path to the job folder in which the outputs will be saved. train_step : int Index of the current training step. Returns ------- """ print("Find tSNE embedding for the generated and the validation cells") self.generate_tSNE_image(sess, cells_no, exp_folder, train_step)
Example #15
Source File: validation.py From stdpopsim with GNU General Public License v3.0 | 6 votes |
def _twopop_IM( engine_id, out_dir, seed, NA=1000, N1=500, N2=5000, T=1000, M12=0, M21=0, pulse=None, samples=None, **sim_kwargs): species = stdpopsim.get_species("AraTha") contig = species.get_contig("chr5", length_multiplier=0.01) # ~270 kb contig = irradiate(contig) model = stdpopsim.IsolationWithMigration( NA=NA, N1=N1, N2=N2, T=T, M12=M12, M21=M21) if pulse is not None: model.demographic_events.append(pulse) model.demographic_events.sort(key=lambda x: x.time) # XXX: AraTha has species.generation_time == 1, but there is the potential # for this to mask bugs related to generation_time scaling, so we use 3 here. model.generation_time = 3 if samples is None: samples = model.get_samples(50, 50, 0) engine = stdpopsim.get_engine(engine_id) t0 = time.perf_counter() ts = engine.simulate(model, contig, samples, seed=seed, **sim_kwargs) t1 = time.perf_counter() out_file = out_dir / f"{seed}.trees" ts.dump(out_file) return out_file, t1 - t0
Example #16
Source File: core.py From ffn with MIT License | 5 votes |
def calc_max_drawdown(prices): """ Calculates the max drawdown of a price series. If you want the actual drawdown series, please use to_drawdown_series. """ return (prices / prices.expanding(min_periods=1).max()).min() - 1
Example #17
Source File: graph.py From PyEveLiveDPS with GNU General Public License v3.0 | 5 votes |
def readjust(self, highestAverage): """ This is for use during the animation cycle, or when a user resizes the window. We must change how much room we have to draw numbers on the left-hand side, as well as adjust the y-axis values. Annoyingly, we have to use a %, not a number of pixels """ self.windowWidth = self.winfo_width() if (highestAverage < 900): self.graphFigure.subplots_adjust(left=(33/self.windowWidth), top=(1-15/self.windowWidth), bottom=(15/self.windowWidth), wspace=0, hspace=0) elif (highestAverage < 9000): self.graphFigure.subplots_adjust(left=(44/self.windowWidth), top=(1-15/self.windowWidth), bottom=(15/self.windowWidth), wspace=0, hspace=0) elif (highestAverage < 90000): self.graphFigure.subplots_adjust(left=(55/self.windowWidth), top=(1-15/self.windowWidth), bottom=(15/self.windowWidth), wspace=0, hspace=0) else: self.graphFigure.subplots_adjust(left=(66/self.windowWidth), top=(1-15/self.windowWidth), bottom=(15/self.windowWidth), wspace=0, hspace=0) if (highestAverage < 100): self.graphFigure.axes[0].set_ylim(bottom=0, top=100) else: self.graphFigure.axes[0].set_ylim(bottom=0, top=(highestAverage+highestAverage*0.1)) self.graphFigure.axes[0].get_yaxis().grid(True, linestyle="-", color="grey", alpha=0.2) self.canvas.draw()
Example #18
Source File: check_dataset.py From signaltrain with GNU General Public License v3.0 | 5 votes |
def estimate_time_shift(x, y): """ Computes the cross-correlation between time series x and y, grabs the index of where it's a maximum. This yields the time difference in samples between x and y. """ if DEBUG: print("computing cross-correlation") corr = signal.correlate(y, x, mode='same', method='fft') if DEBUG: print("finished computing cross-correlation") nx, ny = len(x), len(y) t_samples = np.arange(nx) ct_samples = t_samples - nx//2 # try to center time shift (x axis) on zero cmax_ind = np.argmax(corr) # where is the max of the cross-correlation? dt = ct_samples[cmax_ind] # grab the time shift value corresponding to the max c-corr if DEBUG: print("cmax_ind, nx//2, ny//2, dt =",cmax_ind, nx//2, ny//2, dt) fig, (ax_x, ax_y, ax_corr) = plt.subplots(3, 1) ax_x.get_shared_x_axes().join(ax_x, ax_y) ax_x.plot(t_samples, x) ax_y.plot(t_samples, y) ax_corr.plot(ct_samples, corr) plt.show() return dt # for use in filtering filenames
Example #19
Source File: read_log.py From image-compression-cnn with MIT License | 5 votes |
def plot(values, metric_name): import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import sys plt.style.use('ggplot') fig, ax = plt.subplots(1, 1, figsize=(25, 3)) ax.margins(0) x = [] y = [] for index,v in enumerate( values ): # if not index: continue # plt.plot(x, new_recall, linewidth=2, label='Condensed Mem Network') x.append(index) y.append(v[1]['our']-v[1]['jpeg']) # plt.plot(x,y, 'o') # plt.semilogy(x,y) y_neg = [max(0,i) for i in y] y_pos = [min(0,i) for i in y] plt.bar(x,y_neg) plt.bar(x,y_pos, color='r') plt.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off') plt.title(metric_name.upper(), x=0.5, y=0.8, fontsize=14) plt.legend(loc='') ax.get_xaxis().set_visible(False) ax.xaxis.set_major_formatter(plt.NullFormatter()) fig.tight_layout() # plt.savefig('plot_size_' + metric_name + '.png', bbox_inches='tight_layout', pad_inches=0) plt.savefig('plot_kodak_' + metric_name + '.png')
Example #20
Source File: train_op.py From RelativePose with BSD 3-Clause "New" or "Revised" License | 5 votes |
def import_matplotlib(): if env_display(): import matplotlib else: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt
Example #21
Source File: plot.py From westpa with MIT License | 5 votes |
def __generic_ci__(self, h5file, iteration, i, j, tau, h5key='rate_evolution', dim=0, interface=None): # This function just calls the appropriate plot function for our available # interface. if (interface == None and self.interface == 'text') or interface == 'text': if self.dim > 1: self.__terminal_ci__(h5file, iteration, i, j, tau, h5key) else: self.__terminal_expected__(h5file, iteration, i, j, tau, h5key, dim) else: try: import matplotlib matplotlib.use('TkAgg') from matplotlib import pyplot as plt if self.dim == 3: plt.plot(h5file[h5key]['expected'][:iteration, i, j] / tau, color='black') plt.plot(h5file[h5key]['ci_ubound'][:iteration, i, j] / tau, color='grey') plt.plot(h5file[h5key]['ci_lbound'][:iteration, i, j] / tau, color='grey') else: plt.plot(h5file[h5key]['expected'][:iteration, i] / tau, color='black') plt.plot(h5file[h5key]['ci_ubound'][:iteration, i] / tau, color='grey') plt.plot(h5file[h5key]['ci_lbound'][:iteration, i] / tau, color='grey') plt.show() except: print('Unable to import plotting interface. An X server ($DISPLAY) is required.') if self.dim > 1: self.__terminal_ci__(h5file, iteration, i, j, tau) else: self.__terminal_expected__(h5file, iteration, i, j, tau, h5key, dim) return 1
Example #22
Source File: test_image_regression.py From pytest-regressions with MIT License | 5 votes |
def test_image_regression(image_regression, datadir): import matplotlib # this ensures matplot lib does not use a GUI backend (such as Tk) matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np t = np.arange(0.0, 2.0, 0.01) s = 1 + np.sin(2 * np.pi * t) fig, ax = plt.subplots() ax.plot(t, s) ax.set( xlabel="time (s)", ylabel="voltage (mV)", title="About as simple as it gets, folks", ) ax.grid() image_filename = datadir / "test.png" fig.savefig(str(image_filename)) image_regression.check(image_filename.read_bytes(), diff_threshold=1.0)
Example #23
Source File: bam_cov.py From basenji with Apache License 2.0 | 5 votes |
def set_clips(self, coverage): """ Hash indexes to clip at various thresholds. Must run this before running clip_multi, which will use self.multi_clip_indexes. The objective is to estimate coverage conservatively w/ clip_max and smoothing before asking whether the raw coverage count is compelling. In: coverage (np.array): Pre-clipped genome coverage. Out: self.adaptive_t (int->float): Clip values mapped to coverage thresholds above which to apply them. self.multi_clip_indexes (int->np.array): Clip values mapped to genomic indexes to clip. """ # choose clip thresholds if len(self.adaptive_t) == 0: for clip_value in range(2, self.clip_max + 1): # aiming for .01 cumulative density above the threshold. # decreasing the density increases the thresholds. cdf_matcher = lambda u: (self.adaptive_cdf - (1-poisson.cdf(clip_value, u)))**2 self.adaptive_t[clip_value] = minimize(cdf_matcher, clip_value)['x'][0] # take indexes with coverage between this clip threshold and the next self.multi_clip_indexes = {} for clip_value in range(2, self.clip_max): mci = np.where((coverage > self.adaptive_t[clip_value]) & (coverage <= self.adaptive_t[clip_value + 1]))[0] if len(mci) > 0: self.multi_clip_indexes[clip_value] = mci print('Sites clipped to %d: %d' % (clip_value, len(mci))) # set the last clip_value mci = np.where(coverage > self.adaptive_t[self.clip_max])[0] if len(mci) > 0: self.multi_clip_indexes[self.clip_max] = mci print('Sites clipped to %d: %d' % (self.clip_max, len(mci)))
Example #24
Source File: matplotlib.py From pmdarima with MIT License | 5 votes |
def get_compatible_pyplot(backend=None, debug=True): """Make the backend of MPL compatible. In Travis Mac distributions, python is not installed as a framework. This means that using the TkAgg backend is the best solution (so it doesn't try to use the mac OS backend by default). Parameters ---------- backend : str, optional (default="TkAgg") The backend to default to. debug : bool, optional (default=True) Whether to log the existing backend to stderr. """ import matplotlib # If the backend provided is None, just default to # what's already being used. existing_backend = matplotlib.get_backend() if backend is not None: # Can this raise?... matplotlib.use(backend) # Print out the new backend if debug: sys.stderr.write("Currently using '%s' MPL backend, " "switching to '%s' backend%s" % (existing_backend, backend, os.linesep)) # If backend is not set via env variable, but debug is elif debug: sys.stderr.write("Using '%s' MPL backend%s" % (existing_backend, os.linesep)) from matplotlib import pyplot as plt return plt
Example #25
Source File: maze.py From reinforcement-learning-an-introduction with MIT License | 5 votes |
def insert(self, priority, state, action): # note the priority queue is a minimum heap, so we use -priority self.priority_queue.add_item((tuple(state), action), -priority) # @return: whether the priority queue is empty
Example #26
Source File: trajectory_sampling.py From reinforcement-learning-an-introduction with MIT License | 5 votes |
def __init__(self, n_states, b): self.n_states = n_states self.b = b # transition matrix, each state-action pair leads to b possible states self.transition = np.random.randint(n_states, size=(n_states, len(ACTIONS), b)) # it is not clear how to set the reward, I use a unit normal distribution here # reward is determined by (s, a, s') self.reward = np.random.randn(n_states, len(ACTIONS), b)
Example #27
Source File: blackjack.py From reinforcement-learning-an-introduction with MIT License | 5 votes |
def monte_carlo_es(episodes): # (playerSum, dealerCard, usableAce, action) state_action_values = np.zeros((10, 10, 2, 2)) # initialze counts to 1 to avoid division by 0 state_action_pair_count = np.ones((10, 10, 2, 2)) # behavior policy is greedy def behavior_policy(usable_ace, player_sum, dealer_card): usable_ace = int(usable_ace) player_sum -= 12 dealer_card -= 1 # get argmax of the average returns(s, a) values_ = state_action_values[player_sum, dealer_card, usable_ace, :] / \ state_action_pair_count[player_sum, dealer_card, usable_ace, :] return np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)]) # play for several episodes for episode in tqdm(range(episodes)): # for each episode, use a randomly initialized state and action initial_state = [bool(np.random.choice([0, 1])), np.random.choice(range(12, 22)), np.random.choice(range(1, 11))] initial_action = np.random.choice(ACTIONS) current_policy = behavior_policy if episode else target_policy_player _, reward, trajectory = play(current_policy, initial_state, initial_action) first_visit_check = set() for (usable_ace, player_sum, dealer_card), action in trajectory: usable_ace = int(usable_ace) player_sum -= 12 dealer_card -= 1 state_action = (usable_ace, player_sum, dealer_card, action) if state_action in first_visit_check: continue first_visit_check.add(state_action) # update values of state-action pairs state_action_values[player_sum, dealer_card, usable_ace, action] += reward state_action_pair_count[player_sum, dealer_card, usable_ace, action] += 1 return state_action_values / state_action_pair_count # Monte Carlo Sample with Off-Policy
Example #28
Source File: cliff_walking.py From reinforcement-learning-an-introduction with MIT License | 5 votes |
def choose_action(state, q_value): if np.random.binomial(1, EPSILON) == 1: return np.random.choice(ACTIONS) else: values_ = q_value[state[0], state[1], :] return np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)]) # an episode with Sarsa # @q_value: values for state action pair, will be updated # @expected: if True, will use expected Sarsa algorithm # @step_size: step size for updating # @return: total rewards within this episode
Example #29
Source File: nucleus.py From PanopticSegmentation with MIT License | 5 votes |
def load_nucleus(self, dataset_dir, subset): """Load a subset of the nuclei dataset. dataset_dir: Root directory of the dataset subset: Subset to load. Either the name of the sub-directory, such as stage1_train, stage1_test, ...etc. or, one of: * train: stage1_train excluding validation images * val: validation images from VAL_IMAGE_IDS """ # Add classes. We have one class. # Naming the dataset nucleus, and the class nucleus self.add_class("nucleus", 1, "nucleus") # Which subset? # "val": use hard-coded list above # "train": use data from stage1_train minus the hard-coded list above # else: use the data from the specified sub-directory assert subset in ["train", "val", "stage1_train", "stage1_test", "stage2_test"] subset_dir = "stage1_train" if subset in ["train", "val"] else subset dataset_dir = os.path.join(dataset_dir, subset_dir) if subset == "val": image_ids = VAL_IMAGE_IDS else: # Get image ids from directory names image_ids = next(os.walk(dataset_dir))[1] if subset == "train": image_ids = list(set(image_ids) - set(VAL_IMAGE_IDS)) # Add images for image_id in image_ids: self.add_image( "nucleus", image_id=image_id, path=os.path.join(dataset_dir, image_id, "images/{}.png".format(image_id)))
Example #30
Source File: mountain_car.py From reinforcement-learning-an-introduction with MIT License | 5 votes |
def figure_12_10(): runs = 30 episodes = 50 alphas = np.arange(1, 8) / 4.0 lams = [0.99, 0.95, 0.5, 0] steps = np.zeros((len(lams), len(alphas), runs, episodes)) for lamInd, lam in enumerate(lams): for alphaInd, alpha in enumerate(alphas): for run in tqdm(range(runs)): evaluator = Sarsa(alpha, lam, replacing_trace) for ep in range(episodes): step = play(evaluator) steps[lamInd, alphaInd, run, ep] = step # average over episodes steps = np.mean(steps, axis=3) # average over runs steps = np.mean(steps, axis=2) for lamInd, lam in enumerate(lams): plt.plot(alphas, steps[lamInd, :], label='lambda = %s' % (str(lam))) plt.xlabel('alpha * # of tilings (8)') plt.ylabel('averaged steps per episode') plt.ylim([180, 300]) plt.legend() plt.savefig('../images/figure_12_10.png') plt.close() # figure 12.11, summary comparision of Sarsa(lambda) algorithms # I use 8 tilings rather than 10 tilings