/* * Copyright 2018 org.LTR4L * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.ltr4l.trainers; import java.io.*; import java.util.*; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.ObjectMapper; import org.ltr4l.Ranker; import org.ltr4l.boosting.Ensemble; import org.ltr4l.boosting.RankBoost; import org.ltr4l.evaluation.DCG; import org.ltr4l.evaluation.RankEval; import org.ltr4l.evaluation.RankEval.RankEvalFactory; import org.ltr4l.query.Query; import org.ltr4l.query.QuerySet; import org.ltr4l.svm.AbstractSVM; import org.ltr4l.tools.Config; import org.ltr4l.tools.Error; import org.ltr4l.tools.LossCalculator; import org.ltr4l.tools.Report; import static org.ltr4l.tools.LossCalculator.DataSet.TRAINING; import static org.ltr4l.tools.LossCalculator.DataSet.VALIDATION; /** * Abstract class used for training the model held by Rankers. * This class is also the parameter holder. * * train() must be implemented based on algorithm used. */ public abstract class AbstractTrainer<R extends Ranker, C extends Config> { protected final int epochNum; protected final List<Query> trainingSet; protected final List<Query> validationSet; protected double maxScore; protected final Report report; protected final R ranker; protected final C config; protected final Error errorFunc; protected final int batchSize; protected final int evalK; protected final String modelFile; protected final RankEval eval; protected final LossCalculator lossCalc; AbstractTrainer(List<Query> training, List<Query> validation, C config, R ranker, Error errorFunc, LossCalculator lossCalc) { this.config = config; epochNum = config.numIterations; trainingSet = training; validationSet = validation; maxScore = 0d; this.ranker = ranker; //TODO: ranker, errorFunc, and lossCalc assignments are done in child classes by implementing methods... this.errorFunc = errorFunc; this.lossCalc = lossCalc; //TODO: In child classes, requires that ranker and errorFunc be created already... assert(config.batchSize >= 0); batchSize = config.batchSize; eval = getEvaluator(config); evalK = getEvaluatorAtK(config); modelFile = getModelFile(config); report = Report.getReport(config); } private static RankEval getEvaluator(Config config){ if (config.evaluation == null || config.evaluation.evaluator == null || config.evaluation.evaluator.equals("")) return new DCG.NDCG(); final String evaluator = config.evaluation.evaluator; return RankEvalFactory.get(evaluator); } private static int getEvaluatorAtK(Config config){ final int K_DEFAULT = 10; if(config.evaluation == null || config.evaluation.params == null) return K_DEFAULT; return Config.getInt(config.evaluation.params, "k", K_DEFAULT); } private static String getModelFile(Config config){ if(config.model == null || config.model.file == null || config.model.file.isEmpty()) return Config.Model.DEFAULT_MODEL_FILE; return config.model.file; } private static String getReportFile(Config config){ return (config.report == null) ? null : config.report.file; } public R getRanker() { return ranker; } public double[] calculateLoss() { return new double[]{lossCalc.calculateLoss(TRAINING, ranker), lossCalc.calculateLoss(VALIDATION, ranker)}; } public void validate(int iter, int pos) { double newScore = eval.calculateAvgAllQueries(ranker, validationSet, pos); if (newScore > maxScore) { maxScore = newScore; } double[] losses = calculateLoss(); report.log(iter, newScore, losses[0], losses[1]); } public abstract void train(); public void trainAndValidate() { for (int i = 1; i <= epochNum; i++) { train(); validate(i, evalK); } report.close(); try { if(!config.nomodel) ranker.writeModel(config, modelFile); } catch (IOException e) { throw new RuntimeException(e); } } public static class TrainerFactory { /** * This returns the appropriate implementation of Trainer depending on the algorithm. * @param trainingSet The QuerySet containing the data to be used for training. * @param validationSet The QuerySet containing the data to be used for validation. * @param configFile The Config file containing parameters needed for Ranker class. * @param override Set another Config that overrides configFile. * @return new class which implements trainer. */ public static AbstractTrainer getTrainer(QuerySet trainingSet, QuerySet validationSet, String configFile, Config override) { String algorithm; ObjectMapper mapper = new ObjectMapper(); mapper.disable(JsonParser.Feature.AUTO_CLOSE_SOURCE); try{ Reader reader = new BufferedReader(new InputStreamReader(new FileInputStream(configFile))); reader.mark(8192); //default buffer size... Map model = mapper.readValue(reader, Map.class); algorithm = ((String)model.get("algorithm")).toLowerCase(); reader.reset(); //Note: if the config file exceeds buffer size, this will not work. return getTrainer(algorithm, trainingSet, validationSet, reader, override); } catch (IOException e){ throw new IllegalArgumentException(e); } } /** * This returns the appropriate implementation of Trainer depending on the algorithm. * @param trainingSet The QuerySet containing the data to be used for training. * @param validationSet The QuerySet containing the data to be used for validation. * @param reader The Config Reader containing parameters needed for Ranker class. * @param override Set another Config that overrides reader Config. * @return new class which implements trainer. */ public static AbstractTrainer getTrainer(QuerySet trainingSet, QuerySet validationSet, Reader reader, Config override) { String algorithm; ObjectMapper mapper = new ObjectMapper(); mapper.disable(JsonParser.Feature.AUTO_CLOSE_SOURCE); try{ Map model = mapper.readValue(reader, Map.class); algorithm = ((String)model.get("algorithm")).toLowerCase(); reader.reset(); return getTrainer(algorithm, trainingSet, validationSet, reader, override); } catch (IOException e){ throw new IllegalArgumentException(e); } } /** * This returns the appropriate implementation of Trainer depending on the algorithm. * @param algorithm Algorithm/implementation to be used. * @param trainingSet The QuerySet containing the data to be used for training. * @param validationSet The QuerySet containing the data to be used for validation. * @param reader The Config Reader containing parameters needed for Ranker class. * @param override Set another Config that overrides reader Config. * @return new class which implements trainer. */ public static AbstractTrainer getTrainer(String algorithm, QuerySet trainingSet, QuerySet validationSet, Reader reader, Config override) { List<Query> training = trainingSet.getQueries(); List<Query> validation = validationSet.getQueries(); try{ switch (algorithm.toLowerCase()) { case "prank": { Config config = Config.getConfig(reader, Config.ConfigType.BASIC); config.overrideBy(override); return new PRankTrainer(training, validation, config); } case "oap": { OAPBPMTrainer.OAPBPMConfig config = Config.getConfig(reader, Config.ConfigType.OAP); config.overrideBy(override); return new OAPBPMTrainer(training, validation, config); } case "ranknet": { MLPTrainer.MLPConfig config = Config.getConfig(reader, Config.ConfigType.MLP); config.overrideBy(override); return new RankNetTrainer(training, validation, config); } case "franknet": { MLPTrainer.MLPConfig config = Config.getConfig(reader, Config.ConfigType.MLP); config.overrideBy(override); return new FRankTrainer(training, validation, config); } case "lambdarank": { MLPTrainer.MLPConfig config = Config.getConfig(reader, Config.ConfigType.MLP); config.overrideBy(override); return new LambdaRankTrainer(training, validation, config); } case "nnrank": { MLPTrainer.MLPConfig config = Config.getConfig(reader, Config.ConfigType.MLP); List<Map<String, Object>> layers = config.getReqArrayParams(config.params, "layers"); Map<String, Object> outputLayer = new HashMap<>(); outputLayer.put("num", QuerySet.findMaxLabel(training)); outputLayer.put("activator", "sigmoid"); layers.add(outputLayer); config.overrideBy(override); return new NNRankTrainer(training, validation, config); } case "sortnet": { MLPTrainer.MLPConfig config = Config.getConfig(reader, Config.ConfigType.MLP); config.overrideBy(override); return new SortNetTrainer(training, validation, config); } case "listnet": { MLPTrainer.MLPConfig config = Config.getConfig(reader, Config.ConfigType.MLP); config.overrideBy(override); return new ListNetTrainer(training, validation, config); } case "lambdamart": { Ensemble.TreeConfig config = Config.getConfig(reader, Config.ConfigType.TREE); config.overrideBy(override); return new LambdaMartTrainer(training, validation, config); } case "rankboost": { RankBoost.RankBoostConfig config = Config.getConfig(reader, Config.ConfigType.BOOSTING); config.overrideBy(override); return new RankBoostTrainer(training, validation, config); } case "adaboost": { RankBoost.RankBoostConfig config = Config.getConfig(reader, Config.ConfigType.BOOSTING); config.overrideBy(override); return new AdaBoostTrainer(training, validation, config); } case "ranksvm": { AbstractSVM.SVMConfig config = Config.getConfig(reader, Config.ConfigType.SVM); config.overrideBy(override); return new RankSVMTrainer(training, validation, config); } default: throw new IllegalArgumentException(); } } finally { try { if(reader != null) reader.close(); } catch (IOException ignored) { } } } } }