SparkXGBoost

SparkXGBoost is a Spark implementation of gradient boosting tree using 2nd order approximation of arbitrary user-defined loss function. SparkXGBoost is inspired by the XGBoost project.

SparkXGBoost is distributed under Apache License 2.0.

Build Status codecov.io

What is Gradient Boosting Tree?

The XGBoost team have a fantastic introduction to gradient boosting trees.

Features

SparkXGBoost version supports supervised learning with the gradient boosting tree using 2nd order approximation of arbitrary user-defined loss function. SparkXGBoost ships with The following Loss classes:

To avoid overfitting, SparkXGBoost employs the following regularization methods:

SparkXGBoost is capable of processing multiple learning nodes in the one pass of the training data to improve efficiency.

Design

SparkXGBoost implements the Spark ML Pipeline API, allowing you to easily run a sequence of algorithms to process and learn from data.

In the constructors of SparkXGBoostRegressor and SparkXGBoostClassifier, users will need to supply an instance of the Loss class to define the loss functions and its derivatives. SparkXGBoost currently comes with SquareLoss for linear (normal) regression, LogisticLoss for binary classification and PoissonLoss for Poisson regression of count data. Additional loss function can be specified by the user by sub-classing the Loss.

abstract class Loss{
  // The 1st derivative
  def diff1(label: Double, f: Double): Double
  // The 2nd derivative 
  def diff2(label: Double, f: Double): Double
  // Generate prediction from the score suggested by the tree ensemble
  // For regression, prediction is the label
  // For classification, prediction is the probability in each class
  def toPrediction(score: Double): Double
  // Calculate bias 
  def getInitialBias(input: RDD[LabeledPoint]): Double
}

Please see the example below for typical usage.

Example

trainingData is a DataFrame with the labels stored in a column named "label" and the feature vectors stored in a column name "features". Similarly, testData is DataFrame with the feature vectors stored in a column name "features".

Please note that the feature vectors have to been indexed before feeding to the pipeline to ensure the categorical variables are correctly encoded with metadata.

Currently, all categorical variables are assumed to be ordered. Unordered categorical variables can be used for training after being coded with OneHotEncoder.

  val featureIndexer = new VectorIndexer()
    .setInputCol("features")
    .setOutputCol("indexedFeatures")
    .setMaxCategories(2)
    .fit(trainingData)

  val sparkXGBoostRegressor = new SparkXGBoostRegressor(new SquareLoss)
    .setFeaturesCol("indexedFeatures")
    .setMaxDepth(2)
    .setNumTrees(5)

  val pipeline = new Pipeline()
    .setStages(Array(featureIndexer, sparkXGBoostRegressor))

  val model = pipeline.fit(data)

  val prediction = model.transform(testData)

Parameters

The following parameters can be specified by the setters.

The following parameters can be specified by the setters in SXGBoostModel .

Compatibility

SparkXGBoost has been tested with Spark 1.5.1 and Scala 2.10.

Use SparkXGBoost in Your Project

Option 1: spark-package.org

Releases of SparkXGBoost are available on spark-package.org. You can follow the "How to" for spark-shell, sbt or maven.

As SparkXGBoost is currently under active development, the spark-package.org release might not always include the latest update.

Option 2: Compile

You can access the latest cutting edge codebase through compilation from the source.

Step 1: clone the project from GitHub

git clone https://github.com/rotationsymmetry/sparkxgboost.git

Step 2: compile and package the jar using sbt

cd SparkXGBoost
sbt clean package

You should be able to find the jar file in target/target/scala-2.10/sparkxgboost_2.10-x.y.z.jar

Step 3: load it in your Spark project

./spark-shell --jars path/to/sparkxgboost_2.10-x.y.z.jar

Roadmap

I have following tentative roadmap for the upcoming releases:

0.3

0.4

0.5

0.6

Bugs and Improvements

Many thanks for testing SparkXGBoost!

You can file bug report or provide suggestions using GitHub Issues.

If you would like to improve the codebase, please don't hesitate to submit a pull request.