/*
 * Copyright (c) 2017, Salesforce.com, Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *   list of conditions and the following disclaimer.
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * * Neither the name of the copyright holder nor the names of its
 *   contributors may be used to endorse or promote products derived from
 *   this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package com.salesforce.op.test

import java.io.File

import com.salesforce.op.features.types._
import com.salesforce.op.stages._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.sql.Dataset
import org.scalactic.Equality
import org.scalatest.events.{Event, TestFailed}
import org.scalatest.{Args, Reporter}

import scala.collection.mutable.ArrayBuffer
import scala.reflect._
import scala.reflect.runtime.universe._

/**
 * Base test class for testing OP estimator instances.
 * Includes common tests for fitting estimator and verifying the fitted model.
 *
 * @tparam O             output feature type
 * @tparam ModelType     model type produced by this estimator
 * @tparam EstimatorType type of the estimator being tested
 */
abstract class OpEstimatorSpec[O <: FeatureType : WeakTypeTag : ClassTag,
ModelType <: Model[ModelType] with OpPipelineStage[O] with OpTransformer : ClassTag,
EstimatorType <: Estimator[ModelType] with OpPipelineStage[O] : ClassTag]
  extends OpPipelineStageSpec[O, EstimatorType] {

  /**
   * Input Dataset to fit & transform
   */
  val inputData: Dataset[_]

  /**
   * Estimator instance to be tested
   */
  val estimator: EstimatorType

  /**
   * Expected result of the transformer applied on the Input Dataset
   */
  val expectedResult: Seq[O]

  final override lazy val stage = estimator

  /**
   * Model (transformer) to fit
   */
  final lazy val model: ModelType = estimator.fit(inputData)

  it should "fit a model" in {
    model should not be null
    model shouldBe a[ModelType]
  }

  it should behave like modelSpec()

  it should "have fitted a model that matches the estimator" in {
    withClue("Model doesn't have a parent:") {
      model.hasParent shouldBe true
    }
    withClue("Model parent should be the original estimator instance:") {
      model.parent shouldBe estimator
    }
    withClue("Model and estimator output feature names don't match:") {
      model.getOutputFeatureName shouldBe estimator.getOutputFeatureName
    }
    assert(model.asInstanceOf[OpPipelineStageBase], estimator, expectSameClass = false)
  }

  // TODO: test metadata


  /**
   * Register all model spec tests
   */
  private def modelSpec(): Unit = {
    // Define transformer spec for the fitted model reusing the same inputs & Spark context
    val modelSpec = new OpTransformerSpec[O, ModelType] {
      override implicit val featureTypeEquality: Equality[O] = OpEstimatorSpec.this.featureTypeEquality
      override implicit val seqEquality: Equality[Seq[O]] = OpEstimatorSpec.this.seqEquality
      lazy val transformer: ModelType = OpEstimatorSpec.this.model
      lazy val inputData: Dataset[_] = OpEstimatorSpec.this.inputData
      lazy val expectedResult: Seq[O] = OpEstimatorSpec.this.expectedResult
      override implicit lazy val spark = OpEstimatorSpec.this.spark
      override def specName: String = "model"
      override def tempDir: File = OpEstimatorSpec.this.tempDir
    }

    // Register all model spec tests
    for {
      testName <- modelSpec.testNames
    } registerTest(testName) {
      // Run test & collect failures
      val failures = ArrayBuffer.empty[TestFailed]
      val reporter = new Reporter {
        def apply(event: Event): Unit = event match {
          case f: TestFailed => failures += f
          case _ =>
        }
      }
      // Note: We set 'runTestInNewInstance = true' to avoid restarting Spark context on every test run
      val args = Args(reporter, runTestInNewInstance = true)
      modelSpec.run(testName = Some(testName), args = args)

      // Propagate the failure if any
      for {failure <- failures.headOption} {
        failure.throwable.map(fail(failure.message, _)).getOrElse(fail(failure.message))
      }
    }
  }

}