package com.lenovo.ml /** * Created by YangChenguang on 2017/9/15. */ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.types.StructType import DataPreprocess.segWords import org.apache.spark.ml.PipelineModel object XGBoostInference { def main(args:Array[String]): Unit = { // 1、创建Spark程序入口 val sparkSession = SparkSession.builder().appName("XGBoostInference").enableHiveSupport().getOrCreate() // 2、读取训练数据,对文本预处理后分词 val tableName = args(0) val matrix = sparkSession.sql("SELECT * FROM " + tableName) val words = segWords(sparkSession, args(1), args(2), args(3), args(4), matrix.select("text")) // 3、将原数据与分词结果关联起来 val rows = matrix.rdd.zip(words.rdd).map{ case (rowLeft, rowRight) => Row.fromSeq(rowLeft.toSeq ++ rowRight.toSeq) } val schema = StructType(matrix.schema.fields ++ words.schema.fields) val matrixMerge = sparkSession.createDataFrame(rows, schema) // 4、构建特征向量 val featuredModelTrained = sparkSession.sparkContext.broadcast(PipelineModel.read.load(args(5))) val dataPrepared = featuredModelTrained.value.transform(matrixMerge).repartition(18).cache() // 5、加载分类模型,产出故障预测结果 val xgbModelTrained = sparkSession.sparkContext.broadcast(PipelineModel.read.load(args(6))) val prediction = xgbModelTrained.value.transform(dataPrepared) // 6、将预测结果写到HDFS prediction.select("text", "predictedLabel", "probabilities").rdd.coalesce(1).saveAsTextFile(args(7)) sparkSession.stop() } }