/*
 * Copyright (c) 2017, Salesforce.com, Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *   list of conditions and the following disclaimer.
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * * Neither the name of the copyright holder nor the names of its
 *   contributors may be used to endorse or promote products derived from
 *   this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpRandomForestClassifierTest extends
  OpEstimatorSpec[Prediction, OpPredictorWrapperModel[RandomForestClassificationModel],
    OpPredictorWrapper[RandomForestClassifier, RandomForestClassificationModel]] with PredictionEquality {

  override def specName: String = Spec[OpRandomForestClassifier]

  lazy val (inputData, rawLabelMulti, featuresMulti) =
    TestFeatureBuilder[RealNN, OPVector]("labelMulti", "featuresMulti",
      Seq(
        (1.0.toRealNN, Vectors.dense(12.0, 4.3, 1.3).toOPVector),
        (0.0.toRealNN, Vectors.dense(0.0, 0.3, 0.1).toOPVector),
        (2.0.toRealNN, Vectors.dense(1.0, 3.9, 4.3).toOPVector),
        (2.0.toRealNN, Vectors.dense(10.0, 1.3, 0.9).toOPVector),
        (1.0.toRealNN, Vectors.dense(15.0, 4.7, 1.3).toOPVector),
        (0.0.toRealNN, Vectors.dense(0.5, 0.9, 10.1).toOPVector),
        (1.0.toRealNN, Vectors.dense(11.5, 2.3, 1.3).toOPVector),
        (0.0.toRealNN, Vectors.dense(0.1, 3.3, 0.1).toOPVector),
        (2.0.toRealNN, Vectors.dense(1.0, 4.0, 4.5).toOPVector),
        (2.0.toRealNN, Vectors.dense(10.0, 1.5, 1.0).toOPVector)
      )
    )

  val labelMulti = rawLabelMulti.copy(isResponse = true)

  val estimator = new OpRandomForestClassifier().setInput(labelMulti, featuresMulti)

  val expectedResult = Seq(
    Prediction(1.0, Array(0.0, 17.0, 3.0), Array(0.0, 0.85, 0.15)),
    Prediction(0.0, Array(19.0, 0.0, 1.0), Array(0.95, 0.0, 0.05)),
    Prediction(2.0, Array(0.0, 1.0, 19.0), Array(0.0, 0.05, 0.95)),
    Prediction(2.0, Array(1.0, 2.0, 17.0), Array(0.05, 0.1, 0.85)),
    Prediction(1.0, Array(0.0, 17.0, 3.0), Array(0.0, 0.85, 0.15)),
    Prediction(0.0, Array(16.0, 0.0, 4.0), Array(0.8, 0.0, 0.2)),
    Prediction(1.0, Array(1.0, 17.0, 2.0), Array(0.05, 0.85, 0.1)),
    Prediction(0.0, Array(17.0, 0.0, 3.0), Array(0.85, 0.0, 0.15)),
    Prediction(2.0, Array(2.0, 1.0, 17.0), Array(0.1, 0.05, 0.85)),
    Prediction(2.0, Array(1.0, 2.0, 17.0), Array(0.05, 0.1, 0.85))
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setMaxDepth(10)
      .setImpurity(Impurity.Gini.sparkName)
      .setMaxBins(33)
      .setMinInstancesPerNode(2)
      .setMinInfoGain(0.2)
      .setSubsamplingRate(0.9)
      .setNumTrees(21)
      .setSeed(2L)
    estimator.fit(inputData)

    estimator.predictor.getMaxDepth shouldBe 10
    estimator.predictor.getMaxBins shouldBe 33
    estimator.predictor.getImpurity shouldBe Impurity.Gini.sparkName
    estimator.predictor.getMinInstancesPerNode shouldBe 2
    estimator.predictor.getMinInfoGain shouldBe 0.2
    estimator.predictor.getSubsamplingRate shouldBe 0.9
    estimator.predictor.getNumTrees shouldBe 21
    estimator.predictor.getSeed shouldBe 2L
  }

}