package org.sgdtk.exec; import com.lmax.disruptor.*; import com.lmax.disruptor.dsl.Disruptor; import com.lmax.disruptor.dsl.ProducerType; import org.sgdtk.FeatureVector; import org.sgdtk.Learner; import org.sgdtk.Model; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; /** * Use LMAX Disruptor ring buffer to implement an executor * * LMAX Disruptor is one of my favorite toys for processing. It is simple to use and very effective performance-wise. * Here we can operate contention-free -- unlike VW we dont even have condition variables bounding our buffer! * * @author dpressel */ public class RingBufferTrainingExecutor implements TrainingExecutor { private static final Logger log = LoggerFactory.getLogger(RingBufferTrainingExecutor.class); ExecutorService executor; Disruptor<MessageEvent> disruptor; MessageEventHandler handler; int numEpochs; private File cacheFile; private Strategy strategy; public enum Strategy { YIELD, BUSY }; @Override public int getNumEpochs() { return numEpochs; } @Override public File getCacheFile() { return cacheFile; } /** * Create one */ public RingBufferTrainingExecutor() { this(Strategy.YIELD); } /** * Create one */ public RingBufferTrainingExecutor(Strategy strategy) { this.strategy = strategy; } /** * Class that holds our feature vector */ public static class MessageEvent { private FeatureVector fv; public void set(FeatureVector fv) { this.fv = fv; } } /** * Class that produces our FV holder */ public static class MessageEventFactory implements EventFactory<MessageEvent> { public MessageEvent newInstance() { return new MessageEvent(); } } /** * This is our processor. It is triggered when an event is placed onto the RingBuffer. * */ public static class MessageEventHandler implements EventHandler<MessageEvent> { Learner learner; Model model; private long lastTime; private AtomicInteger currentEpoch = new AtomicInteger(); private List<TrainingEventListener> listeners; /** * Take in the learner and model and train * @param learner The learner * @param model The initialized but empty model */ public MessageEventHandler(Learner learner, Model model, List<TrainingEventListener> listeners) { this.learner = learner; this.model = model; lastTime = System.currentTimeMillis(); this.listeners = listeners; } /** * On a message, check if it is a null FV. If so, we are at the end of an epoch. * Update book-keeping. * @param messageEvent An FV holder * @param l Sequence number (which is increasing) * @param b not used * @throws Exception */ @Override public void onEvent(MessageEvent messageEvent, long l, boolean b) throws Exception { // get the message off the buffer and train on it if (messageEvent.fv == null) { long tNow = System.currentTimeMillis(); double diff = (tNow - lastTime)/1000.; lastTime = tNow; int currentEpoch1Based = currentEpoch.incrementAndGet(); for (TrainingEventListener listener : listeners) { listener.onEpochEnd(learner, model, diff); } log.info("Epoch " + currentEpoch1Based + " completed in " + diff + "s"); return; } learner.trainOne(model, messageEvent.fv); } /** * Get the current epoch * @return */ public int getCurrentEpoch() { return currentEpoch.get(); } } /** * Initialize the Disruptor. The buffer size must be a power of 2 or the RingBuffer will complain * * @param learner The learner * @param model The initialized but untrained model * @param numEpochs The number of epochs * @param cacheFile The cache file to use * @param bufferSize The size of the internal buffer to train from */ @Override public void initialize(Learner learner, Model model, int numEpochs, File cacheFile, int bufferSize, List<TrainingEventListener> listeners) { this.numEpochs = numEpochs; executor = Executors.newSingleThreadExecutor(); MessageEventFactory factory = new MessageEventFactory(); WaitStrategy waitStrategy = (strategy == Strategy.YIELD) ? new YieldingWaitStrategy(): new BusySpinWaitStrategy(); disruptor = new Disruptor<MessageEvent>(factory, ExecUtils.nextPowerOf2(bufferSize), executor, ProducerType.SINGLE, waitStrategy); handler = new MessageEventHandler(learner, model, listeners); disruptor.handleEventsWith(handler); this.cacheFile = cacheFile; } /** * Start the disruptor */ @Override public void start() { disruptor.start(); } /** * Add a feature vector onto the RingBuffer * @param fv feature vector */ @Override public void add(FeatureVector fv) { RingBuffer<MessageEvent> ringBuffer = disruptor.getRingBuffer(); long sequence = ringBuffer.next(); try { MessageEvent event = ringBuffer.get(sequence); event.fv = fv; } finally { ringBuffer.publish(sequence); } } @Override public void kill() { disruptor.shutdown(); executor.shutdownNow(); try { executor.awaitTermination(100, TimeUnit.MICROSECONDS); } catch (InterruptedException intEx) { } } /** * Pretty much busy-wait our way through this check seeing if all epochs have passed yet * Then shutdown the disruptor and our ExecutorService. */ @Override public void join() { while (handler.getCurrentEpoch() < this.numEpochs) { try { Thread.sleep(10); } catch (InterruptedException intEx) { } } kill(); } }