package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap

/**
 * A label indexer that maps a string column of labels to an ML column of label indices.
 * If the input column is numeric, we cast it to string and index the string values.
 * The indices are in [0, numLabels), ordered by label frequencies.
 * So the most frequent label gets index 0.
 *
 * In contrast to Spark [[StringIndexer]] use Short for labels (instead of Double)
 */
class StringToShortIndexer(override val uid: String) extends Estimator[StringToShortIndexerModel]
with StringIndexerBase {

  def this() = this(Identifiable.randomUID("strShortIdx"))

  def setInputCol(value: String): this.type = set(inputCol, value)

  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def fit(dataset: DataFrame): StringToShortIndexerModel = {
    val counts = dataset.select(col($(inputCol)).cast(StringType))
      .map(_.getString(0))
      .countByValue()
    val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
    require(labels.length <= Short.MaxValue,
      s"Unique labels count (${labels.length}) should be less then Short.MaxValue (${Short.MaxValue})")
    copyValues(new StringToShortIndexerModel(uid, labels).setParent(this))
  }

  override def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema)
  }

  override def copy(extra: ParamMap): StringToShortIndexer = defaultCopy(extra)
}

class StringToShortIndexerModel (
  override val uid: String,
  val labels: Array[String]) extends Model[StringToShortIndexerModel] with StringIndexerBase {

  def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)

  require(labels.length <= Short.MaxValue,
    s"Unique labels count (${labels.length}) should be less then Short.MaxValue (${Short.MaxValue})")

  private val labelToIndex: OpenHashMap[String, Short] = {
    val n = labels.length.toShort
    val map = new OpenHashMap[String, Short](n)
    var i: Short = 0
    while (i < n) {
      map.update(labels(i), i)
      i = (i + 1).toShort
    }
    map
  }

  def setInputCol(value: String): this.type = set(inputCol, value)

  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def transform(dataset: DataFrame): DataFrame = {
    if (!dataset.schema.fieldNames.contains($(inputCol))) {
      logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
        "Skip StringToShortIndexerModel.")
      return dataset
    }

    val indexer = udf { label: String =>
      if (labelToIndex.contains(label)) {
        labelToIndex(label)
      } else {
        // TODO: handle unseen labels
        throw new SparkException(s"Unseen label: $label.")
      }
    }
    val outputColName = $(outputCol)
    val metadata = NominalAttribute.defaultAttr
      .withName(outputColName).withValues(labels).toMetadata()
    dataset.select(col("*"),
      indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    if (schema.fieldNames.contains($(inputCol))) {
      validateAndTransformSchema(schema)
    } else {
      // If the input column does not exist during transformation, we skip StringToShortIndexerModel.
      schema
    }
  }

  override def copy(extra: ParamMap): StringToShortIndexerModel = {
    val copied = new StringToShortIndexerModel(uid, labels)
    copyValues(copied, extra).setParent(parent)
  }
}