package com.databricks.spark.sql.perf.mllib.classification

import org.apache.spark.ml.{Estimator, PipelineStage}
import org.apache.spark.ml.classification.RandomForestClassifier

import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._


object RandomForestClassification extends BenchmarkAlgorithm with TreeOrForestClassifier {

  override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
    import ctx.params._
    // TODO: subsamplingRate, featureSubsetStrategy
    // TODO: cacheNodeIds, checkpoint?
    new RandomForestClassifier()
      .setMaxDepth(depth)
      .setNumTrees(maxIter)
      .setSeed(ctx.seed())
  }
}