package slp.core.modeling.runners; import java.io.File; import java.io.IOException; import java.io.Reader; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; import java.util.DoubleSummaryStatistics; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; import slp.core.lexing.Lexer; import slp.core.lexing.runners.LexerRunner; import slp.core.modeling.Model; import slp.core.modeling.ngram.NGramModel; import slp.core.translating.Vocabulary; import slp.core.util.Pair; /** * This class can be used to run {@link Model}-related functions over bodies of code. * It provides the lexing and translation steps necessary to allow immediate learning and modeling from directories or files. * As such, it wraps the pipeline stages {@link Reader} --> {@link Lexer} --> Translate ({@link Vocabulary}) --> {@link Model}. * <br /> * This class uses a {@link LexerRunner}, which differentiates between file and line data and provides a some additional utilities. * It also provides easier access to self-testing (in which each line is forgotten before modeling it and re-learned after), * which is helpful for count-based models such as {@link NGramModel}s. * * @author Vincent Hellendoorn * */ public class ModelRunner { private static final double INV_NEG_LOG_2 = -1.0/Math.log(2); public static final int DEFAULT_NGRAM_ORDER = 6; public static int GLOBAL_PREDICTION_CUTOFF = 10; protected final LexerRunner lexerRunner; protected final Vocabulary vocabulary; protected final Model model; private boolean selfTesting = false; public ModelRunner(Model model, LexerRunner lexerRunner, Vocabulary vocabulary) { this.lexerRunner = lexerRunner; this.vocabulary = vocabulary; this.model = model; } /** * Convenience function that creates a new {@link ModelRunner} instance for the provided {@link Model} * that is backed by the same {@link LexerRunner} and {@link Vocabulary}. * * @param model The model to provide a {@link ModelRunner} for. * @return A new {@link ModelRunner} for this {@link Model}, * with the current {@link ModelRunner}'s {@link LexerRunner} and {@link Vocabulary} */ public ModelRunner copyForModel(Model model) { return new ModelRunner(model, this.lexerRunner, this.vocabulary); } public LexerRunner getLexerRunner() { return this.lexerRunner; } public Model getModel() { return this.model; } public Vocabulary getVocabulary() { return this.vocabulary; } /** * Enables self testing: if we are testing on data that we also trained on, and our models are able to forget events, * we can simulated training on all-but one sequence (the one we are modeling) by temporarily forgetting * an event, modeling it and re-learning it afterwards. This maximizes use of context information and can be used * to simulate full cross-validation. * * @param selfTesting If true, will temporarily "forget" every sequence before modeling it and "re-learn" it afterwards */ public void setSelfTesting(boolean selfTesting) { this.selfTesting = selfTesting; } public static int getPredictionCutoff() { return GLOBAL_PREDICTION_CUTOFF; } public static void setPredictionCutoff(int cutoff) { GLOBAL_PREDICTION_CUTOFF = cutoff; } private final long LEARN_PRINT_INTERVAL = 1000000; private long[] learnStats = new long[2]; public void learnDirectory(File file) { this.learnStats = new long[] { 0, -System.currentTimeMillis() }; this.lexerRunner.lexDirectory(file) .forEach(p -> { this.model.notify(p.left); this.learnTokens(p.right); }); if (this.learnStats[0] > LEARN_PRINT_INTERVAL && this.learnStats[1] != 0) { System.out.printf("Counting complete: %d tokens processed in %ds\n", this.learnStats[0], (System.currentTimeMillis() + this.learnStats[1])/1000); } } public void learnFile(File f) { if (!this.lexerRunner.willLexFile(f)) return; this.model.notify(f); learnTokens(this.lexerRunner.lexFile(f)); } public void learnContent(String content) { learnTokens(this.lexerRunner.lexText(content)); } public void learnTokens(Stream<Stream<String>> lexed) { if (this.lexerRunner.isPerLine()) { lexed.map(this.getVocabulary()::toIndices) .map(l -> l.peek(l2 -> logLearningProgress())) .map(l -> l.collect(Collectors.toList())) .forEach(this.model::learn); } else { this.model.learn(lexed.map(l -> l.peek(l2 -> logLearningProgress())) .flatMap(this.getVocabulary()::toIndices) .collect(Collectors.toList())); } } private void logLearningProgress() { if (++this.learnStats[0] % this.LEARN_PRINT_INTERVAL == 0 && this.learnStats[1] != 0) { System.out.printf("Counting: %dM tokens processed in %ds\n", Math.round(this.learnStats[0]/1e6), (System.currentTimeMillis() + this.learnStats[1])/1000); } } public void forgetDirectory(File file) { try { Files.walk(file.toPath()) .map(Path::toFile) .filter(File::isFile) .forEach(f -> forgetFile(f)); } catch (IOException e) { e.printStackTrace(); } } public void forgetFile(File f) { if (!this.lexerRunner.willLexFile(f)) return; this.model.notify(f); forgetTokens(this.lexerRunner.lexFile(f)); } public void forgetContent(String content) { forgetTokens(this.lexerRunner.lexText(content)); } public void forgetTokens(Stream<Stream<String>> lexed) { if (this.lexerRunner.isPerLine()) { lexed.map(this.getVocabulary()::toIndices) .map(l -> l.collect(Collectors.toList())) .forEach(this.model::forget); } else { this.model.forget(lexed.flatMap(this.getVocabulary()::toIndices).collect(Collectors.toList())); } } private final int MODEL_PRINT_INTERVAL = 100000; private long[] modelStats = new long[2]; private double ent = 0.0; private double mrr = 0.0; public Stream<Pair<File, List<List<Double>>>> modelDirectory(File file) { this.modelStats = new long[] { 0, -System.currentTimeMillis() }; this.ent = 0.0; return this.lexerRunner.lexDirectory(file) .map(p -> { this.model.notify(p.left); return Pair.of(p.left, this.modelTokens(p.right)); }); } public List<List<Double>> modelFile(File f) { if (!this.lexerRunner.willLexFile(f)) return null; this.model.notify(f); return modelTokens(this.lexerRunner.lexFile(f)); } public List<List<Double>> modelContent(String content) { return modelTokens(this.lexerRunner.lexText(content)); } public List<List<Double>> modelTokens(Stream<Stream<String>> lexed) { List<List<Double>> lineProbs; if (this.lexerRunner.isPerLine()) { lineProbs = lexed.map(this.getVocabulary()::toIndices) .map(l -> l.collect(Collectors.toList())) .map(l -> modelSequence(l)) .peek(this::logModelingProgress) .collect(Collectors.toList()); } else { List<Integer> lineLengths = new ArrayList<>(); List<Double> modeled = modelSequence(lexed .map(this.getVocabulary()::toIndices) .map(l -> l.collect(Collectors.toList())) .peek(l -> lineLengths.add(l.size())) .flatMap(l -> l.stream()).collect(Collectors.toList())); lineProbs = toLines(modeled, lineLengths); logModelingProgress(modeled); } return lineProbs; } protected List<Double> modelSequence(List<Integer> tokens) { if (this.selfTesting) this.model.forget(tokens); List<Double> entropies = this.model.model(tokens).stream() .map(this::toProb) .map(ModelRunner::toEntropy) .collect(Collectors.toList()); if (this.selfTesting) this.model.learn(tokens); return entropies; } private void logModelingProgress(List<Double> modeled) { DoubleSummaryStatistics stats = modeled.stream().skip(1) .mapToDouble(Double::doubleValue).summaryStatistics(); long prevCount = this.modelStats[0]; this.modelStats[0] += stats.getCount(); this.ent += stats.getSum(); if (this.modelStats[0] / this.MODEL_PRINT_INTERVAL > prevCount / this.MODEL_PRINT_INTERVAL && this.modelStats[1] != 0) { System.out.printf("Modeling: %dK tokens processed in %ds, avg. entropy: %.4f\n", Math.round(this.modelStats[0]/1e3), (System.currentTimeMillis() + this.modelStats[1])/1000, this.ent/this.modelStats[0]); } } public Stream<Pair<File, List<List<Double>>>> predictDirectory(File file) { this.modelStats = new long[] { 0, -System.currentTimeMillis() }; this.mrr = 0.0; return this.lexerRunner.lexDirectory(file) .map(p -> { this.model.notify(p.left); return Pair.of(p.left, this.predictTokens(p.right)); }); } public List<List<Double>> predictFile(File f) { if (!this.lexerRunner.willLexFile(f)) return null; this.model.notify(f); return predictTokens(this.lexerRunner.lexFile(f)); } public List<List<Double>> predictContent(String content) { return predictTokens(this.lexerRunner.lexText(content)); } public List<List<Double>> predictTokens(Stream<Stream<String>> lexed) { List<List<Double>> lineProbs; if (this.lexerRunner.isPerLine()) { lineProbs = lexed .map(this.getVocabulary()::toIndices) .map(l -> l.collect(Collectors.toList())) .map(l -> predictSequence(l)) .peek(this::logPredictionProgress) .collect(Collectors.toList()); } else { List<Integer> lineLengths = new ArrayList<>(); List<Double> modeled = predictSequence(lexed .map(this.getVocabulary()::toIndices) .map(l -> l.collect(Collectors.toList())) .peek(l -> lineLengths.add(l.size())) .flatMap(l -> l.stream()).collect(Collectors.toList())); lineProbs = toLines(modeled, lineLengths); logPredictionProgress(modeled); } return lineProbs; } protected List<Double> predictSequence(List<Integer> tokens) { if (this.selfTesting) this.model.forget(tokens); List<List<Integer>> preds = toPredictions(this.model.predict(tokens)); List<Double> mrrs = IntStream.range(0, tokens.size()) .mapToObj(i -> preds.get(i).indexOf(tokens.get(i))) .map(ModelRunner::toMRR) .collect(Collectors.toList()); if (this.selfTesting) this.model.learn(tokens); return mrrs; } private void logPredictionProgress(List<Double> modeled) { DoubleSummaryStatistics stats = modeled.stream().skip(1) .mapToDouble(Double::doubleValue).summaryStatistics(); long prevCount = this.modelStats[0]; this.modelStats[0] += stats.getCount(); this.mrr += stats.getSum(); if (this.modelStats[0] / this.MODEL_PRINT_INTERVAL > prevCount / this.MODEL_PRINT_INTERVAL && this.modelStats[1] != 0) { System.out.printf("Predicting: %dK tokens processed in %ds, avg. MRR: %.4f\n", Math.round(this.modelStats[0]/1e3), (System.currentTimeMillis() + this.modelStats[1])/1000, this.mrr/this.modelStats[0]); } } public Stream<Pair<File, List<List<Completion>>>> completeDirectory(File file) { this.modelStats = new long[] { 0, -System.currentTimeMillis() }; this.mrr = 0.0; return this.lexerRunner.lexDirectory(file) .map(p -> { this.model.notify(p.left); return Pair.of(p.left, this.completeTokens(p.right)); }); } public List<List<Completion>> completeFile(File f) { if (!this.lexerRunner.willLexFile(f)) return null; this.model.notify(f); return completeTokens(this.lexerRunner.lexFile(f)); } public List<List<Completion>> completeContent(String content) { return completeTokens(this.lexerRunner.lexText(content)); } public List<List<Completion>> completeTokens(Stream<Stream<String>> lexed) { List<List<Completion>> lineCompletions; if (this.lexerRunner.isPerLine()) { lineCompletions = lexed .map(this.getVocabulary()::toIndices) .map(l -> l.collect(Collectors.toList())) .map(l -> completeSequence(l)) .peek(this::logCompletionProgress) .collect(Collectors.toList()); } else { List<Integer> lineLengths = new ArrayList<>(); List<Completion> commpletions = completeSequence(lexed .map(this.getVocabulary()::toIndices) .map(l -> l.collect(Collectors.toList())) .peek(l -> lineLengths.add(l.size())) .flatMap(l -> l.stream()).collect(Collectors.toList())); lineCompletions = toLines(commpletions, lineLengths); logCompletionProgress(commpletions); } return lineCompletions; } protected List<Completion> completeSequence(List<Integer> tokens) { if (this.selfTesting) this.model.forget(tokens); List<Map<Integer, Pair<Double, Double>>> preds = this.model.predict(tokens); if (this.selfTesting) this.model.learn(tokens); List<Completion> rankings = IntStream.range(0, preds.size()) .mapToObj(i -> { List<Pair<Integer, Double>> completions = preds.get(i).entrySet().stream() .map(e -> Pair.of(e.getKey(), toProb(e.getValue()))) .sorted((p1, p2) -> -Double.compare(p1.right, p2.right)) .limit(GLOBAL_PREDICTION_CUTOFF) .collect(Collectors.toList()); return new Completion(tokens.get(i), completions); }).collect(Collectors.toList()); return rankings; } private void logCompletionProgress(List<Completion> completions) { DoubleSummaryStatistics stats = completions.stream().skip(1) .map(Completion::getRank) .mapToDouble(ModelRunner::toMRR) .summaryStatistics(); long prevCount = this.modelStats[0]; this.modelStats[0] += stats.getCount(); this.mrr += stats.getSum(); if (this.modelStats[0] / this.MODEL_PRINT_INTERVAL > prevCount / this.MODEL_PRINT_INTERVAL && this.modelStats[1] != 0) { System.out.printf("Predicting: %dK tokens processed in %ds, avg. MRR: %.4f\n", Math.round(this.modelStats[0]/1e3), (System.currentTimeMillis() + this.modelStats[1])/1000, this.mrr/this.modelStats[0]); } } public List<Double> toProb(List<Pair<Double, Double>> probConfs) { return probConfs.stream().map(this::toProb).collect(Collectors.toList()); } public double toProb(Pair<Double, Double> probConf) { double prob = probConf.left; double conf = probConf.right; return prob*conf + (1 - conf)/this.getVocabulary().size(); } public static double toEntropy(double probability) { return Math.log(probability) * INV_NEG_LOG_2; } public static double toMRR(Integer ix) { return ix >= 0 ? 1.0 / (ix + 1) : 0.0; } public List<List<Integer>> toPredictions(List<Map<Integer, Pair<Double, Double>>> probConfs) { return probConfs.stream().map(this::toPredictions).collect(Collectors.toList()); } public List<Integer> toPredictions(Map<Integer, Pair<Double, Double>> probConf) { return probConf.entrySet().stream() .map(e -> Pair.of(e.getKey(), toProb(e.getValue()))) .sorted((p1, p2) -> -Double.compare(p1.right, p2.right)) .limit(GLOBAL_PREDICTION_CUTOFF) .map(p -> p.left) .collect(Collectors.toList()); } private <K> List<List<K>> toLines(List<K> modeled, List<Integer> lineLengths) { List<List<K>> perLine = new ArrayList<>(); int ix = 0; for (int i = 0; i < lineLengths.size(); i++) { List<K> line = new ArrayList<>(); for (int j = 0; j < lineLengths.get(i); j++) { line.add(modeled.get(ix++)); } perLine.add(line); } return perLine; } public DoubleSummaryStatistics getStats(Map<File, List<List<Double>>> fileProbs) { return getStats(fileProbs.entrySet().stream().map(e -> Pair.of(e.getKey(), e.getValue()))); } public DoubleSummaryStatistics getStats(Stream<Pair<File, List<List<Double>>>> fileProbs) { return getFileStats(fileProbs.map(p -> p.right)); } public DoubleSummaryStatistics getStats(List<List<Double>> fileProbs) { return getFileStats(Stream.of(fileProbs)); } private DoubleSummaryStatistics getFileStats(Stream<List<List<Double>>> fileProbs) { if (this.lexerRunner.isPerLine()) { return fileProbs.flatMap(List::stream) .flatMap(l -> l.stream().skip(1)) .mapToDouble(p -> p).summaryStatistics(); } else { return fileProbs.flatMap(f -> f.stream() .flatMap(l -> l.stream()) .skip(1)) .mapToDouble(p -> p).summaryStatistics(); } } public DoubleSummaryStatistics getCompletionStats(List<List<Completion>> completions) { List<List<Double>> MRRs = completions.stream() .map(l -> l.stream().map(c -> toMRR(c.getRank()))) .map(l -> l.collect(Collectors.toList())).collect(Collectors.toList()); return getFileStats(Stream.of(MRRs)); } }