ml.dmlc.xgboost4j.java.Booster Java Examples

The following examples show how to use ml.dmlc.xgboost4j.java.Booster. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: XGBoostModel.java    From samantha with MIT License 6 votes vote down vote up
public void setXGBooster(Booster booster) {
    this.booster = booster;
    try {
        Map<String, Integer> feaMap = booster.getFeatureScore(null);
        featureScores = new HashMap<>();
        for (Map.Entry<String, Integer> entry : feaMap.entrySet()) {
            String name = (String)indexSpace.getKeyForIndex(TreeKey.TREE.get(),
                    Integer.parseInt(entry.getKey().substring(1)));
            featureScores.put(name, entry.getValue());
        }
        logger.info("Number of non-zero importance features: {}", featureScores.size());
        logger.info("Feature importance: {}", Json.toJson(featureScores).toString());
    } catch (XGBoostError e) {
        throw new BadRequestException(e);
    }
}
 
Example #2
Source File: XGBoostTrainUDTF.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Nonnull
private static Booster train(@Nonnull final DMatrix dtrain, @Nonnegative final int round,
        @Nonnull final Map<String, Object> params, @Nullable final Reporter reporter)
        throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
        InstantiationException, XGBoostError {
    final Counters.Counter iterCounter = (reporter == null) ? null
            : reporter.getCounter("hivemall.XGBoostTrainUDTF$Counter", "iteration");

    final Booster booster = XGBoostUtils.createBooster(dtrain, params);
    for (int iter = 0; iter < round; iter++) {
        reportProgress(reporter);
        setCounterValue(iterCounter, iter + 1);

        booster.update(dtrain, iter);
    }
    return booster;
}
 
Example #3
Source File: XGBoostBatchPredictUDTF.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Override
public void close() throws HiveException {
    for (Entry<String, List<LabeledPointWithRowId>> e : rowBuffer.entrySet()) {
        String modelId = e.getKey();
        List<LabeledPointWithRowId> rowBatch = e.getValue();
        if (rowBatch.isEmpty()) {
            continue;
        }
        final Booster model = Objects.requireNonNull(mapToModel.get(modelId));
        try {
            predictAndFlush(model, rowBatch);
        } finally {
            XGBoostUtils.close(model);
        }
    }
    this.rowBuffer = null;
    this.mapToModel = null;
}
 
Example #4
Source File: MLXGBoost.java    From RecSys2018 with Apache License 2.0 6 votes vote down vote up
public static Async<Booster> asyncModel(final String modelFile,
		final int nthread) {
	// load xgboost model
	final Async<Booster> modelAsync = new Async<Booster>(() -> {
		try {
			Booster bst = XGBoost.loadModel(modelFile);
			if (nthread > 0) {
				bst.setParam("nthread", nthread);
			}
			return bst;
		} catch (XGBoostError e) {
			e.printStackTrace();
			return null;
		}
	}, Booster::dispose);
	return modelAsync;
}
 
Example #5
Source File: XGBoostUtils.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Nonnull
public static Text serializeBooster(@Nonnull final Booster booster) throws HiveException {
    try {
        byte[] b = IOUtils.toCompressedText(booster.toByteArray());
        return new Text(b);
    } catch (Throwable e) {
        throw new HiveException("Failed to serialize a booster", e);
    }
}
 
Example #6
Source File: XGBoostModel.java    From samantha with MIT License 5 votes vote down vote up
public void loadModel(String modelFile) {
    try {
        ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(modelFile));
        this.booster = (Booster) inputStream.readUnshared();
    } catch (IOException | ClassNotFoundException e) {
        throw new BadRequestException(e);
    }
}
 
Example #7
Source File: XGBoostMethod.java    From samantha with MIT License 5 votes vote down vote up
public void learn(PredictiveModel model, LearningData learningData, LearningData validData) {
    try {
        DMatrix dtrain = new DMatrix(new XGBoostIterator(learningData), null);
        Map<String, DMatrix> watches = new HashMap<>();
        if (validData != null) {
            watches.put("Validation", new DMatrix(new XGBoostIterator(validData), null));
        }
        Booster booster = XGBoost.train(dtrain, params, round, watches, null, null);
        XGBoostModel boostModel = (XGBoostModel) model;
        boostModel.setXGBooster(booster);
    } catch (XGBoostError e) {
        throw new BadRequestException(e);
    }
}
 
Example #8
Source File: XGBoostBatchPredictUDTF.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
private void predictAndFlush(@Nonnull final Booster model,
        @Nonnull final List<LabeledPointWithRowId> rowBatch) throws HiveException {
    DMatrix testData = null;
    final float[][] predicted;
    try {
        testData = XGBoostUtils.createDMatrix(rowBatch);
        predicted = model.predict(testData);
    } catch (XGBoostError e) {
        throw new HiveException("Exception caused at prediction", e);
    } finally {
        XGBoostUtils.close(testData);
    }
    forwardPredicted(rowBatch, predicted);
    rowBatch.clear();
}
 
Example #9
Source File: XGBoostBatchPredictUDTF.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Override
public void process(Object[] args) throws HiveException {
    if (mapToModel == null) {
        this.mapToModel = new HashMap<String, Booster>();
        this.rowBuffer = new HashMap<String, List<LabeledPointWithRowId>>();
    }
    if (args[1] == null) {
        return;
    }

    String modelId =
            PrimitiveObjectInspectorUtils.getString(nonNullArgument(args, 2), modelIdOI);
    Booster model = mapToModel.get(modelId);
    if (model == null) {
        Text arg3 = modelOI.getPrimitiveWritableObject(nonNullArgument(args, 3));
        model = XGBoostUtils.deserializeBooster(arg3);
        mapToModel.put(modelId, model);
    }

    List<LabeledPointWithRowId> rowBatch = rowBuffer.get(modelId);
    if (rowBatch == null) {
        rowBatch = new ArrayList<LabeledPointWithRowId>(_batchSize);
        rowBuffer.put(modelId, rowBatch);
    }
    LabeledPointWithRowId row = parseRow(args);
    rowBatch.add(row);
    if (rowBatch.size() >= _batchSize) {
        predictAndFlush(model, rowBatch);
    }
}
 
Example #10
Source File: XGBoostUtils.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Nonnull
public static Booster deserializeBooster(@Nonnull final Text model) throws HiveException {
    try {
        byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength());
        return XGBoost.loadModel(new FastByteArrayInputStream(b));
    } catch (Throwable e) {
        throw new HiveException("Failed to deserialize a booster", e);
    }
}
 
Example #11
Source File: XGBoostUtils.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
public static void close(@Nullable final Booster booster) {
    if (booster == null) {
        return;
    }
    try {
        booster.dispose();
    } catch (Throwable e) {
        ;
    }
}
 
Example #12
Source File: XGBoostUtils.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Nonnull
public static Booster createBooster(@Nonnull DMatrix matrix,
        @Nonnull Map<String, Object> params) throws NoSuchMethodException, XGBoostError,
        IllegalAccessException, InvocationTargetException, InstantiationException {
    Class<?>[] args = {Map.class, DMatrix[].class};
    Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args);
    ctor.setAccessible(true);
    return ctor.newInstance(new Object[] {params, new DMatrix[] {matrix}});
}
 
Example #13
Source File: MLXGBoost.java    From RecSys2018 with Apache License 2.0 5 votes vote down vote up
public static int[] getFeatureImportance(final Booster model,
		final String[] featNames) throws XGBoostError {

	int[] importances = new int[featNames.length];
	// NOTE: not used feature are dropped here
	Map<String, Integer> importanceMap = model.getFeatureScore(null);

	for (Map.Entry<String, Integer> entry : importanceMap.entrySet()) {
		// get index from f0, f1 feature name output from xgboost
		int index = Integer.parseInt(entry.getKey().substring(1));
		importances[index] = entry.getValue();
	}

	return importances;
}
 
Example #14
Source File: MLXGBoost.java    From RecSys2018 with Apache License 2.0 5 votes vote down vote up
public static MLXGBoostFeature[] analyzeFeatures(final String modelFile,
		final String featureFile) throws Exception {

	Booster model = XGBoost.loadModel(modelFile);

	List<String> temp = new LinkedList<String>();
	try (BufferedReader reader = new BufferedReader(
			new FileReader(featureFile))) {
		String line;
		while ((line = reader.readLine()) != null) {
			temp.add(line);
		}
	}

	// get feature importance scores
	String[] featureNames = new String[temp.size()];
	temp.toArray(featureNames);
	int[] importances = MLXGBoost.getFeatureImportance(model, featureNames);

	// sort features by their importance
	MLXGBoostFeature[] sortedFeatures = new MLXGBoostFeature[featureNames.length];
	for (int i = 0; i < featureNames.length; i++) {
		sortedFeatures[i] = new MLXGBoostFeature(featureNames[i],
				importances[i]);
	}
	Arrays.sort(sortedFeatures, new MLXGBoostFeature.ScoreComparator(true));

	return sortedFeatures;
}
 
Example #15
Source File: UtilFns.java    From SmoothNLP with GNU General Public License v3.0 5 votes vote down vote up
public static Booster loadXgbModel(String modelAddr) {

        try{
            InputStream modelIS = SmoothNLP.IOAdaptor.open(modelAddr);
            Booster booster = XGBoost.loadModel(modelIS);
            return booster;
        }catch(Exception e){
            // add proper warnings later
            System.out.println(e);
            return null;
        }
    }
 
Example #16
Source File: DependencyGraghEdgeCostTrain.java    From SmoothNLP with GNU General Public License v3.0 5 votes vote down vote up
public static void trainXgbModel(String trainFile, String devFile, String modelAddr, int nround, int negSampleRate, int earlyStop, int nthreads) throws IOException{
    final DMatrix trainMatrix = readCoNLL2DMatrix(trainFile,negSampleRate);
    final DMatrix devMatrix = readCoNLL2DMatrix(devFile,negSampleRate);
    try{
        Map<String, Object> params = new HashMap<String, Object>() {
            {
                put("nthread", nthreads);
                put("max_depth", 16);
                put("silent", 0);
                put("objective", "binary:logistic");
                put("colsample_bytree",0.95);
                put("colsample_bylevel",0.95);
                put("eta",0.2);
                put("subsample",0.95);
                put("lambda",0.2);

                put("min_child_weight",5);
                put("scale_pos_weight",negSampleRate);

                // other parameters
                // "objective" -> "multi:softmax", "num_class" -> "6"

                put("eval_metric", "logloss");
                put("tree_method","approx");
            }
        };
        Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
            {
                put("train", trainMatrix);
                put("dev",devMatrix);
            }
        };
        Booster booster = XGBoost.train(trainMatrix, params, nround, watches, null, null,null,earlyStop);
        OutputStream outstream = SmoothNLP.IOAdaptor.create(modelAddr);
        booster.saveModel(outstream);
    }catch(XGBoostError e){
        System.out.println(e);
    }
}
 
Example #17
Source File: MaxEdgeScoreDependencyParser.java    From SmoothNLP with GNU General Public License v3.0 5 votes vote down vote up
public static Booster loadXgbModel(String modelAddr) {

        try{
            InputStream modelIS = SmoothNLP.IOAdaptor.open(modelAddr);
            Booster booster = XGBoost.loadModel(modelIS);
            return booster;
        }catch(Exception e){
            // add proper warnings later
            System.out.println(e);
            return null;
        }
    }
 
Example #18
Source File: DependencyGraphRelationshipTagTrain.java    From SmoothNLP with GNU General Public License v3.0 4 votes vote down vote up
public static void trainXgbModel(String trainFile, String devFile, String modelAddr, int nround, int earlyStop,int nthreads ) throws IOException{
    final DMatrix trainMatrix = readCoNLL2DMatrix(trainFile);
    final DMatrix devMatrix = readCoNLL2DMatrix(devFile);
    try{
        Map<String, Object> params = new HashMap<String, Object>() {
            {
                put("nthread", nthreads);
                put("max_depth", 12);
                put("silent", 0);
                put("objective", "multi:softprob");
                put("colsample_bytree",0.90);
                put("colsample_bylevel",0.90);
                put("eta",0.2);
                put("subsample",0.95);
                put("lambda",1.0);

                // tree methods for regulation
                put("min_child_weight",5);
                put("max_leaves",128);

                // other parameters
                // "objective" -> "multi:softmax", "num_class" -> "6"

                put("eval_metric", "merror");
                put("tree_method","approx");
                put("num_class",tag2float.size());

                put("min_child_weight",5);
            }
        };
        Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
            {
                put("train", trainMatrix);
                put("dev",devMatrix);
            }
        };
        Booster booster = XGBoost.train(trainMatrix, params, nround, watches, null, null,null,earlyStop);
        OutputStream outstream = SmoothNLP.IOAdaptor.create(modelAddr);
        booster.saveModel(outstream);



    }catch(XGBoostError e){
        System.out.println(e);
    }
}
 
Example #19
Source File: MLXGBoost.java    From RecSys2018 with Apache License 2.0 4 votes vote down vote up
public static Async<Booster> asyncModel(final String modelFile) {
	return asyncModel(modelFile, 0);
}
 
Example #20
Source File: XGBoostTrainUDTF.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Override
public void close() throws HiveException {
    final Reporter reporter = getReporter();

    DMatrix dmatrix = null;
    Booster booster = null;
    try {
        dmatrix = matrixBuilder.buildMatrix(labels.toArray(true));
        this.matrixBuilder = null;
        this.labels = null;

        final int round = OptionUtils.getInt(params, "num_round");
        final int earlyStoppingRounds = OptionUtils.getInt(params, "num_early_stopping_rounds");
        if (earlyStoppingRounds > 0) {
            double validationRatio = OptionUtils.getDouble(params, "validation_ratio");
            long seed = OptionUtils.getLong(params, "seed");

            int numRows = (int) dmatrix.rowNum();
            int[] rows = MathUtils.permutation(numRows);
            ArrayUtils.shuffle(rows, new Random(seed));

            int numTest = (int) (numRows * validationRatio);
            DMatrix dtrain = null, dtest = null;
            try {
                dtest = dmatrix.slice(Arrays.copyOf(rows, numTest));
                dtrain = dmatrix.slice(Arrays.copyOfRange(rows, numTest, rows.length));
                booster = train(dtrain, dtest, round, earlyStoppingRounds, params, reporter);
            } finally {
                XGBoostUtils.close(dtrain);
                XGBoostUtils.close(dtest);
            }
        } else {
            booster = train(dmatrix, round, params, reporter);
        }
        onFinishTraining(booster);

        // Output the built model
        String modelId = generateUniqueModelId();
        Text predModel = XGBoostUtils.serializeBooster(booster);

        logger.info("model_id:" + modelId.toString() + ", size:" + predModel.getLength());
        forward(new Object[] {modelId, predModel});
    } catch (Throwable e) {
        throw new HiveException(e);
    } finally {
        XGBoostUtils.close(dmatrix);
        XGBoostUtils.close(booster);
    }
}
 
Example #21
Source File: XGBoostTrainUDTF.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@VisibleForTesting
protected void onFinishTraining(@Nonnull Booster booster) {}
 
Example #22
Source File: XGBoostModel.java    From zoltar with Apache License 2.0 4 votes vote down vote up
/** Returns XGBoost's {@link Booster}. */
public abstract Booster instance();
 
Example #23
Source File: XGBoostTrainUDTF.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Nonnull
private static Booster train(@Nonnull final DMatrix dtrain, @Nonnull final DMatrix dtest,
        @Nonnegative final int round, @Nonnegative final int earlyStoppingRounds,
        @Nonnull final Map<String, Object> params, @Nullable final Reporter reporter)
        throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
        InstantiationException, XGBoostError {
    final Counters.Counter iterCounter = (reporter == null) ? null
            : reporter.getCounter("hivemall.XGBoostTrainUDTF$Counter", "iteration");

    final Booster booster = XGBoostUtils.createBooster(dtrain, params);

    final boolean maximizeEvaluationMetrics =
            OptionUtils.getBoolean(params, "maximize_evaluation_metrics");
    float bestScore = maximizeEvaluationMetrics ? -Float.MAX_VALUE : Float.MAX_VALUE;
    int bestIteration = 0;

    final float[] metricsOut = new float[1];
    for (int iter = 0; iter < round; iter++) {
        reportProgress(reporter);
        setCounterValue(iterCounter, iter + 1);

        booster.update(dtrain, iter);

        String evalInfo =
                booster.evalSet(new DMatrix[] {dtest}, new String[] {"test"}, iter, metricsOut);
        logger.info(evalInfo);

        final float score = metricsOut[0];
        if (maximizeEvaluationMetrics) {
            // Update best score if the current score is better (no update when equal)
            if (score > bestScore) {
                bestScore = score;
                bestIteration = iter;
            }
        } else {
            if (score < bestScore) {
                bestScore = score;
                bestIteration = iter;
            }
        }

        if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
            logger.info(
                String.format("early stopping after %d rounds away from the best iteration",
                    earlyStoppingRounds));
            break;
        }
    }

    return booster;
}