package au.csiro.data61.randomwalk import au.csiro.data61.randomwalk.algorithm.{UniformRandomWalk, VCutRandomWalk} import au.csiro.data61.randomwalk.common.CommandParser.TaskName import au.csiro.data61.randomwalk.common.{CommandParser, Params, Property} import com.typesafe.config.Config import org.apache.log4j.LogManager import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} import org.scalactic.{Every, Good, Or} import spark.jobserver.SparkJobInvalid import spark.jobserver.api._ object Main extends SparkJob { lazy val logger = LogManager.getLogger("myLogger") def main(args: Array[String]) { CommandParser.parse(args) match { case Some(params) => val conf = new SparkConf().setAppName("stellar-random-walk") val context: SparkContext = new SparkContext(conf) runJob(context, null, params) case None => sys.exit(1) } } /** * Saves the word2vec model and features in separate files. * * @param model * @param context * @param config */ private def saveModelAndFeatures(model: Word2VecModel, context: SparkContext, config: Params) : Unit = { model.save(context, s"${config.output}/${Property.modelSuffix}") val numPartitions = getNumOutputPartition(config) context.parallelize(model.getVectors.toList, config.rddPartitions).map { case (nodeId, vector) => s"$nodeId\t${vector.mkString("\t")}" }.repartition(numPartitions).saveAsTextFile(s"${config.output}/${Property.vectorSuffix}") } /** * Runs random-walk configured based on the input parameters. * * @param context * @param param input parameters. * @return */ def doRandomWalk(context: SparkContext, param: Params): RDD[Array[Int]] = { val rw = param.partitioned match { case true => VCutRandomWalk(context, param) case false => UniformRandomWalk(context, param) } val paths = rw.execute() val numPartitions = getNumOutputPartition(param) rw.save(paths, numPartitions, param.output) paths } private def getNumOutputPartition(param: Params): Int = { param.singleOutput match { case true => 1 case false => param.rddPartitions } } /** * Converts sequences of vertex ids to the format accepted by Word2vec. * * @param paths * @return */ def convertPathsToIterables(paths: RDD[Array[Int]]) = { paths.map { p => p.map(_.toString).toList } } /** * Setups an instance of MLlib's Word2vec object. * * @param param * @return */ private def configureWord2Vec(param: Params): Word2Vec = { val word2vec = new Word2Vec() word2vec.setLearningRate(param.w2vLr) .setNumIterations(param.w2vIter) .setNumPartitions(param.w2vPartitions) .setMinCount(0) .setVectorSize(param.w2vDim) .setWindowSize(param.w2vWindow) } override type JobData = Params override type JobOutput = String /** * * @param context * @param runtime * @param params input parameters. * @return */ override def runJob(context: SparkContext, runtime: JobEnvironment, params: JobData): JobOutput = { params.cmd match { case TaskName.node2vec => val paths = doRandomWalk(context, params) val word2Vec = configureWord2Vec(params) val model = word2Vec.fit(convertPathsToIterables(paths)) saveModelAndFeatures(model, context, params) case TaskName.randomwalk => doRandomWalk(context, params) case TaskName.embedding => val paths = context.textFile(params.input).repartition(params.rddPartitions). map(_.split("\\s+").toSeq) val word2Vec = configureWord2Vec(params) val model = word2Vec.fit(paths) saveModelAndFeatures(model, context, params) } params.output } /** * Validates the given config (required as a Job-server app). * * @param sc * @param runtime * @param config input parameters. * @return */ override def validate(sc: SparkContext, runtime: JobEnvironment, config: Config): JobData Or Every[SparkJobInvalid] = { val args = config.getString("rw.input").split("\\s+") CommandParser.parse(args) match { case Some(params) => Good(params) } } }