from logging import getLogger import os from beat import config as bconfig from beat.models import hyper_normal from beat import sampler from beat.backend import SampleStage, thin_buffer from pymc3 import Deterministic from collections import OrderedDict import numpy as num from pyrocko.util import ensuredir logger = getLogger('models.base') __all__ = [ 'ConfigInconsistentError', 'Composite', 'sample', 'Stage', 'load_stage', 'estimate_hypers'] class ConfigInconsistentError(Exception): def __init__(self, errmess=''): self.default = \ '\n Please run: ' \ '"beat update <project_dir> --parameters="hierarchicals"' self.errmess = errmess def __str__(self): return self.errmess + self.default class Composite(object): """ Class that comprises the rules to formulate the problem. Has to be used by an overarching problem object. """ def __init__(self): self.input_rvs = OrderedDict() self.fixed_rvs = OrderedDict() self.hierarchicals = OrderedDict() self.hyperparams = OrderedDict() self.name = None self._like_name = None self.config = None self.slip_varnames = [] def set_slip_varnames(self, varnames): """ Set slip components for GFs. """ self.slip_varnames = [ var for var in varnames if var in bconfig.static_dist_vars] def get_hyper_formula(self, hyperparams): """ Get likelihood formula for the hyper model built. Has to be called within a with model context. problem_config : :class:`config.ProblemConfig` """ hp_specific = self.config.dataset_specific_residual_noise_estimation logpts = hyper_normal( self.datasets, hyperparams, self._llks, hp_specific=hp_specific) llk = Deterministic(self._like_name, logpts) return llk.sum() def apply(self, composite): """ Update composite weight matrixes (in place) with weights in given composite. Parameters ---------- composite : :class:`Composite` containing weight matrixes to use for updates """ for i, weight in enumerate(composite.weights): A = weight.get_value(borrow=True) self.weights[i].set_value(A) def get_hypernames(self): if self.config is not None: return self.config.get_hypernames() else: return list(self.hyperparams.keys()) def sample(step, problem): """ Sample solution space with the previously initalised algorithm. Parameters ---------- step : :class:`SMC` or :class:`pymc3.metropolis.Metropolis` from problem.init_sampler() problem : :class:`Problem` with characteristics of problem to solve """ pc = problem.config.problem_config sc = problem.config.sampler_config pa = sc.parameters if hasattr(pa, 'update_covariances'): if pa.update_covariances: update = problem else: update = None if pc.mode == bconfig.ffi_mode_str: logger.info('Chain initialization with:') if pc.mode_config.initialization == 'random': logger.info('Random starting point.\n') start = None elif pc.mode_config.initialization == 'lsq': logger.info('Least-squares-solution including "uparr" only.\n') if 'seismic' in pc.datatypes: logger.warning( 'Least-squares initialization is not' ' supported (yet) for seismic data, only!') start = [] for i in range(step.n_chains): point = problem.get_random_point() start.append(problem.lsq_solution(point)) else: start = None if sc.name == 'Metropolis': logger.info('... Starting Metropolis ...\n') ensuredir(problem.outfolder) sampler.metropolis_sample( n_steps=pa.n_steps, step=step, progressbar=sc.progressbar, buffer_size=sc.buffer_size, buffer_thinning=sc.buffer_thinning, homepath=problem.outfolder, start=start, burn=pa.burn, thin=pa.thin, model=problem.model, n_jobs=pa.n_jobs, rm_flag=pa.rm_flag) elif sc.name == 'SMC': logger.info('... Starting SMC ...\n') sampler.smc_sample( pa.n_steps, step=step, progressbar=sc.progressbar, model=problem.model, start=start, n_jobs=pa.n_jobs, stage=pa.stage, update=update, buffer_thinning=sc.buffer_thinning, homepath=problem.outfolder, buffer_size=sc.buffer_size, rm_flag=pa.rm_flag) elif sc.name == 'PT': logger.info('... Starting Parallel Tempering ...\n') sampler.pt_sample( step=step, n_chains=pa.n_chains + 1, # add master n_samples=pa.n_samples, start=start, swap_interval=pa.swap_interval, beta_tune_interval=pa.beta_tune_interval, n_workers_posterior=pa.n_chains_posterior, homepath=problem.outfolder, progressbar=sc.progressbar, buffer_size=sc.buffer_size, buffer_thinning=sc.buffer_thinning, model=problem.model, resample=pa.resample, rm_flag=pa.rm_flag, record_worker_chains=pa.record_worker_chains) else: logger.error('Sampler "%s" not implemented.' % sc.name) def estimate_hypers(step, problem): """ Get initial estimates of the hyperparameters """ from beat.sampler.base import iter_parallel_chains, init_stage, \ init_chain_hypers logger.info('... Estimating hyperparameters ...') pc = problem.config.problem_config sc = problem.config.hyper_sampler_config pa = sc.parameters if not (pa.n_chains / pa.n_jobs).is_integer(): raise ValueError('n_chains / n_jobs has to be a whole number!') name = problem.outfolder ensuredir(name) stage_handler = SampleStage(problem.outfolder, backend=sc.backend) chains, step, update = init_stage( stage_handler=stage_handler, step=step, stage=0, progressbar=sc.progressbar, model=problem.model, rm_flag=pa.rm_flag) # setting stage to 1 otherwise only one sample step.stage = 1 step.n_steps = pa.n_steps with problem.model: mtrace = iter_parallel_chains( draws=pa.n_steps, chains=chains, step=step, stage_path=stage_handler.stage_path(1), progressbar=sc.progressbar, model=problem.model, n_jobs=pa.n_jobs, initializer=init_chain_hypers, initargs=(problem,), buffer_size=sc.buffer_size, buffer_thinning=sc.buffer_thinning, chunksize=int(pa.n_chains / pa.n_jobs)) thinned_chain_length = len(thin_buffer( list(range(pa.n_steps)), sc.buffer_thinning, ensure_last=True)) for v in problem.hypernames: i = pc.hyperparameters[v] d = mtrace.get_values( v, combine=True, burn=int(thinned_chain_length * pa.burn), thin=pa.thin, squeeze=True) lower = num.floor(d.min()) - 2. upper = num.ceil(d.max()) + 2. logger.info('Updating hyperparameter %s from %f, %f to %f, %f' % ( v, i.lower, i.upper, lower, upper)) pc.hyperparameters[v].lower = num.atleast_1d(lower) pc.hyperparameters[v].upper = num.atleast_1d(upper) pc.hyperparameters[v].testvalue = num.atleast_1d((upper + lower) / 2.) config_file_name = 'config_' + pc.mode + '.yaml' conf_out = os.path.join(problem.config.project_dir, config_file_name) problem.config.problem_config = pc bconfig.dump(problem.config, filename=conf_out) class Stage(object): """ Stage, containing sampling results and intermediate sampler parameters. """ number = None path = None step = None updates = None mtrace = None def __init__(self, handler=None, homepath=None, stage_number=-1, backend='csv'): if handler is not None: self.handler = handler elif handler is None and homepath is not None: self.handler = SampleStage(homepath, backend=backend) else: raise TypeError('Either handler or homepath have to be not None') self.backend = backend self.number = stage_number def load_results( self, varnames=None, model=None, stage_number=None, chains=None, load='trace'): """ Load stage results from sampling. Parameters ---------- model : :class:`pymc3.model.Model` stage_number : int Number of stage to load chains : list, optional of result chains to load load : str what to load and return 'full', 'trace', 'params' """ if varnames is None and model is not None: varnames = [var.name for var in model.unobserved_RVs] elif varnames is None and model is None: raise ValueError( 'Either "varnames" or "model" need to be not None!') if stage_number is None: stage_number = self.number self.path = self.handler.stage_path(stage_number) if not os.path.exists(self.path): stage_number = self.handler.highest_sampled_stage() logger.info( 'Stage results %s do not exist! Loading last completed' ' stage %s' % (self.path, stage_number)) self.path = self.handler.stage_path(stage_number) self.number = stage_number if load == 'full': to_load = ['params', 'trace'] else: to_load = [load] if 'trace' in to_load: self.mtrace = self.handler.load_multitrace( stage_number, varnames=varnames, chains=chains) if 'params' in to_load: if model is not None: with model: self.step, self.updates = self.handler.load_sampler_params( stage_number) else: raise ValueError('To load sampler params model is required!') def load_stage(problem, stage_number, load='trace', chains=[-1]): stage = Stage( homepath=problem.outfolder, stage_number=stage_number, backend=problem.config.sampler_config.backend) stage.load_results( varnames=problem.varnames, chains=chains, model=problem.model, stage_number=stage_number, load=load) return stage