package org.broadinstitute.hellbender.utils.mcmc; import org.apache.commons.math3.distribution.AbstractRealDistribution; import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.random.RandomGenerator; import org.broadinstitute.hellbender.utils.Utils; import org.broadinstitute.hellbender.utils.param.ParamUtils; import java.util.ArrayList; import java.util.List; import java.util.function.Function; /** * Metropolis MCMC sampler using an adaptive step size that increases / decreases in order to decrease / increase acceptance * rate to some desired value. (A general property of MCMC is that too-low acceptance rate are bad for obvious reasons * but too-high acceptance rates are also undesirable because it implies that steps are too small). * <p> * In order for the Markov chain to converge to the correct posterior distribution, adaptations to the step size must * vanish as sampling proceeds. * <p> * This sampling method is a very good black-box algorithm when we are reasonably confident that the sampled * conditional distribution is close to unimodal but otherwise unknown (i.e. not necessarily log-concave and with * unknown shape and width). * * @author David Benjamin <[email protected]> */ public final class AdaptiveMetropolisSampler { private static final double DEFAULT_OPTIMAL_ACCEPTANCE_RATE = 0.4; private static final double DEFAULT_TIME_SCALE = 20; private static final double DEFAULT_ADJUSTMENT_RATE = 1.0; private int iteration = 1; // the amount of step size adjustment decreases with the iteration number private final double lowerBound; private final double upperBound; private double optimalAcceptanceRate = DEFAULT_OPTIMAL_ACCEPTANCE_RATE; //adjustments to the step size are scaled by adjustmentRate * timeScale / (timeScale + iteration) private final double adjustmentRate; private final double timeScale; private double stepSize; private double xCurrent; //the current sampled value public AdaptiveMetropolisSampler(final double xInitial, final double initialStepSize, final double lowerBound, final double upperBound, final double adjustmentRate, final double timeScale) { Utils.validateArg(lowerBound <= upperBound, "Maximum bound must be greater than or equal to minimum bound."); ParamUtils.isPositive(initialStepSize, "Step size must be positive."); ParamUtils.isPositive(timeScale, "Time scale must be positive."); xCurrent = xInitial; stepSize = initialStepSize; this.lowerBound = lowerBound; this.upperBound = upperBound; this.adjustmentRate = adjustmentRate; this.timeScale = timeScale; } public AdaptiveMetropolisSampler(final double xInitial, final double initialStepSize, final double lowerBound, final double upperBound) { this(xInitial, initialStepSize, lowerBound, upperBound, DEFAULT_ADJUSTMENT_RATE, DEFAULT_TIME_SCALE); } public double sample(final RandomGenerator rng, final Function<Double, Double> logPDF) { Utils.nonNull(rng); Utils.nonNull(logPDF); final AbstractRealDistribution normal = new NormalDistribution(rng, 0, 1); final double proposal = xCurrent + stepSize * normal.sample(); final double acceptanceProbability = (proposal < lowerBound || upperBound < proposal) ? 0 : Math.min(1, Math.exp(logPDF.apply(proposal) - logPDF.apply(xCurrent))); //adjust stepSize larger/smaller to decrease/increase the acceptance rate final double correctionFactor = (acceptanceProbability - optimalAcceptanceRate) * adjustmentRate * (timeScale / (timeScale + iteration)); stepSize *= Math.exp(correctionFactor); iteration++; return rng.nextDouble() < acceptanceProbability ? proposal : xCurrent; } /** * Generate multiple samples from the probability density function. * @param numSamples number of samples to generate * @param numBurnIn number of samples to discard * @return samples drawn from the probability density function */ public List<Double> sample(final RandomGenerator rng, final Function<Double, Double> logPDF, final int numSamples, final int numBurnIn) { Utils.nonNull(rng); Utils.nonNull(logPDF); ParamUtils.isPositive(numSamples, "Number of samples must be positive."); ParamUtils.isPositiveOrZero(numBurnIn, "Number of burn-in samples must be non-negative."); Utils.validateArg(numBurnIn < numSamples, "Number of samples must be greater than number of burn-in samples."); final List<Double> samples = new ArrayList<>(numSamples); for (int i = 0; i < numSamples; i++) { xCurrent = sample(rng, logPDF); if (i > numBurnIn) { samples.add(xCurrent); } } return samples; } }