package com.holdenkarau.spark.testing

import org.apache.spark.sql.types.DataType
import org.apache.spark.ml.linalg.SQLDataTypes.{MatrixType, VectorType}
import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
import org.scalacheck.{Arbitrary, Gen}

/**
 * Extractor that matches the UDTs exposed by Spark ML.
 */
object MLUserDefinedType {
  def unapply(dataType: DataType): Option[Gen[Any]] =
    dataType match {
      case MatrixType => {
        val dense = for {
          rows <- Gen.choose(0, 20)
          cols <- Gen.choose(0, 20)
          values <- Gen.containerOfN[Array, Double](rows * cols, Arbitrary.arbitrary[Double])
        } yield new DenseMatrix(rows, cols, values)
        val sparse = dense.map(_.toSparse)
        Some(Gen.oneOf(dense, sparse))
      }
      case VectorType => {
        val dense = Arbitrary.arbitrary[Array[Double]].map(Vectors.dense)
        val sparse = for {
          indices <- Gen.nonEmptyContainerOf[Set, Int](Gen.choose(0, Int.MaxValue - 1))
          values <- Gen.listOfN(indices.size, Arbitrary.arbitrary[Double])
        } yield Vectors.sparse(indices.max + 1, indices.toSeq.zip(values))
        Some(Gen.oneOf(dense, sparse))
      }
      case _ => None
    }
}