package org.broadinstitute.hellbender.utils.mcmc; import com.google.common.primitives.Doubles; import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.RandomGeneratorFactory; import org.apache.commons.math3.stat.descriptive.moment.Mean; import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; import org.broadinstitute.hellbender.GATKBaseTest; import org.testng.Assert; import org.testng.annotations.Test; import java.util.*; import java.util.function.Function; /** * Unit test for {@link GibbsSampler}. Demonstrates application of {@link GibbsSampler} to a {@link ParameterizedModel} * that is specified using {@link ParameterizedState} and a class that implements {@link DataCollection}. * <p> * Test performs Bayesian inference of a Gaussian model with 2 global parameters specifying the variance and the mean. * </p> * <p> * Data consists of a list of 10000 datapoints drawn from a normal distribution with unity variance and mean. * </p> * <p> * Success of the test is determined by recovery of the input variance and mean, * as well as agreement of the standard deviations of the parameter posteriors with those given by both the * python package emcee (see http://dan.iel.fm/emcee for details) and numerical evaluation in Mathematica * of the analytic forms of the posteriors. * </p> * * @author Samuel Lee <[email protected]> */ public final class GibbsSamplerSingleGaussianUnitTest extends GATKBaseTest { private static final int NUM_DATAPOINTS = 10000; private static final double VARIANCE_MIN = 0.; private static final double VARIANCE_MAX = Double.POSITIVE_INFINITY; private static final double VARIANCE_WIDTH = 0.1; private static final double VARIANCE_INITIAL = 5.; private static final double VARIANCE_TRUTH = 1.; private static final double VARIANCE_POSTERIOR_STANDARD_DEVIATION_TRUTH = 0.014; private static final double MEAN_WIDTH = 0.1; private static final double MEAN_INITIAL = 5.; private static final double MEAN_TRUTH = 1.; private static final double MEAN_POSTERIOR_STANDARD_DEVIATION_TRUTH = 0.01; private static final int NUM_SAMPLES = 500; private static final int NUM_BURN_IN = 250; //test specifications private static final double RELATIVE_ERROR_THRESHOLD_FOR_CENTERS = 0.01; private static final double RELATIVE_ERROR_THRESHOLD_FOR_STANDARD_DEVIATIONS = 0.1; //Create dataset of 10000 datapoints drawn from a normal distribution Normal(MEAN_TRUTH, VARIANCE_TRUTH) private static final int RANDOM_SEED = 42; private static final RandomGenerator rng = RandomGeneratorFactory.createRandomGenerator(new Random(RANDOM_SEED)); private static final NormalDistribution normalDistribution = new NormalDistribution(rng, MEAN_TRUTH, Math.sqrt(VARIANCE_TRUTH)); private static final List<Double> datapointsList = Doubles.asList(normalDistribution.sample(NUM_DATAPOINTS)); //Calculates the exponent for a normal distribution; used in log-likelihood calculation below. private static double normalTerm(final double quantity, final double mean, final double variance) { return (quantity - mean) * (quantity - mean) / (2 * variance); } //Calculates relative error between x and xTrue, with respect to xTrue; used for checking statistics of //posterior samples below. private static double relativeError(final double x, final double xTrue) { return Math.abs((x - xTrue) / xTrue); } //The datapoints used by the samplers must be contained in a class that implements DataCollection. private final class GaussianDataCollection implements DataCollection { private final List<Double> datapoints; public GaussianDataCollection(final List<Double> datapoints) { this.datapoints = new ArrayList<>(datapoints); } public List<Double> getDatapoints() { return Collections.unmodifiableList(datapoints); } } //We enumerate the parameters of the model using an enum that implements the ParameterEnum interface. private enum GaussianParameter implements ParameterEnum { VARIANCE, MEAN } //We create a Modeller helper class to initialize the model state and specify the parameter samplers. private final class GaussianModeller { //Create fields in the Modeller for the model and samplers. private final ParameterizedModel<GaussianParameter, ParameterizedState<GaussianParameter>, GaussianDataCollection> model; private final ParameterSampler<Double, GaussianParameter, ParameterizedState<GaussianParameter>, GaussianDataCollection> varianceSampler; private final ParameterSampler<Double, GaussianParameter, ParameterizedState<GaussianParameter>, GaussianDataCollection> meanSampler; //Constructor for the Modeller takes as parameters all quantities needed to construct the ParameterizedState //(here, the initial variance and the initial mean) and the DataCollection (here, the list of datapoints). private GaussianModeller(final double varianceInitial, final double meanInitial, final List<Double> datapoints) { //Construct the initial ParameterizedState by passing a list of Parameters of mixed type to the constructor. //Initial values (and, implicitly, types) for each of the parameters are set here. final List<Parameter<GaussianParameter, ?>> initialParameters = Arrays.asList( new Parameter<>(GaussianParameter.VARIANCE, varianceInitial), new Parameter<>(GaussianParameter.MEAN, meanInitial)); final ParameterizedState<GaussianParameter> initialState = new ParameterizedState<>(initialParameters); //Construct the GaussianDataCollection by passing a list of datapoints to the constructor. //Here, we pass 10000 datapoints, which were generated above, final GaussianDataCollection dataset = new GaussianDataCollection(datapoints); //Implement ParameterSamplers for each parameter by overriding sample(). This can be done via a lambda that takes //(rng, state, dataCollection) and returns a new sample of the parameter with type identical to that //specified during initialization above. //Sampler for the variance global parameter. Assuming a uniform prior, the relevant log conditional PDF //is given by the log of the product of Gaussian likelihoods for each datapoint c_t: // log[product_t variance^(-1/2) * exp(-(c_t - mean)^2 / (2 * variance))] + constant //which reduces to the form in code below. Slice sampling is used here to generate a new sample, //but, in general, any method can be used; e.g., if the conditional PDF is from the exponential family, //one can simply sample directly from the corresponding Distribution from Apache Commons. varianceSampler = (rng, state, dataCollection) -> { final Function<Double, Double> logConditionalPDF = newVariance -> -0.5 * Math.log(newVariance) * dataCollection.getDatapoints().size() + dataCollection.getDatapoints().stream() .mapToDouble(c -> -normalTerm(c, state.get(GaussianParameter.MEAN, Double.class), newVariance)) .sum(); final SliceSampler sampler = new SliceSampler(rng, logConditionalPDF, VARIANCE_MIN, VARIANCE_MAX, VARIANCE_WIDTH); return sampler.sample(state.get(GaussianParameter.VARIANCE, Double.class)); }; //Sampler for the mean global parameter. Assuming a uniform prior, the relevant log conditional PDF //is given by the log of the product of Gaussian likelihoods for each datapoint c_t: // log[product_t exp(-(c_t - mean)^2 / (2 * variance))] + constant //which reduces to the form in code below. meanSampler = (rng, state, dataCollection) -> { final Function<Double, Double> logConditionalPDF = newMean -> dataCollection.getDatapoints().stream() .mapToDouble(c -> -normalTerm(c, newMean, state.get(GaussianParameter.VARIANCE, Double.class))) .sum(); final SliceSampler sampler = new SliceSampler(rng, logConditionalPDF, MEAN_WIDTH); return sampler.sample(state.get(GaussianParameter.MEAN, Double.class)); }; //Build the ParameterizedModel using the GibbsBuilder pattern. //Pass in the initial ParameterizedState and DataCollection, and specify the class of the ParameterizedState. //Add samplers for each of the parameters, with names matching those used in initialization. model = new ParameterizedModel.GibbsBuilder<>(initialState, dataset) .addParameterSampler(GaussianParameter.VARIANCE, varianceSampler, Double.class) .addParameterSampler(GaussianParameter.MEAN, meanSampler, Double.class) .build(); } } /** * Tests Bayesian inference of a Gaussian model via MCMC. Recovery of input values for the variance and mean * global parameters is checked. In particular, the mean and standard deviation of the posteriors for * both parameters must be recovered to within a relative error of 1% and 10%, respectively, in 250 samples * (after 250 burn-in samples have been discarded). */ @Test public void testRunMCMCOnSingleGaussianModel() { //Create new instance of the Modeller helper class, passing all quantities needed to initialize state and data. final GaussianModeller modeller = new GaussianModeller(VARIANCE_INITIAL, MEAN_INITIAL, datapointsList); //Create a GibbsSampler, passing the total number of samples (including burn-in samples) //and the model held by the Modeller. final GibbsSampler<GaussianParameter, ParameterizedState<GaussianParameter>, GaussianDataCollection> gibbsSampler = new GibbsSampler<>(NUM_SAMPLES, modeller.model); //Run the MCMC. gibbsSampler.runMCMC(); //Get the samples of each of the parameter posteriors (discarding burn-in samples) by passing the //parameter name, type, and burn-in number to the getSamples method. final double[] varianceSamples = Doubles.toArray(gibbsSampler.getSamples(GaussianParameter.VARIANCE, Double.class, NUM_BURN_IN)); final double[] meanSamples = Doubles.toArray(gibbsSampler.getSamples(GaussianParameter.MEAN, Double.class, NUM_BURN_IN)); //Check that the statistics---i.e., the means and standard deviations---of the posteriors //agree with those found by emcee/analytically to a relative error of 1% and 10%, respectively. final double variancePosteriorCenter = new Mean().evaluate(varianceSamples); final double variancePosteriorStandardDeviation = new StandardDeviation().evaluate(varianceSamples); Assert.assertEquals(relativeError(variancePosteriorCenter, VARIANCE_TRUTH), 0., RELATIVE_ERROR_THRESHOLD_FOR_CENTERS); Assert.assertEquals( relativeError(variancePosteriorStandardDeviation, VARIANCE_POSTERIOR_STANDARD_DEVIATION_TRUTH), 0., RELATIVE_ERROR_THRESHOLD_FOR_STANDARD_DEVIATIONS); final double meanPosteriorCenter = new Mean().evaluate(meanSamples); final double meanPosteriorStandardDeviation = new StandardDeviation().evaluate(meanSamples); Assert.assertEquals(relativeError(meanPosteriorCenter, MEAN_TRUTH), 0., RELATIVE_ERROR_THRESHOLD_FOR_CENTERS); Assert.assertEquals( relativeError(meanPosteriorStandardDeviation, MEAN_POSTERIOR_STANDARD_DEVIATION_TRUTH), 0., RELATIVE_ERROR_THRESHOLD_FOR_STANDARD_DEVIATIONS); } }