/*
 * (c) Copyright 2016 Hewlett Packard Enterprise Development LP
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package toolkit.neuralnetwork.performance

import com.typesafe.scalalogging.StrictLogging
import toolkit.neuralnetwork.examples.networks.CIFAR

import scala.collection.mutable.ListBuffer
import libcog._


object Benchmark extends App with StrictLogging {
  val (net, batchSize) = args.length match {
    case 0 => ("cifar10_quick", 256)
    case 1 => (args(0), 256)
    case 2 => (args(0), args(1).toInt)
    case _ => throw new RuntimeException(s"illegal arguments (${args.toList})")
  }

  require(net == "cifar10_quick", s"network $net isn't supported")

  logger.info(s"net: $net")
  logger.info(s"batch size: $batchSize")

  val cg1 = new ComputeGraph {
    val net = new CIFAR(useRandomData = true, learningEnabled = false, batchSize = batchSize)
  }

  val forward = new ListBuffer[Double]()
  val backward = new ListBuffer[Double]()

  cg1 withRelease {
    logger.info(s"starting compilation (inference)")
    cg1.step
    logger.info(s"compilation finished (inference)")

    for (i <- 1 to 50) {
      val start = System.nanoTime()
      cg1.step
      val stop = System.nanoTime()
      val elapsed = (stop - start).toDouble / 1e6
      logger.info(s"Iteration: $i forward time: $elapsed ms.")
      forward += elapsed
    }
  }

  val cg2 = new ComputeGraph {
    val net = new CIFAR(useRandomData = true, learningEnabled = true, batchSize = batchSize)
  }

  cg2 withRelease {
    logger.info(s"starting compilation (learning)")
    cg2.step
    logger.info(s"compilation finished (learning)")

    for (i <- 1 to 50) {
      val start = System.nanoTime()
      cg2.step
      val stop = System.nanoTime()
      val elapsed = (stop - start).toDouble / 1e6
      logger.info(s"Iteration: $i forward-backward time: $elapsed ms.")
      backward += elapsed
    }
  }

  logger.info(s"Average Forward pass: ${forward.sum / forward.length} ms.")
  logger.info(s"Average Forward-Backward: ${backward.sum / backward.length} ms.")
}