package com.tdunning.tdigest.quality; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.tdunning.math.stats.*; import org.apache.mahout.math.jet.random.AbstractContinousDistribution; import org.junit.Assert; import org.junit.Test; import java.io.*; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; /** * Produce measurements of accuracy versus compression factor for fixed data size */ public class AccuracyTest { private static final int N = 1_000_000; private final Random gen = new Random(); /** * Generates information that demonstrates that t-digests can be merged without major loss of * accuracy. */ @Test public void merge() { final Random seedGenerator = new Random(); try (PrintWriter out = new PrintWriter(new File("merge.csv"))) { out.printf("type,parts,q,e0,e1,e2,e2.rel,e3\n"); List<Callable<String>> tasks = Lists.newArrayList(); for (int k = 0; k < 20; k++) { final int currentK = k; tasks.add(new Callable<String>() { final Random gen = new Random(seedGenerator.nextLong()); @Override public String call() { StringWriter s = new StringWriter(); PrintWriter out = new PrintWriter(s); System.out.printf("Starting %d\n", currentK); for (int parts : new int[]{2, 5, 10, 20, 50, 100}) { ArrayList<Double> data = Lists.newArrayList(); TDigest dist = new MergingDigest(100); dist.recordAllData(); // we accumulate the data into multiple sub-digests List<TDigest> subs = Lists.newArrayList(); for (int i = 0; i < parts; i++) { subs.add(new MergingDigest(100).recordAllData()); } List<TDigest> highRes = Lists.newArrayList(); for (int i = 0; i < parts; i++) { highRes.add(new MergingDigest(200)); } int[] cnt = new int[parts]; for (int i = 0; i < 100000; i++) { double x = gen.nextDouble(); data.add(x); dist.add(x); subs.get(i % parts).add(x); highRes.get(i % parts).add(x); cnt[i % parts]++; } dist.compress(); Collections.sort(data); // collect the raw data from the sub-digests List<Double> data2 = Lists.newArrayList(); int i = 0; int k = 0; int totalByCount = 0; for (TDigest digest : subs) { assertEquals("Sub-digest size check", cnt[i], digest.size()); int k2 = 0; for (Centroid centroid : digest.centroids()) { Iterables.addAll(data2, centroid.data()); assertEquals("Centroid consistency", centroid.count(), centroid.data().size()); k2 += centroid.data().size(); } totalByCount += cnt[i]; k += k2; assertEquals("Sub-digest centroid sum check", cnt[i], k2); assertEquals("Sub-digest centroid sum check", cnt[i], subs.get(i).size()); i++; } assertEquals("Sub-digests don't add up to the right size", data.size(), k); assertEquals("Counts don't match up", data.size(), totalByCount); // verify that the raw data all got recorded Collections.sort(data2); assertEquals(data.size(), data2.size()); Iterator<Double> ix = data.iterator(); for (Double x : data2) { assertEquals(ix.next(), x); } // now merge the sub-digests TDigest dist2 = new MergingDigest(100).recordAllData(); dist2.add(subs); assertEquals(String.format("Digest count is wrong %d vs %d", totalByCount, dist2.size()), totalByCount, dist2.size()); // verify the merged result has the right data List<Double> data3 = Lists.newArrayList(); for (Centroid centroid : dist2.centroids()) { Iterables.addAll(data3, centroid.data()); } Collections.sort(data3); assertEquals(String.format("Total data size %d vs %d", data.size(), data3.size()), data.size(), data3.size()); ix = data.iterator(); for (Double x : data3) { assertEquals(ix.next(), x); } TDigest dist3 = new MergingDigest(100); dist3.add(highRes); final double[] allData = new double[data.size()]; int iz = 0; for (double x : data) { allData[iz++] = x; } for (double q : new double[]{0.001, 0.01, 0.1, 0.2, 0.3, 0.5}) { double z = Dist.quantile(q, allData); double e1 = dist.quantile(q) - z; double e2 = dist2.quantile(q) - z; double e2Relative = Math.abs(e2) / q; double e3 = dist3.quantile(q) - z; out.printf("quantile,%d,%.6f,%.6f,%.6f,%.6f,%.6f,%.6f\n", parts, q, z - q, e1, e2, e2Relative, e3); Assert.assertTrue(String.format("Relative error: parts=%d, q=%.4f, e1=%.5f, e2=%.5f, rel=%.4f, e3=%.4f", parts, q, e1, e2, e2Relative, e3), e2Relative < 0.4); Assert.assertTrue(String.format("Absolute error: parts=%d, q=%.4f, e1=%.5f, e2=%.5f, rel=%.4f, e3=%.4f", parts, q, e1, e2, e2Relative, e3), Math.abs(e2) < 0.015); } for (double x : new double[]{0.001, 0.01, 0.1, 0.2, 0.3, 0.5}) { double z = Dist.cdf(x, allData); double e1 = dist.cdf(x) - z; double e2 = dist2.cdf(x) - z; double e3 = dist3.cdf(z) - z; out.printf("cdf,%d,%.6f,%.6f,%.6f,%.6f,%.6f,%.6f\n", parts, x, z - x, e1, e2, Math.abs(e2) / x, e3); Assert.assertTrue(String.format("Absolute cdf: parts=%d, x=%.4f, e1=%.5f, e2=%.5f", parts, x, e1, e2), Math.abs(e2) < 0.015); Assert.assertTrue(String.format("Relative cdf: parts=%d, x=%.4f, e1=%.5f, e2=%.5f, rel=%.3f", parts, x, e1, e2, Math.abs(e2) / x), Math.abs(e2) / x < 0.4); } out.flush(); } System.out.printf(" Finishing %d\n", currentK + 1); out.close(); return s.toString(); } }); } ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() + 2); try { for (Future<String> result : executor.invokeAll(tasks)) { out.write(result.get()); } } catch (Throwable e) { fail(e.getMessage()); } finally { executor.shutdownNow(); executor.awaitTermination(10, TimeUnit.SECONDS); } } catch (InterruptedException e) { fail("Tasks interrupted"); } catch (FileNotFoundException e) { fail("Couldn't write to data output file merge.csv"); } } @Test public void testTreeAccuracy() throws IOException, InterruptedException { // TODO there is a fair bit of duplicated code here String head = Git.getHash(true).substring(0, 10); String experiment = "tree-digest"; new File("tests").mkdirs(); PrintWriter quantiles = new PrintWriter(String.format("tests/accuracy-%s-%s.csv", experiment, head)); PrintWriter sizes = new PrintWriter(String.format("tests/accuracy-sizes-%s-%s.csv", experiment, head)); PrintWriter cdf = new PrintWriter(String.format("tests/accuracy-cdf-%s-%s.csv", experiment, head)); quantiles.printf("digest, dist, sort, q.digest, q.raw, error, compression, x, k, clusters\n"); cdf.printf("digest, dist, sort, x.digest, x.raw, error, compression, q, k, clusters\n"); sizes.printf("digest, dist, sort, q.0, q.1, dk, mean, compression, count, k, clusters\n"); ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() + 4); Collection<Callable<Integer>> tasks = new ArrayList<>(); AtomicInteger lines = new AtomicInteger(); long t0 = System.nanoTime(); for (int k = 0; k < 20; k++) { int finalK = k; tasks.add(() -> { for (Util.Distribution dist : Collections.singleton(Util.Distribution.UNIFORM)) { // for (Util.Distribution dist : Util.Distribution.values()) { AbstractContinousDistribution dx = dist.create(gen); double[] raw = new double[N]; for (int i = 0; i < N; i++) { raw[i] = dx.nextDouble(); } double[] sorted = Arrays.copyOf(raw, raw.length); Arrays.sort(sorted); for (double compression : new double[]{20, 50, 100, 200, 500}) { for (Util.Factory factory : Collections.singleton(Util.Factory.TREE)) { // for (Util.Factory factory : Util.Factory.values()) { TDigest digest = factory.create(compression); for (double datum : raw) { digest.add(datum); } evaluate(finalK, quantiles, sizes, cdf, dist, "unsorted", sorted, compression, factory.create(compression)); digest = factory.create(compression); for (double datum : sorted) { digest.add(datum); } evaluate(finalK, quantiles, sizes, cdf, dist, "sorted", sorted, compression, factory.create(compression)); } } } int count = lines.incrementAndGet(); long t = System.nanoTime(); double duration = (t - t0) * 1e-9; System.out.printf("%d, %d, %.2f, %.3f\n", finalK, count, duration, count / duration); return finalK; }); } pool.invokeAll(tasks); sizes.close(); quantiles.close(); cdf.close(); } @Test public void testAccuracyVersusCompression() throws IOException, InterruptedException { String head = Git.getHash(true).substring(0, 10); String experiment = "digest"; new File("tests").mkdirs(); try (PrintWriter out = new PrintWriter(String.format("tests/accuracy-%s-%s.csv", experiment, head)); PrintWriter cdf = new PrintWriter(String.format("tests/accuracy-cdf-%s-%s.csv", experiment, head)); PrintWriter sizes = new PrintWriter(String.format("tests/accuracy-sizes-%s-%s.csv", experiment, head))) { out.printf("digest, dist, sort, q.digest, q.raw, error, compression, q, x, k, clusters\n"); cdf.printf("digest, dist, sort, x.digest, x.raw, error, compression, q, k, clusters\n"); sizes.printf("digest, dist, sort, q.0, q.1, dk, mean, compression, count, k, clusters\n"); ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() + 4); Collection<Callable<Integer>> tasks = new ArrayList<>(); AtomicInteger lines = new AtomicInteger(); long t0 = System.nanoTime(); for (int k = 0; k < 50; k++) { int finalK = k; tasks.add(() -> { try { // for (Util.Distribution dist : Collections.singleton(Util.Distribution.UNIFORM)) { for (Util.Distribution dist : Util.Distribution.values()) { AbstractContinousDistribution dx = dist.create(gen); int size = (int) (N + new Random().nextGaussian() * 1000); double[] raw = new double[size]; for (int i = 0; i < size; i++) { raw[i] = dx.nextDouble(); } double[] sorted = Arrays.copyOf(raw, raw.length); Arrays.sort(sorted); for (boolean useWeightLimit : new boolean[]{true, false}) { for (ScaleFunction scale : ScaleFunction.values()) { if (scale.toString().contains("_NO_NORM") || scale.toString().equals("K_0") || scale.toString().contains("FAST") || scale.toString().contains("kSize")) { continue; } for (double compression : new double[]{50, 100, 200, 500, 1000}) { // for (double compression : new double[]{100, 200, 500}) { for (Util.Factory factory : Collections.singleton(Util.Factory.MERGE)) { // for (Util.Factory factory : Util.Factory.values()) { TDigest digest = factory.create(compression); MergingDigest.useWeightLimit = useWeightLimit; digest.setScaleFunction(scale); for (double datum : raw) { digest.add(datum); } digest.compress(); evaluate(finalK, out, sizes, cdf, dist, "unsorted", sorted, compression, digest); // digest = factory.create(compression); // for (double datum : sorted) { // digest.add(datum); // } // evaluate(finalK, out, sizes, cdf, dist, "sorted", factory, sorted, compression, digest); } } } } } } catch (Throwable e) { e.printStackTrace(); } int count = lines.incrementAndGet(); long t = System.nanoTime(); double duration = (t - t0) * 1e-9; System.out.printf("%d, %d, %.2f, %.3f\n", finalK, count, duration, count / duration); return finalK; }); } pool.invokeAll(tasks); } } private void evaluate(int k, PrintWriter quantiles, PrintWriter sizes, PrintWriter cdf, Util.Distribution dist, String sort, double[] sorted, double compression, TDigest digest) { int clusters = digest.centroidCount(); double qx = 0; for (Centroid centroid : digest.centroids()) { double dq = (double) centroid.count() / sorted.length; double k0 = ((MergingDigest) digest).getScaleFunction().k(qx, compression, digest.size()); double k1 = ((MergingDigest) digest).getScaleFunction().k(qx + dq, compression, digest.size()); //noinspection SynchronizationOnLocalVariableOrMethodParameter synchronized (sizes) { sizes.printf("%s,%s,%s,%.8f,%.8f,%.8f,%.8g,%.0f,%d,%d,%d\n", digest, dist, sort, qx, qx + dq, k1 - k0, centroid.mean(), compression, centroid.count(), k, clusters); } qx += dq; } for (double q : new double[]{1e-6, 1e-5, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 1 - 1e-5, 1 - 1e-6}) { double x = Dist.quantile(q, sorted); double q1 = digest.cdf(x); double q0 = Dist.cdf(x, sorted); double error = (q1 - q0) / Math.min(q1, 1 - q1); //noinspection SynchronizationOnLocalVariableOrMethodParameter synchronized (quantiles) { quantiles.printf("%s,%s,%s,%.8f,%.8f,%.8g,%.0f,%.8g,%.8g,%d,%d\n", digest, dist, sort, q1, q0, error, compression, q, x, k, clusters); } } for (double q : new double[]{1e-6, 1e-5, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 1 - 1e-5, 1 - 1e-6}) { double x1 = digest.quantile(q); double x0 = Dist.quantile(q, sorted); double error = (x1 - x0) / Math.min(x1, 1 - x1); //noinspection SynchronizationOnLocalVariableOrMethodParameter synchronized (cdf) { cdf.printf("%s,%s,%s,%.8f,%.8f,%.8g,%.0f,%.8g,%d,%d\n", digest, dist, sort, x1, x0, error, compression, q, k, clusters); } } } /** * Prints the actual samples that went into a few clusters near the tails and near the median. * <p> * This is important for testing how close to ideal a real-world t-digest might be. In particular, * it lets us visualize how clusters are shaped in sample space to look for smear or skew. * <p> * The accuracy.r script produces a visualization of the data produced by this test. * * @throws FileNotFoundException If output file can't be opened. * @throws InterruptedException If threads are interrupted (we don't ever expect that to happen). */ @Test public void testBucketFill() throws FileNotFoundException, InterruptedException { ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() + 2); Collection<Callable<Integer>> tasks = new ArrayList<>(); AtomicInteger lines = new AtomicInteger(); long t0 = System.nanoTime(); PrintWriter samples = new PrintWriter("accuracy-samples.csv"); samples.printf("digest, dist, sort, compression, k, centroid, centroid.down, i, x, mean, q0, q1\n"); for (int k = 0; k < 20; k++) { int finalK = k; tasks.add(() -> { for (double compression : new double[]{100}) { for (Util.Distribution dist : Util.Distribution.values()) { AbstractContinousDistribution dx = dist.create(gen); double[] raw = new double[N]; for (int i = 0; i < N; i++) { raw[i] = dx.nextDouble(); } double[] sorted = Arrays.copyOf(raw, raw.length); Arrays.sort(sorted); for (ScaleFunction scale : new ScaleFunction[]{ScaleFunction.K_2, ScaleFunction.K_3}) { // for (Util.Factory factory : Collections.singletonList(Util.Factory.MERGE)) { MergingDigest digest = new MergingDigest(compression); digest.recordAllData(); digest.setScaleFunction(scale); evaluate2(finalK, dist, samples, raw, compression, digest); // evaluate2(finalK, dist, samples, "sorted", factory, sorted, compression); } } // } } int count = lines.incrementAndGet(); long t = System.nanoTime(); double duration = (t - t0) * 1e-9; System.out.printf("%d, %d, %.2f, %.3f\n", finalK, count, duration, count / duration); return finalK; }); } pool.invokeAll(tasks); samples.close(); } private void evaluate2(int k, Util.Distribution dist, PrintWriter samples, double[] data, double compression, TDigest digest) { for (double datum : data) { digest.add(datum); } double qx = 0; int cx = 0; Collection<Centroid> centroids = digest.centroids(); for (Centroid centroid : centroids) { double dq = (double) centroid.count() / N; if (qx < 0.05 || Math.abs(qx - 0.5) < 0.025 || qx > 0.95) { int sx = 0; //noinspection SynchronizationOnLocalVariableOrMethodParameter synchronized (samples) { for (Double x : centroid.data()) { samples.printf("%s,%s,%s,%.0f,%d,%d,%d,%d,%.8f,%.8f,%.8f,%.8f\n", digest, dist, "unsorted", compression, k, cx, centroids.size() - cx - 1, sx, x, centroid.mean(), qx, qx + dq); sx++; } } } qx += dq; cx++; } } }