SparkXGBoost is a Spark implementation of gradient boosting tree using 2nd order approximation of arbitrary user-defined loss function. SparkXGBoost is inspired by the XGBoost project.
SparkXGBoost
is distributed under Apache License 2.0.
The XGBoost team have a fantastic introduction to gradient boosting trees.
SparkXGBoost version supports supervised learning with the gradient boosting tree using 2nd order approximation of arbitrary user-defined loss function. SparkXGBoost ships with The following Loss
classes:
SquareLoss
for linear (normal) regressionLogisticLoss
for binary classificationPoissonLoss
for Poisson regression of count dataTo avoid overfitting, SparkXGBoost employs the following regularization methods:
SparkXGBoost is capable of processing multiple learning nodes in the one pass of the training data to improve efficiency.
SparkXGBoost implements the Spark ML Pipeline API, allowing you to easily run a sequence of algorithms to process and learn from data.
SparkXGBoostRegressor
and SparkXGBoostRegressionModel
are the predictor and model for continuous labels.SparkXGBoostClassifier
and SparkXGBoostClassificationModel
are the predictor and model for categorical labels.In the constructors of SparkXGBoostRegressor
and SparkXGBoostClassifier
, users will need to supply an instance of
the Loss
class to define the loss functions and its derivatives. SparkXGBoost currently comes with
SquareLoss
for linear (normal) regression, LogisticLoss
for binary classification and
PoissonLoss
for Poisson regression of count data. Additional loss function can be specified by the user
by sub-classing the Loss
.
abstract class Loss{
// The 1st derivative
def diff1(label: Double, f: Double): Double
// The 2nd derivative
def diff2(label: Double, f: Double): Double
// Generate prediction from the score suggested by the tree ensemble
// For regression, prediction is the label
// For classification, prediction is the probability in each class
def toPrediction(score: Double): Double
// Calculate bias
def getInitialBias(input: RDD[LabeledPoint]): Double
}
Please see the example below for typical usage.
trainingData
is a DataFrame
with the labels stored in a column named "label" and the feature vectors stored in a column name "features". Similarly, testData
is DataFrame
with the feature vectors stored in a column name "features".
Please note that the feature vectors have to been indexed before feeding to the pipeline to ensure the categorical variables are correctly encoded with metadata.
Currently, all categorical variables are assumed to be ordered. Unordered categorical variables can be used for training after being coded with OneHotEncoder.
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(2)
.fit(trainingData)
val sparkXGBoostRegressor = new SparkXGBoostRegressor(new SquareLoss)
.setFeaturesCol("indexedFeatures")
.setMaxDepth(2)
.setNumTrees(5)
val pipeline = new Pipeline()
.setStages(Array(featureIndexer, sparkXGBoostRegressor))
val model = pipeline.fit(data)
val prediction = model.transform(testData)
The following parameters can be specified by the setters.
DataFrame
DataFrame
The following parameters can be specified by the setters in SXGBoostModel
.
DataFrame
DataFrame
SparkXGBoost has been tested with Spark 1.5.1 and Scala 2.10.
Releases of SparkXGBoost are available on spark-package.org. You can follow the "How to" for spark-shell, sbt or maven.
As SparkXGBoost is currently under active development, the spark-package.org release might not always include the latest update.
You can access the latest cutting edge codebase through compilation from the source.
Step 1: clone the project from GitHub
git clone https://github.com/rotationsymmetry/sparkxgboost.git
Step 2: compile and package the jar using sbt
cd SparkXGBoost
sbt clean package
You should be able to find the jar file in target/target/scala-2.10/sparkxgboost_2.10-x.y.z.jar
Step 3: load it in your Spark project
./spark-shell --jars path/to/sparkxgboost_2.10-x.y.z.jar
lib
folder next to src
. Then sbt should be able to put SparkXGBoost in your class path.I have following tentative roadmap for the upcoming releases:
0.3
0.4
0.5
0.6
Many thanks for testing SparkXGBoost!
You can file bug report or provide suggestions using GitHub Issues.
If you would like to improve the codebase, please don't hesitate to submit a pull request.