package com.wealthfront.thompsonsampling; import cern.jet.random.engine.MersenneTwister; import cern.jet.random.engine.RandomEngine; import com.google.common.base.Charsets; import com.google.common.base.Function; import com.google.common.collect.FluentIterable; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.io.Files; import org.junit.Ignore; import org.junit.Test; import java.io.File; import java.io.IOException; import java.util.List; import static java.lang.String.format; public class MainTest { private final RandomEngine engine = new MersenneTwister(-1); private static interface BanditCreator { public BatchedBandit bandit(); } @Test @Ignore public void testAll() throws IOException { ImmutableList<Double> mainArmWeights = ImmutableList.<Double>builder().add(0.04, 0.05, 0.045, 0.03, 0.02, 0.035).build(); for (int i = 2; i <= 6; i++) { List<Double> armWeights = mainArmWeights.subList(0, i); BanditCreator creator = new BanditCreator() { @Override public BatchedBandit bandit() { BatchedABTest bandit = new BatchedABTest(); return bandit; } }; String name = String.format("full_ab_%d", i); banditTest(armWeights, creator, name); } for (int i = 2; i <= 6; i++) { List<Double> armWeights = mainArmWeights.subList(0, i); BanditCreator creator = new BanditCreator() { @Override public BatchedBandit bandit() { PairwiseAbTest bandit = new PairwiseAbTest(); return bandit; } }; String name = String.format("pairwise_ab_%d", i); banditTest(armWeights, creator, name); } for (int i = 2; i <= 6; i++) { List<Double> armWeights = mainArmWeights.subList(0, i); BanditCreator creator = new BanditCreator() { @Override public BatchedBandit bandit() { BatchedThompsonSampling bandit = new BatchedThompsonSampling(); return bandit; } }; String name = String.format("thompson_arms_%d", i); banditTest(armWeights, creator, name); } } @Test @Ignore public void testFriday() throws IOException { ImmutableList<Double> mainArmWeights = ImmutableList.<Double>builder().add(0.01, 0.015).build(); for (int i = 2; i <= 2; i++) { List<Double> armWeights = mainArmWeights.subList(0, i); BanditCreator creator = new BanditCreator() { @Override public BatchedBandit bandit() { BatchedABTest bandit = new BatchedABTest(); return bandit; } }; String name = String.format("full_ab_%d", i); banditTest(armWeights, creator, name); } for (int i = 2; i <= 2; i++) { List<Double> armWeights = mainArmWeights.subList(0, i); final int iSafe = i; BanditCreator creator = new BanditCreator() { @Override public BatchedBandit bandit() { BatchedThompsonSampling bandit = new BatchedThompsonSampling(); return bandit; } }; String name = String.format("thompson_arms_%d", i); banditTest(armWeights, creator, name); } } private void banditTest(List<Double> armWeights, BanditCreator creator, String name) throws IOException { File file = new File(format("/tmp/bandit-results/%s.csv", name)); Files.createParentDirs(file); Files.touch(file); List<BernouliArm> arms = FluentIterable.from(armWeights).transform(new Function<Double, BernouliArm>() { @Override public BernouliArm apply(Double aDouble) { return new BernouliArm(aDouble, engine); } }).toList(); for (int i = 0; i < 10000; i++) { BatchedBanditTester tester = new BatchedBanditTester(creator.bandit(), engine, arms); String l = format("%d,%d,%f\n", tester.getWinningArm(), tester.getIterations(), tester.getCumulativeRegret()); Files.append(l, file, Charsets.UTF_8); } } @Test public void computeWeights() { long uniqueExitToInvite = 465; long uniqueControl = 242 + 214; long sentInviteExitToInvite = 48; long sentInviteControl = 16 + 18; ObservedArmPerformance exitToInvite = new ObservedArmPerformance(sentInviteExitToInvite, uniqueExitToInvite - sentInviteExitToInvite); ObservedArmPerformance control = new ObservedArmPerformance(sentInviteControl, uniqueControl - sentInviteControl); BatchedThompsonSampling bandit = new BatchedThompsonSampling(); BanditPerformance performance = new BanditPerformance(Lists.newArrayList(exitToInvite, control)); System.out.println(bandit.getBanditStatistics(performance)); uniqueExitToInvite = 474; uniqueControl = 243 + 218; sentInviteExitToInvite = 48; sentInviteControl = 16 + 20; exitToInvite = new ObservedArmPerformance(sentInviteExitToInvite, uniqueExitToInvite - sentInviteExitToInvite); control = new ObservedArmPerformance(sentInviteControl, uniqueControl - sentInviteControl); performance = new BanditPerformance(Lists.newArrayList(exitToInvite, control)); bandit = new BatchedThompsonSampling(); System.out.println(bandit.getBanditStatistics(performance)); } }