package org.apache.spark.ml.regression.examples

import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.GaussianProcessRegression
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}

trait GPExample {
  def name : String

  val spark = SparkSession.builder().appName(name).master("local[4]").getOrCreate()

  /*
   * It takes `gp`, runs 10-fold cross-validation on `instances` and returns the rmse
   * if it's below `expectedRMSE`. Exception is generated otherwise.
   */
  def cv(gp: GaussianProcessRegression, instances: DataFrame, expectedRMSE: Double) = {
    val cv = new CrossValidator()
      .setEstimator(gp)
      .setEvaluator(new RegressionEvaluator())
      .setEstimatorParamMaps(new ParamGridBuilder().build())
      .setNumFolds(10)

    val rmse = cv.fit(instances).avgMetrics.head
    println("RMSE: " + rmse)
    assert(rmse < expectedRMSE)
  }
}