package com.giorgioinf.twtml.spark

import org.apache.spark.{Logging, SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.twitter.TwitterUtils

object LinearRegression extends Logging {

  def main(args: Array[String]) {

    log.info("Parsing applications arguments")

    val conf = new ConfArguments()
      .setAppName("twitter-stream-ml-linear-regression")
      .parse(args.toList)

    log.info("Initializing session stats...")

    val session = new SessionStats(conf).open

    log.info("Initializing Spark Machine Learning Model...")

    MllibHelper.reset(conf)

    val model = new StreamingLinearRegressionWithSGD()
      .setNumIterations(conf.numIterations)
      .setStepSize(conf.stepSize)
      .setMiniBatchFraction(conf.miniBatchFraction)
      .setInitialWeights(Vectors.zeros(MllibHelper.numFeatures))

    log.info("Initializing Spark Context...")

    val sc = new SparkContext(conf.sparkConf)

    log.info("Initializing Streaming Spark Context... {} sec/batch", conf.seconds)

    val ssc = new StreamingContext(sc, Seconds(conf.seconds))

    log.info("Initializing Twitter stream...")

    val stream = TwitterUtils.createStream(ssc, None)
      .filter(MllibHelper.filtrate)
      .map(MllibHelper.featurize)
      .cache()

    log.info("Initializing prediction model...")

    val count = sc.accumulator(0L, "count")

    stream.foreachRDD({ rdd =>
      if (rdd.isEmpty) log.debug("batch: 0")
      else {
        val realPred = rdd.map{ lb =>
          (lb.label, Utils.round(model.latestModel.predict(lb.features)))
        }
        val batch = rdd.count
        count += batch
        val real = realPred.map(_._1)
        val pred = realPred.map(_._2)
        val realStdev = Utils.round(real.stdev)
        val predStdev = Utils.round(pred.stdev)
        val mse = Utils.round(realPred.map{case(v, p) => math.pow((v - p), 2)}.mean())

        if (log.isDebugEnabled) {
          log.debug("count: {}", count)
          // batch, mse (training mean squared error)
          log.debug("batch: {},  mse: {}", batch, mse)
          log.debug("stdev (real, pred): ({}, {})", realStdev.toLong,
            predStdev.toLong)
          log.debug("value (real, pred): {} ...", realPred.take(10).toArray)
        }

        session.update(count.value, batch, mse, realStdev, predStdev,
          real.toArray, pred.toArray);

      }

    })

    log.info("Initializing training model...")

    // training after prediction
    model.trainOn(stream)

    // Start the streaming computation
    ssc.start()
    log.info("Initialization complete.")
    ssc.awaitTermination()
  }

}