import breeze.linalg.DenseVector
import com.github.nearbydelta.deepspark.data._
import com.github.nearbydelta.deepspark.layer.{BasicLayer, VectorRBFLayer}
import com.github.nearbydelta.deepspark.network.SimpleNetwork
import com.github.nearbydelta.deepspark.train.{TrainerBuilder, TrainingParam}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{SparkConf, SparkContext}

/**
 * Created by bydelta on 15. 10. 16.
 */
object TestXOR {
  def main(args: Array[String]) {
    val conf = new SparkConf().setMaster("local[5]").setAppName("TestXOR")
      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .set("spark.broadcast.blockSize", "40960")
      .set("spark.akka.frameSize", "50")
    val sc = new SparkContext(conf)

    val data = (0 to 100).collect {
      case i if i > 75 || i < 25 ⇒
        (0 to 100).collect {
          case j if j > 75 || j < 25 ⇒
            val xor =
              if (i > 75 && j < 25) true
              else if (i < 25 && j > 75) true
              else false
            (DenseVector[Double](i / 100.0, j / 100.0), xor)
        }
    }.flatMap(x ⇒ x)

    val train = sc.makeRDD(data)
    val test = train

    try {
      Weight.scalingDownBy(10.0)
      val builder = new AdaGrad(l2decay = 0.001, rate = 0.01)
      val rbf = new VectorRBFLayer withActivation GaussianRBF withCenters Seq(DenseVector(1.0, 1.0), DenseVector(0.0, 0.0), DenseVector(1.0, 0.0), DenseVector(0.0, 1.0))
      val network = new SimpleNetwork[Boolean]()
        //        .add(new BasicLayer withInput 2 withOutput 4)
        .add(rbf)
        //        .add(new BasicLayer withActivation LeakyReLU withOutput 4)
        .add(new BasicLayer withActivation SoftmaxCEE withOutput 2)
        .initiateBy(builder)

      println(rbf.epsilon.value)
      require(network.NOut == 2)
      //      require(network.layers.head.asInstanceOf[BasicLayer].bias != null)
      //      require(network.layers.head.asInstanceOf[BasicLayer].weight.value != null)
      //      require(network.layers.head.asInstanceOf[BasicLayer].bias.value.length > 0)

      val trained = new TrainerBuilder(TrainingParam(miniBatch = 10, maxIter = 100, dataOnLocal = true,
        reuseSaveData = true, storageLevel = StorageLevel.MEMORY_ONLY))
        .build(network, train, test, CrossEntropyErr,
          (x: Boolean) ⇒ if (x) DenseVector(1.0, 0.0) else DenseVector(0.0, 1.0), "XORTest")
        .getTrainedNetwork
      println(rbf.epsilon.value)

      (0 until 10).foreach { _ ⇒
        val (in, exp) = data(Math.floor(Math.random() * data.length).toInt)
        val out = trained.predictSoft(in)
        println(s"IN : $in, EXPECTED: $exp, OUTPUT ${out(0) > out(1)} $out")
      }
    } finally {
      sc.stop()
    }
  }
}