import breeze.linalg.DenseVector
import com.github.nearbydelta.deepspark.data._
import com.github.nearbydelta.deepspark.layer.{BasicLayer, NetworkConcatLayer}
import com.github.nearbydelta.deepspark.network.{GeneralNetwork, SimpleNetwork}
import com.github.nearbydelta.deepspark.train.{TrainerBuilder, TrainingParam}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{SparkConf, SparkContext}

import scala.reflect.{ClassTag, classTag}

/**
 * Created by bydelta on 15. 10. 16.
 */
object TestConcat {
  def main(args: Array[String]) {
    val conf = new SparkConf().setMaster("local[5]").setAppName("TestXOR")
      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .set("spark.broadcast.blockSize", "40960")
      .set("spark.akka.frameSize", "50")
    val sc = new SparkContext(conf)

    val data = (0 to 10).collect {
      case i if i > 7 || i < 3 ⇒
        (0 to 10).collect {
          case j if j > 7 || j < 3 ⇒
            val xor =
              if (i > 7 && j > 7) true
              else if (i < 3 && j < 3) true
              else false
            (0 to 10).collect {
              case k if k > 7 || k < 3 ⇒
                (0 to 10).collect {
                  case l if l > 7 || l < 3 ⇒
                    val xor2 =
                      if (i > 7 && j > 7) true
                      else if (i < 3 && j < 3) true
                      else false
                    (Array(DenseVector(i / 10.0, j / 10.0), DenseVector(k / 10.0, l / 10.0)),
                      xor && xor2)
                }
            }.flatMap(x ⇒ x)
        }.flatMap(x ⇒ x)
    }.flatMap(x ⇒ x)

    val train = sc.makeRDD(data)
    val test = train

    try {
      val builder = new AdaGrad(l2decay = 0.00001, rate = 0.01)
      val input1 = new SimpleNetwork[DataVec]()
        .add(new BasicLayer withInput 2 withOutput 4)
        .add(new BasicLayer withInput 4 withOutput 1)
      val input2 = new SimpleNetwork[DataVec]()
        .add(new BasicLayer withInput 2 withOutput 4)
        .add(new BasicLayer withInput 4 withOutput 1)
      val concat = new ConcatLayer().addNetwork(input1).addNetwork(input2)
      val network = new GeneralNetwork[Array[DataVec], Boolean](concat)
        .add(new BasicLayer withInput 2 withOutput 4)
        .add(new BasicLayer withInput 4 withOutput 1)
        .initiateBy(builder)

      require(network.NOut == 1)

      val trained = new TrainerBuilder(TrainingParam(miniBatch = 10, maxIter = 1000, storageLevel = StorageLevel.MEMORY_ONLY))
        .build(network, train, test, SquaredErr, (x: Boolean) ⇒ if (x) DenseVector(1.0) else DenseVector(0.0), "XORTest")
        .getTrainedNetwork

      (0 until 10).foreach { _ ⇒
        val (in, exp) = data(Math.floor(Math.random() * data.length).toInt)
        val out = trained.predictSoft(in)
        println(s"IN : $in, EXPECTED: $exp, OUTPUT $out")
      }
    } finally {
      sc.stop()
    }
  }

  class ConcatLayer extends NetworkConcatLayer[DataVec] {
    override implicit protected val evidenceI: ClassTag[Array[DataVec]] = classTag[Array[DataVec]]
  }

}