package com.ibm.aardpfark.spark.ml.regression

import com.ibm.aardpfark.pfa.PredictorResult
import org.apache.spark.ml.regression.LinearRegression

class LinearRegressionSuite extends SparkRegressorPFASuiteBase[PredictorResult] {

  val dataset = spark.read.format("libsvm").load(inputPath)
  val lr = new LinearRegression()
  override val sparkTransformer = lr.fit(dataset)
  val result = sparkTransformer.transform(dataset)

  override val input = withColumnAsArray(result, lr.getFeaturesCol).toJSON.collect()
  override val expectedOutput = result.select(lr.getPredictionCol).toJSON.collect()

  // Additional tests
  test("LinearRegression w/o fitIntercept") {
    val sparkTransformer = lr.setFitIntercept(false).fit(dataset)
    val result = sparkTransformer.transform(dataset)
    val expectedOutput = result.select(lr.getPredictionCol).toJSON.collect()

    parityTest(sparkTransformer, input, expectedOutput)
  }

}