/**
 * Copyright 2015, deepsense.io
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.deepsense.deeplang.doperables.spark.wrappers.estimators

import scala.language.reflectiveCalls

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml
import org.apache.spark.ml.param.{ParamMap, Param => SparkParam}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

import io.deepsense.deeplang.ExecutionContext
import io.deepsense.deeplang.doperables.report.Report
import io.deepsense.deeplang.doperables.serialization.SerializableSparkModel
import io.deepsense.deeplang.doperables.{SparkEstimatorWrapper, SparkModelWrapper}
import io.deepsense.deeplang.params.wrappers.spark.SingleColumnCreatorParamWrapper
import io.deepsense.deeplang.params.{Param, Params}
import io.deepsense.sparkutils.ML

object EstimatorModelWrapperFixtures {

  class SimpleSparkModel private[EstimatorModelWrapperFixtures]()
    extends ML.Model[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "modelId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    def setPredictionCol(value: String): this.type = set(predictionCol, value)

    override def copy(extra: ParamMap): this.type = defaultCopy(extra)

    override def transformDF(dataset: DataFrame): DataFrame = {
      dataset.selectExpr("*", "1 as " + $(predictionCol))
    }

    @DeveloperApi
    override def transformSchema(schema: StructType): StructType = ???
  }

  class SimpleSparkEstimator extends ML.Estimator[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "estimatorId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    override def fitDF(dataset: DataFrame): SimpleSparkModel =
      new SimpleSparkModel().setPredictionCol($(predictionCol))

    override def copy(extra: ParamMap): ML.Estimator[SimpleSparkModel] = defaultCopy(extra)

    @DeveloperApi
    override def transformSchema(schema: StructType): StructType = {
      schema.add(StructField($(predictionCol), IntegerType, nullable = false))
    }
  }

  trait HasPredictionColumn extends Params {
    val predictionColumn = new SingleColumnCreatorParamWrapper[
        ml.param.Params { val predictionCol: SparkParam[String] }](
      "prediction column",
      None,
      _.predictionCol)
    setDefault(predictionColumn, "abcdefg")

    def getPredictionColumn(): String = $(predictionColumn)
    def setPredictionColumn(value: String): this.type = set(predictionColumn, value)
  }

  class SimpleSparkModelWrapper
    extends SparkModelWrapper[SimpleSparkModel, SimpleSparkEstimator]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report: Report = ???

    override protected def loadModel(
      ctx: ExecutionContext,
      path: String): SerializableSparkModel[SimpleSparkModel] = ???
  }

  class SimpleSparkEstimatorWrapper
    extends SparkEstimatorWrapper[SimpleSparkModel, SimpleSparkEstimator, SimpleSparkModelWrapper]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report: Report = ???
  }
}