/* * Copyright 2016 The BigDL Authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.intel.analytics.bigdl.models.vgg import java.text.SimpleDateFormat import java.util.Date import com.intel.analytics.bigdl._ import com.intel.analytics.bigdl.dataset.DataSet import com.intel.analytics.bigdl.dataset.image._ import com.intel.analytics.bigdl.nn.{ClassNLLCriterion, Module} import com.intel.analytics.bigdl.optim._ import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric._ import com.intel.analytics.bigdl.utils.{Engine, LoggerFilter, OptimizerV1, OptimizerV2, T, Table} import com.intel.analytics.bigdl.visualization.{TrainSummary, ValidationSummary} import org.apache.log4j.{Level, Logger} import org.apache.spark.SparkContext object Train { LoggerFilter.redirectSparkInfoLogs() import Utils._ def main(args: Array[String]): Unit = { trainParser.parse(args, new TrainParams()).map(param => { val conf = Engine.createSparkConf().setAppName("Train Vgg on Cifar10") // Will throw exception without this config when has only one executor .set("spark.rpc.message.maxSize", "200") val sc = new SparkContext(conf) Engine.init val trainDataSet = DataSet.array(Utils.loadTrain(param.folder), sc) -> BytesToBGRImg() -> BGRImgNormalizer(trainMean, trainStd) -> BGRImgToBatch(param.batchSize) val model = if (param.modelSnapshot.isDefined) { Module.load[Float](param.modelSnapshot.get) } else { if (param.graphModel) VggForCifar10.graph(classNum = 10) else VggForCifar10(classNum = 10) } if (param.optimizerVersion.isDefined) { param.optimizerVersion.get.toLowerCase match { case "optimizerv1" => Engine.setOptimizerVersion(OptimizerV1) case "optimizerv2" => Engine.setOptimizerVersion(OptimizerV2) } } val optimMethod = if (param.stateSnapshot.isDefined) { OptimMethod.load[Float](param.stateSnapshot.get) } else { new SGD[Float](learningRate = param.learningRate, learningRateDecay = 0.0, weightDecay = param.weightDecay, momentum = 0.9, dampening = 0.0, nesterov = false, learningRateSchedule = SGD.EpochStep(25, 0.5)) } val optimizer = Optimizer( model = model, dataset = trainDataSet, criterion = new ClassNLLCriterion[Float]() ) val validateSet = DataSet.array(Utils.loadTest(param.folder), sc) -> BytesToBGRImg() -> BGRImgNormalizer(testMean, testStd) -> BGRImgToBatch(param.batchSize) if (param.checkpoint.isDefined) { optimizer.setCheckpoint(param.checkpoint.get, Trigger.everyEpoch) } if (param.overWriteCheckpoint) { optimizer.overWriteCheckpoint() } if (param.summaryPath.isDefined) { val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val timeStamp = sdf.format(new Date()) val trainSummry = new TrainSummary(param.summaryPath.get, s"vgg-on-cifar10-train-$timeStamp") optimizer.setTrainSummary(trainSummry) val validationSummary = new ValidationSummary(param.summaryPath.get, s"vgg-on-cifar10-val-$timeStamp") optimizer.setValidationSummary(validationSummary) } optimizer .setValidation(Trigger.everyEpoch, validateSet, Array(new Top1Accuracy[Float])) .setOptimMethod(optimMethod) .setEndWhen(Trigger.maxEpoch(param.maxEpoch)) .optimize() sc.stop() }) } }