// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.flaky

import com.microsoft.ml.spark.core.test.base.TimeLimitedFlaky
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.ml.spark.io.http.PartitionConsolidator
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.scalatest.Assertion

class PartitionConsolidatorSuite extends TransformerFuzzing[PartitionConsolidator] with TimeLimitedFlaky {

  import session.implicits._

  override val numCores: Option[Int] = Some(2)

  lazy val df: DataFrame = (1 to 1000).toDF("values")

  override val sortInDataframeEquality: Boolean = true

  override def testObjects(): Seq[TestObject[PartitionConsolidator]] = Seq(
    new TestObject(new PartitionConsolidator(), df))

  override def reader: MLReadable[_] = PartitionConsolidator

  def getPartitionDist(df: DataFrame): List[Int] = {
    df.rdd.mapPartitions(it => Iterator(it.length)).collect().toList
  }

  //TODO figure out what is causing the issue on the build server
  override def testSerialization(): Unit = {}

  override def testExperiments(): Unit = {}

  def basicTest(df: DataFrame): Assertion = {
    val pd1 = getPartitionDist(df)
    val newDF = new PartitionConsolidator().transform(df)
    val pd2 = getPartitionDist(newDF)
    assert(pd1.sum === pd2.sum)
    assert(pd2.max >= pd1.max)
    assert(pd1.length === pd2.length)
  }

  test("basic functionality") {
    basicTest(df)
  }

  test("works with more partitions than cores") {
    basicTest(df.repartition(12))
  }

  test("overheads") {
    val baseDF = (1 to 1000).toDF("values").cache()
    println(baseDF.count())

    def getDF: Dataset[Row] = baseDF.map { x => Thread.sleep(10); x }(
      RowEncoder(new StructType().add("values", DoubleType)))

    val t1 = getTime(3)(
      getDF.foreach(_ => ()))._2
    val t2 = getTime(3)(
      new PartitionConsolidator().transform(getDF).foreach(_ => ()))._2

    println(t2.toDouble / t1.toDouble)
    assert(t2.toDouble / t1.toDouble < 3.0)
  }

  test("works with more partitions than cores2") {
    basicTest(df.repartition(100))
  }

  test("work with 1 partition") {
    basicTest(df.repartition(1))
  }

}