import org.apache.log4j._
Logger.getRootLogger.setLevel(Level.OFF)
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier

// maximum number of worker nodes in cluster
val numNodes = 5
// batch size, ~10K is good for GPU
val batchSize = 1000
// number of iterations to run
val numIterations = 5
val train = MLUtils.loadLibSVMFile(sc, "file:///data/mnist/mnist.scale")
//val layers = Array[Int](780, 2500, 2000, 1500, 1000, 500, 10)
val layers = Array[Int](780, 10)
val trainer = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(1000).setSeed(1234L).setMaxIter(1)
for (i <- 1 to numNodes) {
  val dataPartitions = sc.parallelize(1 to i, i)
  val sample = train.sample(true, 1.0 / i, 11L).collect
  val parallelData = sqlContext.createDataFrame(dataPartitions.flatMap(x => sample))
  parallelData.persist
  parallelData.count
  val t = System.nanoTime()
  val model = trainer.fit(parallelData)
  println(i + "\t" + batchSize + "\t" + (System.nanoTime() - t) / (numIterations * 1e9)) 	
  parallelData.unpersist()
}