/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.example.imageclassification

import java.nio.file.Paths

import com.intel.analytics.bigdl.dataset.image._
import com.intel.analytics.bigdl.dlframes.DLClassifierModel
import com.intel.analytics.bigdl.example.imageclassification.MlUtils._
import com.intel.analytics.bigdl.numeric.NumericFloat
import com.intel.analytics.bigdl.utils.{Engine, LoggerFilter}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext

/**
 * An example to show how to use DLClassifier Transform
 */
object ImagePredictor {
  LoggerFilter.redirectSparkInfoLogs()
  Logger.getLogger("com.intel.analytics.bigdl.example").setLevel(Level.INFO)

  def main(args: Array[String]): Unit = {
    predictParser.parse(args, new PredictParams()).map(param => {
      val conf = Engine.createSparkConf()
      conf.setAppName("Predict with trained model")
      val sc = new SparkContext(conf)
      Engine.init
      val sqlContext = new SQLContext(sc)

      val partitionNum = Engine.nodeNumber() * Engine.coreNumber()
      val model = loadModel(param)
      val valTrans = new DLClassifierModel(model, Array(3, imageSize, imageSize))
        .setBatchSize(param.batchSize)
        .setFeaturesCol("features")
        .setPredictionCol("predict")

      val valRDD = if (param.isHdfs) {
        // load image set from hdfs
        imagesLoadSeq(param.folder, sc, param.classNum).coalesce(partitionNum, true)
      } else {
        // load image set from local
        val paths = LocalImageFiles.readPaths(Paths.get(param.folder), hasLabel = false)
        sc.parallelize(imagesLoad(paths, 256), partitionNum)
      }

      val transf = RowToByteRecords() ->
          BytesToBGRImg() ->
          BGRImgCropper(imageSize, imageSize) ->
          BGRImgNormalizer(testMean, testStd) ->
          BGRImgToImageVector()

      val valDF = transformDF(sqlContext.createDataFrame(valRDD), transf)

      valTrans.transform(valDF)
          .select("imageName", "predict")
          .collect()
          .take(param.showNum)
          .foreach(println)
      sc.stop()
    })
  }
}