import matplotlib matplotlib.use('Agg') import matplotlib.pylab as plt import numpy as np import pickle import os from pySDC.implementations.problem_classes.Van_der_Pol_implicit import vanderpol from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI from pySDC.projects.RDC.equidistant_RDC import Equidistant_RDC def compute_RDC_errors(): """ Van der Pol's oscillator with RDC """ # initialize level parameters level_params = dict() level_params['restol'] = 0 level_params['dt'] = 10.0 / 40.0 # initialize sweeper parameters sweeper_params = dict() sweeper_params['collocation_class'] = Equidistant_RDC sweeper_params['num_nodes'] = 41 sweeper_params['QI'] = 'IE' # initialize problem parameters problem_params = dict() problem_params['newton_tol'] = 1E-14 problem_params['newton_maxiter'] = 50 problem_params['mu'] = 10 problem_params['u0'] = (2.0, 0) # initialize step parameters step_params = dict() step_params['maxiter'] = None # initialize controller parameters controller_params = dict() controller_params['logger_level'] = 30 # Fill description dictionary for easy hierarchy creation description = dict() description['problem_class'] = vanderpol description['problem_params'] = problem_params description['sweeper_class'] = generic_implicit description['sweeper_params'] = sweeper_params description['level_params'] = level_params description['step_params'] = step_params # instantiate the controller controller_rdc = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) # set time parameters t0 = 0.0 Tend = 10.0 # get initial values on finest level P = controller_rdc.MS[0].levels[0].prob uinit = P.u_exact(t0) ref_sol = np.load('data/vdp_ref.npy') maxiter_list = range(1, 11) results = dict() results['maxiter_list'] = maxiter_list for maxiter in maxiter_list: # ugly, but much faster than re-initializing the controller over and over again controller_rdc.MS[0].params.maxiter = maxiter # call main function to get things done... uend_rdc, stats_rdc = controller_rdc.run(u0=uinit, t0=t0, Tend=Tend) err = np.linalg.norm(uend_rdc.values - ref_sol, np.inf) / np.linalg.norm(ref_sol, np.inf) print('Maxiter = %2i -- Error: %8.4e' % (controller_rdc.MS[0].params.maxiter, err)) results[maxiter] = err fname = 'data/vdp_results.pkl' file = open(fname, 'wb') pickle.dump(results, file) file.close() assert os.path.isfile(fname), 'ERROR: pickle did not create file' def plot_RDC_results(cwd=''): """ Routine to visualize the errors Args: cwd (string): current working directory """ file = open(cwd + 'data/vdp_results.pkl', 'rb') results = pickle.load(file, encoding='latin-1') file.close() # retrieve the list of nvars from results assert 'maxiter_list' in results, 'ERROR: expecting the list of maxiters in the results dictionary' maxiter_list = sorted(results['maxiter_list']) # Set up plotting parameters params = {'legend.fontsize': 20, 'figure.figsize': (12, 8), 'axes.labelsize': 20, 'axes.titlesize': 20, 'xtick.labelsize': 16, 'ytick.labelsize': 16, 'lines.linewidth': 3 } plt.rcParams.update(params) # create new figure plt.figure() # take x-axis limits from nvars_list + some spacning left and right plt.xlim([min(maxiter_list) - 1, max(maxiter_list) + 1]) plt.xlabel('maxiter') plt.ylabel('rel. error') plt.grid() min_err = 1E99 max_err = 0E00 err_list = [] # loop over nvars, get errors and find min/max error for y-axis limits for maxiter in maxiter_list: err = results[maxiter] min_err = min(err, min_err) max_err = max(err, max_err) err_list.append(err) plt.semilogy(maxiter_list, err_list, ls='-', marker='o', markersize=10, label='RDC') # adjust y-axis limits, add legend plt.ylim([min_err / 10, max_err * 10]) plt.legend(loc=1, ncol=1, numpoints=1) # plt.show() # save plot as PNG, beautify fname = 'data/RDC_errors_vdp.png' plt.savefig(fname, rasterized=True, bbox_inches='tight') assert os.path.isfile(fname), 'ERROR: plot was not created' return None if __name__ == "__main__": compute_RDC_errors() plot_RDC_results()