package org.apache.spark

import java.util.concurrent.{Callable, Executors}

import com.sap.spark.dsmock.DefaultSource
import org.apache.spark.sql.sources.HashPartitioningFunction
import org.apache.spark.sql.{GlobalSapSQLContext, Row, SQLContext}
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.scalatest.FunSuite

import scala.concurrent.duration._

/**
  * Test suite to verify behavior of the mockable [[DefaultSource]].
  */
class MockedDefaultSourceSuite
  extends FunSuite
  with GlobalSapSQLContext {

  val testTimeout = 10 // seconds

  private def numberOfThreads: Int = {
    val noOfCores = Runtime.getRuntime.availableProcessors()
    assert(noOfCores > 0)

    if (noOfCores == 1) 2 // It should always be multithreaded although only
                          // one processor is available (pseudo-multithreading)
    else noOfCores
  }

  def runMultiThreaded[A](op: Int => A): Seq[A] = {
    info(s"Running with $numberOfThreads threads")
    val pool = Executors.newFixedThreadPool(numberOfThreads)

    val futures = 1 to numberOfThreads map { i =>
      val task = new Callable[A] {
        override def call(): A = op(i)
      }
      pool.submit(task)
    }

    futures.map(_.get(testTimeout, SECONDS))
  }

  test("Underlying mocks of multiple threads are distinct") {
    val dataSources = runMultiThreaded { _ =>
      DefaultSource.withMock(identity)
    }

    dataSources foreach { current =>
      val sourcesWithoutCurrent = dataSources.filter(_.ne(current))
      assert(sourcesWithoutCurrent.forall(_.underlying ne current))
    }
  }

  test("Mocking works as expected") {
    runMultiThreaded { i =>
      DefaultSource.withMock { defaultSource =>
        when(defaultSource.getAllPartitioningFunctions(
          anyObject[SQLContext],
          anyObject[Map[String, String]]))
          .thenReturn(Seq(HashPartitioningFunction(s"foo$i", Seq.empty, None)))

        val Array(Row(name)) = sqlc
          .sql("SHOW PARTITION FUNCTIONS USING com.sap.spark.dsmock")
          .select("name")
          .collect()

        assertResult(s"foo$i")(name)
      }
    }
  }
}