package org.broadinstitute.hellbender.utils.mcmc;

import com.google.common.primitives.Doubles;
import org.apache.commons.math3.distribution.BetaDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.RandomGeneratorFactory;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.commons.math3.stat.descriptive.moment.Variance;
import org.testng.Assert;
import org.testng.annotations.Test;

import java.util.Random;
import java.util.function.Function;


/**
 * Unit tests for {@link SliceSampler}.
 *
 * @author Samuel Lee <[email protected]>
 */
public final class SliceSamplerUnitTest {
    private static final int RANDOM_SEED = 42;
    private static final RandomGenerator rng =
            RandomGeneratorFactory.createRandomGenerator(new Random(RANDOM_SEED));

    private static double relativeError(final double x, final double xTrue) {
        return Math.abs((x - xTrue) / xTrue);
    }

    /**
     * Test slice sampling of a normal distribution.  Checks that input mean and standard deviation are recovered
     * by 10000 samples to a relative error of 0.5% and 2%, respectively.
     */
    @Test
    public void testSliceSamplingOfNormalDistribution() {
        rng.setSeed(RANDOM_SEED);

        final double mean = 5.;
        final double standardDeviation = 0.75;
        final NormalDistribution normalDistribution = new NormalDistribution(mean, standardDeviation);
        final Function<Double, Double> normalLogPDF = normalDistribution::logDensity;

        final double xInitial = 1.;
        final double xMin = Double.NEGATIVE_INFINITY;
        final double xMax = Double.POSITIVE_INFINITY;
        final double width = 0.5;
        final int numSamples = 10000;
        final SliceSampler normalSampler = new SliceSampler(rng, normalLogPDF, xMin, xMax, width);
        final double[] samples = Doubles.toArray(normalSampler.sample(xInitial, numSamples));

        final double sampleMean = new Mean().evaluate(samples);
        final double sampleStandardDeviation = new StandardDeviation().evaluate(samples);
        Assert.assertEquals(relativeError(sampleMean, mean), 0., 0.005);
        Assert.assertEquals(relativeError(sampleStandardDeviation, standardDeviation), 0., 0.02);
    }

    /**
     * Test slice sampling of a monotonic beta distribution as an example of sampling of a bounded random variable.
     * Checks that input mean and variance are recovered by 10000 samples to a relative error of 0.5% and 2%,
     * respectively.
     */
    @Test
    public void testSliceSamplingOfMonotonicBetaDistribution() {
        rng.setSeed(RANDOM_SEED);

        final double alpha = 10.;
        final double beta = 1.;
        final BetaDistribution betaDistribution = new BetaDistribution(alpha, beta);
        final Function<Double, Double> betaLogPDF = betaDistribution::logDensity;

        final double xInitial = 0.5;
        final double xMin = 0.;
        final double xMax = 1.;
        final double width = 0.1;
        final int numSamples = 10000;
        final SliceSampler betaSampler = new SliceSampler(rng, betaLogPDF, xMin, xMax, width);
        final double[] samples = Doubles.toArray(betaSampler.sample(xInitial, numSamples));

        final double mean = betaDistribution.getNumericalMean();
        final double variance = betaDistribution.getNumericalVariance();
        final double sampleMean = new Mean().evaluate(samples);
        final double sampleVariance = new Variance().evaluate(samples);
        Assert.assertEquals(relativeError(sampleMean, mean), 0., 0.005);
        Assert.assertEquals(relativeError(sampleVariance, variance), 0., 0.02);
    }

    /**
     * Test slice sampling of a peaked beta distribution as an example of sampling of a bounded random variable.
     * Checks that input mean and variance are recovered by 10000 samples to a relative error of 0.5% and 2%,
     * respectively.
     */
    @Test
    public void testSliceSamplingOfPeakedBetaDistribution() {
        rng.setSeed(RANDOM_SEED);

        final double alpha = 10.;
        final double beta = 4.;
        final BetaDistribution betaDistribution = new BetaDistribution(alpha, beta);
        final Function<Double, Double> betaLogPDF = betaDistribution::logDensity;

        final double xInitial = 0.5;
        final double xMin = 0.;
        final double xMax = 1.;
        final double width = 0.1;
        final int numSamples = 10000;
        final SliceSampler betaSampler = new SliceSampler(rng, betaLogPDF, xMin, xMax, width);
        final double[] samples = Doubles.toArray(betaSampler.sample(xInitial, numSamples));

        final double mean = betaDistribution.getNumericalMean();
        final double variance = betaDistribution.getNumericalVariance();
        final double sampleMean = new Mean().evaluate(samples);
        final double sampleVariance = new Variance().evaluate(samples);
        Assert.assertEquals(relativeError(sampleMean, mean), 0., 0.005);
        Assert.assertEquals(relativeError(sampleVariance, variance), 0., 0.02);
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInitialPointOutOfRange() {
        rng.setSeed(RANDOM_SEED);

        final double mean = 5.;
        final double standardDeviation = 0.75;
        final NormalDistribution normalDistribution = new NormalDistribution(mean, standardDeviation);
        final Function<Double, Double> normalLogPDF = normalDistribution::logDensity;

        final double xInitial = -10.;
        final double xMin = 0.;
        final double xMax = 1.;
        final double width = 0.5;
        final SliceSampler normalSampler = new SliceSampler(rng, normalLogPDF, xMin, xMax, width);
        normalSampler.sample(xInitial);
    }
}