org.apache.spark.mllib.tree.configuration.Algo Scala Examples

The following examples show how to use org.apache.spark.mllib.tree.configuration.Algo. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: MLLibRandomForest.scala    From reforest   with Apache License 2.0 5 votes vote down vote up
package reforest.example

import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.configuration.{Algo, QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity.Entropy
import org.apache.spark.mllib.util.MLUtils
import reforest.rf.feature.RFStrategyFeatureSQRT
import reforest.rf.parameter._
import reforest.util.CCUtil

import scala.util.Random

object MLLibRandomForest {
  def main(args: Array[String]): Unit = {

    val property = RFParameterBuilder.apply
      .addParameter(RFParameterType.Dataset, "data/sample-covtype.libsvm")
      .addParameter(RFParameterType.NumFeatures, 54)
      .addParameter(RFParameterType.NumClasses, 10)
      .addParameter(RFParameterType.NumTrees, 100)
      .addParameter(RFParameterType.Depth, Array(10))
      .addParameter(RFParameterType.BinNumber, Array(8))
      .addParameter(RFParameterType.SparkMaster, "local[4]")
      .addParameter(RFParameterType.SparkCoresMax, 4)
      .addParameter(RFParameterType.SparkPartition, 4*4)
      .addParameter(RFParameterType.SparkExecutorMemory, "4096m")
      .addParameter(RFParameterType.SparkExecutorInstances, 1)
      .build


    val sc = CCUtil.getSparkContext(property)
    sc.setLogLevel("error")

    val timeStart = System.currentTimeMillis()
    val data = MLUtils.loadLibSVMFile(sc, property.dataset, property.numFeatures, property.sparkCoresMax * 2)

    val splits = data.randomSplit(Array(0.6, 0.2, 0.2), 0)
    val (trainingData, testData) = (splits(0), splits(2))

    // Train a RandomForest model.
    //    val categoricalFeaturesInfo = Array.tabulate(200)(i => (i, 5)).toMap
    val categoricalFeaturesInfo = Map[Int, Int]()
    val featureSubsetStrategy = "sqrt"
    val impurity = "entropy"

    val s = new
        Strategy(Algo.Classification, Entropy, property.getMaxDepth, property.numClasses, property.getMaxBinNumber, QuantileStrategy.Sort, categoricalFeaturesInfo, 1)

    val model = RandomForest.trainClassifier(trainingData, s, property.getMaxNumTrees, featureSubsetStrategy, Random.nextInt())
    val timeEnd = System.currentTimeMillis()

    val labelAndPreds = testData.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }

    val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
    println("Time: "+(timeEnd-timeStart))
    println("Test Error = " + testErr)
    if (property.outputTree) {
      println("Learned classification forest model:\n" + model.toDebugString)
    }
  }
} 
Example 2
Source File: MLLibRandomForestFromFile.scala    From reforest   with Apache License 2.0 5 votes vote down vote up
package reforest.example

import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.configuration.{Algo, QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity.Entropy
import org.apache.spark.mllib.util.MLUtils
import reforest.rf.feature.RFStrategyFeatureSQRT
import reforest.rf.parameter._
import reforest.util.{CCUtil, CCUtilIO}

import scala.util.Random

object MLLibRandomForestFromFile {
  def main(args: Array[String]): Unit = {

    val property = RFParameterFromFile(args(0)).applyAppName("MLLib")

    val sc = CCUtil.getSparkContext(property)
    sc.setLogLevel("error")

    val timeStart = System.currentTimeMillis()
    val data = MLUtils.loadLibSVMFile(sc, property.dataset, property.numFeatures, property.sparkCoresMax * 2)

    val splits = data.randomSplit(Array(0.7, 0.3), 0)
    val (trainingData, testData) = (splits(0), splits(1))

    // Train a RandomForest model.
    //    val categoricalFeaturesInfo = Array.tabulate(200)(i => (i, 5)).toMap
    val categoricalFeaturesInfo = Map[Int, Int]()
    val featureSubsetStrategy = "sqrt"
    val impurity = "entropy"

    val s = new
        Strategy(Algo.Classification, Entropy, property.getMaxDepth, property.numClasses, property.getMaxBinNumber, QuantileStrategy.Sort, categoricalFeaturesInfo, 1)

    val model = RandomForest.trainClassifier(trainingData, s, property.getMaxNumTrees, featureSubsetStrategy, Random.nextInt())
    val timeEnd = System.currentTimeMillis()

    val labelAndPreds = testData.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }

    val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
    CCUtilIO.logACCURACY(property, (1-testErr), (timeEnd-timeStart))
    println("Time: "+(timeEnd-timeStart))
    println("Test Error = " + testErr)
    if (property.outputTree) {
      println("Learned classification forest model:\n" + model.toDebugString)
    }
  }
} 
Example 3
Source File: GradientBoostedTreesExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.examples.mllib

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
import org.apache.spark.mllib.util.MLUtils

    // 加载数据
    val data = MLUtils.loadLibSVMFile(sc, "../data/mllib/rf_libsvm_data.txt")
    // 将数据随机分配为两份,一份用于训练,一份用于测试
    val splits = data.randomSplit(Array(0.7, 0.3))
    //数据分成训练和测试数据集
    val (trainingData, testData) = (splits(0), splits(1))
    //创建一个分类的提升策略并设置迭代次数为3(随机森林也支持回归)
    val boostingStrategy =BoostingStrategy.defaultParams("Classification")
        boostingStrategy.numIterations = 3
    //梯度提升决策树:综合多个决策树,消除噪声,避免过拟合
    val model = GradientBoostedTrees.train(trainingData,boostingStrategy)
    //基于测试实例评估模型并计算测试错误
    val testErr = testData.map { point =>
            //预测
            val prediction = model.predict(point.features)
            if (point.label == prediction) 
                1.0 
            else 0.0}.mean()//平均数
    //检查模型
    println("Test Error = " + testErr)
    println("Learned Random Forest:n" + model.toDebugString)
  }
}