package com.wealthfront.thompsonsampling; import cern.jet.random.engine.MersenneTwister; import cern.jet.random.engine.RandomEngine; import com.google.common.collect.Lists; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class BatchedABTestTest { @Test public void testCorrectArmChosen() { int correct = 0; for (int i = 0; i<= 10000; i++) { RandomEngine engine = new MersenneTwister(i); BanditPerformance performance = new BanditPerformance(2); BatchedABTest batchedBandit = new BatchedABTest(); batchedBandit.setRandomEngine(engine); BatchedBanditTester tester = new BatchedBanditTester(batchedBandit, engine); if (i % 100 == 0) { System.out.println("Batches complete " + i); } correct += tester.getWinningArm(); } assertTrue(correct > 9500); } @Test public void testChiSquareComputation() { BanditPerformance performance = new BanditPerformance(Lists.newArrayList(new ObservedArmPerformance(100L, 0L), new ObservedArmPerformance(0L, 100L))); BatchedABTest batchedABTest = new BatchedABTest(); batchedABTest.setRequiresMinSamples(false); assertEquals(new Integer(0), batchedABTest.getBanditStatistics(performance).getVictoriousArm().get()); performance = new BanditPerformance(Lists.newArrayList(new ObservedArmPerformance(0L, 100L), new ObservedArmPerformance(100L, 0L))); batchedABTest = new BatchedABTest(); batchedABTest.setRequiresMinSamples(false); assertEquals(new Integer(1), batchedABTest.getBanditStatistics(performance).getVictoriousArm().get()); performance = new BanditPerformance(Lists.newArrayList(new ObservedArmPerformance(5L, 5L), new ObservedArmPerformance(5L, 5L))); batchedABTest = new BatchedABTest(); batchedABTest.setRequiresMinSamples(false); assertFalse(batchedABTest.getBanditStatistics(performance).getVictoriousArm().isPresent()); } }