package com.wealthfront.thompsonsampling; import cern.jet.random.engine.MersenneTwister; import cern.jet.random.engine.RandomEngine; import com.google.common.collect.Lists; import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; import static java.lang.Math.max; import static java.lang.Math.min; public class BatchedThompsonSamplingTest { @Test public void testUpdate() { BanditPerformance performance = new BanditPerformance(2); BatchedThompsonSampling bandit = new BatchedThompsonSampling(); performance.update(Lists.newArrayList(new ObservedArmPerformance(1, 2), new ObservedArmPerformance(3, 4))); assertEquals(Lists.newArrayList(new ObservedArmPerformance(1, 2), new ObservedArmPerformance(3, 4)), performance.getPerformances()); performance.update(Lists.newArrayList(new ObservedArmPerformance(1, 2), new ObservedArmPerformance(3, 4))); assertEquals(Lists.newArrayList(new ObservedArmPerformance(2, 4), new ObservedArmPerformance(6, 8)), performance.getPerformances()); try { performance.update(Lists.newArrayList(new ObservedArmPerformance(1, 2), new ObservedArmPerformance(3, 4), new ObservedArmPerformance(5, 6))); fail("Expecting IllegalArgumentException"); } catch (IllegalArgumentException e) { } } @Test @Ignore public void testCorrectArmChosen() { int correct = 0; for (int i = 0; i< 10000; i++) { RandomEngine engine = new MersenneTwister(i); BatchedThompsonSampling batchedBandit = new BatchedThompsonSampling(); batchedBandit.setRandomEngine(engine); BatchedBanditTester tester = new BatchedBanditTester(batchedBandit, engine); if (i % 100 == 0) { System.out.println("Batches complete " + i); } correct += tester.getWinningArm(); } System.out.println(correct); assertTrue(correct > 9500); } @Test public void testPerformance() { int maxBanditIterations = 0; double maxBanditRegret = 0.0; for (int i = 51; i<= 60; i++) { RandomEngine engine = new MersenneTwister(i); BanditPerformance performance = new BanditPerformance(2); BatchedThompsonSampling batchedBandit = new BatchedThompsonSampling(); batchedBandit.setRandomEngine(engine); BatchedBanditTester tester = new BatchedBanditTester(batchedBandit, engine); double regret = performance.cumulativeRegret(0.015, Lists.newArrayList(0.01, 0.015)); maxBanditIterations = max(maxBanditIterations, tester.getIterations()); maxBanditRegret = max(maxBanditRegret, regret); assertEquals(1, tester.getWinningArm()); } int minAbIterations = Integer.MAX_VALUE; double minAbRegret = Double.MAX_VALUE; for (int i = 51; i<= 60; i++) { RandomEngine engine = new MersenneTwister(i); BanditPerformance performance = new BanditPerformance(2); BatchedABTest batchedBandit = new BatchedABTest(); batchedBandit.setRandomEngine(engine); BatchedBanditTester tester = new BatchedBanditTester(batchedBandit, engine); double regret = performance.cumulativeRegret(0.015, Lists.newArrayList(0.01, 0.015)); minAbIterations = min(minAbIterations, tester.getIterations()); minAbRegret = min(minAbRegret, regret); assertEquals(1, tester.getWinningArm()); } } @Test public void testPerformance2() { int maxBanditIterations = 0; double maxBanditRegret = 0.0; for (int i = 51; i<= 60; i++) { RandomEngine engine = new MersenneTwister(i); BanditPerformance performance = new BanditPerformance(6); BatchedThompsonSampling batchedBandit = new BatchedThompsonSampling(); batchedBandit.setRandomEngine(engine); BatchedBanditTester tester = new BatchedBanditTester(batchedBandit, engine, Lists.newArrayList(new BernouliArm(0.04, engine), new BernouliArm(0.05, engine), new BernouliArm(0.045, engine), new BernouliArm(0.03, engine), new BernouliArm(0.02, engine), new BernouliArm(0.035, engine))); double regret = performance.cumulativeRegret(0.05, Lists.newArrayList(0.04, 0.05, 0.045, 0.03, 0.02, 0.035)); maxBanditIterations = max(maxBanditIterations, tester.getIterations()); maxBanditRegret = max(maxBanditRegret, regret); } int minAbIterations = Integer.MAX_VALUE; double minAbRegret = Double.MAX_VALUE; for (int i = 51; i<= 60; i++) { RandomEngine engine = new MersenneTwister(i); BanditPerformance performance = new BanditPerformance(6); BatchedABTest batchedBandit = new BatchedABTest(); batchedBandit.setRandomEngine(engine); BatchedBanditTester tester = new BatchedBanditTester(batchedBandit, engine, Lists.newArrayList(new BernouliArm(0.04, engine), new BernouliArm(0.05, engine), new BernouliArm(0.045, engine), new BernouliArm(0.03, engine), new BernouliArm(0.02, engine), new BernouliArm(0.035, engine))); double regret = performance.cumulativeRegret(0.05, Lists.newArrayList(0.04, 0.05, 0.045, 0.03, 0.02, 0.35)); minAbIterations = min(minAbIterations, tester.getIterations()); minAbRegret = min(minAbRegret, regret); } System.out.println("Min A/B regret: " + minAbRegret); System.out.println("Max Bandit regret: " + maxBanditRegret); System.out.println("Min A/B # batches (batch size = 100 samples): " + minAbIterations); System.out.println("Max Bandit # batches (batch size = 100 samples): " + maxBanditIterations); } }