org.apache.spark.sql.SparkSession Scala Examples

The following examples show how to use org.apache.spark.sql.SparkSession. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: TextFileFormat.scala    From drizzle-spark   with Apache License 2.0 12 votes vote down vote up
package org.apache.spark.sql.execution.datasources.text

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
import org.apache.hadoop.util.ReflectionUtils

import org.apache.spark.TaskContext
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration


  def getCompressionExtension(context: TaskAttemptContext): String = {
    // Set the compression extension, similar to code in TextOutputFormat.getDefaultWorkFile
    if (FileOutputFormat.getCompressOutput(context)) {
      val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec])
      ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension
    } else {
      ""
    }
  }
} 
Example 2
Source File: StreamingConsumer.scala    From Scala-Programming-Projects   with MIT License 11 votes vote down vote up
package coinyser

import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._

object StreamingConsumer {
  def fromJson(df: DataFrame): Dataset[Transaction] = {
    import df.sparkSession.implicits._
    val schema = Seq.empty[Transaction].toDS().schema
    df.select(from_json(col("value").cast("string"), schema).alias("v"))
      .select("v.*").as[Transaction]
  }

  def transactionStream(implicit spark: SparkSession, config: KafkaConfig): Dataset[Transaction] =
    fromJson(spark.readStream.format("kafka")
      .option("kafka.bootstrap.servers", config.bootStrapServers)
      .option("startingoffsets", "earliest")
      .option("subscribe", config.transactionsTopic)
      .load()
    )

} 
Example 3
Source File: SparkNarrowTest.scala    From spark-tools   with Apache License 2.0 7 votes vote down vote up
package io.univalence

import java.net.URLClassLoader
import java.sql.Date

import io.univalence.centrifuge.Sparknarrow
import org.apache.spark.SparkConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession
import org.scalatest.FunSuite

case class Person(name: String, age: Int, date: Date)

class SparknarrowTest extends FunSuite {

  val conf: SparkConf = new SparkConf()
  conf.setAppName("yo")
  conf.set("spark.sql.caseSensitive", "true")
  conf.setMaster("local[2]")

  implicit val ss: SparkSession = SparkSession.builder.config(conf).getOrCreate
  import ss.implicits._

  test("testBasicCC") {

    val classDef = Sparknarrow.basicCC(Encoders.product[Person].schema).classDef
    checkDefinition(classDef)

  }

  def checkDefinition(scalaCode: String): Unit = {
    //TODO do a version for 2.11 and 2.12
    
  }

  test("play with scala eval") {

    val code =
      """
        case class Tata(str: String)
        case class Toto(age: Int, tata: Tata)
      """

    checkDefinition(code)
    checkDefinition(code)

  }

  ignore("printSchema StructType") {
    val yo = StructType(
      Seq(
        StructField("name", StringType),
        StructField("tel", ArrayType(StringType))
      )
    )

    yo.printTreeString()
  }

} 
Example 4
Source File: StreamingKafka10.scala    From BigData-News   with Apache License 2.0 7 votes vote down vote up
package com.vita.spark

import org.apache.kafka.common.serialization.StringDeserializer
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.kafka010.KafkaUtils
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent
import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe

object StreamingKafka10 {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .master("local[2]")
      .appName("streaming")
      .getOrCreate()


    val sc = spark.sparkContext
    val ssc = new StreamingContext(sc, Seconds(5))

    val kafkaParams = Map[String, Object](
      "bootstrap.servers" -> "node6:9092",
      "key.deserializer" -> classOf[StringDeserializer],
      "value.deserializer" -> classOf[StringDeserializer],
      "group.id" -> "0001",
      "auto.offset.reset" -> "latest",
      "enable.auto.commit" -> (false: java.lang.Boolean)
    )

    val topics = Array("weblogs")
    val stream = KafkaUtils.createDirectStream[String, String](
      ssc,
      PreferConsistent,
      Subscribe[String, String](topics, kafkaParams)
    )

    val lines = stream.map(x => x.value())
    val words = lines.flatMap(_.split(" "))
    val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _)
    wordCounts.print()

    ssc.start()
    ssc.awaitTermination()
  }
} 
Example 5
Source File: DataFrameExample.scala    From drizzle-spark   with Apache License 2.0 7 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

import java.io.File

import scopt.OptionParser

import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.util.Utils


object DataFrameExample {

  case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
    extends AbstractParams[Params]

  def main(args: Array[String]) {
    val defaultParams = Params()

    val parser = new OptionParser[Params]("DataFrameExample") {
      head("DataFrameExample: an example app using DataFrame for ML.")
      opt[String]("input")
        .text(s"input path to dataframe")
        .action((x, c) => c.copy(input = x))
      checkConfig { params =>
        success
      }
    }

    parser.parse(args, defaultParams) match {
      case Some(params) => run(params)
      case _ => sys.exit(1)
    }
  }

  def run(params: Params): Unit = {
    val spark = SparkSession
      .builder
      .appName(s"DataFrameExample with $params")
      .getOrCreate()

    // Load input data
    println(s"Loading LIBSVM file with UDT from ${params.input}.")
    val df: DataFrame = spark.read.format("libsvm").load(params.input).cache()
    println("Schema from LIBSVM:")
    df.printSchema()
    println(s"Loaded training data as a DataFrame with ${df.count()} records.")

    // Show statistical summary of labels.
    val labelSummary = df.describe("label")
    labelSummary.show()

    // Convert features column to an RDD of vectors.
    val features = df.select("features").rdd.map { case Row(v: Vector) => v }
    val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
      (summary, feat) => summary.add(Vectors.fromML(feat)),
      (sum1, sum2) => sum1.merge(sum2))
    println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

    // Save the records in a parquet file.
    val tmpDir = Utils.createTempDir()
    val outputDir = new File(tmpDir, "dataframe").toString
    println(s"Saving to $outputDir as Parquet file.")
    df.write.parquet(outputDir)

    // Load the records back.
    println(s"Loading Parquet file with UDT from $outputDir.")
    val newDF = spark.read.parquet(outputDir)
    println(s"Schema from Parquet:")
    newDF.printSchema()

    spark.stop()
  }
}
// scalastyle:on println 
Example 6
Source File: TokenizerExample.scala    From drizzle-spark   with Apache License 2.0 7 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer}
import org.apache.spark.sql.functions._
// $example off$
import org.apache.spark.sql.SparkSession

object TokenizerExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder
      .appName("TokenizerExample")
      .getOrCreate()

    // $example on$
    val sentenceDataFrame = spark.createDataFrame(Seq(
      (0, "Hi I heard about Spark"),
      (1, "I wish Java could use case classes"),
      (2, "Logistic,regression,models,are,neat")
    )).toDF("id", "sentence")

    val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words")
    val regexTokenizer = new RegexTokenizer()
      .setInputCol("sentence")
      .setOutputCol("words")
      .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false)

    val countTokens = udf { (words: Seq[String]) => words.length }

    val tokenized = tokenizer.transform(sentenceDataFrame)
    tokenized.select("sentence", "words")
        .withColumn("tokens", countTokens(col("words"))).show(false)

    val regexTokenized = regexTokenizer.transform(sentenceDataFrame)
    regexTokenized.select("sentence", "words")
        .withColumn("tokens", countTokens(col("words"))).show(false)
    // $example off$

    spark.stop()
  }
}
// scalastyle:on println 
Example 7
Source File: SnowflakeConnectorUtils.scala    From spark-snowflake   with Apache License 2.0 6 votes vote down vote up
package net.snowflake.spark.snowflake

import java.nio.file.Paths
import java.security.InvalidKeyException

import net.snowflake.spark.snowflake.pushdowns.SnowflakeStrategy
import org.apache.spark.sql.SparkSession
import org.slf4j.{Logger, LoggerFactory}


  def disablePushdownSession(session: SparkSession): Unit = {
    session.experimental.extraStrategies = session.experimental.extraStrategies
      .filterNot(strategy => strategy.isInstanceOf[SnowflakeStrategy])
  }

  def setPushdownSession(session: SparkSession, enabled: Boolean): Unit = {
    if (enabled) {
      enablePushdownSession(session)
    } else {
      disablePushdownSession(session)
    }
  }

  // TODO: Improve error handling with retries, etc.

  @throws[SnowflakeConnectorException]
  def handleS3Exception(ex: Exception): Unit = {
    if (ex.getCause.isInstanceOf[InvalidKeyException]) {
      // Most likely cause: Unlimited strength policy files not installed
      var msg: String = "Strong encryption with Java JRE requires JCE " +
        "Unlimited Strength Jurisdiction Policy " +
        "files. " +
        "Follow JDBC client installation instructions " +
        "provided by Snowflake or contact Snowflake " +
        "Support. This needs to be installed in the Java runtime for all Spark executor nodes."

      log.error(
        "JCE Unlimited Strength policy files missing: {}. {}.",
        ex.getMessage: Any,
        ex.getCause.getMessage: Any
      )

      val bootLib: String =
        java.lang.System.getProperty("sun.boot.library.path")

      if (bootLib != null) {
        msg += " The target directory on your system is: " + Paths
          .get(bootLib, "security")
          .toString
        log.error(msg)
      }

      throw new SnowflakeConnectorException(msg)
    } else {
      throw ex
    }
  }
}

class SnowflakeConnectorException(message: String) extends Exception(message)
class SnowflakePushdownException(message: String)
  extends SnowflakeConnectorException(message)
class SnowflakeConnectorFeatureNotSupportException(message: String)
  extends Exception(message)

class SnowflakePushdownUnsupportedException(message: String,
                                            val unsupportedOperation: String,
                                            val details: String,
                                            val isKnownUnsupportedOperation: Boolean)
  extends Exception(message) 
Example 8
Source File: LinearRegressionWithElasticNetExample.scala    From drizzle-spark   with Apache License 2.0 6 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.regression.LinearRegression
// $example off$
import org.apache.spark.sql.SparkSession

object LinearRegressionWithElasticNetExample {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder
      .appName("LinearRegressionWithElasticNetExample")
      .getOrCreate()

    // $example on$
    // Load training data
    val training = spark.read.format("libsvm")
      .load("data/mllib/sample_linear_regression_data.txt")

    val lr = new LinearRegression()
      .setMaxIter(10)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)

    // Fit the model
    val lrModel = lr.fit(training)

    // Print the coefficients and intercept for linear regression
    println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

    // Summarize the model over the training set and print out some metrics
    val trainingSummary = lrModel.summary
    println(s"numIterations: ${trainingSummary.totalIterations}")
    println(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]")
    trainingSummary.residuals.show()
    println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
    println(s"r2: ${trainingSummary.r2}")
    // $example off$

    spark.stop()
  }
}
// scalastyle:on println 
Example 9
Source File: SqlNetworkWordCount.scala    From drizzle-spark   with Apache License 2.0 6 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.streaming

import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext, Time}


object SparkSessionSingleton {

  @transient  private var instance: SparkSession = _

  def getInstance(sparkConf: SparkConf): SparkSession = {
    if (instance == null) {
      instance = SparkSession
        .builder
        .config(sparkConf)
        .getOrCreate()
    }
    instance
  }
}
// scalastyle:on println 
Example 10
Source File: DNSstat.scala    From jdbcsink   with Apache License 2.0 6 votes vote down vote up
import org.apache.spark.sql.SparkSession
import java.util.Properties
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.{from_json,window}
import java.sql.{Connection,Statement,DriverManager}
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.Row

class JDBCSink() extends ForeachWriter[Row]{
 val driver = "com.mysql.jdbc.Driver"
      var connection:Connection = _
      var statement:Statement = _

    def open(partitionId: Long,version: Long): Boolean = {
        Class.forName(driver)
        connection = DriverManager.getConnection("jdbc:mysql://10.88.1.102:3306/aptwebservice", "root", "mysqladmin")
        statement = connection.createStatement
        true
      }
      def process(value: Row): Unit = {
        statement.executeUpdate("replace into DNSStat(ip,domain,time,count) values(" 
                                    + "'" + value.getString(0) + "'" + ","//ip
                                    + "'" + value.getString(1) + "'" + ","//domain
                                    + "'" + value.getTimestamp(2) + "'" + "," //time
                                    + value.getLong(3) //count
                                    + ")") 
      }

      def close(errorOrNull: Throwable): Unit = {
        connection.close
      }
}

object DNSstatJob{

val schema: StructType = StructType(
        Seq(StructField("Vendor", StringType,true),
         StructField("Id", IntegerType,true),
         StructField("Time", LongType,true),
         StructField("Conn", StructType(Seq(
                                        StructField("Proto", IntegerType, true), 
                                        StructField("Sport", IntegerType, true), 
                                        StructField("Dport", IntegerType, true), 
                                        StructField("Sip", StringType, true), 
                                        StructField("Dip", StringType, true)
                                        )), true),
        StructField("Dns", StructType(Seq(
                                        StructField("Domain", StringType, true), 
                                        StructField("IpCount", IntegerType, true), 
                                        StructField("Ip", StringType, true) 
                                        )), true)))

    def main(args: Array[String]) {
    val spark=SparkSession
          .builder
          .appName("DNSJob")
          .config("spark.some.config.option", "some-value")
          .getOrCreate()
    import spark.implicits._
    val connectionProperties = new Properties()
    connectionProperties.put("user", "root")
    connectionProperties.put("password", "mysqladmin")
    val bruteForceTab = spark.read
                .jdbc("jdbc:mysql://10.88.1.102:3306/aptwebservice", "DNSTab",connectionProperties)
    bruteForceTab.registerTempTable("DNSTab")
    val lines = spark
          .readStream
          .format("kafka")
          .option("kafka.bootstrap.servers", "10.94.1.110:9092")
          .option("subscribe","xdr")
          //.option("startingOffsets","earliest")
          .option("startingOffsets","latest")
          .load()
          .select(from_json($"value".cast(StringType),schema).as("jsonData"))
    lines.registerTempTable("xdr")
    val filterDNS = spark.sql("select CAST(from_unixtime(xdr.jsonData.Time DIV 1000000) as timestamp) as time,xdr.jsonData.Conn.Sip as sip, xdr.jsonData.Dns.Domain from xdr inner join DNSTab on xdr.jsonData.Dns.domain = DNSTab.domain")
    
    val windowedCounts = filterDNS
                        .withWatermark("time","5 minutes")
                        .groupBy(window($"time", "1 minutes", "1 minutes"),$"sip",$"domain")
                        .count()
                        .select($"sip",$"domain",$"window.start",$"count")

    val writer = new JDBCSink()
    val query = windowedCounts
       .writeStream
        .foreach(writer)
        .outputMode("update")
        .option("checkpointLocation","/checkpoint/")
        .start()
        query.awaitTermination() 
   } 
} 
Example 11
Source File: SparkPFASuiteBase.scala    From aardpfark   with Apache License 2.0 6 votes vote down vote up
package com.ibm.aardpfark.pfa

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.SparkConf
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.scalactic.Equality
import org.scalatest.FunSuite

abstract class SparkPFASuiteBase extends FunSuite with DataFrameSuiteBase with PFATestUtils {

  val sparkTransformer: Transformer
  val input: Array[String]
  val expectedOutput: Array[String]

  val sparkConf =  new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID).
    set("spark.driver.host", "localhost")
  override lazy val spark = SparkSession.builder().config(sparkConf).getOrCreate()
  override val reuseContextIfPossible = true

  // Converts column containing a vector to an array
  def withColumnAsArray(df: DataFrame, colName: String) = {
    val vecToArray = udf { v: Vector => v.toArray }
    df.withColumn(colName, vecToArray(df(colName)))
  }

  def withColumnAsArray(df: DataFrame, first: String, others: String*) = {
    val vecToArray = udf { v: Vector => v.toArray }
    var result = df.withColumn(first, vecToArray(df(first)))
    others.foreach(c => result = result.withColumn(c, vecToArray(df(c))))
    result
  }

  // Converts column containing a vector to a sparse vector represented as a map
  def getColumnAsSparseVectorMap(df: DataFrame, colName: String) = {
    val vecToMap = udf { v: Vector => v.toSparse.indices.map(i => (i.toString, v(i))).toMap }
    df.withColumn(colName, vecToMap(df(colName)))
  }

}

abstract class Result

object ApproxEquality extends ApproxEquality

trait ApproxEquality {

  import org.scalactic.Tolerance._
  import org.scalactic.TripleEquals._

  implicit val seqApproxEq: Equality[Seq[Double]] = new Equality[Seq[Double]] {
    override def areEqual(a: Seq[Double], b: Any): Boolean = {
      b match {
        case d: Seq[Double] =>
          a.zip(d).forall { case (l, r) => l === r +- 0.001 }
        case _ =>
          false
      }
    }
  }

  implicit val vectorApproxEq: Equality[Vector] = new Equality[Vector] {
    override def areEqual(a: Vector, b: Any): Boolean = {
      b match {
        case v: Vector =>
          a.toArray.zip(v.toArray).forall { case (l, r) => l === r +- 0.001 }
        case _ =>
          false
      }
    }
  }
} 
Example 12
Source File: MNISTBenchmark.scala    From spark-knn   with Apache License 2.0 6 votes vote down vote up
package com.github.saurfang.spark.ml.knn.examples

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.classification.{KNNClassifier, NaiveKNNClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.tuning.{Benchmarker, ParamGridBuilder}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.log4j

import scala.collection.mutable


object MNISTBenchmark {

  val logger = log4j.Logger.getLogger(getClass)

  def main(args: Array[String]) {
    val ns = if(args.isEmpty) (2500 to 10000 by 2500).toArray else args(0).split(',').map(_.toInt)
    val path = if(args.length >= 2) args(1) else "data/mnist/mnist.bz2"
    val numPartitions = if(args.length >= 3) args(2).toInt else 10
    val models = if(args.length >=4) args(3).split(',') else Array("tree","naive")

    val spark = SparkSession.builder().getOrCreate()
    val sc = spark.sparkContext
    import spark.implicits._

    //read in raw label and features
    val rawDataset = MLUtils.loadLibSVMFile(sc, path)
      .zipWithIndex()
      .filter(_._2 < ns.max)
      .sortBy(_._2, numPartitions = numPartitions)
      .keys
      .toDF()

    // convert "features" from mllib.linalg.Vector to ml.linalg.Vector
    val dataset =  MLUtils.convertVectorColumnsToML(rawDataset)
      .cache()
    dataset.count() //force persist

    val limiter = new Limiter()
    val knn = new KNNClassifier()
      .setTopTreeSize(numPartitions * 10)
      .setFeaturesCol("features")
      .setPredictionCol("prediction")
      .setK(1)
    val naiveKNN = new NaiveKNNClassifier()

    val pipeline = new Pipeline()
      .setStages(Array(limiter, knn))
    val naivePipeline = new Pipeline()
      .setStages(Array(limiter, naiveKNN))

    val paramGrid = new ParamGridBuilder()
      .addGrid(limiter.n, ns)
      .build()

    val bm = new Benchmarker()
      .setEvaluator(new MulticlassClassificationEvaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumTimes(3)

    val metrics = mutable.ArrayBuffer[String]()
    if(models.contains("tree")) {
      val bmModel = bm.setEstimator(pipeline).fit(dataset)
      metrics += s"knn: ${bmModel.avgTrainingRuntimes.toSeq} / ${bmModel.avgEvaluationRuntimes.toSeq}"
    }
    if(models.contains("naive")) {
      val naiveBMModel = bm.setEstimator(naivePipeline).fit(dataset)
      metrics += s"naive: ${naiveBMModel.avgTrainingRuntimes.toSeq} / ${naiveBMModel.avgEvaluationRuntimes.toSeq}"
    }
    logger.info(metrics.mkString("\n"))
  }
}

class Limiter(override val uid: String) extends Transformer {
  def this() = this(Identifiable.randomUID("limiter"))

  val n: IntParam = new IntParam(this, "n", "number of rows to limit")

  def setN(value: Int): this.type = set(n, value)

  // hack to maintain number of partitions (otherwise it collapses to 1 which is unfair for naiveKNN)
  override def transform(dataset: Dataset[_]): DataFrame = dataset.limit($(n)).repartition(dataset.rdd.partitions.length).toDF()

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = schema
} 
Example 13
Source File: SparkIntroduction.scala    From reactive-machine-learning-systems   with MIT License 6 votes vote down vote up
package com.reactivemachinelearning

import org.apache.spark.sql.SparkSession
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}
import org.apache.spark.mllib.linalg.Vectors

object SparkIntroduction {

  def main(args: Array[String]) {
    // handle args

    // setup
    val session = SparkSession.builder.appName("Simple ModelExample").getOrCreate()
    import session.implicits._

    // Load and parse the train and test data
    val inputBasePath = "example_data"
    val outputBasePath = "."
    val trainingDataPath = inputBasePath + "/training.txt"
    val testingDataPath = inputBasePath + "/testing.txt"
    val currentOutputPath = outputBasePath + System.currentTimeMillis()

    val trainingData = session.read.textFile(trainingDataPath)
    val trainingParsed = trainingData.map { line =>
      val parts = line.split(',')
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    }.cache()

    val testingData = session.read.textFile(testingDataPath)
    val testingParsed = testingData.map { line =>
      val parts = line.split(',')
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    }.cache()

    // Building the model
    val numIterations = 100
    val model = LinearRegressionWithSGD.train(trainingParsed.rdd, numIterations)

    // Evaluate model on testing examples
    val predictionsAndLabels = testingParsed.map { case LabeledPoint(label, features) =>
      val prediction = model.predict(features)
      (prediction, label)
    }

    // Report performance statistics
    val metrics = new MulticlassMetrics(predictionsAndLabels.rdd)
    val precision = metrics.precision
    val recall = metrics.recall
    println(s"Precision: $precision Recall: $recall")

    // Save model
    model.save(session.sparkContext, currentOutputPath)
  }

} 
Example 14
Source File: CSVTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.test

import cn.piflow.bundle.csv.CsvParser
import cn.piflow.bundle.json.JsonSave
import cn.piflow.{FlowImpl, Path, Runner}
import org.apache.spark.sql.SparkSession
import org.junit.Test

class CSVTest {

  @Test
  def testCSVHeaderRead(): Unit ={

    val csvParserParameters  = Map(
      "csvPath" -> "hdfs://10.0.86.89:9000/xjzhu/student.csv",
      "header" -> "true",
      "delimiter" -> ",",
      "schema" -> "")
    val jsonSaveParameters = Map(
      "jsonPath" -> "hdfs://10.0.86.89:9000/xjzhu/student_csv2json")

    val csvParserStop = new CsvParser
    csvParserStop.setProperties(csvParserParameters)

    val jsonPathStop =new JsonSave
    jsonPathStop.setProperties(jsonSaveParameters)

    val flow = new FlowImpl();

    flow.addStop("CSVParser", csvParserStop);
    flow.addStop("JsonSave", jsonPathStop);
    flow.addPath(Path.from("CSVParser").to("JsonSave"));

    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("piflow-hive-bundle")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("spark.jars","/opt/project/piflow-jar-bundle/out/artifacts/piflow-jar-bundle/piflow-jar-bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .start(flow);

    process.awaitTermination();
    spark.close();
  }

  @Test
  def testCSVSchemaRead(): Unit ={

    val csvParserParameters : Map[String, String] = Map(
      "csvPath" -> "hdfs://10.0.86.89:9000/xjzhu/student_schema.csv",
      "header" -> "false",
      "delimiter" -> ",",
      "schema" -> "id,name,gender,age"
    )
    val jsonSaveParameters = Map(
      "jsonPath" -> "hdfs://10.0.86.89:9000/xjzhu/student_schema_csv2json")


    val csvParserStop = new CsvParser
    csvParserStop.setProperties(csvParserParameters)

    val jsonSaveStop = new JsonSave
    jsonSaveStop.setProperties(jsonSaveParameters)

    val flow = new FlowImpl();

    flow.addStop("CSVParser", csvParserStop);
    flow.addStop("JsonSave", jsonSaveStop);
    flow.addPath(Path.from("CSVParser").to("JsonSave"));

    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("piflow-hive-bundle")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("spark.jars","/opt/project/piflow-jar-bundle/out/artifacts/piflow-jar-bundle/piflow-jar-bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .start(flow);

    process.awaitTermination();
    spark.close();
  }

} 
Example 15
Source File: UrlTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.test

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.junit.Test

import scala.util.parsing.json.JSON

class UrlTest {
  @Test
  def testGetHttp(): Unit = {

    // parse flow json
    val file = "src/main/resources/url.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    //    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("DblpParserTest")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "3")
      .config("spark.jars", "/opt/work/111/piflow-master/out/artifacts/piflow_bundle/piflow-bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .start(flow);

    process.awaitTermination();
    spark.close();
  }

} 
Example 16
Source File: CsvStringTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.test

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.junit.Test

import scala.util.parsing.json.JSON

class CsvStringTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/CsvStringTest.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    //execute flow
    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("piflow-hive-bundle")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("spark.jars","/opt/project/gitwork/out/artifacts/piflow_bundle/piflow_bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "hdfs://10.0.86.89:9000/xjzhu/piflow/checkpoints/")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
  @Test
  def testFlow2json() = {

    //parse flow json
    val file = "src/main/resources/flow.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]

    //create flow
    val flowBean = FlowBean(map)
    val flowJson = flowBean.toJson()
    println(flowJson)
  }

} 
Example 17
Source File: ShellFlowTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.test

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.junit.Test

import scala.util.parsing.json.JSON

class ShellFlowTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/shellflow.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    //execute flow
    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("piflow-hive-bundle")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("spark.jars","/opt/project/piflow/out/artifacts/piflow_bundle/piflow-bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .start(flow);

    process.awaitTermination();
    spark.close();
  }
} 
Example 18
Source File: FlattenXmlParserTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.test

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.junit.Test

import scala.util.parsing.json.JSON

class FlattenXmlParserTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/FlattenXmlParser.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flowrelationKey+"_"
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    //execute flow
    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("piflow-hive-bundle")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("spark.jars","/root/Desktop/gitWORK/out/artifacts/piflow_bundle/piflow_bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "hdfs://10.0.86.89:9000/xjzhu/piflow/checkpoints/")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
  @Test
  def testFlow2json() = {

    //parse flow json
    val file = "src/main/resources/flow.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]

    //create flow
    val flowBean = FlowBean(map)
    val flowJson = flowBean.toJson()
    println(flowJson)
  }

} 
Example 19
Source File: HdfsTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.test

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class HdfsTest {
  @Test
  def testHdfs(): Unit = {

    // parse flow json
    val file = "src/main/resources/hdfs.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    //    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()

    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("DblpParserTest")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "3")
      .config("spark.jars", "/opt/work/111/piflow-master/out/artifacts/piflow_bundle/piflow-bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .start(flow);

    process.awaitTermination();
    spark.close();
  }

} 
Example 20
Source File: IncrementTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.test

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class IncrementTest {

  @Test
  def testIncrmentMysql(): Unit ={

    //parse flow json
    val file = "src/main/resources/increment/mysql.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()

    //execute flow
    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("mysql_increment")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("spark.jars","/opt/project/piflow/out/artifacts/piflow_bundle/piflow-bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "hdfs://10.0.86.89:9000/xjzhu/piflow/checkpoints/")
      .bind("debug.path","hdfs://10.0.86.89:9000/xjzhu/piflow/debug/")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }

} 
Example 21
Source File: FtpNewTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.test

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.junit.Test

import scala.util.parsing.json.JSON

class FtpNewTest {
  @Test
  def ftpNew(): Unit = {

    // parse flow json
    val file = "src/main/resources/ftpNew.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    //    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("DblpParserTest")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "3")
      .config("spark.jars", "/opt/work/111/piflow-master/out/artifacts/piflow_bundle/piflow-bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .start(flow);

    process.awaitTermination();
    spark.close();
  }

} 
Example 22
Source File: JsonFolderTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.test

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.junit.Test

import scala.util.parsing.json.JSON

class JsonFolderTest {

  @Test
  def testFlow(): Unit ={

//测试数据




    //parse flow json
    val file = "src/main/resources/JsonFolderTest.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()


    //execute flow
    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("piflow-hive-bundle")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("spark.jars","/opt/project/gitwork/out/artifacts/piflow_bundle/piflow_bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "hdfs://10.0.86.89:9000/xjzhu/piflow/checkpoints/")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
  @Test
  def testFlow2json() = {

    //parse flow json
    val file = "src/main/resources/flow.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]

    //create flow
    val flowBean = FlowBean(map)
    val flowJson = flowBean.toJson()
    println(flowJson)
  }

} 
Example 23
Source File: getMongoDBTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.mongodb

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class getMongoDBTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/mongoDB/getMongoDB.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp","-tcpAllowOthers","-tcpPort","50001").start()


    //execute flow
    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("piflow-hive-bundle")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("spark.jars","/root/Desktop/gitWORK/out/artifacts/piflow_bundle/piflow-bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "hdfs://10.0.86.89:9000/xjzhu/piflow/checkpoints/")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
  @Test
  def testFlow2json() = {

    //parse flow json
    val file = "src/main/resources/flow.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]

    //create flow
    val flowBean = FlowBean(map)
    val flowJson = flowBean.toJson()
    println(flowJson)
  }

} 
Example 24
Source File: putMongoDBTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.mongodb

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class putMongoDBTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/mongoDB/putMongoDB.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp","-tcpAllowOthers","-tcpPort","50001").start()


    //execute flow
    val spark = SparkSession.builder()
      .master("spark://10.0.86.89:7077")
      .appName("piflow-hive-bundle")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("spark.jars","/root/Desktop/gitWORK/out/artifacts/piflow_bundle/piflow-bundle.jar")
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "hdfs://10.0.86.89:9000/xjzhu/piflow/checkpoints/")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
  @Test
  def testFlow2json() = {

    //parse flow json
    val file = "src/main/resources/flow.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]

    //create flow
    val flowBean = FlowBean(map)
    val flowJson = flowBean.toJson()
    println(flowJson)
  }

} 
Example 25
Source File: ProvinceCleanTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.clean

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.PropertyUtil
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class ProvinceCleanTest {

  @Test
  def ProvinceCleanFlow(): Unit = {

    //parse flow json
    val file = "src/main/resources/flow/clean/ProvinceClean.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort", "50001").start()

    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("ProvinceCleanTest")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
} 
Example 26
Source File: TitleCleanTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.clean

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.PropertyUtil
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class TitleCleanTest {

  @Test
  def TitleCleanFlow(): Unit = {

    //parse flow json
    val file = "src/main/resources/flow/clean/TitleClean.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort", "50001").start()

    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("TitleCleanTest")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
} 
Example 27
Source File: EmailCleanTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.clean

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.PropertyUtil
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class EmailCleanTest {

  @Test
  def EmailCleanFlow(): Unit = {

    //parse flow json
    val file = "src/main/resources/flow/clean/EmailClean.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort", "50001").start()

    //execute flow
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("EmailCleanTest")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
} 
Example 28
Source File: PhoneNumberCleanTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.clean

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.PropertyUtil
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class PhoneNumberCleanTest {

  @Test
  def PhoneNumberCleanFlow(): Unit = {

    //parse flow json
    val file = "src/main/resources/flow/clean/PhoneNumberClean.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort", "50001").start()

    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("PhoneNumberCleanTest")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
} 
Example 29
Source File: IdentityNumberCleanTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.clean

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.PropertyUtil
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class IdentityNumberCleanTest {

  @Test
  def IdentityNumberCleanFlow(): Unit = {

    //parse flow json
    val file = "src/main/resources/flow/clean/IdentityNumberClean.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort", "50001").start()

    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("IdentityNumberCleanTest")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }
} 
Example 30
Source File: DeleteHdfsTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.hdfs

import java.net.InetAddress

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.{PropertyUtil, ServerIpUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class DeleteHdfsTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/flow/hdfs/deleteHdfs.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()


    val ip = InetAddress.getLocalHost.getHostAddress
    cn.piflow.util.FileUtil.writeFile("server.ip=" + ip, ServerIpUtil.getServerIpFile())
    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()
    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("hive")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "8g")
      .config("spark.cores.max", "8")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }


} 
Example 31
Source File: GetHdfsTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.hdfs

import java.net.InetAddress

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.{PropertyUtil, ServerIpUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class GetHdfsTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/flow/hdfs/getHdfs.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()


    val ip = InetAddress.getLocalHost.getHostAddress
    cn.piflow.util.FileUtil.writeFile("server.ip=" + ip, ServerIpUtil.getServerIpFile())
    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()
    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("hive")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "8g")
      .config("spark.cores.max", "8")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }


} 
Example 32
Source File: SelectFilesByNameTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.hdfs

import java.net.InetAddress

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.{PropertyUtil, ServerIpUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class SelectFilesByNameTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/flow/hdfs/selectFileByName.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()


    val ip = InetAddress.getLocalHost.getHostAddress
    cn.piflow.util.FileUtil.writeFile("server.ip=" + ip, ServerIpUtil.getServerIpFile())
    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()
    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("hive")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "8g")
      .config("spark.cores.max", "8")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }


} 
Example 33
Source File: FileDownhdfsHdfsTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.hdfs

import java.net.InetAddress

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.{PropertyUtil, ServerIpUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class FileDownhdfsHdfsTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/flow/hdfs/fileDownHdfs.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()


    val ip = InetAddress.getLocalHost.getHostAddress
    cn.piflow.util.FileUtil.writeFile("server.ip=" + ip, ServerIpUtil.getServerIpFile())
    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()
    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("hive")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "8g")
      .config("spark.cores.max", "8")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }


} 
Example 34
Source File: UnzipFilesonHdfsTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.hdfs

import java.net.InetAddress

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.{PropertyUtil, ServerIpUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class UnzipFilesonHdfsTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/flow/hdfs/unzipFilesOnHdfs.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()


    val ip = InetAddress.getLocalHost.getHostAddress
    cn.piflow.util.FileUtil.writeFile("server.ip=" + ip, ServerIpUtil.getServerIpFile())
    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()
    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("hive")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "8g")
      .config("spark.cores.max", "8")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }


} 
Example 35
Source File: SaveToHdfsTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.hdfs

import java.net.InetAddress

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.{PropertyUtil, ServerIpUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class SaveToHdfsTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/flow/hdfs/saveToHdfs.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()


    val ip = InetAddress.getLocalHost.getHostAddress
    cn.piflow.util.FileUtil.writeFile("server.ip=" + ip, ServerIpUtil.getServerIpFile())
    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()
    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("hive")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "8g")
      .config("spark.cores.max", "8")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }


} 
Example 36
Source File: ListHdfsTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.hdfs

import java.net.InetAddress

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.{PropertyUtil, ServerIpUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class ListHdfsTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/flow/hdfs/listHdfs.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()


    val ip = InetAddress.getLocalHost.getHostAddress
    cn.piflow.util.FileUtil.writeFile("server.ip=" + ip, ServerIpUtil.getServerIpFile())
    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()
    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("hive")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "8g")
      .config("spark.cores.max", "8")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }


} 
Example 37
Source File: PutHdfsTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.hdfs

import java.net.InetAddress

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.{PropertyUtil, ServerIpUtil}
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class PutHdfsTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/flow/hdfs/putHdfs.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()


    val ip = InetAddress.getLocalHost.getHostAddress
    cn.piflow.util.FileUtil.writeFile("server.ip=" + ip, ServerIpUtil.getServerIpFile())
    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()
    //execute flow
    val spark = SparkSession.builder()
      .master("local[12]")
      .appName("hive")
      .config("spark.driver.memory", "4g")
      .config("spark.executor.memory", "8g")
      .config("spark.cores.max", "8")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }


} 
Example 38
Source File: CsvParserTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.csv

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.PropertyUtil
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class CsvParserTest {

  @Test
  def testFlow(): Unit ={

    //parse flow json
    val file = "src/main/resources/flow/csv/CsvParser.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort", "50001").start()

    //execute flow
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("CsvParserTest")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("hive.metastore.uris",PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }

} 
Example 39
Source File: CsvStringParserTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.csv

import cn.piflow.Runner
import cn.piflow.conf.bean.FlowBean
import cn.piflow.conf.util.{FileUtil, OptionUtil}
import cn.piflow.util.PropertyUtil
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test

import scala.util.parsing.json.JSON

class CsvStringParserTest {

  @Test
  def testFlow(): Unit ={
    //parse flow json
    val file = "src/main/resources/flow/csv/CsvStringParser.json"
    val flowJsonStr = FileUtil.fileReader(file)
    val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
    println(map)

    //create flow
    val flowBean = FlowBean(map)
    val flow = flowBean.constructFlow()
    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort", "50001").start()

    //execute flow
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("CsvStringParserTest")
      .config("spark.driver.memory", "1g")
      .config("spark.executor.memory", "2g")
      .config("spark.cores.max", "2")
      .config("hive.metastore.uris", PropertyUtil.getPropertyValue("hive.metastore.uris"))
      .enableHiveSupport()
      .getOrCreate()

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "")
      .bind("debug.path","")
      .start(flow);

    process.awaitTermination();
    val pid = process.pid();
    println(pid + "!!!!!!!!!!!!!!!!!!!!!")
    spark.close();
  }

} 
Example 40
Source File: GroupTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
import cn.piflow._
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test


class GroupTest {
  @Test
  def testProject() {
    val flow1 = new FlowImpl();

    flow1.addStop("CleanHouse", new CleanHouse());
    flow1.addStop("CopyTextFile", new CopyTextFile());
    flow1.addStop("CountWords", new CountWords());
    flow1.addPath(Path.of("CleanHouse" -> "CopyTextFile" -> "CountWords"));

    val flow2 = new FlowImpl();
    flow2.addStop("PrintCount", new PrintCount());

    val fg = new GroupImpl();
    fg.addGroupEntry("flow1", flow1);
    fg.addGroupEntry("flow2", flow2, Condition.after("flow1"));


    val flow3 = new FlowImpl();
    flow3.addStop("TestStop", new TestStop());

    val flow4 = new FlowImpl();
    flow4.addStop("TestStop", new TestStop());

    val project = new GroupImpl();

    project.addGroupEntry("flow3",flow3)
    project.addGroupEntry("flowGroup",fg,Condition.after("flow3"))
    project.addGroupEntry("flow4",flow4, Condition.after("flowGroup"))

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()

    val spark = SparkSession.builder.master("local[4]")
      .getOrCreate();

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "hdfs://10.0.86.89:9000/xjzhu/piflow/checkpoints/")
      .bind("debug.path","hdfs://10.0.86.89:9000/xjzhu/piflow/debug/")
      .start(project);

    process.awaitTermination();
    spark.close();
  }
} 
Example 41
Source File: FlowGroupTest.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
import cn.piflow._
import org.apache.spark.sql.SparkSession
import org.h2.tools.Server
import org.junit.Test


class FlowGroupTest {
  @Test
  def testProcessGroup() {
    val flow1 = new FlowImpl();

    flow1.addStop("CleanHouse", new CleanHouse());
    flow1.addStop("CopyTextFile", new CopyTextFile());
    flow1.addStop("CountWords", new CountWords());
    flow1.addPath(Path.of("CleanHouse" -> "CopyTextFile" -> "CountWords"));

    val flow2 = new FlowImpl();

    flow2.addStop("PrintCount", new PrintCount());

    val fg = new GroupImpl();
    fg.addGroupEntry("flow1", flow1);
    fg.addGroupEntry("flow2", flow2, Condition.after("flow1"));

    val h2Server = Server.createTcpServer("-tcp", "-tcpAllowOthers", "-tcpPort","50001").start()

    val spark = SparkSession.builder.master("local[4]")
      .getOrCreate();

    val process = Runner.create()
      .bind(classOf[SparkSession].getName, spark)
      .bind("checkpoint.path", "hdfs://10.0.86.89:9000/xjzhu/piflow/checkpoints/")
      .bind("debug.path","hdfs://10.0.86.89:9000/xjzhu/piflow/debug/")
      .start(fg);

    process.awaitTermination();
    spark.close();
  }
} 
Example 42
Source File: TestBroadCast.scala    From asyspark   with MIT License 5 votes vote down vote up
package org.apache.spark.examples

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession

import scala.collection.mutable


object TestBroadCast extends Logging{
  val sparkSession = SparkSession.builder().appName("test BoradCast").getOrCreate()
  val sc = sparkSession.sparkContext
  def main(args: Array[String]): Unit = {

    //    val data = sc.parallelize(Seq(1 until 10000000))
    val num = args(args.length - 2).toInt
    val times = args(args.length -1).toInt
    println(num)
    val start = System.nanoTime()
    val seq =Seq(1 until num)
    for(i <- 0 until times) {
      val start2 = System.nanoTime()
      val bc = sc.broadcast(seq)
      val rdd = sc.parallelize(1 until 10, 5)
      rdd.map(_ => bc.value.take(1)).collect()
      println((System.nanoTime() - start2)/ 1e6 + "ms")
    }
    logInfo((System.nanoTime() - start) / 1e6 + "ms")
  }

  def testMap(): Unit ={

    val smallRDD = sc.parallelize(Seq(1,2,3))
    val bigRDD = sc.parallelize(Seq(1 until 20))

    bigRDD.mapPartitions {
      partition =>
        val hashMap = new mutable.HashMap[Int,Int]()
        for(ele <- smallRDD) {
          hashMap(ele) = ele
        }
        // some operation here
        partition

    }
  }
} 
Example 43
Source File: PailDataSourceSpec.scala    From utils   with Apache License 2.0 5 votes vote down vote up
package com.indix.utils.spark.pail

import java.util

import com.backtype.hadoop.pail.{PailFormatFactory, PailSpec, PailStructure}
import com.backtype.support.{Utils => PailUtils}
import com.google.common.io.Files
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, FlatSpec}
import org.scalatest.Matchers._

import scala.collection.JavaConverters._
import scala.util.Random

case class User(name: String, age: Int)

class UserPailStructure extends PailStructure[User] {
  override def isValidTarget(dirs: String*): Boolean = true

  override def getType: Class[_] = classOf[User]

  override def serialize(user: User): Array[Byte] = PailUtils.serialize(user)

  override def getTarget(user: User): util.List[String] = List(user.age % 10).map(_.toString).asJava

  override def deserialize(serialized: Array[Byte]): User = PailUtils.deserialize(serialized).asInstanceOf[User]
}

class PailDataSourceSpec extends FlatSpec with BeforeAndAfterAll with PailDataSource {
  private var spark: SparkSession = _

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    spark = SparkSession.builder().master("local[2]").appName("PailDataSource").getOrCreate()
  }

  val userPailSpec = new PailSpec(PailFormatFactory.SEQUENCE_FILE, new UserPailStructure)

  "PailBasedReaderWriter" should "read/write user records from/into pail" in {
    val output = Files.createTempDir()
    val users = (1 to 100).map { index => User(s"foo$index", Random.nextInt(40))}
    spark.sparkContext.parallelize(users)
      .saveAsPail(output.getAbsolutePath, userPailSpec)

    val input = output.getAbsolutePath
    val total = spark.sparkContext.pailFile[User](input)
      .map(u => u.name)
      .count()

    total should be(100)
    FileUtils.deleteDirectory(output)
  }
} 
Example 44
Source File: ParquetAvroDataSourceSpec.scala    From utils   with Apache License 2.0 5 votes vote down vote up
package com.indix.utils.spark.parquet

import java.io.File

import com.google.common.io.Files
import com.indix.utils.spark.parquet.avro.ParquetAvroDataSource
import org.apache.commons.io.FileUtils
import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.spark.sql.SparkSession
import org.scalactic.Equality
import org.scalatest.Matchers.{be, convertToAnyShouldWrapper, equal}
import org.scalatest.{BeforeAndAfterAll, FlatSpec}
import java.util.{Arrays => JArrays}

case class SampleAvroRecord(a: Int, b: String, c: Seq[String], d: Boolean, e: Double, f: collection.Map[String, String], g: Array[Byte])

class ParquetAvroDataSourceSpec extends FlatSpec with BeforeAndAfterAll with ParquetAvroDataSource {
  private var spark: SparkSession = _
  implicit val sampleAvroRecordEq = new Equality[SampleAvroRecord] {
    override def areEqual(left: SampleAvroRecord, b: Any): Boolean = b match {
      case right: SampleAvroRecord =>
        left.a == right.a &&
          left.b == right.b &&
          Equality.default[Seq[String]].areEqual(left.c, right.c) &&
          left.d == right.d &&
          left.e == right.e &&
          Equality.default[collection.Map[String, String]].areEqual(left.f, right.f) &&
          JArrays.equals(left.g, right.g)
      case _ => false
    }
  }

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    spark = SparkSession.builder().master("local[2]").appName("ParquetAvroDataSource").getOrCreate()
  }

  override protected def afterAll(): Unit = {
    try {
      spark.sparkContext.stop()
    } finally {
      super.afterAll()
    }
  }

  "AvroBasedParquetDataSource" should "read/write avro records as ParquetData" in {

    val outputLocation = Files.createTempDir().getAbsolutePath + "/output"

    val sampleRecords: Seq[SampleAvroRecord] = Seq(
      SampleAvroRecord(1, "1", List("a1"), true, 1.0d, Map("a1" -> "b1"), "1".getBytes),
      SampleAvroRecord(2, "2", List("a2"), false, 2.0d, Map("a2" -> "b2"), "2".getBytes),
      SampleAvroRecord(3, "3", List("a3"), true, 3.0d, Map("a3" -> "b3"), "3".getBytes),
      SampleAvroRecord(4, "4", List("a4"), true, 4.0d, Map("a4" -> "b4"), "4".getBytes),
      SampleAvroRecord(5, "5", List("a5"), false, 5.0d, Map("a5" -> "b5"), "5".getBytes)
    )

    val sampleDf = spark.createDataFrame(sampleRecords)

    sampleDf.rdd.saveAvroInParquet(outputLocation, sampleDf.schema, CompressionCodecName.GZIP)

    val sparkVal = spark

    import sparkVal.implicits._

    val records: Array[SampleAvroRecord] = spark.read.parquet(outputLocation).as[SampleAvroRecord].collect()

    records.length should be(5)
    // We use === to use the custom Equality defined above for comparing Array[Byte]
    // Ref - https://github.com/scalatest/scalatest/issues/491
    records.sortBy(_.a) === sampleRecords.sortBy(_.a)

    FileUtils.deleteDirectory(new File(outputLocation))
  }

} 
Example 45
Source File: WordCount.scala    From spark-solr   with Apache License 2.0 5 votes vote down vote up
package com.lucidworks.spark.example.query

import com.lucidworks.spark.SparkApp.RDDProcessor
import com.lucidworks.spark.rdd.{SelectSolrRDD, SolrRDD}
import com.lucidworks.spark.util.ConfigurationConstants._
import org.apache.commons.cli.{CommandLine, Option}
import org.apache.solr.common.SolrDocument
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.{SparkConf, SparkContext}

import scala.collection.immutable.HashMap


class WordCount extends RDDProcessor{
  def getName: String = "word-count"

  def getOptions: Array[Option] = {
    Array(
    Option.builder()
          .argName("QUERY")
          .longOpt("query")
          .hasArg
          .required(false)
          .desc("URL encoded Solr query to send to Solr")
          .build()
    )
  }

  def run(conf: SparkConf, cli: CommandLine): Int = {
    val zkHost = cli.getOptionValue("zkHost", "localhost:9983")
    val collection = cli.getOptionValue("collection", "collection1")
    val queryStr = cli.getOptionValue("query", "*:*")

    val sc = SparkContext.getOrCreate(conf)
    val solrRDD: SelectSolrRDD = new SelectSolrRDD(zkHost, collection, sc)
    val rdd: RDD[SolrDocument]  = solrRDD.query(queryStr)

    val words: RDD[String] = rdd.map(doc => if (doc.containsKey("text_t")) doc.get("text_t").toString else "")
    val pWords: RDD[String] = words.flatMap(s => s.toLowerCase.replaceAll("[.,!?\n]", " ").trim().split(" "))

    val wordsCountPairs: RDD[(String, Int)] = pWords.map(s => (s, 1))
                                                    .reduceByKey((a,b) => a+b)
                                                    .map(item => item.swap)
                                                    .sortByKey(false)
                                                    .map(item => item.swap)

    wordsCountPairs.take(20).iterator.foreach(println)

    val sparkSession: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    // Now use schema information in Solr to build a queryable SchemaRDD

    // Pro Tip: SolrRDD will figure out the schema if you don't supply a list of field names in your query
    val options = HashMap[String, String](
      SOLR_ZK_HOST_PARAM -> zkHost,
      SOLR_COLLECTION_PARAM -> collection,
      SOLR_QUERY_PARAM -> queryStr
      )

    val df: DataFrame = sparkSession.read.format("solr").options(options).load()
    val numEchos = df.filter(df.col("type_s").equalTo("echo")).count()
    println("numEchos >> " + numEchos)

    sc.stop()
    0
  }
} 
Example 46
Source File: SolrStreamWriter.scala    From spark-solr   with Apache License 2.0 5 votes vote down vote up
package com.lucidworks.spark

import com.lucidworks.spark.util.{SolrQuerySupport, SolrSupport}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.streaming.OutputMode
import com.lucidworks.spark.util.ConfigurationConstants._
import org.apache.spark.sql.types.StructType

import scala.collection.mutable


class SolrStreamWriter(
    val sparkSession: SparkSession,
    parameters: Map[String, String],
    val partitionColumns: Seq[String],
    val outputMode: OutputMode)(
  implicit val solrConf : SolrConf = new SolrConf(parameters))
  extends Sink with LazyLogging {

  require(solrConf.getZkHost.isDefined, s"Parameter ${SOLR_ZK_HOST_PARAM} not defined")
  require(solrConf.getCollection.isDefined, s"Parameter ${SOLR_COLLECTION_PARAM} not defined")

  val collection : String = solrConf.getCollection.get
  val zkhost: String = solrConf.getZkHost.get

  lazy val solrVersion : String = SolrSupport.getSolrVersion(solrConf.getZkHost.get)
  lazy val uniqueKey: String = SolrQuerySupport.getUniqueKey(zkhost, collection.split(",")(0))

  lazy val dynamicSuffixes: Set[String] = SolrQuerySupport.getFieldTypes(
      Set.empty,
      SolrSupport.getSolrBaseUrl(zkhost),
      SolrSupport.getCachedCloudClient(zkhost),
      collection,
      skipDynamicExtensions = false)
    .keySet
    .filter(f => f.startsWith("*_") || f.endsWith("_*"))
    .map(f => if (f.startsWith("*_")) f.substring(1) else f.substring(0, f.length-1))

  @volatile private var latestBatchId: Long = -1L
  val acc: SparkSolrAccumulator = new SparkSolrAccumulator
  val accName = if (solrConf.getAccumulatorName.isDefined) solrConf.getAccumulatorName.get else "Records Written"
  sparkSession.sparkContext.register(acc, accName)
  SparkSolrAccumulatorContext.add(accName, acc.id)

  override def addBatch(batchId: Long, df: DataFrame): Unit = {
    if (batchId <= latestBatchId) {
      logger.info(s"Skipping already processed batch $batchId")
    } else {
      val rows = df.collect()
      if (rows.nonEmpty) {
        val schema: StructType = df.schema
        val solrClient = SolrSupport.getCachedCloudClient(zkhost)

        // build up a list of updates to send to the Solr Schema API
        val fieldsToAddToSolr = SolrRelation.getFieldsToAdd(schema, solrConf, solrVersion, dynamicSuffixes)

        if (fieldsToAddToSolr.nonEmpty) {
          SolrRelation.addFieldsForInsert(fieldsToAddToSolr, collection, solrClient)
        }

        val solrDocs = rows.toStream.map(row => SolrRelation.convertRowToSolrInputDocument(row, solrConf, uniqueKey))
        acc.add(solrDocs.length.toLong)
        SolrSupport.sendBatchToSolrWithRetry(zkhost, solrClient, collection, solrDocs, solrConf.commitWithin)
        logger.info(s"Written ${solrDocs.length} documents to Solr collection $collection from batch $batchId")
        latestBatchId = batchId
      }
    }
  }
} 
Example 47
Source File: KuduController.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package controllers

import org.apache.kudu.spark.kudu._
import org.apache.spark.sql.{ DataFrame, SparkSession }
import org.slf4j.{ Logger, LoggerFactory }

import scala.util.{ Failure, Try }

class KuduController(sparkSession: SparkSession, master: String) {

  val alogger: Logger = LoggerFactory.getLogger(this.getClass)

  def readData(table: String): Try[DataFrame] =  Try{
    sparkSession
      .sqlContext
      .read
      .options(Map("kudu.master" -> master, "kudu.table" -> table)).kudu
  }.recoverWith {
    case ex =>
      alogger.error(s"Exception ${ex.getMessage}\n ${ex.getStackTrace.mkString("\n")} ")
      Failure(ex)
  }
} 
Example 48
Source File: PhysicalDatasetController.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package controllers

import cats.syntax.show.toShow
import com.typesafe.config.Config
import daf.dataset.{ DatasetParams, FileDatasetParams, KuduDatasetParams }
import daf.filesystem.fileFormatShow
import org.apache.spark.sql.{ DataFrame, SparkSession }
import org.apache.spark.SparkConf
import org.slf4j.{ Logger, LoggerFactory }

class PhysicalDatasetController(sparkSession: SparkSession,
                                kuduMaster: String,
                                defaultLimit: Option[Int] = None,
                                defaultChunkSize: Int = 0) {

  lazy val kuduController = new KuduController(sparkSession, kuduMaster)
  lazy val hdfsController = new HDFSController(sparkSession)

  val logger: Logger = LoggerFactory.getLogger(this.getClass)

  private def addLimit(dataframe: DataFrame, limit: Option[Int]) = (limit, defaultLimit) match {
    case (None, None)                 => dataframe
    case (None, Some(value))          => dataframe.limit { value }
    case (Some(value), None)          => dataframe.limit { value }
    case (Some(value), Some(default)) => dataframe.limit { math.min(value, default) }
  }

  def kudu(params: KuduDatasetParams, limit: Option[Int] = None) = {
    logger.debug { s"Reading data from kudu table [${params.table}]" }
    kuduController.readData(params.table).map { addLimit(_, limit) }
  }

  def hdfs(params: FileDatasetParams, limit: Option[Int] = None) = {
    logger.debug { s"Reading data from hdfs at path [${params.path}]" }
    hdfsController.readData(params.path, params.format.show, params.param("separator")).map { addLimit(_, limit) }
  }

  def get(params: DatasetParams, limit: Option[Int]= None) = params match {
    case kuduParams: KuduDatasetParams => kudu(kuduParams, limit)
    case hdfsParams: FileDatasetParams => hdfs(hdfsParams, limit)
  }

}

object PhysicalDatasetController {

  private def getOptionalString(path: String, underlying: Config) = {
    if (underlying.hasPath(path)) {
      Some(underlying.getString(path))
    } else {
      None
    }
  }

  private def getOptionalInt(path: String, underlying: Config) = {
    if (underlying.hasPath(path)) {
      Some(underlying.getInt(path))
    } else {
      None
    }
  }

  val logger: Logger = LoggerFactory.getLogger(this.getClass)

  def apply(configuration: Config): PhysicalDatasetController = {

    val sparkConfig = new SparkConf()
    sparkConfig.set("spark.driver.memory", configuration.getString("spark.driver.memory"))

    val sparkSession = SparkSession.builder().master("local").config(sparkConfig).getOrCreate()

    val kuduMaster = configuration.getString("kudu.master")

    val defaultLimit = if (configuration hasPath "daf.row_limit") Some {
      configuration.getInt("daf.row_limit")
    } else None

    System.setProperty("sun.security.krb5.debug", "true")

    new PhysicalDatasetController(sparkSession, kuduMaster, defaultLimit)
  }
} 
Example 49
Source File: HDFSController.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package controllers

import com.databricks.spark.avro._
import org.apache.spark.sql.{ DataFrame, SparkSession }
import org.slf4j.{Logger, LoggerFactory}

import scala.util.{Failure, Try}

class HDFSController(sparkSession: SparkSession) {

  val alogger: Logger = LoggerFactory.getLogger(this.getClass)

  def readData(path: String, format: String, separator: Option[String]): Try[DataFrame] =  format match {
    case "csv" => Try {
      val pathFixAle = path + "/" + path.split("/").last + ".csv"
      alogger.debug(s"questo e' il path $pathFixAle")
      separator match {
        case None => sparkSession.read.csv(pathFixAle)
        case Some(sep) => sparkSession.read.format("csv")
          .option("sep", sep)
          .option("inferSchema", "true")
          .option("header", "true")
          .load(pathFixAle)
      }
    }
    case "parquet" => Try { sparkSession.read.parquet(path) }
    case "avro"    => Try { sparkSession.read.avro(path) }
    case unknown   => Failure { new IllegalArgumentException(s"Unsupported format [$unknown]") }
  }
} 
Example 50
Source File: HDFSBase.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package daf.util

import better.files.{ File, _ }
import daf.util.DataFrameClasses.{ Address, Person }
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.hdfs.{ HdfsConfiguration, MiniDFSCluster }
import org.apache.hadoop.test.PathUtils
import org.apache.spark.sql.{ SaveMode, SparkSession }
import org.scalatest.{ BeforeAndAfterAll, FlatSpec, Matchers }
import org.slf4j.LoggerFactory

import scala.util.{ Failure, Random, Try }

abstract class HDFSBase extends FlatSpec with Matchers with BeforeAndAfterAll {

  var miniCluster: Try[MiniDFSCluster] = Failure[MiniDFSCluster](new Exception)

  var fileSystem: Try[FileSystem] = Failure[FileSystem](new Exception)

  val sparkSession: SparkSession = SparkSession.builder().master("local").getOrCreate()

  val alogger = LoggerFactory.getLogger(this.getClass)

  val (testDataPath, confPath) = {
    val testDataPath = s"${PathUtils.getTestDir(this.getClass).getCanonicalPath}/MiniCluster"
    val confPath = s"$testDataPath/conf"
    (
      testDataPath.toFile.createIfNotExists(asDirectory = true, createParents = false),
      confPath.toFile.createIfNotExists(asDirectory = true, createParents = false)
    )
  }

  def pathAvro = "opendata/test.avro"
  def pathParquet = "opendata/test.parquet"
  def pathCsv = "opendata/test.csv"

  def getSparkSession = sparkSession

  override def beforeAll(): Unit = {

    val conf = new HdfsConfiguration()
    conf.setBoolean("dfs.permissions", true)
    System.clearProperty(MiniDFSCluster.PROP_TEST_BUILD_DATA)

    conf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, testDataPath.pathAsString)
    //FileUtil.fullyDelete(testDataPath.toJava)

    conf.set(s"hadoop.proxyuser.${System.getProperties.get("user.name")}.groups", "*")
    conf.set(s"hadoop.proxyuser.${System.getProperties.get("user.name")}.hosts", "*")

    val builder = new MiniDFSCluster.Builder(conf)
    miniCluster = Try(builder.build())
    fileSystem = miniCluster.map(_.getFileSystem)
    fileSystem.foreach(fs => {
      val confFile: File = confPath / "hdfs-site.xml"
      for { os <- confFile.newOutputStream.autoClosed } fs.getConf.writeXml(os)
    })

    writeDf()
  }

  override def afterAll(): Unit = {
    miniCluster.foreach(_.shutdown(true))
    val _ = testDataPath.parent.parent.delete(true)
    sparkSession.stop()
  }

  
  private def writeDf(): Unit = {
    import sparkSession.implicits._

    alogger.info(s"TestDataPath ${testDataPath.toJava.getAbsolutePath}")
    alogger.info(s"ConfPath ${confPath.toJava.getAbsolutePath}")
    val persons = (1 to 10).map(i => Person(s"Andy$i", Random.nextInt(85), Address("Via Ciccio Cappuccio")))
    val caseClassDS = persons.toDS()
    caseClassDS.write.format("parquet").mode(SaveMode.Overwrite).save(pathParquet)
    caseClassDS.write.format("com.databricks.spark.avro").mode(SaveMode.Overwrite).save(pathAvro)
    //writing directly the Person dataframe generates an exception
    caseClassDS.toDF.select("name", "age").write.format("csv").mode(SaveMode.Overwrite).option("header", "true").save(pathCsv)
  }
}

object DataFrameClasses {

  final case class Address(street: String)

  final case class Person(name: String, age: Int, address: Address)
} 
Example 51
Source File: KuduMiniCluster.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package it.teamdigitale.miniclusters

import org.apache.kudu.client.{KuduClient, MiniKuduCluster}
import org.apache.kudu.client.KuduClient.KuduClientBuilder
import org.apache.kudu.client.MiniKuduCluster.MiniKuduClusterBuilder
import org.apache.kudu.spark.kudu.KuduContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}
import org.apache.logging.log4j.LogManager

class KuduMiniCluster extends AutoCloseable {
  val alogger = LogManager.getLogger(this.getClass)

  var kuduMiniCluster: MiniKuduCluster = _
  var kuduClient: KuduClient = _
  var kuduContext: KuduContext = _
  var sparkSession = SparkSession.builder().appName(s"test-${System.currentTimeMillis()}").master("local[*]").getOrCreate()

  def start() {
    alogger.info("Starting KUDU mini cluster")

    System.setProperty(
      "binDir",
      s"${System.getProperty("user.dir")}/src/test/kudu_executables/${sun.awt.OSInfo.getOSType().toString.toLowerCase}"
    )


    kuduMiniCluster = new MiniKuduClusterBuilder()
      .numMasters(1)
      .numTservers(3)
      .build()

    val envMap = Map[String, String](("Xmx", "512m"))

    kuduClient = new KuduClientBuilder(kuduMiniCluster.getMasterAddresses).build()
    assert(kuduMiniCluster.waitForTabletServers(1))

    kuduContext = new KuduContext(kuduMiniCluster.getMasterAddresses, sparkSession.sparkContext)

  }

  override def close() {
    alogger.info("Ending KUDU mini cluster")
    kuduClient.shutdown()
    kuduMiniCluster.shutdown()
    sparkSession.close()
  }
}

  object KuduMiniCluster {

    def main(args: Array[String]): Unit = {

      try {
        val kudu = new KuduMiniCluster()
        kudu.start()

        println(s"MASTER KUDU ${kudu.kuduMiniCluster.getMasterAddresses}")
        while(true){

        }
      }
    }


  } 
Example 52
Source File: SparkSqlUtils.scala    From HadoopLearning   with MIT License 5 votes vote down vote up
package com.c503.utils

import java.io.{BufferedInputStream, BufferedReader, FileInputStream, InputStreamReader}
import java.nio.file.Path

import com.google.common.io.Resources
import org.apache.log4j.{Level, Logger}
import org.apache.mesos.Protos.Resource
import org.apache.spark.sql.SparkSession

import scala.io.Source


  def readSqlByPath(sqlPath: String) = {
    val buf = new StringBuilder
    val path = this.getPathByName(sqlPath)
    val file = Source.fromFile(path)
    for (line <- file.getLines) {
      buf ++= line + "\n"
    }
    file.close
    buf.toString()
  }


} 
Example 53
Source File: Spark_SQL_1.scala    From HadoopLearning   with MIT License 5 votes vote down vote up
package com.c503.sparksql

import com.c503.utils.SparkSqlUtils
import org.apache.spark.sql.SparkSession


object Spark_SQL_1 {

  def main(args: Array[String]): Unit = {

    SparkSqlUtils.offLogger()

    val spark = SparkSqlUtils.newSparkSession("spark_sql_1")

    val df = spark.read.json(SparkSqlUtils.getPathByName("person.json"))

    df.show()

    //聚合,排序求和,平均
    df.createOrReplaceTempView("person")
    val result_normal = spark.sql("" +
      "select distinct(a.sex), a.avg_age, a.min_age, a.max_age, a.max_subtract_min_age, a.count_name, b.count_name_than_30 from " +
      "(select " +
      "avg(age) as avg_age, " +
      "min(age) as min_age, " +
      "max(age) as max_age, " +
      "max(age) - min(age) as max_subtract_min_age, " +
      "count(name) as count_name, " +
      "sex " +
      "from person " +
      "group by sex) as a, " +
      "(select " +
      "count(name) as count_name_than_30 " +
      "from person " +
      "where age >= 30 " +
      "group by sex) b "
    )
    result_normal.show()

    val result_than_30 = spark.sql("" +
      "select " +
      "count(name) as count_name_than_30 " +
      "from person " +
      "where age >= 30 " +
      "group by sex"
    )
    result_than_30.show()


  }


  def baseHandler(): Unit = {
    val spark = SparkSqlUtils.newSparkSession("spark_sql_1")

    val df = spark.read.json(SparkSqlUtils.getPathByName("person.json"))

    //展示数据
    df.show()

    //查看schema
    df.printSchema()

    // 选择多列
    df.select(df("name"), df("age") + 1).show()

    // 条件过滤
    df.filter(df("age") > 30).show()

    // 分组聚合
    df.groupBy(df("age")).count().show()

    // 排序
    df.sort(df("name").desc).show()

    //多列排序
    df.sort(df("name").desc, df("age").asc).show()

    //对列进行重命名
    df.select(df("name").as("username"), df("age")).show()
  }

} 
Example 54
Source File: DolphinToSpark.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.csv

import java.util

import com.webank.wedatasphere.linkis.engine.configuration.SparkConfiguration
import com.webank.wedatasphere.linkis.storage.{domain => wds}
import com.webank.wedatasphere.linkis.storage.resultset.table.{TableMetaData, TableRecord}
import com.webank.wedatasphere.linkis.storage.resultset.{ResultSetReader}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types._



object DolphinToSpark {

  val bigDecimalPrecision = 20
  val bigDecimalScale = 10

  def createTempView(spark: SparkSession, tableName: String, res: String): Unit = {
    createTempView(spark, tableName, res, false)
  }

  def createTempView(spark: SparkSession, tableName: String, res: String, forceReplace: Boolean): Unit = {
    if (forceReplace || spark.sessionState.catalog.getTempView(tableName).isEmpty) {
      val reader = ResultSetReader.getTableResultReader(res)
      val metadata = reader.getMetaData.asInstanceOf[TableMetaData]
      val rowList = new util.ArrayList[Row]()
      var len = SparkConfiguration.DOLPHIN_LIMIT_LEN.getValue
      while (reader.hasNext && len > 0){
        rowList.add(Row.fromSeq(reader.getRecord.asInstanceOf[TableRecord].row))
        len = len -1
      }
      val df: DataFrame = spark.createDataFrame(rowList,metadataToSchema(metadata))
      df.createOrReplaceTempView(tableName)
    }
  }

  def metadataToSchema(metaData: TableMetaData):StructType = {
    new StructType(metaData.columns.map(field => StructField(field.columnName,toSparkType(field.dataType))))
  }

  def toSparkType(dataType:wds.DataType):DataType = dataType match {
    case wds.NullType => NullType
    //case wds.StringType | wds.CharType | wds.VarcharType | wds.StructType | wds.ListType | wds.ArrayType | wds.MapType => StringType
    case wds.BooleanType =>  BooleanType
    case wds.ShortIntType => ShortType
    case wds.IntType => IntegerType
    case wds.LongType => LongType
    case wds.FloatType => FloatType
    case wds.DoubleType  => DoubleType
    case wds.DecimalType => DecimalType(bigDecimalPrecision,bigDecimalScale)
    case wds.DateType => DateType
    //case wds.TimestampType => TimestampType
    case wds.BinaryType => BinaryType
    case _ => StringType
  }

} 
Example 55
Source File: CSTableParser.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.cs

import java.util.regex.Pattern

import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.cs.client.service.CSTableService
import com.webank.wedatasphere.linkis.cs.common.entity.metadata.CSTable
import com.webank.wedatasphere.linkis.cs.common.utils.CSCommonUtils
import com.webank.wedatasphere.linkis.engine.exception.ExecuteError
import com.webank.wedatasphere.linkis.engine.execute.EngineExecutorContext
import org.apache.commons.lang.StringUtils
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.csv.DolphinToSpark

import scala.collection.mutable.ArrayBuffer


  def getCSTable(csTempTable:String,  contextIDValueStr: String, nodeNameStr: String):CSTable = {
    CSTableService.getInstance().getUpstreamSuitableTable(contextIDValueStr, nodeNameStr, csTempTable)
  }

  def registerTempTable(csTable: CSTable):Unit = {
    val spark = SparkSession.builder().enableHiveSupport().getOrCreate()
    info(s"Start to create  tempView to sparkSession viewName(${csTable.getName}) location(${csTable.getLocation})")
    DolphinToSpark.createTempView(spark, csTable.getName, csTable.getLocation, true)
    info(s"Finished to create  tempView to sparkSession viewName(${csTable.getName}) location(${csTable.getLocation})")
  }
} 
Example 56
Source File: ExportData.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.imexport

import java.io.File

import com.webank.wedatasphere.linkis.common.utils.Logging
import org.apache.spark.sql.SparkSession
import org.json4s._
import org.json4s.jackson.JsonMethods._

import scala.io.Source


object ExportData extends Logging {
  implicit val formats = DefaultFormats

  def exportData(spark: SparkSession, dataInfo: String, destination: String): Unit = {
    exportDataFromFile(spark, parse(dataInfo).extract[Map[String, Any]], parse(destination).extract[Map[String, Any]])
  }

  def exportDataByFile(spark: SparkSession, dataInfoPath: String, destination: String): Unit = {
    val fileSource = Source.fromFile(dataInfoPath)
    val dataInfo = fileSource.mkString
    exportDataFromFile(spark, parse(dataInfo).extract[Map[String, Any]], parse(destination).extract[Map[String, Any]])
    fileSource.close()
    val file = new File(dataInfoPath)
    if (file.exists()) {
      file.delete()
    }
  }

  def exportDataFromFile(spark: SparkSession, dataInfo: Map[String, Any], dest: Map[String, Any]): Unit = {

    //Export dataFrame
    val df = spark.sql(getExportSql(dataInfo))
    //dest

    val pathType = LoadData.getMapValue[String](dest, "pathType", "share")
    val path = if ("share".equals(pathType))
      "file://" + LoadData.getMapValue[String](dest, "path")
    else
      "hdfs://" + LoadData.getMapValue[String](dest, "path")

    val hasHeader = LoadData.getMapValue[Boolean](dest, "hasHeader", false)
    val isCsv = LoadData.getMapValue[Boolean](dest, "isCsv", true)
    val isOverwrite = LoadData.getMapValue[Boolean](dest, "isOverwrite", true)
    val sheetName = LoadData.getMapValue[String](dest, "sheetName", "Sheet1")
    val fieldDelimiter = LoadData.getMapValue[String](dest, "fieldDelimiter", ",")

    if (isCsv) {
      CsvRelation.saveDFToCsv(spark, df, path, hasHeader, isOverwrite,option = Map("fieldDelimiter" -> fieldDelimiter))
    } else {
      df.write.format("com.webank.wedatasphere.spark.excel")
        .option("sheetName", sheetName)
        .option("useHeader", hasHeader)
        .mode("overwrite").save(path)
    }
    warn(s"Succeed to export data  to path:$path")
  }

  def getExportSql(dataInfo: Map[String, Any]): String = {
    val sql = new StringBuilder
    //dataInfo
    val database = LoadData.getMapValue[String](dataInfo, "database")
    val tableName = LoadData.getMapValue[String](dataInfo, "tableName")
    val isPartition = LoadData.getMapValue[Boolean](dataInfo, "isPartition", false)
    val partition = LoadData.getMapValue[String](dataInfo, "partition", "ds")
    val partitionValue = LoadData.getMapValue[String](dataInfo, "partitionValue", "1993-01-02")
    val columns = LoadData.getMapValue[String](dataInfo, "columns", "*")
    sql.append("select ").append(columns).append(" from ").append(s"$database.$tableName")
    if (isPartition) sql.append(" where ").append(s"$partition=$partitionValue")
    val sqlString = sql.toString()
    warn(s"export sql:$sqlString")
    sqlString
  }

} 
Example 57
Source File: ModelPersistence.scala    From reactive-machine-learning-systems   with MIT License 5 votes vote down vote up
package com.reactivemachinelearning

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{QuantileDiscretizer, VectorAssembler}
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.sql.SparkSession

object ModelPersistence extends App {

  val session = SparkSession.builder.appName("ModelPersistence").getOrCreate()

  val data = Seq(
    (0, 18.0, 0),
    (1, 20.0, 0),
    (2, 8.0, 1),
    (3, 5.0, 1),
    (4, 2.0, 0),
    (5, 21.0, 0),
    (6, 7.0, 1),
    (7, 18.0, 0),
    (8, 3.0, 1),
    (9, 22.0, 0),
    (10, 8.0, 1),
    (11, 2.0, 0),
    (12, 5.0, 1),
    (13, 4.0, 1),
    (14, 1.0, 0),
    (15, 11.0, 0),
    (16, 7.0, 1),
    (17, 15.0, 0),
    (18, 3.0, 1),
    (19, 20.0, 0))

  val instances = session.createDataFrame(data)
    .toDF("id", "seeds", "label")

  val discretizer = new QuantileDiscretizer()
    .setInputCol("seeds")
    .setOutputCol("discretized")
    .setNumBuckets(3)

  val assembler = new VectorAssembler()
    .setInputCols(Array("discretized"))
    .setOutputCol("features")

  val classifier = new LogisticRegression()
    .setMaxIter(5)

  val pipeline = new Pipeline()
    .setStages(Array(discretizer, assembler, classifier))

  val paramMaps = new ParamGridBuilder()
    .addGrid(classifier.regParam, Array(0.0, 0.1))
    .build()

  val evaluator = new BinaryClassificationEvaluator()

  val crossValidator = new CrossValidator()
    .setEstimator(pipeline)
    .setEvaluator(evaluator)
    .setNumFolds(2)
    .setEstimatorParamMaps(paramMaps)

  val model = crossValidator.fit(instances)

  model.write.overwrite().save("my-model")

  val persistedModel = CrossValidatorModel.load("./my-model")
  println(s"UID: ${persistedModel.uid}")

} 
Example 58
Source File: SparkTest.scala    From Spark-Scala-Maven-Example   with MIT License 5 votes vote down vote up
package net.martinprobson.spark

import java.io.InputStream

import grizzled.slf4j.Logging
import org.apache.spark.sql.SparkSession
import org.scalatest.{Outcome, fixture}

class SparkTest extends fixture.FunSuite with Logging {

  type FixtureParam = SparkSession

  def withFixture(test: OneArgTest): Outcome = {
    val sparkSession = SparkSession.builder
      .appName("Test-Spark-Local")
      .master("local[2]")
      .getOrCreate()
    try {
      withFixture(test.toNoArgTest(sparkSession))
    } finally sparkSession.stop
  }

  test("empsRDD rowcount") { spark =>
    val empsRDD = spark.sparkContext.parallelize(getInputData("/data/employees.json"), 5)
    assert(empsRDD.count === 1000)
  }

  test("titlesRDD rowcount") { spark =>
    val titlesRDD = spark.sparkContext.parallelize(getInputData("/data/titles.json"), 5)
    assert(titlesRDD.count === 1470)
  }

  private def getInputData(name: String): Seq[String] = {
    val is: InputStream = getClass.getResourceAsStream(name)
    scala.io.Source.fromInputStream(is).getLines.toSeq
  }
} 
Example 59
Source File: Main.scala    From spark-ml-serving   with Apache License 2.0 5 votes vote down vote up
import io.hydrosphere.spark_ml_serving.LocalPipelineModel
import io.hydrosphere.spark_ml_serving.common.{LocalData, LocalDataColumn}
import org.apache.spark.SparkConf
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg._
import org.apache.spark.sql.SparkSession

object Train extends App {

  val conf = new SparkConf()
    .setMaster("local[2]")
    .setAppName("example")
    .set("spark.ui.enabled", "false")

  val session: SparkSession = SparkSession.builder().config(conf).getOrCreate()

  val df = session.createDataFrame(Seq(
            (0, Array("a", "b", "c")),
            (1, Array("a", "b", "b", "c", "a"))
         )).toDF("id", "words")

   val cv = new CountVectorizer()
     .setInputCol("words")
     .setOutputCol("features")
     .setVocabSize(3)
     .setMinDF(2)

   val pipeline = new Pipeline().setStages(Array(cv))

   val model = pipeline.fit(df)
   model.write.overwrite().save("../target/test_models/2.0.2/countVectorizer")
}

object Serve extends App {

  import LocalPipelineModel._

  val model = LocalPipelineModel .load("../target/test_models/2.0.2/countVectorizer")

  val data = LocalData(List(LocalDataColumn("words", List(
    List("a", "b", "d"),
    List("a", "b", "b", "b")

  ))))
  val result = model.transform(data)

  println(result)
} 
Example 60
Source File: LocalLDAModel.scala    From spark-ml-serving   with Apache License 2.0 5 votes vote down vote up
package io.hydrosphere.spark_ml_serving.clustering

import io.hydrosphere.spark_ml_serving.TypedTransformerConverter
import io.hydrosphere.spark_ml_serving.common._
import io.hydrosphere.spark_ml_serving.common.utils.{DataUtils, ParamUtils}
import org.apache.spark.ml.clustering.{LocalLDAModel => SparkLocalLDA}
import org.apache.spark.mllib.clustering.{LocalLDAModel => OldSparkLocalLDA}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.sql.SparkSession
import DataUtils._
import scala.reflect.runtime.universe

class LocalLDAModel(override val sparkTransformer: SparkLocalLDA)
  extends LocalTransformer[SparkLocalLDA] {

  lazy val oldModel: OldSparkLocalLDA = {
    val mirror     = universe.runtimeMirror(sparkTransformer.getClass.getClassLoader)
    val parentTerm = universe.typeOf[SparkLocalLDA].decl(universe.TermName("oldLocalModel")).asTerm
    mirror.reflect(sparkTransformer).reflectField(parentTerm).get.asInstanceOf[OldSparkLocalLDA]
  }

  override def transform(localData: LocalData): LocalData = {
    localData.column(sparkTransformer.getFeaturesCol) match {
      case Some(column) =>
        val newData = column.data.mapToMlLibVectors.map(oldModel.topicDistribution(_).toList)
        localData.withColumn(
          LocalDataColumn(
            sparkTransformer.getTopicDistributionCol,
            newData
          )
        )
      case None => localData
    }
  }
}

object LocalLDAModel
  extends SimpleModelLoader[SparkLocalLDA]
  with TypedTransformerConverter[SparkLocalLDA] {

  override def build(metadata: Metadata, data: LocalData): SparkLocalLDA = {
    val topics = DataUtils.constructMatrix(
      data.column("topicsMatrix").get.data.head.asInstanceOf[Map[String, Any]]
    )
    val gammaShape = data.column("gammaShape").get.data.head.asInstanceOf[java.lang.Double]
    val topicConcentration =
      data.column("topicConcentration").get.data.head.asInstanceOf[java.lang.Double]
    val docConcentration = DataUtils.constructVector(
      data.column("docConcentration").get.data.head.asInstanceOf[Map[String, Any]]
    )
    val vocabSize = data.column("vocabSize").get.data.head.asInstanceOf[java.lang.Integer]

    val oldLdaCtor = classOf[OldSparkLocalLDA].getDeclaredConstructor(
      classOf[Matrix],
      classOf[Vector],
      classOf[Double],
      classOf[Double]
    )
    val oldLDA = oldLdaCtor.newInstance(
      Matrices.fromML(topics),
      Vectors.fromML(docConcentration),
      topicConcentration,
      gammaShape
    )

    val ldaCtor = classOf[SparkLocalLDA].getDeclaredConstructor(
      classOf[String],
      classOf[Int],
      classOf[OldSparkLocalLDA],
      classOf[SparkSession]
    )

    val lda = ldaCtor.newInstance(metadata.uid, vocabSize, oldLDA, null)

    ParamUtils.set(lda, lda.optimizer, metadata)
    ParamUtils.set(lda, lda.keepLastCheckpoint, metadata)
    ParamUtils.set(lda, lda.seed, metadata)
    ParamUtils.set(lda, lda.featuresCol, metadata)
    ParamUtils.set(lda, lda.learningDecay, metadata)
    ParamUtils.set(lda, lda.checkpointInterval, metadata)
    ParamUtils.set(lda, lda.learningOffset, metadata)
    ParamUtils.set(lda, lda.maxIter, metadata)
    ParamUtils.set(lda, lda.k, metadata)
    lda
  }

  override implicit def toLocal(sparkTransformer: SparkLocalLDA): LocalTransformer[SparkLocalLDA] =
    new LocalLDAModel(sparkTransformer)

} 
Example 61
Source File: GenericTestSpec.scala    From spark-ml-serving   with Apache License 2.0 5 votes vote down vote up
package io.hydrosphere.spark_ml_serving

import io.hydrosphere.spark_ml_serving.common.LocalData
import org.apache.spark.SparkConf
import org.apache.spark.ml.linalg.{Matrix, Vector}
import org.apache.spark.mllib.linalg.{Matrix => OldMatrix, Vector => OldVector}
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.scalatest.{BeforeAndAfterAll, FunSpec}

trait GenericTestSpec extends FunSpec with BeforeAndAfterAll {
  val conf = new SparkConf()
    .setMaster("local[2]")
    .setAppName("test")
    .set("spark.ui.enabled", "false")

  val session: SparkSession = SparkSession.builder().config(conf).getOrCreate()

  def modelPath(modelName: String): String = s"./target/test_models/${session.version}/$modelName"

  def test(
    name: String,
    data: => DataFrame,
    steps: => Seq[PipelineStage],
    columns: => Seq[String],
    accuracy: Double = 0.01
  ) = {
    val path = modelPath(name.toLowerCase())
    var validation = LocalData.empty
    var localPipelineModel = Option.empty[LocalPipelineModel]

    it("should train") {
      val pipeline = new Pipeline().setStages(steps.toArray)
      val pipelineModel = pipeline.fit(data)
      validation = LocalData.fromDataFrame(pipelineModel.transform(data))
      pipelineModel.write.overwrite().save(path)
    }

    it("should load local version") {
      localPipelineModel = Some(LocalPipelineModel.load(path))
      assert(localPipelineModel.isDefined)
    }

    it("should transform LocalData") {
      val localData = LocalData.fromDataFrame(data)
      val model = localPipelineModel.get
      val result = model.transform(localData)
      columns.foreach { col =>
        val resCol = result
          .column(col)
          .getOrElse(throw new IllegalArgumentException("Result column is absent"))
        val valCol = validation
          .column(col)
          .getOrElse(throw new IllegalArgumentException("Validation column is absent"))
        resCol.data.zip(valCol.data).foreach {
          case (r: Seq[Number @unchecked], v: Seq[Number @unchecked]) if r.head.isInstanceOf[Number] && r.head.isInstanceOf[Number] =>
            r.zip(v).foreach {
              case (ri, vi) =>
                assert(ri.doubleValue() - vi.doubleValue() <= accuracy, s"$ri - $vi > $accuracy")
            }
          case (r: Number, v: Number) =>
            assert(r.doubleValue() - v.doubleValue() <= accuracy, s"$r - $v > $accuracy")
          case (r, n) =>
            assert(r === n)
        }
        result.column(col).foreach { resData =>
          resData.data.foreach { resRow =>
            if (resRow.isInstanceOf[Seq[_]]) {
              assert(resRow.isInstanceOf[List[_]], resRow)
            } else if (resRow.isInstanceOf[Vector] || resRow.isInstanceOf[OldVector] || resRow
              .isInstanceOf[Matrix] || resRow.isInstanceOf[OldMatrix]) {
              assert(false, s"SparkML type detected. Column: $col, value: $resRow")
            }
          }
        }
      }
    }
  }

  def modelTest(
    data: => DataFrame,
    steps: => Seq[PipelineStage],
    columns: => Seq[String],
    accuracy: Double = 0.01
  ): Unit = {
    lazy val name = steps.map(_.getClass.getSimpleName).foldLeft("") {
      case ("", b) => b
      case (a, b) => a + "-" + b
    }

    describe(name) {
      test(name, data, steps, columns, accuracy)
    }
  }
} 
Example 62
Source File: SparkSessionConfiguration.scala    From spark-structured-streaming-examples   with Apache License 2.0 5 votes vote down vote up
package com.phylosoft.spark.learning

import java.io.File

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

trait SparkSessionConfiguration {

  val settings: Traversable[(String, String)]

  private val warehouseLocation = "file:///" + new File("spark-warehouse").getAbsolutePath.toString

  private lazy val conf = new SparkConf()
    .set("spark.sql.warehouse.dir", warehouseLocation)
    .set("spark.sql.session.timeZone", "UTC")
    .set("spark.sql.shuffle.partitions", "4") // keep the size of shuffles small
    .set("spark.sql.cbo.enabled", "true")
    .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .set("spark.kryoserializer.buffer", "24")
    .setAll(settings)

  implicit lazy val spark: SparkSession = SparkSession.builder
    .config(conf)
    .enableHiveSupport()
    .getOrCreate()

} 
Example 63
Source File: UserActionsRateSource.scala    From spark-structured-streaming-examples   with Apache License 2.0 5 votes vote down vote up
package com.phylosoft.spark.learning.sql.streaming.source.rate

import org.apache.spark.sql.functions.{col, lit, pmod, rand}
import org.apache.spark.sql.{DataFrame, SparkSession}


class UserActionsRateSource(val spark: SparkSession,
                            val rowsPerSecond: String = "5",
                            val numPartitions: String = "1")
  extends RateSource {

  def loadUserActions(): DataFrame = {
    readStream()
      .where((rand() * 100).cast("integer") < 30) // 30 out of every 100 user actions
      .select(pmod(col("value"), lit(9)).as("userId"), col("timestamp").as("actionTime"))
  }

} 
Example 64
Source File: AdRateSources.scala    From spark-structured-streaming-examples   with Apache License 2.0 5 votes vote down vote up
package com.phylosoft.spark.learning.sql.streaming.source.rate

import org.apache.spark.sql.functions.{col, rand}
import org.apache.spark.sql.{DataFrame, SparkSession}

class AdRateSources(val spark: SparkSession,
                    val rowsPerSecond: String = "5",
                    val numPartitions: String = "1")
  extends RateSource {

  def loadImpressions(): DataFrame = {
    readStream()
      .select(
        col("value").as("adId"),
        col("timestamp").as("impressionTime"))
  }

  def loadClicks(): DataFrame = {
    readStream()
      .where((rand() * 100).cast("integer") < 10) // 10 out of every 100 impressions result in a click
      .select((col("value") - 50).as("adId"), col("timestamp").as("clickTime")) // -50 so that a click with same id as impression is generated much later (i.e. delayed data).
      .where("adId > 0")
  }

} 
Example 65
Source File: RateSource.scala    From spark-structured-streaming-examples   with Apache License 2.0 5 votes vote down vote up
package com.phylosoft.spark.learning.sql.streaming.source.rate

import com.phylosoft.spark.learning.sql.streaming.source.StreamingSource
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, SparkSession}

trait RateSource
  extends StreamingSource {

  val spark: SparkSession
  val rowsPerSecond: String
  val numPartitions: String

  override def readStream(): DataFrame = {
    spark.readStream
      .format("rate")
      .option("rowsPerSecond", rowsPerSecond)
      .option("numPartitions", numPartitions)
      .load()
      .select(col("*"))
  }

} 
Example 66
Source File: TimeSeriesGenerator.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.clock.{ RandomClock, UniformClock }
import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.SparkContext
import org.apache.spark.sql.{ DFConverter, SparkSession }
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._

import scala.util
import scala.util.Random


class TimeSeriesGenerator(
  @transient val sc: SparkContext,
  begin: Long,
  end: Long,
  frequency: Long
)(
  uniform: Boolean = true,
  ids: Seq[Int] = Seq(1),
  ratioOfCycleSize: Double = 1.0,
  columns: Seq[(String, (Long, Int, util.Random) => Double)] = Seq.empty,
  numSlices: Int = sc.defaultParallelism,
  seed: Long = System.currentTimeMillis()
) extends Serializable {
  require(ids.nonEmpty, s"ids must be non-empty.")

  private val schema = {
    var _schema = Schema(
      "time" -> LongType,
      "id" -> IntegerType
    )
    columns.foreach {
      case (columnName, _) => _schema = _schema.add(columnName, DoubleType)
    }
    _schema
  }

  def generate(): TimeSeriesRDD = {
    val cycles = if (uniform) {
      new UniformClock(sc, begin = begin, end = end, frequency = frequency, offset = 0L, endInclusive = true)
    } else {
      new RandomClock(sc, begin = begin, end = end, frequency = frequency, offset = 0L,
        seed = seed, endInclusive = true)
    }
    val cycleSize = math.max(math.ceil(ids.size * ratioOfCycleSize), 1).toInt

    val orderedRdd = cycles.asOrderedRDD(numSlices).mapPartitionsWithIndexOrdered {
      case (partIndex, iter) =>
        val rand = new Random(seed + partIndex)
        def getCycle(time: Long): Seq[InternalRow] = {
          val randIds = rand.shuffle(ids).take(cycleSize)
          randIds.map {
            id =>
              val values = columns.map {
                case (_, fn) =>
                  fn(time, id, rand)
              }
              InternalRow.fromSeq(time +: id +: values)
          }
        }
        iter.map(_._2).flatMap{ case t => getCycle(t).map((t, _)) }
    }

    val df = DFConverter.toDataFrame(orderedRdd, schema)
    TimeSeriesRDD.fromDFWithRanges(df, orderedRdd.getPartitionRanges.toArray)
  }
} 
Example 67
Source File: TimeType.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.time.types

import com.twosigma.flint.FlintConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.{ SQLContext, SparkSession, types }

trait TimeType {
  
  def roundDownPrecision(nanosSinceEpoch: Long): Long
}

object TimeType {
  case object LongType extends TimeType {
    override def internalToNanos(value: Long): Long = value
    override def nanosToInternal(nanos: Long): Long = nanos
    override def roundDownPrecision(nanos: Long): Long = nanos
  }

  // Spark sql represents timestamp as microseconds internally
  case object TimestampType extends TimeType {
    override def internalToNanos(value: Long): Long = value * 1000
    override def nanosToInternal(nanos: Long): Long = nanos / 1000
    override def roundDownPrecision(nanos: Long): Long = nanos - nanos % 1000
  }

  def apply(timeType: String): TimeType = {
    timeType match {
      case "long" => LongType
      case "timestamp" => TimestampType
      case _ => throw new IllegalAccessException(s"Unsupported time type: ${timeType}. " +
        s"Only `long` and `timestamp` are supported.")
    }
  }

  def apply(sqlType: types.DataType): TimeType = {
    sqlType match {
      case types.LongType => LongType
      case types.TimestampType => TimestampType
      case _ => throw new IllegalArgumentException(s"Unsupported time type: ${sqlType}")
    }
  }

  def get(sparkSession: SparkSession): TimeType = {
    TimeType(sparkSession.conf.get(
      FlintConf.TIME_TYPE_CONF, FlintConf.TIME_TYPE_DEFAULT
    ))
  }
} 
Example 68
Source File: TimeTypeSuite.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.FlintConf
import org.apache.spark.sql.SparkSession

trait TimeTypeSuite {
  def withTimeType[T](conf: String*)(block: => Unit): Unit = {
    val spark = SparkSession.builder().getOrCreate()

    conf.foreach {
      conf =>
        val savedConf = spark.conf.getOption(FlintConf.TIME_TYPE_CONF)
        spark.conf.set(FlintConf.TIME_TYPE_CONF, conf)
        try {
          block
        } finally {
          savedConf match {
            case None => spark.conf.unset(FlintConf.TIME_TYPE_CONF)
            case Some(oldConf) => spark.conf.set(FlintConf.TIME_TYPE_CONF, oldConf)
          }
        }
    }
  }

  def withAllTimeType(block: => Unit): Unit = {
    withTimeType("long", "timestamp")(block)
  }
} 
Example 69
Source File: SKRSpec.scala    From spark-kafka-writer   with Apache License 2.0 5 votes vote down vote up
package com.github.benfradet.spark.kafka.writer

import java.util.concurrent.atomic.AtomicInteger

import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer}
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.kafka010.{ConsumerStrategies, KafkaUtils, LocationStrategies}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.scalatest.concurrent.Eventually
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

case class Foo(a: Int, b: String)

trait SKRSpec
  extends AnyWordSpec
  with Matchers
  with BeforeAndAfterEach
  with BeforeAndAfterAll
  with Eventually {

  val sparkConf = new SparkConf()
    .setMaster("local[1]")
    .setAppName(getClass.getSimpleName)

  var ktu: KafkaTestUtils = _
  override def beforeAll(): Unit = {
    ktu = new KafkaTestUtils
    ktu.setup()
  }
  override def afterAll(): Unit = {
    SKRSpec.callbackTriggerCount.set(0)
    if (ktu != null) {
      ktu.tearDown()
      ktu = null
    }
  }

  var topic: String = _
  var ssc: StreamingContext = _
  var spark: SparkSession = _
  override def afterEach(): Unit = {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
    if (spark != null) {
      spark.stop()
      spark = null
    }
  }
  override def beforeEach(): Unit = {
    ssc = new StreamingContext(sparkConf, Seconds(1))
    spark = SparkSession.builder
      .config(sparkConf)
      .getOrCreate()
    topic = s"topic-${Random.nextInt()}"
    ktu.createTopics(topic)
  }

  def collect(ssc: StreamingContext, topic: String): ArrayBuffer[String] = {
    val kafkaParams = Map(
      "bootstrap.servers" -> ktu.brokerAddress,
      "auto.offset.reset" -> "earliest",
      "key.deserializer" -> classOf[StringDeserializer],
      "value.deserializer" -> classOf[StringDeserializer],
      "group.id" -> "test-collect"
    )
    val results = new ArrayBuffer[String]
    KafkaUtils.createDirectStream[String, String](
      ssc,
      LocationStrategies.PreferConsistent,
      ConsumerStrategies.Subscribe[String, String](Set(topic), kafkaParams)
    ).map(_.value())
      .foreachRDD { rdd =>
        results ++= rdd.collect()
        ()
      }
    results
  }

  val producerConfig = Map(
    "bootstrap.servers" -> "127.0.0.1:9092",
    "key.serializer" -> classOf[StringSerializer].getName,
    "value.serializer" -> classOf[StringSerializer].getName
  )
}

object SKRSpec {
  val callbackTriggerCount = new AtomicInteger()
} 
Example 70
Source File: TokenizerSuite.scala    From spark-nkp   with Apache License 2.0 5 votes vote down vote up
package com.github.uosdmlab.nkp

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{CountVectorizer, IDF}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfter, FunSuite}


class TokenizerSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {

  private var tokenizer: Tokenizer = _

  private val spark: SparkSession =
    SparkSession.builder()
      .master("local[2]")
      .appName("Tokenizer Suite")
      .getOrCreate

  spark.sparkContext.setLogLevel("WARN")

  import spark.implicits._

  override protected def afterAll(): Unit = {
    try {
      spark.stop
    } finally {
      super.afterAll()
    }
  }

  before {
    tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")
  }

  private val df = spark.createDataset(
    Seq(
      "아버지가방에들어가신다.",
      "사랑해요 제플린!",
      "스파크는 재밌어",
      "나는야 데이터과학자",
      "데이터야~ 놀자~"
    )
  ).toDF("text")

  test("Default parameters") {
    assert(tokenizer.getFilter sameElements Array.empty[String])
  }

  test("Basic operation") {
    val words = tokenizer.transform(df)

    assert(df.count == words.count)
    assert(words.schema.fieldNames.contains(tokenizer.getOutputCol))
  }

  test("POS filter") {
    val nvTokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("nvWords")
      .setFilter("N", "V")

    val words = tokenizer.transform(df).join(nvTokenizer.transform(df), "text")

    assert(df.count == words.count)
    assert(words.schema.fieldNames.contains(nvTokenizer.getOutputCol))
    assert(words.where(s"SIZE(${tokenizer.getOutputCol}) < SIZE(${nvTokenizer.getOutputCol})").count == 0)
  }

  test("TF-IDF pipeline") {
    tokenizer.setFilter("N")

    val cntVec = new CountVectorizer()
      .setInputCol("words")
      .setOutputCol("tf")

    val idf = new IDF()
      .setInputCol("tf")
      .setOutputCol("tfidf")

    val pipe = new Pipeline()
      .setStages(Array(tokenizer, cntVec, idf))

    val pipeModel = pipe.fit(df)

    val result = pipeModel.transform(df)

    assert(result.count == df.count)

    val fields = result.schema.fieldNames
    assert(fields.contains(tokenizer.getOutputCol))
    assert(fields.contains(cntVec.getOutputCol))
    assert(fields.contains(idf.getOutputCol))

    result.show
  }
} 
Example 71
Source File: FeatureCrossSelectorExample.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl.feature.examples

import org.apache.spark.SparkConf
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.operator.{VarianceSelector, VectorCartesian}
import org.apache.spark.sql.SparkSession

object FeatureCrossSelectorExample {

  def main(args: Array[String]): Unit = {

    val conf = new SparkConf()

    val input = conf.get("spark.input.path", "data/a9a/a9a_123d_train_trans.libsvm")
    val numFeatures = conf.get("spark.num.feature", "123")
    val twoOrderNumFeatures = conf.getInt("spark.two.order.num.feature", 123)
    val threeOrderNumFeatures = conf.getInt("spark.three.order.num.feature", 123)

    val spark = SparkSession.builder().master("local").config(conf).getOrCreate()

    val data = spark.read.format("libsvm")
      .option("numFeatures", numFeatures)
      .load(input)
      .persist()

    val cartesian = new VectorCartesian()
      .setInputCols(Array("features", "features"))
      .setOutputCol("f_f")

    val selector = new VarianceSelector()
      .setFeaturesCol("f_f")
      .setOutputCol("selected_f_f")
      .setNumTopFeatures(twoOrderNumFeatures)

    val cartesian2 = new VectorCartesian()
      .setInputCols(Array("features", "selected_f_f"))
      .setOutputCol("f_f_f")

    val selector2 = new VarianceSelector()
      .setFeaturesCol("f_f_f")
      .setOutputCol("selected_f_f_f")
      .setNumTopFeatures(threeOrderNumFeatures)

    val assembler = new VectorAssembler()
      .setInputCols(Array("features", "selected_f_f", "selected_f_f_f"))
      .setOutputCol("assembled_features")

    val pipeline = new Pipeline()
      .setStages(Array(cartesian, selector, cartesian2, selector2, assembler))

    val crossDF = pipeline.fit(data).transform(data).persist()
    data.unpersist()
    crossDF.drop("f_f", "f_f_f", "selected_f_f", "selected_f_f_f")
    crossDF.show(1)

    val splitDF = crossDF.randomSplit(Array(0.9, 0.1))

    val trainDF = splitDF(0).persist()
    val testDF = splitDF(1).persist()

    val originalLR = new LogisticRegression()
      .setFeaturesCol("features")
      .setLabelCol("label")
      .setMaxIter(20)
      .setRegParam(0.01)

    val originalPredictions = originalLR.fit(trainDF).transform(testDF)
    originalPredictions.show(1)
    val originalEvaluator = new BinaryClassificationEvaluator()
      .setLabelCol("label")
      .setRawPredictionCol("rawPrediction")
      .setMetricName("areaUnderROC")
    val originalAUC = originalEvaluator.evaluate(originalPredictions)
    println(s"original features auc: $originalAUC")

    val crossLR = new LogisticRegression()
      .setFeaturesCol("assembled_features")
      .setLabelCol("label")
      .setMaxIter(20)
      .setRegParam(0.01)

    val crossPredictions = crossLR.fit(trainDF).transform(testDF)
    crossPredictions.show(1)
    val crossEvaluator = new BinaryClassificationEvaluator()
      .setLabelCol("label")
      .setRawPredictionCol("rawPrediction")
      .setMetricName("areaUnderROC")
    val crossAUC = crossEvaluator.evaluate(crossPredictions)
    println(s"cross features auc: $crossAUC")

    spark.close()
  }
} 
Example 72
Source File: DataLoader.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl.feature

import org.apache.spark.sql.{DataFrame, SparkSession}

abstract class DataLoader(ss: SparkSession) {
  def load(input: String, separator: String): DataFrame

  def load(input: String): DataFrame = load(input, " ")
}

case class LibSVMDataLoader(ss: SparkSession) extends DataLoader(ss) {
  override def load(input: String, separator: String): DataFrame = {
    ss.read.format("libsvm").load(input)
  }
}

case class CSVDataLoader(ss: SparkSession) extends DataLoader(ss) {
  override def load(input: String, separator: String): DataFrame = {
    ss.read.csv(input)
  }
}

case class JSONDataLoader(ss: SparkSession) extends DataLoader(ss) {
  override def load(input: String, separator: String): DataFrame = {
    ss.read.json(input)
  }
}

case class DocumentDataLoader(ss: SparkSession) extends DataLoader(ss) {
  override def load(input: String, separator: String): DataFrame = {
    ss.createDataFrame(
      ss.sparkContext.textFile(input).map(Tuple1.apply)
    ).toDF("sentence")
  }
}

case class LabeledDocumentDataLoader(ss: SparkSession) extends DataLoader(ss) {
  override def load(input: String, separator: String): DataFrame = {
    require(separator.equals(","),
      "the label and sentence should be separated by comma")
    ss.createDataFrame(
      ss.sparkContext.textFile(input)
        .map { line =>
          val splits = line.split(separator)
          (splits(0), splits(1))
        })
      .toDF("label", "sentence")
  }

  override def load(input: String): DataFrame = load(input, ",")
}

case class SimpleDataLoader(ss: SparkSession) extends DataLoader(ss) {
  override def load(input: String, separator: String): DataFrame = {
    ss.createDataFrame(
      ss.sparkContext.textFile(input)
        .map(_.split(separator)).map(Tuple1.apply)
    ).toDF("features")
  }
}

case class LabeledSimpleDataLoader(ss: SparkSession) extends DataLoader(ss) {
  override def load(input: String, separator: String): DataFrame = {
    ss.createDataFrame(
      ss.sparkContext.textFile(input)
        .map { line =>
          val splits = line.split(separator)
          (splits.head, splits.tail)
        }
    ).toDF("label", "features")
  }
}


object DataLoader {

  def load(ss: SparkSession,
           format: String,
           input: String,
           separator: String = " "): DataFrame = {
    format match {
      case "libsvm" => LibSVMDataLoader(ss).load(input)
      case "csv" => CSVDataLoader(ss).load(input)
      case "json" => JSONDataLoader(ss).load(input)
      case "document" => DocumentDataLoader(ss).load(input, separator)
      case "label-document" => LabeledDocumentDataLoader(ss).load(input, separator)
      case "simple" => SimpleDataLoader(ss).load(input, separator)
      case "label-simple" => LabeledSimpleDataLoader(ss).load(input, separator)
      case _ => SimpleDataLoader(ss).load(input, separator)
    }
  }

} 
Example 73
Source File: Sampler.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl.feature.preprocess

import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

import scala.util.Random


class Sampler(fraction: Double,
              override val uid: String,
              seed: Int = Random.nextInt)
  extends Transformer {

  def this(fraction: Double) = this(fraction, Identifiable.randomUID("sampler"))

  
  final def getOutputCol: String = $(inputCol)

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.sample(false, fraction, seed).toDF
  }

  override def transformSchema(schema: StructType): StructType = {
    schema
  }

  override def copy(extra: ParamMap): Sampler = defaultCopy(extra)
}

object Sampler {

  def main(args: Array[String]): Unit = {
    val ss = SparkSession
      .builder
      .master("local")
      .appName("preprocess")
      .getOrCreate()

    val training = ss.read.format("libsvm")
      .load("/Users/jiangjiawei/dev-tools/spark-2.2.0/data/mllib/sample_libsvm_data.txt")

    println(training.count)

    val sampler = new Sampler(0.5)
      .setInputCol("features")

    val pipeline = new Pipeline()
      .setStages(Array(sampler))

    val model = pipeline.fit(training)

    val test = ss.read.format("libsvm")
      .load("/Users/jiangjiawei/dev-tools/spark-2.2.0/data/mllib/sample_libsvm_data.txt")

    model.transform(test).select("*")
      .collect()
      .foreach { case Row(label: Double, vector: Vector) =>
        println(s"($label, " +
          s"${vector.toSparse.indices.mkString("[", ",", "]")}, " +
          s"${vector.toSparse.values.mkString("[", ",", "]")}")
      }

    ss.stop()
  }
} 
Example 74
Source File: FPreprocess.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl.feature.preprocess

import com.tencent.angel.spark.automl.AutoConf
import com.tencent.angel.spark.automl.feature.DataLoader
import com.tencent.angel.spark.automl.utils.ArgsUtil
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.sql.SparkSession

import scala.collection.mutable.ArrayBuffer


object FPreprocess {

  def main(args: Array[String]): Unit = {

    val params = ArgsUtil.parse(args)
    val master = params.getOrElse("master", "yarn")
    val deploy = params.getOrElse("deploy-mode", "cluster")
    val input = params.getOrElse("input", "")
    val inputSeparator = params.getOrElse(AutoConf.Preprocess.ML_DATA_SPLITOR,
      AutoConf.Preprocess.DEFAULT_ML_DATA_SPLITOR)
    val inputFormat = params.getOrElse(AutoConf.Preprocess.ML_DATA_INPUT_FORMAT,
      AutoConf.Preprocess.DEFAULT_ML_DATA_INPUT_FORMAT)
    val inputType = params.getOrElse(AutoConf.Preprocess.INPUT_TYPE,
      AutoConf.Preprocess.DEFAULT_INPUT_TYPE)
    val sampleRate = params.getOrElse(AutoConf.Preprocess.SAMPLE_RATE,
      AutoConf.Preprocess.DEFAULT_SAMPLE_RATE).toDouble
    val imbalanceSampleRate = params.getOrElse(AutoConf.Preprocess.IMBALANCE_SAMPLE,
      AutoConf.Preprocess.DEFAULT_IMBALANCE_SAMPLE)
    val hasTokenizer = if (inputFormat.equals("document")) true else false
    val hasStopWordsRemover = if (inputFormat.equals("document")) true else false

    val ss = SparkSession
      .builder
      .master(master + "-" + deploy)
      .appName("preprocess")
      .getOrCreate()

    var training = DataLoader.load(ss, inputFormat, input, inputSeparator)

    var components = new ArrayBuffer[PipelineStage]

    if (sampleRate > 0 & sampleRate < 1.0)
      Components.addSampler(components,
        "features", sampleRate)

    if (hasTokenizer)
      Components.addTokenizer(components,
        "sentence", "words")

    if (hasStopWordsRemover)
      Components.addStopWordsRemover(components,
        "words", "filterWords")

    val pipeline = new Pipeline()
      .setStages(components.toArray)

    val model = pipeline.fit(training)

    ss.stop()
  }

} 
Example 75
Source File: DataUtils.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl.utils

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SparkSession}

object DataUtils {

  def parse(ss: SparkSession,
            schema: StructType,
            X: Array[Vector],
            Y: Array[Double]): DataFrame = {
    require(X.size == Y.size,
      "The size of configurations should be equal to the size of rewards.")
    ss.createDataFrame(
      Y.zip(X)).toDF("label", "features")
  }

  def parse(ss: SparkSession,
            schema: StructType,
            X: Vector): DataFrame = {
    parse(ss, schema, Array(X), Array(0))
  }

  def toBreeze(values: Array[Double]): BDV[Double] = {
    new BDV[Double](values)
  }

  def toBreeze(vector: Vector): BDV[Double] = vector match {
    case sv: SparseVector => new BDV[Double](vector.toDense.values)
    case dv: DenseVector => new BDV[Double](dv.values)
  }

  def toBreeze(X: Array[Vector]): BDM[Double] = {
    val mat = BDM.zeros[Double](X.size, X(0).size)
    for (i <- 0 until X.size) {
      for (j <- 0 until X(0).size) {
        mat(i, j) = X(i)(j)
      }
    }
    mat
  }

} 
Example 76
Source File: MetadataTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.operator.{MetadataTransformUtils, VectorCartesian}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfter, FunSuite}

class MetadataTest extends FunSuite with BeforeAndAfter {

  var spark: SparkSession = _

  before {
    spark = SparkSession.builder().master("local").getOrCreate()
  }

  after {
    spark.close()
  }

  test("test_vector_cartesian") {
    val data = spark.read.format("libsvm")
      .option("numFeatures", "123")
      .load("data/a9a/a9a_123d_train_trans.libsvm")
      .persist()

    val cartesian = new VectorCartesian()
      .setInputCols(Array("features", "features"))
      .setOutputCol("cartesian_features")

    val assembler = new VectorAssembler()
      .setInputCols(Array("features", "cartesian_features"))
      .setOutputCol("assemble_features")

    val pipeline = new Pipeline()
      .setStages(Array(cartesian, assembler))

    val featureModel = pipeline.fit(data)
    val crossDF = featureModel.transform(data)

    crossDF.schema.fields.foreach { field =>
      println("name: " + field.name)
      println("metadata: " + field.metadata.toString())
    }
  }

  test("test_three_order_cartesian") {
    val data = spark.read.format("libsvm")
      .option("numFeatures", 8)
      .load("data/abalone/abalone_8d_train.libsvm")
      .persist()

    val cartesian = new VectorCartesian()
      .setInputCols(Array("features", "features"))
      .setOutputCol("f_f")

    val cartesian2 = new VectorCartesian()
      .setInputCols(Array("features", "f_f"))
      .setOutputCol("f_f_f")

    val pipeline = new Pipeline()
      .setStages(Array(cartesian, cartesian2))

    val crossDF = pipeline.fit(data).transform(data).persist()

    // first cartesian, the number of dimensions is 64
    println("first cartesian dimension = " + crossDF.select("f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).length)
    println(crossDF.select("f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).mkString(","))

    println()

    // second cartesian, the number of dimensions is 512
    println("second cartesian dimension = " + crossDF.select("f_f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).length)
    println(crossDF.select("f_f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).mkString(","))
  }
} 
Example 77
Source File: PipelineTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import com.tencent.angel.spark.automl.feature.preprocess.{HashingTFWrapper, IDFWrapper, TokenizerWrapper}
import com.tencent.angel.spark.automl.feature.{PipelineBuilder, PipelineWrapper, TransformerWrapper}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfter, FunSuite}

class PipelineTest extends FunSuite with BeforeAndAfter {

  var spark: SparkSession = _

  before {
    spark = SparkSession.builder().master("local").getOrCreate()
  }

  after {
    spark.close()
  }

  test("test_tfidf") {
    val sentenceData = spark.createDataFrame(Seq(
      (0.0, "Hi I heard about Spark"),
      (0.0, "I wish Java could use case classes"),
      (1.0, "Logistic regression models are neat")
    )).toDF("label", "sentence")

    val pipelineWrapper = new PipelineWrapper()

    val transformers = Array[TransformerWrapper](
      new TokenizerWrapper(),
      new HashingTFWrapper(20),
      new IDFWrapper()
    )

    val stages = PipelineBuilder.build(transformers)

    transformers.foreach { transformer =>
      val inputCols = transformer.getInputCols
      val outputCols = transformer.getOutputCols
      inputCols.foreach(print)
      print("    ")
      outputCols.foreach(print)
      println()
    }

    pipelineWrapper.setStages(stages)

    val model = pipelineWrapper.fit(sentenceData)

    val outputDF = model.transform(sentenceData)
    outputDF.select("outIDF").show()
    outputDF.select("outIDF").foreach { row =>
      println(row.get(0).getClass.getSimpleName)
      val arr = row.get(0)
      println(arr.toString)
    }
    outputDF.rdd.map(row => row.toString()).repartition(1)
      .saveAsTextFile("tmp/output/tfidf")
  }
} 
Example 78
Source File: GoogleAuthentication.scala    From amadou   with Apache License 2.0 5 votes vote down vote up
package com.mediative.amadou.bigquery

import java.io.{File, FileReader}
import scala.collection.JavaConversions._
import com.google.api.client.extensions.java6.auth.oauth2.AuthorizationCodeInstalledApp
import com.google.api.client.extensions.jetty.auth.oauth2.LocalServerReceiver
import com.google.api.client.googleapis.auth.oauth2.{
  GoogleAuthorizationCodeFlow,
  GoogleClientSecrets
}
import com.google.api.client.http.{HttpRequest, HttpRequestInitializer}
import com.google.api.client.http.javanet.NetHttpTransport
import com.google.api.client.json.jackson2.JacksonFactory
import com.google.api.client.util.store.FileDataStoreFactory
import org.apache.spark.sql.SparkSession

sealed abstract class GoogleAuthentication(val scopes: String*)

object GoogleAuthentication {
  lazy val HTTP_TRANSPORT = new NetHttpTransport()
  lazy val JSON_FACTORY   = new JacksonFactory()

  case object Dbm
      extends GoogleAuthentication("https://www.googleapis.com/auth/doubleclickbidmanager")

  def apply(auth: GoogleAuthentication, spark: SparkSession): HttpRequestInitializer = auth match {
    case Dbm =>
      val clientFilePath = spark.conf.get("spark.google.cloud.auth.client.file")
      require(clientFilePath != null, "'google.cloud.auth.client.file' not configured")

      val clientFile = new File(clientFilePath)
      require(clientFile.exists, s"$clientFilePath does not exists")

      val clientSecrets    = GoogleClientSecrets.load(JSON_FACTORY, new FileReader(clientFile))
      val dataStoreFactory = new FileDataStoreFactory(clientFile.getParentFile)

      val flow = new GoogleAuthorizationCodeFlow.Builder(
        HTTP_TRANSPORT,
        JSON_FACTORY,
        clientSecrets,
        auth.scopes)
        .setDataStoreFactory(dataStoreFactory)
        .build()

      val cred = new AuthorizationCodeInstalledApp(flow, new LocalServerReceiver())
        .authorize("user")
      new CustomHttpRequestInitializer(cred)
  }

  class CustomHttpRequestInitializer(wrapped: HttpRequestInitializer)
      extends HttpRequestInitializer {
    override def initialize(httpRequest: HttpRequest) = {
      wrapped.initialize(httpRequest)
      httpRequest.setConnectTimeout(10 * 60000) // 10 minutes connect timeout
      httpRequest.setReadTimeout(10 * 60000)    // 10 minutes read timeout
      ()
    }
  }
} 
Example 79
Source File: package.scala    From amadou   with Apache License 2.0 5 votes vote down vote up
package com.mediative.amadou

import com.google.api.services.bigquery.model._
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem
import com.google.cloud.hadoop.io.bigquery._
import org.apache.hadoop.fs.{FileSystem, Path}
import net.ceedubs.ficus.readers.ValueReader
import net.ceedubs.ficus.FicusInstances

import org.apache.spark.sql.{Dataset, SparkSession, Encoder}
import java.util.concurrent.ThreadLocalRandom
import scala.collection.JavaConversions._

package object bigquery extends FicusInstances {

  object CreateDisposition extends Enumeration {
    val CREATE_IF_NEEDED, CREATE_NEVER = Value
  }

  object WriteDisposition extends Enumeration {
    val WRITE_TRUNCATE, WRITE_APPEND, WRITE_EMPTY = Value
  }

  val BQ_CSV_DATE_FORMAT = "yyyy-MM-dd HH:mm:ss zzz"

  object TableNotFound {
    import com.google.api.client.googleapis.json.GoogleJsonResponseException
    import com.google.api.client.googleapis.json.GoogleJsonError
    import scala.collection.JavaConverters._

    def unapply(error: Throwable): Option[GoogleJsonError.ErrorInfo] = error match {
      case error: GoogleJsonResponseException =>
        Some(error.getDetails)
          .filter(_.getCode == 404)
          .flatMap(_.getErrors.asScala.find(_.getReason == "notFound"))
      case _ => None
    }
  }

  def tableHasDataForDate(
      spark: SparkSession,
      table: TableReference,
      date: java.sql.Date,
      column: String): Boolean = {
    val bq = BigQueryClient.getInstance(spark.sparkContext.hadoopConfiguration)
    bq.hasDataForDate(table, date, column)
  }

  
    def saveAsBigQueryTable(
        tableRef: TableReference,
        writeDisposition: WriteDisposition.Value,
        createDisposition: CreateDisposition.Value): Unit = {
      val bucket = conf.get(BigQueryConfiguration.GCS_BUCKET_KEY)
      val temp =
        s"spark-bigquery-${System.currentTimeMillis()}=${ThreadLocalRandom.current.nextInt(Int.MaxValue)}"
      val gcsPath = s"gs://$bucket/spark-bigquery-tmp/$temp"
      self.write.json(gcsPath)

      val schemaFields = self.schema.fields.map { field =>
        import org.apache.spark.sql.types._

        val fieldType = field.dataType match {
          case BooleanType    => "BOOLEAN"
          case LongType       => "INTEGER"
          case IntegerType    => "INTEGER"
          case StringType     => "STRING"
          case DoubleType     => "FLOAT"
          case TimestampType  => "TIMESTAMP"
          case _: DecimalType => "INTEGER"
        }
        new TableFieldSchema().setName(field.name).setType(fieldType)
      }.toList

      val tableSchema = new TableSchema().setFields(schemaFields)

      bq.load(gcsPath, tableRef, tableSchema, writeDisposition, createDisposition)
      delete(new Path(gcsPath))
    }

    private def delete(path: Path): Unit = {
      val fs = FileSystem.get(path.toUri, conf)
      fs.delete(path, true)
      ()
    }

  }

  implicit val valueReader: ValueReader[BigQueryTable.PartitionStrategy] =
    ValueReader[String].map {
      _ match {
        case "month" => BigQueryTable.PartitionByMonth
        case "day"   => BigQueryTable.PartitionByDay
        case other   => sys.error(s"Unknown partition strategy")
      }
    }
} 
Example 80
Source File: HdfsUrl.scala    From amadou   with Apache License 2.0 5 votes vote down vote up
package com.mediative.amadou

import org.apache.hadoop.fs.{Path, FSDataOutputStream}
import org.apache.spark.sql.SparkSession


case class HdfsUrl(url: String, dateFormat: Option[String] = None) {
  def path = new Path(url)

  def /(subPath: String): HdfsUrl =
    copy(url = new Path(path, subPath).toString)

  def /(date: DateInterval): HdfsUrl = {
    val datePath = dateFormat.fold(date.toString)(date.format)
    this./(datePath)
  }

  def exists(spark: SparkSession) = fileSystem(spark).exists(path)

  def open[T](spark: SparkSession)(f: FSDataOutputStream => T): T = {
    val stream = fileSystem(spark).create(path)
    try {
      f(stream)
    } finally {
      stream.close()
    }
  }

  def fileSystem(spark: SparkSession) =
    path.getFileSystem(spark.sparkContext.hadoopConfiguration)

  override def toString = path.toString
} 
Example 81
Source File: BugDemonstrationTest.scala    From spark-tsne   with Apache License 2.0 5 votes vote down vote up
package com.github.saurfang.spark.tsne

import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}


class BugDemonstrationTest extends FunSuite with Matchers with BeforeAndAfterAll {
  private var sparkSession : SparkSession = _
  override def beforeAll(): Unit = {
    super.beforeAll()
    sparkSession = SparkSession.builder().appName("BugTests").master("local[2]").getOrCreate()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    sparkSession.stop()
  }

  test("This demonstrates a bug was fixed in tsne-spark 2.1") {
    val sc = sparkSession.sparkContext

    val observations = sc.parallelize(
      Seq(
        Vectors.dense(1.0, 10.0, 100.0),
        Vectors.dense(2.0, 20.0, 200.0),
        Vectors.dense(3.0, 30.0, 300.0)
      )
    )

    // Compute column summary statistics.
    val summary: MultivariateStatisticalSummary = Statistics.colStats(observations)
    val expectedMean = Vectors.dense(2.0,20.0,200.0)
    val resultMean = summary.mean
    assertEqualEnough(resultMean, expectedMean)
    val expectedVariance = Vectors.dense(1.0,100.0,10000.0)
    assertEqualEnough(summary.variance, expectedVariance)
    val expectedNumNonZeros = Vectors.dense(3.0, 3.0, 3.0)
    assertEqualEnough(summary.numNonzeros, expectedNumNonZeros)
  }

  private def assertEqualEnough(sample: Vector, expected: Vector): Unit = {
    expected.toArray.zipWithIndex.foreach{ case(d: Double, i: Int) =>
      sample(i) should be (d +- 1E-12)
    }
  }
} 
Example 82
Source File: TiHandleRDD.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.tispark

import com.pingcap.tikv.meta.TiDAGRequest
import com.pingcap.tikv.util.RangeSplitter
import com.pingcap.tikv.{TiConfiguration, TiSession}
import com.pingcap.tispark.utils.TiUtil
import com.pingcap.tispark.{TiPartition, TiTableReference}
import gnu.trove.list.array.TLongArrayList
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.{Partition, TaskContext, TaskKilledException}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._


class TiHandleRDD(
    override val dagRequest: TiDAGRequest,
    override val physicalId: Long,
    val output: Seq[Attribute],
    override val tiConf: TiConfiguration,
    override val tableRef: TiTableReference,
    @transient private val session: TiSession,
    @transient private val sparkSession: SparkSession)
    extends TiRDD(dagRequest, physicalId, tiConf, tableRef, session, sparkSession) {

  private val outputTypes = output.map(_.dataType)
  private val converters =
    outputTypes.map(CatalystTypeConverters.createToCatalystConverter)

  override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] =
    new Iterator[InternalRow] {
      checkTimezone()

      private val tiPartition = split.asInstanceOf[TiPartition]
      private val session = TiSession.getInstance(tiConf)
      private val snapshot = session.createSnapshot(dagRequest.getStartTs)
      private[this] val tasks = tiPartition.tasks

      private val handleIterator = snapshot.indexHandleRead(dagRequest, tasks)
      private val regionManager = session.getRegionManager
      private lazy val handleList = {
        val lst = new TLongArrayList()
        handleIterator.asScala.foreach {
          // Kill the task in case it has been marked as killed. This logic is from
          // InterruptedIterator, but we inline it here instead of wrapping the iterator in order
          // to avoid performance overhead.
          if (context.isInterrupted()) {
            throw new TaskKilledException
          }
          lst.add(_)
        }
        lst
      }
      // Fetch all handles and group by region id
      private val regionHandleMap = RangeSplitter
        .newSplitter(regionManager)
        .groupByAndSortHandlesByRegionId(physicalId, handleList)
        .map(x => (x._1.first.getId, x._2))

      private val iterator = regionHandleMap.iterator

      override def hasNext: Boolean = {
        // Kill the task in case it has been marked as killed.
        if (context.isInterrupted()) {
          throw new TaskKilledException
        }
        iterator.hasNext
      }

      override def next(): InternalRow = {
        val next = iterator.next
        val regionId = next._1
        val handleList = next._2

        // Returns RegionId:[handle1, handle2, handle3...] K-V pair
        val sparkRow = Row.apply(regionId, handleList.toArray())
        TiUtil.rowToInternalRow(sparkRow, outputTypes, converters)
      }
    }
} 
Example 83
Source File: TiRowRDD.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.tispark

import com.pingcap.tikv._
import com.pingcap.tikv.columnar.TiColumnarBatchHelper
import com.pingcap.tikv.meta.TiDAGRequest
import com.pingcap.tispark.listener.CacheInvalidateListener
import com.pingcap.tispark.{TiPartition, TiTableReference}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.{Partition, TaskContext, TaskKilledException}
import org.slf4j.Logger

import scala.collection.JavaConversions._

class TiRowRDD(
    override val dagRequest: TiDAGRequest,
    override val physicalId: Long,
    val chunkBatchSize: Int,
    override val tiConf: TiConfiguration,
    val output: Seq[Attribute],
    override val tableRef: TiTableReference,
    @transient private val session: TiSession,
    @transient private val sparkSession: SparkSession)
    extends TiRDD(dagRequest, physicalId, tiConf, tableRef, session, sparkSession) {

  protected val logger: Logger = log

  // cache invalidation call back function
  // used for driver to update PD cache
  private val callBackFunc = CacheInvalidateListener.getInstance()

  override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] =
    new Iterator[ColumnarBatch] {
      checkTimezone()

      private val tiPartition = split.asInstanceOf[TiPartition]
      private val session = TiSession.getInstance(tiConf)
      session.injectCallBackFunc(callBackFunc)
      private val snapshot = session.createSnapshot(dagRequest.getStartTs)
      private[this] val tasks = tiPartition.tasks

      private val iterator =
        snapshot.tableReadChunk(dagRequest, tasks, chunkBatchSize)

      override def hasNext: Boolean = {
        // Kill the task in case it has been marked as killed. This logic is from
        // Interrupted Iterator, but we inline it here instead of wrapping the iterator in order
        // to avoid performance overhead.
        if (context.isInterrupted()) {
          throw new TaskKilledException
        }
        iterator.hasNext
      }

      override def next(): ColumnarBatch = {
        TiColumnarBatchHelper.createColumnarBatch(iterator.next)
      }
    }.asInstanceOf[Iterator[InternalRow]]

} 
Example 84
Source File: TiRDD.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.tispark

import com.pingcap.tikv._
import com.pingcap.tikv.exception.TiInternalException
import com.pingcap.tikv.meta.TiDAGRequest
import com.pingcap.tikv.types.Converter
import com.pingcap.tikv.util.RangeSplitter
import com.pingcap.tikv.util.RangeSplitter.RegionTask
import com.pingcap.tispark.{TiPartition, TiTableReference}
import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

abstract class TiRDD(
    val dagRequest: TiDAGRequest,
    val physicalId: Long,
    val tiConf: TiConfiguration,
    val tableRef: TiTableReference,
    @transient private val session: TiSession,
    @transient private val sparkSession: SparkSession)
    extends RDD[InternalRow](sparkSession.sparkContext, Nil) {

  private lazy val partitionPerSplit = tiConf.getPartitionPerSplit

  protected def checkTimezone(): Unit = {
    if (!tiConf.getLocalTimeZone.equals(Converter.getLocalTimezone)) {
      throw new TiInternalException(
        "timezone are different! driver: " + tiConf.getLocalTimeZone + " executor:" + Converter.getLocalTimezone +
          " please set user.timezone in spark.driver.extraJavaOptions and spark.executor.extraJavaOptions")
    }
  }

  override protected def getPartitions: Array[Partition] = {
    val keyWithRegionTasks = RangeSplitter
      .newSplitter(session.getRegionManager)
      .splitRangeByRegion(dagRequest.getRangesByPhysicalId(physicalId), dagRequest.getStoreType)

    val hostTasksMap = new mutable.HashMap[String, mutable.Set[RegionTask]]
      with mutable.MultiMap[String, RegionTask]

    var index = 0
    val result = new ListBuffer[TiPartition]
    for (task <- keyWithRegionTasks) {
      hostTasksMap.addBinding(task.getHost, task)
      val tasks = hostTasksMap(task.getHost)
      if (tasks.size >= partitionPerSplit) {
        result.append(new TiPartition(index, tasks.toSeq, sparkContext.applicationId))
        index += 1
        hostTasksMap.remove(task.getHost)
      }

    }
    // add rest
    for (tasks <- hostTasksMap.values) {
      result.append(new TiPartition(index, tasks.toSeq, sparkContext.applicationId))
      index += 1
    }
    result.toArray
  }

  override protected def getPreferredLocations(split: Partition): Seq[String] =
    split.asInstanceOf[TiPartition].tasks.head.getHost :: Nil
} 
Example 85
Source File: databases.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{Row, SparkSession, TiContext}


case class TiShowDatabasesCommand(tiContext: TiContext, delegate: ShowDatabasesCommand)
    extends TiCommand(delegate) {
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val databases =
      // Not leveraging catalog-specific db pattern, at least Hive and Spark behave different than each other.
      delegate.databasePattern.fold(tiCatalog.listDatabases())(tiCatalog.listDatabases)
    databases.map { d =>
      Row(d)
    }
  }
} 
Example 86
Source File: parser.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.extensions

import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.command.{
  CacheTableCommand,
  CreateViewCommand,
  ExplainCommand,
  UncacheTableCommand
}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{SparkSession, TiContext}

case class TiParser(getOrCreateTiContext: SparkSession => TiContext)(
    sparkSession: SparkSession,
    delegate: ParserInterface)
    extends ParserInterface {
  private lazy val tiContext = getOrCreateTiContext(sparkSession)
  private lazy val internal = new SparkSqlParser(sparkSession.sqlContext.conf)

  
  private def needQualify(tableIdentifier: TableIdentifier) =
    tableIdentifier.database.isEmpty && tiContext.sessionCatalog
      .getTempView(tableIdentifier.table)
      .isEmpty
} 
Example 87
Source File: TestSparkSession.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession


private[spark] class TestSparkSession(sparkContext: SparkContext) {
  self =>
  private val spark = SparkSession
    .builder()
    .sparkContext(sparkContext)
    .getOrCreate()
  SparkSession.setDefaultSession(spark)
  SparkSession.setActiveSession(spark)

  def session: SparkSession = spark
} 
Example 88
Source File: TestParquet.scala    From spark-dev   with GNU General Public License v3.0 5 votes vote down vote up
package examples.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.desc


case class Purchase(date: String, time: String,
	city: String, category: String, amount: Double, method: String)

object TestParquet {
	def main(args: Array[String]): Unit = {

		if (args.length == 0) {
			println("Error: Please mention the parquet file...")
			sys.exit(-1)
		}

		val spark = SparkSession
			.builder()
			.master("local[4]")
			.appName("TestParquet")
			.config(conf = new SparkConf())
			.getOrCreate()

		import spark.implicits._

		val ds = spark.read.parquet(args(0)).as[Purchase]

		ds.printSchema()

		println("Number of Method=\"Discover\": "
			+ ds.filter(data => data.method == "Discover").count())

		// Top 3 methods...
		println(">>> Top 3 Methods using Dataframe API >>>")

		ds.groupBy("Method")
			.sum("Amount")
			.withColumnRenamed("sum(Amount)", "Total")
			.orderBy(desc("Total"))
			.take(3)
			.foreach(println)

		println(">>> Top 3 Methods using SQL >>>")

		ds.createOrReplaceTempView("temp")
		val sqlStr = "select Method, sum(Amount) as Total from temp group by Method order by Total desc"

		spark
			.sql(sqlStr)
			.take(3)
			.foreach(println)
	}
} 
Example 89
Source File: TestDataset.scala    From spark-dev   with GNU General Public License v3.0 5 votes vote down vote up
package examples.sql

import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkConf

// Model for people.json
case class People(name: String, age: Option[Long])
// Model for users.parquet
case class Users(name: String, favorite_color: Option[String], favorite_numbers: Option[Array[Int]])


object TestDataset {
	def main(args: Array[String]): Unit = {

		val spark = SparkSession.builder()
			.master("local[4]")
			.appName("TestDataset")
			.config(conf = new SparkConf())
			.getOrCreate()

		readJson(spark)

		readParquet(spark)

	}

	def readParquet(spark: SparkSession): Unit = {
		import spark.implicits._
		val parquetFile = "/media/linux-1/spark-2.0.0-bin-hadoop2.7/examples/src/main/resources/users.parquet"
		val usersDS = spark.read.parquet(parquetFile).as[Users]

		usersDS.printSchema()

		usersDS.show()

	}

	def readJson(spark: SparkSession): Unit = {
		import spark.implicits._
		val jsonFile = "/media/linux-1/spark-2.0.0-bin-hadoop2.7/examples/src/main/resources/people.json"
		val peopleDS = spark.read.json(jsonFile).as[People]

		val partialFilterAge = filterAge(20, _: People)
		peopleDS.filter(partialFilterAge).show()
	}

	def filterAge(condition: Long, data: People): Boolean = {
		val ret: Boolean = data.age match {
			case Some(value) =>
				if (value > condition) true
				else false
			case None => false
		}
		ret
	}
} 
Example 90
Source File: SparkStreamingRedisSuite.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.redislabs.provider.redis

import com.redislabs.provider.redis.env.Env
import com.redislabs.provider.redis.util.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.scalatest.{BeforeAndAfterEach, FunSuite}


trait SparkStreamingRedisSuite extends FunSuite with Env with BeforeAndAfterEach with Logging {

  override protected def beforeEach(): Unit = {
    super.beforeEach()
    spark = SparkSession.builder().config(conf).getOrCreate()
    sc = spark.sparkContext
    ssc = new StreamingContext(sc, Seconds(1))
  }

  override protected def afterEach(): Unit = {
    ssc.stop()
    spark.stop
    System.clearProperty("spark.driver.port")
    super.afterEach()
  }

} 
Example 91
Source File: Env.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.redislabs.provider.redis.env

import com.redislabs.provider.redis.RedisConfig
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.{SparkConf, SparkContext}

trait Env {

  val conf: SparkConf
  var spark: SparkSession = _
  var sc: SparkContext = _
  var ssc: StreamingContext = _

  val redisHost = "127.0.0.1"
  val redisPort = 6379
  val redisAuth = "passwd"
  val redisConfig: RedisConfig
} 
Example 92
Source File: ModelSerialization.scala    From CTRmodel   with Apache License 2.0 5 votes vote down vote up
package com.ggstar.example

import com.ggstar.ctrmodel._
import com.ggstar.features.FeatureEngineering
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}

object ModelSerialization {
  def main(args: Array[String]): Unit = {

    Logger.getLogger("org").setLevel(Level.ERROR)

    val conf = new SparkConf()
      .setMaster("local")
      .setAppName("ctrModel")
      .set("spark.submit.deployMode", "client")

    val spark = SparkSession.builder.config(conf).getOrCreate()

    val resourcesPath = this.getClass.getResource("/samples.snappy.orc")
    val rawSamples = spark.read.format("orc").option("compression", "snappy").load(resourcesPath.getPath)


    //transform array to vector for following vectorAssembler
    val samples = FeatureEngineering.transferArray2Vector(rawSamples)

    samples.printSchema()
    samples.show(5, false)


    //model training
    println("Neural Network Ctr Prediction Model:")
    val innModel = new InnerProductNNCtrModel()
    innModel.train(samples)
    val transformedData = innModel.transform(samples)

    transformedData.show(1,false)

    //model serialization by mleap
    val mleapModelSerializer = new com.ggstar.serving.mleap.serialization.ModelSerializer()
    mleapModelSerializer.serializeModel(innModel._pipelineModel, "jar:file:/Users/zhwang/Workspace/CTRmodel/model/inn.model.mleap.zip", transformedData)

    //model serialization by JPMML
    val jpmmlModelSerializer = new com.ggstar.serving.jpmml.serialization.ModelSerializer()
    jpmmlModelSerializer.serializeModel(innModel._pipelineModel, "model/inn.model.jpmml.xml", transformedData)
  }
} 
Example 93
Source File: ModelSelection.scala    From CTRmodel   with Apache License 2.0 5 votes vote down vote up
package com.ggstar.example

import com.ggstar.ctrmodel._
import com.ggstar.evaluation.Evaluator
import com.ggstar.features.FeatureEngineering
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.log4j.{Level, Logger}

object ModelSelection {
  def main(args: Array[String]): Unit = {

    Logger.getLogger("org").setLevel(Level.ERROR)

    val conf = new SparkConf()
      .setMaster("local")
      .setAppName("ctrModel")
      .set("spark.submit.deployMode", "client")

    val spark = SparkSession.builder.config(conf).getOrCreate()

    val resourcesPath = this.getClass.getResource("/samples.snappy.orc")
    val rawSamples = spark.read.format("orc").option("compression", "snappy").load(resourcesPath.getPath)
    rawSamples.printSchema()
    rawSamples.show(10)

    //transform array to vector for following vectorAssembler
    val samples = FeatureEngineering.transferArray2Vector(rawSamples)

    //split samples into training samples and validation samples
    val Array(trainingSamples, validationSamples) = samples.randomSplit(Array(0.7, 0.3))
    val evaluator = new Evaluator

    
  }
} 
Example 94
Source File: RecommendationModelReuse.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.MovieRecommendation

import org.apache.spark.sql.SparkSession
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
import org.apache.spark.mllib.recommendation.Rating
import scala.Tuple2
import org.apache.spark.rdd.RDD

object RecommendationModelReuse {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .appName("JavaLDAExample")
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/").
      getOrCreate()

    val ratigsFile = "data/ratings.csv"
    val ratingDF = spark.read.format("com.databricks.spark.csv").option("header", true).load(ratigsFile)
    val selectedRatingsDF = ratingDF.select(ratingDF.col("userId"), ratingDF.col("movieId"), ratingDF.col("rating"), ratingDF.col("timestamp"))

    // Randomly split ratings RDD into training data RDD (75%) and test data RDD (25%)
    val splits = selectedRatingsDF.randomSplit(Array(0.75, 0.25), seed = 12345L)
    val testData = splits(1)

    val testRDD = testData.rdd.map(row => {
      val userId = row.getString(0)
      val movieId = row.getString(1)
      val ratings = row.getString(2)
      Rating(userId.toInt, movieId.toInt, ratings.toDouble)
    })

    //Load the workflow back
    val same_model = MatrixFactorizationModel.load(spark.sparkContext, "model/MovieRecomModel/")

    // Making Predictions. Get the top 6 movie predictions for user 668
    println("Rating:(UserID, MovieID, Rating)")
    println("----------------------------------")
    val topRecsForUser = same_model.recommendProducts(458, 10)
    for (rating <- topRecsForUser) {
      println(rating.toString())
    }
    println("----------------------------------")

    val rmseTest = MovieRecommendation.computeRmse(same_model, testRDD, true)
    println("Test RMSE: = " + rmseTest) //Less is better

    //Movie recommendation for a specific user. Get the top 6 movie predictions for user 668
    println("Recommendations: (MovieId => Rating)")
    println("----------------------------------")
    val recommendationsUser = same_model.recommendProducts(458, 10)
    recommendationsUser.map(rating => (rating.product, rating.rating)).foreach(println)
    println("----------------------------------")

    spark.stop()
  }
} 
Example 95
Source File: MovieRecommendation.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.MovieRecommendation

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SQLImplicits
import org.apache.spark.sql._
import org.apache.spark.sql.Dataset
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
import org.apache.spark.mllib.recommendation.Rating
import scala.Tuple2
import org.apache.spark.rdd.RDD

object MovieRecommendation {  
  //Compute the RMSE to evaluate the model. Less the RMSE better the model and it's prediction capability. 
  def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean): Double = {
    val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product)))
    val predictionsAndRatings = predictions.map { x => ((x.user, x.product), x.rating)
    }.join(data.map(x => ((x.user, x.product), x.rating))).values
    if (implicitPrefs) {
      println("(Prediction, Rating)")
      println(predictionsAndRatings.take(5).mkString("\n"))
    }
    math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean())
  }

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .appName("JavaLDAExample")
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/").
      getOrCreate()

    val ratigsFile = "data/ratings.csv"
    val df1 = spark.read.format("com.databricks.spark.csv").option("header", true).load(ratigsFile)

    val ratingsDF = df1.select(df1.col("userId"), df1.col("movieId"), df1.col("rating"), df1.col("timestamp"))
    ratingsDF.show(false)

    val moviesFile = "data/movies.csv"
    val df2 = spark.read.format("com.databricks.spark.csv").option("header", "true").load(moviesFile)

    val moviesDF = df2.select(df2.col("movieId"), df2.col("title"), df2.col("genres"))
    moviesDF.show(false)

    ratingsDF.createOrReplaceTempView("ratings")
    moviesDF.createOrReplaceTempView("movies")

    

    var rmseTest = computeRmse(model, testRDD, true)
    println("Test RMSE: = " + rmseTest) //Less is better

    //Movie recommendation for a specific user. Get the top 6 movie predictions for user 668
    println("Recommendations: (MovieId => Rating)")
    println("----------------------------------")
    val recommendationsUser = model.recommendProducts(668, 6)
    recommendationsUser.map(rating => (rating.product, rating.rating)).foreach(println)
    println("----------------------------------")

    spark.stop()
  }
} 
Example 96
Source File: LDAModelReuse.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.TopicModelling

import org.apache.spark.sql.SparkSession
import org.apache.spark.mllib.clustering.{ DistributedLDAModel, LDA }

object LDAModelReuse {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "data/")
      .appName(s"OneVsRestExample")
      .getOrCreate()

    //Restoring the model for reuse
    val savedLDAModel = DistributedLDAModel.load(spark.sparkContext, "model/LDATrainedModel/")

    val lda = new LDAforTM() // actual computations are done here
    val defaultParams = Params().copy(input = "data/4UK1UkTX.csv", savedLDAModel) // Loading the parameters to train the LDA model
    lda.run(defaultParams, false) // Training the LDA model with the default parameters but don't save the trained model again 
    spark.stop()
  }
} 
Example 97
Source File: Preprocess.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.BitCoin

import java.io.{ BufferedWriter, File, FileWriter }
import org.apache.spark.sql.types.{ DoubleType, IntegerType, StructField, StructType }
import org.apache.spark.sql.{ DataFrame, Row, SparkSession }
import scala.collection.mutable.ListBuffer

object Preprocess {
  //how many of first rows are omitted
    val dropFirstCount: Int = 612000

    def rollingWindow(data: DataFrame, window: Int, xFilename: String, yFilename: String): Unit = {
      var i = 0
      val xWriter = new BufferedWriter(new FileWriter(new File(xFilename)))
      val yWriter = new BufferedWriter(new FileWriter(new File(yFilename)))

      val zippedData = data.rdd.zipWithIndex().collect()
      System.gc()
      val dataStratified = zippedData.drop(dropFirstCount) //todo slice fisrt 614K
      while (i < (dataStratified.length - window)) {
        val x = dataStratified
          .slice(i, i + window)
          .map(r => r._1.getAs[Double]("Delta")).toList
        val y = dataStratified.apply(i + window)._1.getAs[Integer]("label")
        val stringToWrite = x.mkString(",")
        xWriter.write(stringToWrite + "\n")
        yWriter.write(y + "\n")

        i += 1
        if (i % 10 == 0) {
          xWriter.flush()
          yWriter.flush()
        }
      }

      xWriter.close()
      yWriter.close()
    }
    
  def main(args: Array[String]): Unit = {
    //todo modify these variables to match desirable files
    val priceDataFileName: String = "C:/Users/admin-karim/Desktop/bitstampUSD_1-min_data_2012-01-01_to_2017-10-20.csv/bitstampUSD_1-min_data_2012-01-01_to_2017-10-20.csv"
    val outputDataFilePath: String = "output/scala_test_x.csv"
    val outputLabelFilePath: String = "output/scala_test_y.csv"

    val spark = SparkSession
      .builder()
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/")
      .appName("Bitcoin Preprocessing")
      .getOrCreate()

    val data = spark.read.format("com.databricks.spark.csv").option("header", "true").load(priceDataFileName)
    data.show(10)
    println((data.count(), data.columns.size))

    val dataWithDelta = data.withColumn("Delta", data("Close") - data("Open"))

    import org.apache.spark.sql.functions._
    import spark.sqlContext.implicits._

    val dataWithLabels = dataWithDelta.withColumn("label", when($"Close" - $"Open" > 0, 1).otherwise(0))
    rollingWindow(dataWithLabels, 22, outputDataFilePath, outputLabelFilePath)    
    spark.stop()
  }
} 
Example 98
Source File: SparkSessionCreate.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML

import org.apache.spark.sql.SparkSession

object SparkSessionCreate {
  def createSession(): SparkSession = {
    val spark = SparkSession
      .builder
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/")
      .appName(s"OneVsRestExample")
      .getOrCreate()

    return spark
  }
} 
Example 99
Source File: ChurnPredictionSVM.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.ChrunPrediction

import org.apache.spark._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.ml.classification.{LinearSVC, LinearSVCModel}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.max
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

object ChurnPredictionSVM {
  def main(args: Array[String]) {
    val spark: SparkSession = SparkSessionCreate.createSession("ChurnPredictionSVM")
    import spark.implicits._

    val numFolds = 10
    val MaxIter: Seq[Int] = Seq(1000)
    val RegParam: Seq[Double] = Seq(0.10) // L2 regularization param, set 0.10 with L1 reguarization
    val Tol: Seq[Double] = Seq(1e-4)
    val ElasticNetParam: Seq[Double] = Seq(0.00001) // Combination of L1 and L2

    val svm = new LinearSVC()

    // Chain indexers and tree in a Pipeline.
    val pipeline = new Pipeline()
      .setStages(Array(PipelineConstruction.ipindexer,
        PipelineConstruction.labelindexer,
        PipelineConstruction.assembler,
        svm))

    // Search through decision tree's maxDepth parameter for best model                               
    val paramGrid = new ParamGridBuilder()
      .addGrid(svm.maxIter, MaxIter)
      .addGrid(svm.regParam, RegParam)
      .addGrid(svm.tol, Tol)
      .build()

    val evaluator = new BinaryClassificationEvaluator()
      .setLabelCol("label")
      .setRawPredictionCol("prediction")

    // Set up 3-fold cross validation
    val crossval = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(numFolds)

    val cvModel = crossval.fit(Preprocessing.trainDF)

    val predictions = cvModel.transform(Preprocessing.testSet) 
    val selectPrediction = predictions.select("label", "features", "rawPrediction","prediction")
    selectPrediction.show(10)
    
    val accuracy = evaluator.evaluate(predictions)
    println("Classification accuracy: " + accuracy)    

    // Compute other performence metrices
    val predictionAndLabels = predictions
      .select("prediction", "label")
      .rdd.map(x => (x(0).asInstanceOf[Double], x(1)
        .asInstanceOf[Double]))

    val metrics = new BinaryClassificationMetrics(predictionAndLabels)
   
    val areaUnderPR = metrics.areaUnderPR
    println("Area under the precision-recall curve: " + areaUnderPR)
    
    val areaUnderROC = metrics.areaUnderROC
    println("Area under the receiver operating characteristic (ROC) curve: " + areaUnderROC)

    

    val lp = predictions.select("label", "prediction")
    val counttotal = predictions.count()
    val correct = lp.filter($"label" === $"prediction").count()
    val wrong = lp.filter(not($"label" === $"prediction")).count()
    val ratioWrong = wrong.toDouble / counttotal.toDouble
    val ratioCorrect = correct.toDouble / counttotal.toDouble
    val truep = lp.filter($"prediction" === 0.0).filter($"label" === $"prediction").count() / counttotal.toDouble
    val truen = lp.filter($"prediction" === 1.0).filter($"label" === $"prediction").count() / counttotal.toDouble
    val falsep = lp.filter($"prediction" === 1.0).filter(not($"label" === $"prediction")).count() / counttotal.toDouble
    val falsen = lp.filter($"prediction" === 0.0).filter(not($"label" === $"prediction")).count() / counttotal.toDouble

    println("Total Count: " + counttotal)
    println("Correct: " + correct)
    println("Wrong: " + wrong)
    println("Ratio wrong: " + ratioWrong)
    println("Ratio correct: " + ratioCorrect)
    println("Ratio true positive: " + truep)
    println("Ratio false positive: " + falsep)
    println("Ratio true negative: " + truen)
    println("Ratio false negative: " + falsen)
  }
} 
Example 100
Source File: Describe.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.ChrunPrediction

import org.apache.spark._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.ml.classification.{ BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel }
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.max
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator }
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

import org.apache.spark._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.Dataset

import org.apache.spark.ml.linalg.{ Matrix, Vectors }
import org.apache.spark.ml.stat.Correlation
import org.apache.spark.sql.Row

object Describe {
  case class CustomerAccount(state_code: String, account_length: Integer, area_code: String,
    international_plan: String, voice_mail_plan: String, num_voice_mail: Double,
    total_day_mins: Double, total_day_calls: Double, total_day_charge: Double,
    total_evening_mins: Double, total_evening_calls: Double, total_evening_charge: Double,
    total_night_mins: Double, total_night_calls: Double, total_night_charge: Double,
    total_international_mins: Double, total_international_calls: Double, total_international_charge: Double,
    total_international_num_calls: Double, churn: String)

  val schema = StructType(Array(
    StructField("state_code", StringType, true),
    StructField("account_length", IntegerType, true),
    StructField("area_code", StringType, true),
    StructField("international_plan", StringType, true),
    StructField("voice_mail_plan", StringType, true),
    StructField("num_voice_mail", DoubleType, true),
    StructField("total_day_mins", DoubleType, true),
    StructField("total_day_calls", DoubleType, true),
    StructField("total_day_charge", DoubleType, true),
    StructField("total_evening_mins", DoubleType, true),
    StructField("total_evening_calls", DoubleType, true),
    StructField("total_evening_charge", DoubleType, true),
    StructField("total_night_mins", DoubleType, true),
    StructField("total_night_calls", DoubleType, true),
    StructField("total_night_charge", DoubleType, true),
    StructField("total_international_mins", DoubleType, true),
    StructField("total_international_calls", DoubleType, true),
    StructField("total_international_charge", DoubleType, true),
    StructField("total_international_num_calls", DoubleType, true),
    StructField("churn", StringType, true)))

  def main(args: Array[String]) {
    val spark = SparkSession
      .builder
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/")
      .appName("Desribe")
      .getOrCreate()

    spark.conf.set("spark.debug.maxToStringFields", 10000)
    val DEFAULT_MAX_TO_STRING_FIELDS = 2500
    if (SparkEnv.get != null) {
      SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS)
    } else {
      DEFAULT_MAX_TO_STRING_FIELDS
    }
    import spark.implicits._

    val trainSet: Dataset[CustomerAccount] = spark.read.
      option("inferSchema", "false")
      .format("com.databricks.spark.csv")
      .schema(schema)
      .load("data/churn-bigml-80.csv")
      .as[CustomerAccount]

    val statsDF = trainSet.describe()   
    statsDF.show()

    trainSet.createOrReplaceTempView("UserAccount")
    spark.catalog.cacheTable("UserAccount")
    
    spark.sqlContext.sql("SELECT churn, SUM(total_day_mins) + SUM(total_evening_mins) + SUM(total_night_mins) + SUM(total_international_mins) as Total_minutes FROM UserAccount GROUP BY churn").show()
    spark.sqlContext.sql("SELECT churn, SUM(total_day_charge) as TDC, SUM(total_evening_charge) as TEC, SUM(total_night_charge) as TNC, SUM(total_international_charge) as TIC, SUM(total_day_charge) + SUM(total_evening_charge) + SUM(total_night_charge) + SUM(total_international_charge) as Total_charge FROM UserAccount GROUP BY churn ORDER BY Total_charge DESC").show()
    trainSet.groupBy("churn").count.show()
    spark.sqlContext.sql("SELECT churn,SUM(total_international_num_calls) FROM UserAccount GROUP BY churn")
    
  }
} 
Example 101
Source File: SparkSessionCreate.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.ChrunPrediction

import org.apache.spark.sql.SparkSession

object SparkSessionCreate {
  def createSession(appName:String): SparkSession = {
    val spark = SparkSession
      .builder
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/")
      .appName(appName)
      .getOrCreate()

    return spark
  }
} 
Example 102
Source File: ChurnPredictionLR.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.ChrunPrediction

import org.apache.spark._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

object ChurnPredictionLR {
  def main(args: Array[String]) {
    val spark: SparkSession = SparkSessionCreate.createSession("ChurnPredictionLogisticRegression")
    import spark.implicits._

    val numFolds = 10
    val MaxIter: Seq[Int] = Seq(100)
    val RegParam: Seq[Double] = Seq(1.0) // L2 regularization param, set 0.10 with L1 reguarization
    val Tol: Seq[Double] = Seq(1e-8)
    val ElasticNetParam: Seq[Double] = Seq(1.0) // Combination of L1 and L2

    val lr = new LogisticRegression()
                    .setLabelCol("label")
                    .setFeaturesCol("features")

    // Chain indexers and tree in a Pipeline.
    val pipeline = new Pipeline()
      .setStages(Array(PipelineConstruction.ipindexer,
        PipelineConstruction.labelindexer,
        PipelineConstruction.assembler,
        lr))

    // Search through decision tree's maxDepth parameter for best model                               
    val paramGrid = new ParamGridBuilder()
      .addGrid(lr.maxIter, MaxIter)
      .addGrid(lr.regParam, RegParam)
      .addGrid(lr.tol, Tol)
      .addGrid(lr.elasticNetParam, ElasticNetParam)
      .build()

    val evaluator = new BinaryClassificationEvaluator()
                  .setLabelCol("label")
                  .setRawPredictionCol("prediction")

    // Set up 10-fold cross validation
    val crossval = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(numFolds)

    val cvModel = crossval.fit(Preprocessing.trainDF)   

    val predictions = cvModel.transform(Preprocessing.testSet)
    val result = predictions.select("label", "prediction", "probability")
    val resutDF = result.withColumnRenamed("prediction", "Predicted_label")
    resutDF.show(10)
    
    val accuracy = evaluator.evaluate(predictions)
    println("Classification accuracy: " + accuracy)    

    // Compute other performence metrices
    val predictionAndLabels = predictions
      .select("prediction", "label")
      .rdd.map(x => (x(0).asInstanceOf[Double], x(1)
        .asInstanceOf[Double]))

    val metrics = new BinaryClassificationMetrics(predictionAndLabels)
    val areaUnderPR = metrics.areaUnderPR
    println("Area under the precision-recall curve: " + areaUnderPR)
    
    val areaUnderROC = metrics.areaUnderROC
    println("Area under the receiver operating characteristic (ROC) curve: " + areaUnderROC)

    

    val lp = predictions.select("label", "prediction")
    val counttotal = predictions.count()
    val correct = lp.filter($"label" === $"prediction").count()
    val wrong = lp.filter(not($"label" === $"prediction")).count()
    val ratioWrong = wrong.toDouble / counttotal.toDouble
    val ratioCorrect = correct.toDouble / counttotal.toDouble
    val truep = lp.filter($"prediction" === 0.0).filter($"label" === $"prediction").count() / counttotal.toDouble
    val truen = lp.filter($"prediction" === 1.0).filter($"label" === $"prediction").count() / counttotal.toDouble
    val falsep = lp.filter($"prediction" === 1.0).filter(not($"label" === $"prediction")).count() / counttotal.toDouble
    val falsen = lp.filter($"prediction" === 0.0).filter(not($"label" === $"prediction")).count() / counttotal.toDouble

    println("Total Count: " + counttotal)
    println("Correct: " + correct)
    println("Wrong: " + wrong)
    println("Ratio wrong: " + ratioWrong)
    println("Ratio correct: " + ratioCorrect)
    println("Ratio true positive: " + truep)
    println("Ratio false positive: " + falsep)
    println("Ratio true negative: " + truen)
    println("Ratio false negative: " + falsen)
  }
} 
Example 103
Source File: XmlReader.scala    From spark-xml   with Apache License 2.0 5 votes vote down vote up
package com.databricks.spark.xml

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, SQLContext, SparkSession}
import org.apache.spark.sql.types.StructType
import com.databricks.spark.xml.util.XmlFile
import com.databricks.spark.xml.util.FailFastMode


  @deprecated("Use xmlFile(SparkSession, ...)", "0.5.0")
  def xmlFile(sqlContext: SQLContext, path: String): DataFrame = {
    // We need the `charset` and `rowTag` before creating the relation.
    val (charset, rowTag) = {
      val options = XmlOptions(parameters.toMap)
      (options.charset, options.rowTag)
    }
    val relation = XmlRelation(
      () => XmlFile.withCharset(sqlContext.sparkContext, path, charset, rowTag),
      Some(path),
      parameters.toMap,
      schema)(sqlContext)
    sqlContext.baseRelationToDataFrame(relation)
  }

  @deprecated("Use xmlRdd(SparkSession, ...)", "0.5.0")
  def xmlRdd(sqlContext: SQLContext, xmlRDD: RDD[String]): DataFrame = {
    val relation = XmlRelation(
      () => xmlRDD,
      None,
      parameters.toMap,
      schema)(sqlContext)
    sqlContext.baseRelationToDataFrame(relation)
  }

} 
Example 104
Source File: XmlPartitioningSuite.scala    From spark-xml   with Apache License 2.0 5 votes vote down vote up
package com.databricks.spark.xml

import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers


final class XmlPartitioningSuite extends AnyFunSuite with Matchers with BeforeAndAfterAll {

  private def doPartitionTest(suffix: String, blockSize: Long, large: Boolean): Unit = {
    val spark = SparkSession.builder()
      .master("local[2]")
      .appName("XmlPartitioningSuite")
      .config("spark.ui.enabled", false)
      .config("spark.hadoop.fs.local.block.size", blockSize)
      .getOrCreate()
    try {
      val fileName = s"fias_house${if (large) ".large" else ""}.xml$suffix"
      val xmlFile = getClass.getClassLoader.getResource(fileName).getFile
      val results = spark.read.option("rowTag", "House").option("mode", "FAILFAST").xml(xmlFile)
      // Test file has 37 records; large file is 20x the records
      assert(results.count() === (if (large) 740 else 37))
    } finally {
      spark.stop()
    }
  }

  test("Uncompressed small file with specially chosen block size") {
    doPartitionTest("", 8342, false)
  }

  test("Uncompressed small file with small block size") {
    doPartitionTest("", 500, false)
  }

  test("bzip2 small file with small block size") {
    doPartitionTest(".bz2", 500, false)
  }

  test("bzip2 large file with small block size") {
    // Note, the large bzip2 test file was compressed such that there are several blocks
    // in the compressed input (e.g. bzip2 -1 on a file with much more than 100k data)
    doPartitionTest(".bz2", 500, true)
  }

  test("gzip small file") {
    // Block size won't matter
    doPartitionTest(".gz", 500, false)
  }

  test("gzip large file") {
    // Block size won't matter
    doPartitionTest(".gz", 500, true)
  }

} 
Example 105
Source File: DataGenerator.scala    From iterative-broadcast-join   with Apache License 2.0 5 votes vote down vote up
package com.godatadriven.generator

import com.godatadriven.common.Config
import com.godatadriven.generator.UniformDataGenerator.KeyLabel
import org.apache.spark.sql.{SaveMode, SparkSession}

import scala.util.Random

trait DataGenerator {


  def numberOfRows(numberOfKeys: Int = Config.numberOfKeys,
                   keysMultiplier: Int = Config.keysMultiplier): Long =
    generateSkewedSequence(numberOfKeys).map(_._2).sum * keysMultiplier.toLong

  
  def generateSkewedSequence(numberOfKeys: Int): List[(Int, Int)] =
    (0 to numberOfKeys).par.map(i =>
      (i, Math.ceil(
        (numberOfKeys.toDouble - i.toDouble) / (i.toDouble + 1.0)
      ).toInt)
    ).toList

  def createMediumTable(spark: SparkSession, tableName: String, numberOfPartitions: Int): Unit = {

    import spark.implicits._

    val df = spark
      .read
      .parquet("table_large.parquet")
      .as[Int]
      .distinct()
      .mapPartitions(rows => {
        val r = new Random()
        rows.map(key =>
          KeyLabel(
            key,
            s"Description for entry $key, that can be anything",
            // Already preallocate the pass of the broadcast iteration here
            Math.floor(r.nextDouble() * Config.numberOfBroadcastPasses).toInt
          )
        )
      })
      .repartition(numberOfPartitions)

    assert(df.count() == Config.numberOfKeys)

    df
      .write
      .mode(SaveMode.Overwrite)
      .parquet(tableName)
  }

  def buildTestset(spark: SparkSession,
                   numberOfKeys: Int = Config.numberOfKeys,
                   keysMultiplier: Int = Config.keysMultiplier,
                   numberOfPartitions: Int = Config.numberOfPartitions): Unit

  def getName: String

  def getMediumTableName: String

  def getLargeTableName: String

} 
Example 106
Source File: NormalJoin.scala    From iterative-broadcast-join   with Apache License 2.0 5 votes vote down vote up
package com.godatadriven.join

import org.apache.spark.sql.{DataFrame, SparkSession}

object NormalJoin extends JoinStrategy {

  override def join(spark: SparkSession, dfLarge: DataFrame, dfMedium: DataFrame): DataFrame = {
    // Explicitly disable the broadcastjoin
    spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

    dfLarge
      .join(
        dfMedium,
        Seq("key"),
        "left_outer"
      )
      .select(
        dfLarge("key"),
        dfMedium("label")
      )
  }
} 
Example 107
Source File: IterativeBroadcastJoin.scala    From iterative-broadcast-join   with Apache License 2.0 5 votes vote down vote up
package com.godatadriven.join

import com.godatadriven.SparkUtil
import com.godatadriven.common.Config
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.annotation.tailrec

object IterativeBroadcastJoin extends JoinStrategy {

  @tailrec
  private def iterativeBroadcastJoin(spark: SparkSession,
                                     result: DataFrame,
                                     broadcast: DataFrame,
                                     iteration: Int = 0): DataFrame =
    if (iteration < Config.numberOfBroadcastPasses) {
      val tableName = s"tmp_broadcast_table_itr_$iteration.parquet"

      val out = result.join(
        broadcast.filter(col("pass") === lit(iteration)),
        Seq("key"),
        "left_outer"
      ).select(
        result("key"),

        // Join in the label
        coalesce(
          result("label"),
          broadcast("label")
        ).as("label")
      )

      SparkUtil.dfWrite(out, tableName)

      iterativeBroadcastJoin(
        spark,
        SparkUtil.dfRead(spark, tableName),
        broadcast,
        iteration + 1
      )
    } else result

  override def join(spark: SparkSession,
                    dfLarge: DataFrame,
                    dfMedium: DataFrame): DataFrame = {
    broadcast(dfMedium)
    iterativeBroadcastJoin(
      spark,
      dfLarge
        .select("key")
        .withColumn("label", lit(null)),
      dfMedium
    )
  }

} 
Example 108
Source File: SparkUtil.scala    From iterative-broadcast-join   with Apache License 2.0 5 votes vote down vote up
package com.godatadriven

import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

object SparkUtil {

  def dfWrite(df: DataFrame, name: String): Unit =
    df
      .write
      .mode(SaveMode.Overwrite)
      .parquet(name)

  def dfRead(spark: SparkSession, name: String): DataFrame =
    spark
      .read
      .load(name)

} 
Example 109
Source File: SparkSuite.scala    From spark-sorted   with Apache License 2.0 5 votes vote down vote up
package com.tresata.spark.sorted

import org.scalactic.Equality
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.{ Dataset, SparkSession }

object SparkSuite {
  lazy val spark: SparkSession = {
    val session = SparkSession.builder
      .master("local[*]")
      .appName("test")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .config("spark.ui.enabled", false)
      .config("spark.sql.shuffle.partitions", 4)
      .getOrCreate()
    session
  }
  lazy val sc: SparkContext = spark.sparkContext

  lazy val jsc = new JavaSparkContext(sc)
  def javaSparkContext() = jsc
}

trait SparkSuite {
  implicit lazy val spark: SparkSession = SparkSuite.spark
  implicit lazy val sc: SparkContext = SparkSuite.spark.sparkContext

  implicit def rddEq[X]: Equality[RDD[X]] = new Equality[RDD[X]] {
    private def toCounts[Y](s: Seq[Y]): Map[Y, Int] = s.groupBy(identity).mapValues(_.size)

    def areEqual(a: RDD[X], b: Any): Boolean = b match {
      case s: Seq[_] => toCounts(a.collect) == toCounts(s)
      case rdd: RDD[_] => toCounts(a.collect) == toCounts(rdd.collect)
    }
  }

  implicit def gsEq[K, V](implicit rddEq: Equality[RDD[(K, V)]]): Equality[GroupSorted[K, V]] = new Equality[GroupSorted[K, V]] {
    def areEqual(a: GroupSorted[K, V], b: Any): Boolean = rddEq.areEqual(a, b)
  }
  
  implicit def dsEq[X](implicit rddEq: Equality[RDD[X]]): Equality[Dataset[X]] = new Equality[Dataset[X]] {
    def areEqual(a: Dataset[X], b: Any): Boolean = b match {
      case ds: Dataset[_] => rddEq.areEqual(a.rdd, ds.rdd)
      case x => rddEq.areEqual(a.rdd, x)
    }
  }
} 
Example 110
Source File: TestUtils.scala    From odsc-east-realish-predictions   with Apache License 2.0 5 votes vote down vote up
package com.twilio.open.odsc.realish

import com.holdenkarau.spark.testing.{LocalSparkContext, SparkContextProvider}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, Suite}

object TestUtils {

}

@SerialVersionUID(1L)
case class UserPersonality(uuid: String, name: String, tags: Array[String])
  extends Serializable

@SerialVersionUID(1L)
case class Author(uuid: String, name: String, age: Int) extends Serializable

@SerialVersionUID(1L)
case class LibraryBook(uuid: String, name: String, author: Author) extends Serializable

case class MockKafkaDataFrame(key: Array[Byte], value: Array[Byte])

trait SharedSparkSql extends BeforeAndAfterAll with SparkContextProvider {
  self: Suite =>

  @transient var _sparkSql: SparkSession = _
  @transient private var _sc: SparkContext = _

  override def sc: SparkContext = _sc

  def conf: SparkConf

  def sparkSql: SparkSession = _sparkSql

  override def beforeAll() {
    _sparkSql = SparkSession.builder().config(conf).getOrCreate()

    _sc = _sparkSql.sparkContext
    setup(_sc)
    super.beforeAll()
  }

  override def afterAll() {
    try {
      _sparkSql.close()
      _sparkSql = null
      LocalSparkContext.stop(_sc)
      _sc = null
    } finally {
      super.afterAll()
    }
  }

} 
Example 111
Source File: StreamingApp.scala    From odsc-east-realish-predictions   with Apache License 2.0 5 votes vote down vote up
package com.twilio.open.odsc.realish.utils

import com.twilio.open.odsc.realish.listeners.InsightsQueryListener
import org.apache.spark.sql.SparkSession
import org.slf4j.Logger

trait StreamingApp {
  val logger: Logger
  def run(): Unit
}

trait Restartable {
  def restart(): Unit
}

trait RestartableStreamingApp extends StreamingApp with Restartable {
  val spark: SparkSession

  val streamingQueryListener: InsightsQueryListener = {
    new InsightsQueryListener(spark, restart)
  }

  def monitoredRun(): Unit = {
    run()
    monitorStreams()
  }

  
  def restart(): Unit = {
    logger.info(s"restarting the application. cleaning up old stream listener and streams")

    val streams = spark.streams
    streams.removeListener(streamingQueryListener)
    streams.active.foreach { stream =>
      logger.info(s"stream_name=${stream.name} state=active status=${stream.status} action=stop_stream")
      stream.stop()
    }
    logger.info(s"attempting to restart the application")
    monitoredRun()
  }
} 
Example 112
Source File: InsightsQueryListener.scala    From odsc-east-realish-predictions   with Apache License 2.0 5 votes vote down vote up
package com.twilio.open.odsc.realish.listeners

import kamon.Kamon
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.JavaConverters._

object InsightsQueryListener {
  val log: Logger = LoggerFactory.getLogger(classOf[InsightsQueryListener])

  def apply(spark: SparkSession, restart: () => Unit): InsightsQueryListener = {
    new InsightsQueryListener(spark, restart)
  }

}

class InsightsQueryListener(sparkSession: SparkSession, restart: () => Unit) extends StreamingQueryListener {
  import InsightsQueryListener._
  private val streams = sparkSession.streams
  private val defaultTag = Map("app_name" -> sparkSession.sparkContext.appName)

  def doubleToLong(value: Double): Long = {
    value match {
      case a if a.isInfinite => 0L
      case b if b == Math.floor(b) => b.toLong
      case c => Math.rint(c).toLong
    }
  }

  override def onQueryStarted(event: QueryStartedEvent): Unit = {
    if (log.isDebugEnabled) log.debug(s"onQueryStarted queryName=${event.name} id=${event.id} runId=${event.runId}")
  }

  //https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
  override def onQueryProgress(progressEvent: QueryProgressEvent): Unit = {
    val progress = progressEvent.progress
    val inputRowsPerSecond = progress.inputRowsPerSecond
    val processedRowsPerSecond = progress.processedRowsPerSecond

    // note: leaving this here to remind that we can do fancy things with this for metrics sake
    

    val sources = progress.sources.map { source =>
      val description = source.description
      val startOffset = source.startOffset
      val endOffset = source.endOffset
      val inputRows = source.numInputRows

      s"topic=$description startOffset=$startOffset endOffset=$endOffset numRows=$inputRows"
    }
    val tags = defaultTag + ( "stream_name" -> progress.name )
    Kamon.metrics.histogram("spark.query.progress.processed.rows.rate", tags).record(doubleToLong(processedRowsPerSecond))
    Kamon.metrics.histogram("spark.query.progress.input.rows.rate", tags).record(doubleToLong(inputRowsPerSecond))

    // todo - could take num.rows.total, given total percentage of records that will be watermarked going forwards... (simple metric that say loss_percentage due to watermark)

    // should give min, avg, max, watermark
    val eventTime = progress.eventTime
    if (eventTime != null) {

      log.info(s"event.time=${eventTime.asScala.mkString(",")}")
    }

    log.info(s"query.progress query=${progress.name} kafka=${sources.mkString(",")} inputRows/s=$inputRowsPerSecond processedRows/s=$processedRowsPerSecond durationMs=${progress.durationMs} sink=${progress.sink.json}")
  }

  override def onQueryTerminated(event: QueryTerminatedEvent): Unit = {
    log.warn(s"queryTerminated: $event")
    val possibleStreamingQuery = streams.get(event.id)
    if (possibleStreamingQuery != null) {
      val progress = possibleStreamingQuery.lastProgress
      val sources = progress.sources
      log.warn(s"last.progress.sources sources=$sources")
    }

    event.exception match {
      case Some(exception) =>
        log.warn(s"queryEndedWithException exception=$exception resetting.all.streams")
        restart()
      case None =>
    }
  }
} 
Example 113
Source File: TestUtils.scala    From odsc-east-realish-predictions   with Apache License 2.0 5 votes vote down vote up
package com.twilio.open.odsc.realish

import com.holdenkarau.spark.testing.{LocalSparkContext, SparkContextProvider}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, Suite}

object TestUtils {

}

@SerialVersionUID(1L)
case class UserPersonality(uuid: String, name: String, tags: Array[String])
  extends Serializable

@SerialVersionUID(1L)
case class Author(uuid: String, name: String, age: Int) extends Serializable

@SerialVersionUID(1L)
case class LibraryBook(uuid: String, name: String, author: Author) extends Serializable

case class MockKafkaDataFrame(key: Array[Byte], value: Array[Byte])

trait SharedSparkSql extends BeforeAndAfterAll with SparkContextProvider {
  self: Suite =>

  @transient var _sparkSql: SparkSession = _
  @transient private var _sc: SparkContext = _

  override def sc: SparkContext = _sc

  def conf: SparkConf

  def sparkSql: SparkSession = _sparkSql

  override def beforeAll() {
    _sparkSql = SparkSession.builder().config(conf).getOrCreate()

    _sc = _sparkSql.sparkContext
    setup(_sc)
    super.beforeAll()
  }

  override def afterAll() {
    try {
      _sparkSql.close()
      _sparkSql = null
      LocalSparkContext.stop(_sc)
      _sc = null
    } finally {
      super.afterAll()
    }
  }

} 
Example 114
Source File: GenericMainClass.scala    From darwin   with Apache License 2.0 5 votes vote down vote up
package it.agilelab.darwin.app.spark

import java.text.SimpleDateFormat
import java.util.Date

import com.typesafe.config.{Config, ConfigFactory}
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.sql.SparkSession
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.JavaConverters._

trait GenericMainClass {
  self: SparkManager =>

  val genericMainClassLogger: Logger = LoggerFactory.getLogger("SparkManager")

  private def makeFileSystem(session: SparkSession): FileSystem = {
    if (session.sparkContext.isLocal) {
      FileSystem.getLocal(session.sparkContext.hadoopConfiguration)
    }
    else {
      FileSystem.get(session.sparkContext.hadoopConfiguration)
    }
  }


  
  // scalastyle:off
  private def getGlobalConfig: Config = {
    genericMainClassLogger.debug("system environment vars")
    for ((k, v) <- System.getenv().asScala.toSeq.sortBy(_._1)) genericMainClassLogger.debug(s"$k -> $v")

    genericMainClassLogger.debug("system properties")
    for ((k, v) <- System.getProperties.asScala.toSeq.sortBy(_._1)) genericMainClassLogger.debug(s"$k -> $v")

    ConfigFactory.load()
  }

  // scalastyle:on

} 
Example 115
Source File: SparkManager.scala    From darwin   with Apache License 2.0 5 votes vote down vote up
package it.agilelab.darwin.app.spark

import com.typesafe.config.Config
import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.JavaConverters._

trait SparkManager {

  val sparkManagerLogger: Logger = LoggerFactory.getLogger("SparkManager")

  
  protected def defaultParallelism(implicit sparkSession: SparkSession, config: Config): Int = {
    sparkSession.conf.getOption(SparkConfigurationKeys.SPARK_EXECUTOR_INSTANCES) match {
      case Some(instances) =>
        sparkSession.conf.getOption(SparkConfigurationKeys.SPARK_CORES).getOrElse("1").toInt * instances.toInt
      case None =>
        sparkManagerLogger.info("Spark is configured with dynamic allocation, default parallelism will be gathered from app " +
          "conf: " +
          "next.process.parallelism")
        if (config.hasPath(SparkConfigurationKeys.PARALLELISM)) {
          config.getInt(SparkConfigurationKeys.PARALLELISM)
        } else {
          sparkManagerLogger.info("next.process.parallelism was not set fallback to sparkSession.defaultParallelism")
          sparkSession.sparkContext.defaultParallelism
        }
    }
  }
} 
Example 116
Source File: SchemaManagerSparkApp.scala    From darwin   with Apache License 2.0 5 votes vote down vote up
package it.agilelab.darwin.app.spark

import java.nio.ByteOrder

import com.typesafe.config.{Config, ConfigFactory}
import it.agilelab.darwin.app.spark.classes._
import it.agilelab.darwin.manager.AvroSchemaManagerFactory
import org.apache.avro.reflect.ReflectData
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.sql.SparkSession
import org.slf4j.{Logger, LoggerFactory}

object SchemaManagerSparkApp extends GenericMainClass with SparkManager {

  val mainLogger: Logger = LoggerFactory.getLogger("SchemaManagerSparkApp")

  val endianness: ByteOrder = ByteOrder.BIG_ENDIAN

  override protected def runJob(settings: Config)(implicit fs: FileSystem, sparkSession: SparkSession): Int = {
    import sparkSession.implicits._

    val ds = sparkSession.createDataset(sparkSession.sparkContext.parallelize(1 to 1000, 20))
    mainLogger.info("Registering schemas")
    //    val reflections = new Reflections("it.agilelab.darwin.app.spark.classes")
    //    val annotationClass: Class[AvroSerde] = classOf[AvroSerde]
    //    val classes = reflections.getTypesAnnotatedWith(annotationClass).asScala.toSeq
    //      .filter(c => !c.isInterface && !Modifier.isAbstract(c.getModifiers))
    //    val schemas = classes.map(c => ReflectData.get().getSchema(Class.forName(c.getName)))
    val schemas = Seq(ReflectData.get().getSchema(classOf[Menu]), ReflectData.get().getSchema(classOf[MenuItem]),
      ReflectData.get().getSchema(classOf[Food]), ReflectData.get().getSchema(classOf[Order]),
      ReflectData.get().getSchema(classOf[Price]))
    val conf = ConfigFactory.load()
    val manager = AvroSchemaManagerFactory.initialize(conf)
    val registeredIDs: Seq[Long] = manager.registerAll(schemas).map(_._1)
    mainLogger.info("Schemas registered")

    mainLogger.info("Getting ID for a schema")
    manager.getId(ReflectData.get().getSchema(classOf[Menu]))
    mainLogger.info("ID retrieved for the schema")

    mainLogger.info("Get Schema from ID")
    val d2 = ds.map { x =>
      AvroSchemaManagerFactory.initialize(conf).getSchema(registeredIDs(x % registeredIDs.size))
      x
    }
    d2.count()
    mainLogger.info("All schemas obtained")
    10
  }

  override protected def handleException(exception: Throwable, applicationSettings: Config): Unit = {
    mainLogger.error(exception.getMessage)
  }
} 
Example 117
Source File: AmazonReviewsPipeline.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.pipelines.text

import breeze.linalg.SparseVector
import keystoneml.evaluation.BinaryClassifierEvaluator
import keystoneml.loaders.{AmazonReviewsDataLoader, LabeledData}
import keystoneml.nodes.learning.LogisticRegressionEstimator
import keystoneml.nodes.nlp._
import keystoneml.nodes.stats.TermFrequency
import keystoneml.nodes.util.CommonSparseFeatures
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}
import keystoneml.pipelines.Logging
import scopt.OptionParser
import keystoneml.workflow.Pipeline

object AmazonReviewsPipeline extends Logging {
  val appName = "AmazonReviewsPipeline"

  def run(spark: SparkSession, conf: AmazonReviewsConfig): Pipeline[String, Double] = {
    val amazonTrainData = AmazonReviewsDataLoader(spark, conf.trainLocation, conf.threshold).labeledData
    val trainData = LabeledData(amazonTrainData.repartition(conf.numParts).cache())

    val training = trainData.data
    val labels = trainData.labels

    // Build the classifier estimator
    val predictor = Trim andThen
        LowerCase() andThen
        Tokenizer() andThen
        NGramsFeaturizer(1 to conf.nGrams) andThen
        TermFrequency(x => 1) andThen
        (CommonSparseFeatures[Seq[String]](conf.commonFeatures), training) andThen
        (LogisticRegressionEstimator[SparseVector[Double]](numClasses = 2, numIters = conf.numIters), training, labels)

    // Evaluate the classifier
    val amazonTestData = AmazonReviewsDataLoader(spark, conf.testLocation, conf.threshold).labeledData
    val testData = LabeledData(amazonTestData.repartition(conf.numParts).cache())
    val testLabels = testData.labels
    val testResults = predictor(testData.data)
    val eval = BinaryClassifierEvaluator.evaluate(testResults.get.map(_ > 0), testLabels.map(_ > 0))

    logInfo("\n" + eval.summary())
    predictor
  }

  case class AmazonReviewsConfig(
    trainLocation: String = "",
    testLocation: String = "",
    threshold: Double = 3.5,
    nGrams: Int = 2,
    commonFeatures: Int = 100000,
    numIters: Int = 20,
    numParts: Int = 512)

  def parse(args: Array[String]): AmazonReviewsConfig = new OptionParser[AmazonReviewsConfig](appName) {
    head(appName, "0.1")
    opt[String]("trainLocation") required() action { (x,c) => c.copy(trainLocation=x) }
    opt[String]("testLocation") required() action { (x,c) => c.copy(testLocation=x) }
    opt[Double]("threshold") action { (x,c) => c.copy(threshold=x)}
    opt[Int]("nGrams") action { (x,c) => c.copy(nGrams=x) }
    opt[Int]("commonFeatures") action { (x,c) => c.copy(commonFeatures=x) }
    opt[Int]("numIters") action { (x,c) => c.copy(numParts=x) }
    opt[Int]("numParts") action { (x,c) => c.copy(numParts=x) }
  }.parse(args, AmazonReviewsConfig()).get

  
  def main(args: Array[String]) = {
    val conf = new SparkConf().setAppName(appName)
    conf.setIfMissing("spark.master", "local[2]") // This is a fallback if things aren't set via spark submit.

    val spark = SparkSession.builder.config(conf).getOrCreate()

    val appConfig = parse(args)
    run(spark, appConfig)

    spark.stop()
  }
} 
Example 118
Source File: VLORRealDataExample.scala    From spark-vlbfgs   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.example

import org.apache.spark.ml.classification.{LogisticRegression, VLogisticRegression}
import org.apache.spark.sql.{Dataset, SparkSession}

object VLORRealDataExample {

  // https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#a9a
  def main(args: Array[String]) = {
    val spark = SparkSession
      .builder()
      .appName("VLogistic Regression real data example")
      .getOrCreate()

    val sc = spark.sparkContext

    val dataset1: Dataset[_] = spark.read.format("libsvm").load("data/a9a")

    val trainer = new LogisticRegression()
      .setFitIntercept(false)
      .setRegParam(0.5)
    val model = trainer.fit(dataset1)

    val vtrainer = new VLogisticRegression()
      .setColsPerBlock(100)
      .setRowsPerBlock(10)
      .setColPartitions(3)
      .setRowPartitions(3)
      .setRegParam(0.5)
    val vmodel = vtrainer.fit(dataset1)

    println(s"VLogistic regression coefficients: ${vmodel.coefficients}")
    println(s"Logistic regression coefficients: ${model.coefficients}")

    sc.stop()
  }
} 
Example 119
Source File: LORExample2.scala    From spark-vlbfgs   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.example

import org.apache.spark.ml.classification.MyLogisticRegression
import org.apache.spark.sql.{Dataset, SparkSession}

object LORExample2 {

  def main(args: Array[String]) = {

    var maxIter: Int = 100

    var dimension: Int = 780
    var regParam: Double = 0.5
    var fitIntercept: Boolean = true
    var elasticNetParam = 1.0

    var dataPath: String = null

    try {
      maxIter = args(0).toInt

      dimension = args(1).toInt

      regParam = args(2).toDouble
      fitIntercept = args(3).toBoolean
      elasticNetParam = args(4).toDouble

      dataPath = args(5)
    } catch {
      case _: Throwable =>
        println("Param list: "
          + "maxIter dimension"
          + " regParam fitIntercept elasticNetParam dataPath")
        println("parameter description:" +
          "\nmaxIter          max iteration number for VLogisticRegression" +
          "\ndimension        training data dimension number" +
          "\nregParam         regularization parameter" +
          "\nfitIntercept     whether to train intercept, true or false" +
          "\nelasticNetParam  elastic net parameter for regulization" +
          "\ndataPath         training data path on HDFS")

        System.exit(-1)
    }

    val spark = SparkSession
      .builder()
      .appName("LOR for testing")
      .getOrCreate()

    val sc = spark.sparkContext

    try {
      println(s"begin load data from $dataPath")
      val dataset: Dataset[_] = spark.read.format("libsvm")
        .option("numFeatures", dimension.toString)
        .load(dataPath)

      val trainer = new MyLogisticRegression()
        .setMaxIter(maxIter)
        .setRegParam(regParam)
        .setFitIntercept(fitIntercept)
        .setElasticNetParam(elasticNetParam)

      val model = trainer.fit(dataset)

      println(s"LOR done, coeffs non zeros: ${model.coefficients.numNonzeros}")
    } catch {
      case e: Exception =>
        e.printStackTrace()
    }finally {
      // println("Press ENTER to exit.")
      // System.in.read()
    }
    sc.stop()
  }

} 
Example 120
Source File: CreateHiveTableAsSelectCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import scala.util.control.NonFatal

import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand



case class CreateHiveTableAsSelectCommand(
    tableDesc: CatalogTable,
    query: LogicalPlan,
    outputColumnNames: Seq[String],
    mode: SaveMode)
  extends DataWritingCommand {

  private val tableIdentifier = tableDesc.identifier

  override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog
    if (catalog.tableExists(tableIdentifier)) {
      assert(mode != SaveMode.Overwrite,
        s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite")

      if (mode == SaveMode.ErrorIfExists) {
        throw new AnalysisException(s"$tableIdentifier already exists.")
      }
      if (mode == SaveMode.Ignore) {
        // Since the table already exists and the save mode is Ignore, we will just return.
        return Seq.empty
      }

      // For CTAS, there is no static partition values to insert.
      val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap
      InsertIntoHiveTable(
        tableDesc,
        partition,
        query,
        overwrite = false,
        ifPartitionNotExists = false,
        outputColumnNames = outputColumnNames).run(sparkSession, child)
    } else {
      // TODO ideally, we should get the output data ready first and then
      // add the relation into catalog, just in case of failure occurs while data
      // processing.
      assert(tableDesc.schema.isEmpty)
      catalog.createTable(
        tableDesc.copy(schema = outputColumns.toStructType), ignoreIfExists = false)

      try {
        // Read back the metadata of the table which was created just now.
        val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier)
        // For CTAS, there is no static partition values to insert.
        val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap
        InsertIntoHiveTable(
          createdTableMeta,
          partition,
          query,
          overwrite = true,
          ifPartitionNotExists = false,
          outputColumnNames = outputColumnNames).run(sparkSession, child)
      } catch {
        case NonFatal(e) =>
          // drop the created table.
          catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, purge = false)
          throw e
      }
    }

    Seq.empty[Row]
  }

  override def argString: String = {
    s"[Database:${tableDesc.database}, " +
    s"TableName: ${tableDesc.identifier.table}, " +
    s"InsertIntoHiveTable]"
  }
} 
Example 121
Source File: Main.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
import org.apache.spark.sql.SparkSession


// TODO: actually rebuild this jar with the new changes.
object Main {
  def main(args: Array[String]) {
    // scalastyle:off println
    println("Running regression test for SPARK-8489.")
    val spark = SparkSession.builder
      .master("local")
      .appName("testing")
      .enableHiveSupport()
      .getOrCreate()
    // This line should not throw scala.reflect.internal.MissingRequirementError.
    // See SPARK-8470 for more detail.
    val df = spark.createDataFrame(Seq(MyCoolClass("1", "2", "3")))
    df.collect()
    println("Regression test for SPARK-8489 success!")
    // scalastyle:on println
    spark.stop()
  }
} 
Example 122
Source File: TestHiveSingleton.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.test

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.hive.HiveExternalCatalog
import org.apache.spark.sql.hive.client.HiveClient


trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll {
  override protected val enableAutoThreadAudit = false
  protected val spark: SparkSession = TestHive.sparkSession
  protected val hiveContext: TestHiveContext = TestHive
  protected val hiveClient: HiveClient =
    spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client

  protected override def afterAll(): Unit = {
    try {
      hiveContext.reset()
    } finally {
      super.afterAll()
    }
  }

} 
Example 123
Source File: CommitFailureTestSource.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.TaskContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

class CommitFailureTestSource extends SimpleTextSource {
  
  override def prepareWrite(
      sparkSession: SparkSession,
      job: Job,
      options: Map[String, String],
      dataSchema: StructType): OutputWriterFactory =
    new OutputWriterFactory {
      override def newInstance(
          path: String,
          dataSchema: StructType,
          context: TaskAttemptContext): OutputWriter = {
        new SimpleTextOutputWriter(path, dataSchema, context) {
          var failed = false
          TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
            failed = true
            SimpleTextRelation.callbackCalled = true
          }

          override def write(row: InternalRow): Unit = {
            if (SimpleTextRelation.failWriter) {
              sys.error("Intentional task writer failure for testing purpose.")

            }
            super.write(row)
          }

          override def close(): Unit = {
            super.close()
            sys.error("Intentional task commitment failure for testing purpose.")
          }
        }
      }

      override def getFileExtension(context: TaskAttemptContext): String = ""
    }

  override def shortName(): String = "commit-failure-test"
} 
Example 124
Source File: SparkSQLEnv.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.io.PrintStream

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.Utils


  def stop() {
    logDebug("Shutting down Spark SQL Environment")
    // Stop the SparkContext
    if (SparkSQLEnv.sparkContext != null) {
      sparkContext.stop()
      sparkContext = null
      sqlContext = null
    }
  }
} 
Example 125
Source File: HiveMetastoreLazyInitializationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.util.Utils

class HiveMetastoreLazyInitializationSuite extends SparkFunSuite {

  test("lazily initialize Hive client") {
    val spark = SparkSession.builder()
      .appName("HiveMetastoreLazyInitializationSuite")
      .master("local[2]")
      .enableHiveSupport()
      .config("spark.hadoop.hive.metastore.uris", "thrift://127.0.0.1:11111")
      .getOrCreate()
    val originalLevel = org.apache.log4j.Logger.getRootLogger().getLevel
    try {
      // Avoid outputting a lot of expected warning logs
      spark.sparkContext.setLogLevel("error")

      // We should be able to run Spark jobs without Hive client.
      assert(spark.sparkContext.range(0, 1).count() === 1)

      // We should be able to use Spark SQL if no table references.
      assert(spark.sql("select 1 + 1").count() === 1)
      assert(spark.range(0, 1).count() === 1)

      // We should be able to use fs
      val path = Utils.createTempDir()
      path.delete()
      try {
        spark.range(0, 1).write.parquet(path.getAbsolutePath)
        assert(spark.read.parquet(path.getAbsolutePath).count() === 1)
      } finally {
        Utils.deleteRecursively(path)
      }

      // Make sure that we are not using the local derby metastore.
      val exceptionString = Utils.exceptionString(intercept[AnalysisException] {
        spark.sql("show tables")
      })
      for (msg <- Seq(
        "show tables",
        "Could not connect to meta store",
        "org.apache.thrift.transport.TTransportException",
        "Connection refused")) {
        exceptionString.contains(msg)
      }
    } finally {
      spark.sparkContext.setLogLevel(originalLevel.toString)
      spark.stop()
    }
  }
} 
Example 126
Source File: datasources.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql.execution.command

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.xsql.XSQLSessionCatalog

case class XSQLShowDatasourcesCommand(datasourcePattern: Option[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("dataSourceName", StringType, nullable = false)() :: Nil
  }

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    val datasources =
      datasourcePattern
        .map { pattern =>
          catalog.listDatasources(pattern)
        }
        .getOrElse(catalog.listDatasources())
    datasources.map { d =>
      Row(d)
    }
  }
}

case class XSQLAddDatasourceCommand(dataSourceName: String, properties: Map[String, String])
  extends RunnableCommand {
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    catalog.addDataSource(dataSourceName, properties)
    Seq.empty[Row]
  }
}

case class XSQLRemoveDatasourceCommand(dataSourceName: String, ifExists: Boolean)
  extends RunnableCommand {
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    catalog.removeDataSource(dataSourceName, ifExists)
    Seq.empty[Row]
  }
}

case class XSQLRefreshDatasourceCommand(dataSourceName: String) extends RunnableCommand {
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    catalog.refreshDataSource(dataSourceName)
    Seq.empty[Row]
  }
} 
Example 127
Source File: databases.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql.execution.command

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.xsql.XSQLSessionCatalog


case class XSQLSetDatabaseCommand(dataSourceName: Option[String], databaseName: String)
  extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    if (dataSourceName.isEmpty) {
      catalog.setCurrentDatabase(databaseName)
    } else {
      catalog.setCurrentDatabase(dataSourceName.get, databaseName)
    }
    Seq.empty[Row]
  }
} 
Example 128
Source File: XSQLAnalyzeTableCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql.execution.command

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.execution.command.{CommandUtils, RunnableCommand}
import org.apache.spark.sql.xsql.XSQLSessionCatalog


case class XSQLAnalyzeTableCommand(tableIdent: TableIdentifier, noscan: Boolean = true)
  extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val sessionState = sparkSession.sessionState
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    val catalogDB = catalog.getUsedCatalogDatabase(tableIdent.dataSource, tableIdent.database)
    if (catalogDB == None) {
      return Seq.empty[Row]
    }
    val ds = catalogDB.get.dataSourceName
    val db = catalogDB.get.name
    val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db), Some(ds))
    val tableMeta = catalog.getRawTable(tableIdentWithDB)
    if (tableMeta.tableType == CatalogTableType.VIEW) {
      throw new AnalysisException("ANALYZE TABLE is not supported on views.")
    }

    // Compute stats for the whole table
    val newTotalSize = CommandUtils.calculateTotalSize(sparkSession, tableMeta)
    val newRowCount =
      if (noscan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count()))

    // Update the metastore if the above statistics of the table are different from those
    // recorded in the metastore.
    val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount)
    if (newStats.isDefined) {
      catalog.alterTableStats(tableIdentWithDB, newStats)
    }

    Seq.empty[Row]
  }
} 
Example 129
Source File: StreamingIncrementCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql.execution.command

import java.util.Locale

import org.apache.spark.SparkException
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.StreamingRelationV2
import org.apache.spark.sql.sources.v2.StreamWriteSupport
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import org.apache.spark.sql.xsql.DataSourceManager._
import org.apache.spark.sql.xsql.StreamingSinkType


case class StreamingIncrementCommand(plan: LogicalPlan) extends RunnableCommand {

  private var outputMode: OutputMode = OutputMode.Append
  // dummy
  override def output: Seq[AttributeReference] = Seq.empty
  // dummy
  override def producedAttributes: AttributeSet = plan.producedAttributes

  override def run(sparkSession: SparkSession): Seq[Row] = {
    import StreamingSinkType._
    val qe = new QueryExecution(sparkSession, new ConstructedStreaming(plan))
    val df = new Dataset(sparkSession, qe, RowEncoder(qe.analyzed.schema))
    plan.collectLeaves.head match {
      case StreamingRelationV2(_, _, extraOptions, _, _) =>
        val source = extraOptions.getOrElse(STREAMING_SINK_TYPE, DEFAULT_STREAMING_SINK)
        val sinkOptions = extraOptions.filter(_._1.startsWith(STREAMING_SINK_PREFIX)).map { kv =>
          val key = kv._1.substring(STREAMING_SINK_PREFIX.length)
          (key, kv._2)
        }
        StreamingSinkType.withName(source.toUpperCase(Locale.ROOT)) match {
          case CONSOLE =>
          case TEXT | PARQUET | ORC | JSON | CSV =>
            if (sinkOptions.get(STREAMING_SINK_PATH) == None) {
              throw new SparkException("Sink type is file, must config path")
            }
          case KAFKA =>
            if (sinkOptions.get(STREAMING_SINK_BOOTSTRAP_SERVERS) == None) {
              throw new SparkException("Sink type is kafka, must config bootstrap servers")
            }
            if (sinkOptions.get(STREAMING_SINK_TOPIC) == None) {
              throw new SparkException("Sink type is kafka, must config kafka topic")
            }
          case _ =>
            throw new SparkException(
              "Sink type is invalid, " +
                s"select from ${StreamingSinkType.values}")
        }
        val ds = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
        val disabledSources = sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
        val sink = ds.newInstance() match {
          case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) =>
            w
          case _ =>
            val ds = DataSource(
              sparkSession,
              className = source,
              options = sinkOptions.toMap,
              partitionColumns = Nil)
            ds.createSink(InternalOutputModes.Append)
        }
        val outputMode = InternalOutputModes(
          extraOptions.getOrElse(STREAMING_OUTPUT_MODE, DEFAULT_STREAMING_OUTPUT_MODE))
        val duration =
          extraOptions.getOrElse(STREAMING_TRIGGER_DURATION, DEFAULT_STREAMING_TRIGGER_DURATION)
        val trigger =
          extraOptions.getOrElse(STREAMING_TRIGGER_TYPE, DEFAULT_STREAMING_TRIGGER_TYPE) match {
            case STREAMING_MICRO_BATCH_TRIGGER => Trigger.ProcessingTime(duration)
            case STREAMING_ONCE_TRIGGER => Trigger.Once()
            case STREAMING_CONTINUOUS_TRIGGER => Trigger.Continuous(duration)
          }
        val query = sparkSession.sessionState.streamingQueryManager.startQuery(
          extraOptions.get("queryName"),
          extraOptions.get(STREAMING_CHECKPOINT_LOCATION),
          df,
          sinkOptions.toMap,
          sink,
          outputMode,
          useTempCheckpointLocation = source == DEFAULT_STREAMING_SINK,
          recoverFromCheckpointLocation = true,
          trigger = trigger)
        query.awaitTermination()
    }
    // dummy
    Seq.empty
  }
}

case class ConstructedStreaming(child: LogicalPlan) extends UnaryNode {
  override def output: Seq[Attribute] = child.output
} 
Example 130
Source File: XSQLCreateHiveTableAsSelectCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql.execution.command

import scala.util.control.NonFatal

import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.xsql.XSQLSessionCatalog


case class XSQLCreateHiveTableAsSelectCommand(
    tableDesc: CatalogTable,
    query: LogicalPlan,
    outputColumnNames: Seq[String],
    mode: SaveMode)
  extends DataWritingCommand {

  override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    val tableIdentifier = catalog.getUsedTableIdentifier(tableDesc.identifier)
    val newTableDesc = tableDesc.copy(identifier = tableIdentifier)
    if (catalog.tableExists(tableIdentifier)) {
      assert(
        mode != SaveMode.Overwrite,
        s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite")

      if (mode == SaveMode.ErrorIfExists) {
        throw new AnalysisException(s"$tableIdentifier already exists.")
      }
      if (mode == SaveMode.Ignore) {
        // Since the table already exists and the save mode is Ignore, we will just return.
        return Seq.empty
      }

      XSQLInsertIntoHiveTable(
        newTableDesc,
        Map.empty,
        query,
        overwrite = false,
        ifPartitionNotExists = false,
        outputColumnNames = outputColumnNames).run(sparkSession, child)
    } else {
      // TODO ideally, we should get the output data ready first and then
      // add the relation into catalog, just in case of failure occurs while data
      // processing.
      assert(newTableDesc.schema.isEmpty)
      catalog.createTable(newTableDesc.copy(schema = query.schema), ignoreIfExists = false)

      try {
        // Read back the metadata of the table which was created just now.
        val createdTableMeta = catalog.getTableMetadata(newTableDesc.identifier)
        // For CTAS, there is no static partition values to insert.
        val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap
        XSQLInsertIntoHiveTable(
          createdTableMeta,
          partition,
          query,
          overwrite = true,
          ifPartitionNotExists = false,
          outputColumnNames = outputColumnNames).run(sparkSession, child)
      } catch {
        case NonFatal(e) =>
          // drop the created table.
          catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, purge = false)
          throw e
      }
    }

    Seq.empty[Row]
  }

  override def argString: String = {
    s"[TableName: ${tableDesc.identifier.table}, " +
      s"InsertIntoHiveTable]"
  }
} 
Example 131
Source File: CatalogFileIndex.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.net.URI

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.StructType



private class PrunedInMemoryFileIndex(
    sparkSession: SparkSession,
    tableBasePath: Path,
    fileStatusCache: FileStatusCache,
    override val partitionSpec: PartitionSpec,
    override val metadataOpsTimeNs: Option[Long])
  extends InMemoryFileIndex(
    sparkSession,
    partitionSpec.partitions.map(_.path),
    Map.empty,
    Some(partitionSpec.partitionColumns),
    fileStatusCache) 
Example 132
Source File: SaveIntoDataSourceCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.CreatableRelationProvider


case class SaveIntoDataSourceCommand(
    query: LogicalPlan,
    dataSource: CreatableRelationProvider,
    options: Map[String, String],
    mode: SaveMode) extends RunnableCommand {

  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    dataSource.createRelation(
      sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query))

    Seq.empty[Row]
  }

  override def simpleString: String = {
    val redacted = SQLConf.get.redactOptions(options)
    s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}"
  }
} 
Example 133
Source File: HadoopFsRelation.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
import org.apache.spark.sql.types.{StructField, StructType}



case class HadoopFsRelation(
    location: FileIndex,
    partitionSchema: StructType,
    dataSchema: StructType,
    bucketSpec: Option[BucketSpec],
    fileFormat: FileFormat,
    options: Map[String, String])(val sparkSession: SparkSession)
  extends BaseRelation with FileRelation {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  private def getColName(f: StructField): String = {
    if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
      f.name
    } else {
      f.name.toLowerCase(Locale.ROOT)
    }
  }

  val overlappedPartCols = mutable.Map.empty[String, StructField]
  partitionSchema.foreach { partitionField =>
    if (dataSchema.exists(getColName(_) == getColName(partitionField))) {
      overlappedPartCols += getColName(partitionField) -> partitionField
    }
  }

  // When data and partition schemas have overlapping columns, the output
  // schema respects the order of the data schema for the overlapping columns, and it
  // respects the data types of the partition schema.
  val schema: StructType = {
    StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++
      partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f))))
  }

  def partitionSchemaOption: Option[StructType] =
    if (partitionSchema.isEmpty) None else Some(partitionSchema)

  override def toString: String = {
    fileFormat match {
      case source: DataSourceRegister => source.shortName()
      case _ => "HadoopFiles"
    }
  }

  override def sizeInBytes: Long = {
    val compressionFactor = sqlContext.conf.fileCompressionFactor
    (location.sizeInBytes * compressionFactor).toLong
  }


  override def inputFiles: Array[String] = location.inputFiles
} 
Example 134
Source File: SQLExecution.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}

object SQLExecution {

  val EXECUTION_ID_KEY = "spark.sql.execution.id"

  private val _nextExecutionId = new AtomicLong(0)

  private def nextExecutionId: Long = _nextExecutionId.getAndIncrement

  private val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]()

  def getQueryExecution(executionId: Long): QueryExecution = {
    executionIdToQueryExecution.get(executionId)
  }

  private val testing = sys.props.contains("spark.testing")

  private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
    val sc = sparkSession.sparkContext
    // only throw an exception during tests. a missing execution ID should not fail a job.
    if (testing && sc.getLocalProperty(EXECUTION_ID_KEY) == null) {
      // Attention testers: when a test fails with this exception, it means that the action that
      // started execution of a query didn't call withNewExecutionId. The execution ID should be
      // set by calling withNewExecutionId in the action that begins execution, like
      // Dataset.collect or DataFrameWriter.insertInto.
      throw new IllegalStateException("Execution ID should be set")
    }
  }

  
  def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = {
    val sc = sparkSession.sparkContext
    // Set all the specified SQL configs to local properties, so that they can be available at
    // the executor side.
    val allConfigs = sparkSession.sessionState.conf.getAllConfs
    val originalLocalProps = allConfigs.collect {
      case (key, value) if key.startsWith("spark") =>
        val originalValue = sc.getLocalProperty(key)
        sc.setLocalProperty(key, value)
        (key, originalValue)
    }

    try {
      body
    } finally {
      for ((key, value) <- originalLocalProps) {
        sc.setLocalProperty(key, value)
      }
    }
  }
} 
Example 135
Source File: subquery.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}


case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls.
    val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
    plan transformAllExpressions {
      case sub: ExecSubqueryExpression =>
        val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
        val sameResult = sameSchema.find(_.sameResult(sub.plan))
        if (sameResult.isDefined) {
          sub.withNewPlan(sameResult.get)
        } else {
          sameSchema += sub.plan
          sub
        }
    }
  }
} 
Example 136
Source File: ExistingRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Encoder, Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

object RDDConversions {
  def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
    data.mapPartitions { iterator =>
      val numColumns = outputTypes.length
      val mutableRow = new GenericInternalRow(numColumns)
      val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter)
      iterator.map { r =>
        var i = 0
        while (i < numColumns) {
          mutableRow(i) = converters(i)(r.productElement(i))
          i += 1
        }

        mutableRow
      }
    }
  }

  
case class RDDScanExec(
    output: Seq[Attribute],
    rdd: RDD[InternalRow],
    name: String,
    override val outputPartitioning: Partitioning = UnknownPartitioning(0),
    override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode {

  private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("")

  override val nodeName: String = s"Scan $name$rddName"

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    rdd.mapPartitionsWithIndexInternal { (index, iter) =>
      val proj = UnsafeProjection.create(schema)
      proj.initialize(index)
      iter.map { r =>
        numOutputRows += 1
        proj(r)
      }
    }
  }

  override def simpleString: String = {
    s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}"
  }
} 
Example 137
Source File: cache.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

case class CacheTableCommand(
    tableIdent: TableIdentifier,
    plan: Option[LogicalPlan],
    isLazy: Boolean) extends RunnableCommand {
  require(plan.isEmpty || tableIdent.database.isEmpty,
    "Database name is not allowed in CACHE TABLE AS SELECT")

  override protected def innerChildren: Seq[QueryPlan[_]] = plan.toSeq

  override def run(sparkSession: SparkSession): Seq[Row] = {
    plan.foreach { logicalPlan =>
      Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
    }
    sparkSession.catalog.cacheTable(tableIdent.quotedString)

    if (!isLazy) {
      // Performs eager caching
      sparkSession.table(tableIdent).count()
    }

    Seq.empty[Row]
  }
}


case class UncacheTableCommand(
    tableIdent: TableIdentifier,
    ifExists: Boolean) extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val tableId = tableIdent.quotedString
    if (!ifExists || sparkSession.catalog.tableExists(tableId)) {
      sparkSession.catalog.uncacheTable(tableId)
    }
    Seq.empty[Row]
  }
}


  override def makeCopy(newArgs: Array[AnyRef]): ClearCacheCommand = ClearCacheCommand()
} 
Example 138
Source File: resources.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 139
Source File: AnalyzeTableCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTableType



case class AnalyzeTableCommand(
    tableIdent: TableIdentifier,
    noscan: Boolean = true) extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val sessionState = sparkSession.sessionState
    val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase)
    val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db))
    val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB)
    if (tableMeta.tableType == CatalogTableType.VIEW) {
      throw new AnalysisException("ANALYZE TABLE is not supported on views.")
    }

    // Compute stats for the whole table
    val newTotalSize = CommandUtils.calculateTotalSize(sparkSession, tableMeta)
    val newRowCount =
      if (noscan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count()))

    // Update the metastore if the above statistics of the table are different from those
    // recorded in the metastore.
    val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount)
    if (newStats.isDefined) {
      sessionState.catalog.alterTableStats(tableIdentWithDB, newStats)
    }

    Seq.empty[Row]
  }
} 
Example 140
Source File: DataWritingCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.hadoop.conf.Configuration

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.datasources.FileFormatWriter
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.util.SerializableConfiguration


  def logicalPlanOutputWithNames(
      query: LogicalPlan,
      names: Seq[String]): Seq[Attribute] = {
    // Save the output attributes to a variable to avoid duplicated function calls.
    val outputAttributes = query.output
    assert(outputAttributes.length == names.length,
      "The length of provided names doesn't match the length of output attributes.")
    outputAttributes.zip(names).map { case (attr, outputName) =>
      attr.withName(outputName)
    }
  }
} 
Example 141
Source File: StreamingQueryWrapper.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.util.UUID

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus}


  def explainInternal(extended: Boolean): String = {
    streamingQuery.explainInternal(extended)
  }

  override def sparkSession: SparkSession = {
    streamingQuery.sparkSession
  }

  override def recentProgress: Array[StreamingQueryProgress] = {
    streamingQuery.recentProgress
  }

  override def status: StreamingQueryStatus = {
    streamingQuery.status
  }

  override def exception: Option[StreamingQueryException] = {
    streamingQuery.exception
  }
} 
Example 142
Source File: MetadataLogFileIndex.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import scala.collection.mutable

import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.types.StructType



class MetadataLogFileIndex(
    sparkSession: SparkSession,
    path: Path,
    userSpecifiedSchema: Option[StructType])
  extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) {

  private val metadataDirectory = new Path(path, FileStreamSink.metadataDir)
  logInfo(s"Reading streaming file log from $metadataDirectory")
  private val metadataLog =
    new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toUri.toString)
  private val allFilesFromLog = metadataLog.allFiles().map(_.toFileStatus).filterNot(_.isDirectory)
  private var cachedPartitionSpec: PartitionSpec = _

  override protected val leafFiles: mutable.LinkedHashMap[Path, FileStatus] = {
    new mutable.LinkedHashMap ++= allFilesFromLog.map(f => f.getPath -> f)
  }

  override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = {
    allFilesFromLog.groupBy(_.getPath.getParent)
  }

  override def rootPaths: Seq[Path] = path :: Nil

  override def refresh(): Unit = { }

  override def partitionSpec(): PartitionSpec = {
    if (cachedPartitionSpec == null) {
      cachedPartitionSpec = inferPartitioning()
    }
    cachedPartitionSpec
  }
} 
Example 143
Source File: OffsetSeqLog.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming


import java.io.{InputStream, OutputStream}
import java.nio.charset.StandardCharsets._

import scala.io.{Source => IOSource}

import org.apache.spark.sql.SparkSession


class OffsetSeqLog(sparkSession: SparkSession, path: String)
  extends HDFSMetadataLog[OffsetSeq](sparkSession, path) {

  override protected def deserialize(in: InputStream): OffsetSeq = {
    // called inside a try-finally where the underlying stream is closed in the caller
    def parseOffset(value: String): Offset = value match {
      case OffsetSeqLog.SERIALIZED_VOID_OFFSET => null
      case json => SerializedOffset(json)
    }
    val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
    if (!lines.hasNext) {
      throw new IllegalStateException("Incomplete log file")
    }

    val version = parseVersion(lines.next(), OffsetSeqLog.VERSION)

    // read metadata
    val metadata = lines.next().trim match {
      case "" => None
      case md => Some(md)
    }
    OffsetSeq.fill(metadata, lines.map(parseOffset).toArray: _*)
  }

  override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = {
    // called inside a try-finally where the underlying stream is closed in the caller
    out.write(("v" + OffsetSeqLog.VERSION).getBytes(UTF_8))

    // write metadata
    out.write('\n')
    out.write(offsetSeq.metadata.map(_.json).getOrElse("").getBytes(UTF_8))

    // write offsets, one per line
    offsetSeq.offsets.map(_.map(_.json)).foreach { offset =>
      out.write('\n')
      offset match {
        case Some(json: String) => out.write(json.getBytes(UTF_8))
        case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8))
      }
    }
  }
}

object OffsetSeqLog {
  private[streaming] val VERSION = 1
  private val SERIALIZED_VOID_OFFSET = "-"
} 
Example 144
Source File: FileStreamSinkLog.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.net.URI

import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf


class FileStreamSinkLog(
    metadataLogVersion: Int,
    sparkSession: SparkSession,
    path: String)
  extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) {

  private implicit val formats = Serialization.formats(NoTypeHints)

  protected override val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSinkLogCleanupDelay

  protected override val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion

  protected override val defaultCompactInterval =
    sparkSession.sessionState.conf.fileSinkLogCompactInterval

  require(defaultCompactInterval > 0,
    s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $defaultCompactInterval) " +
      "to a positive value.")

  override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = {
    val deletedFiles = logs.filter(_.action == FileStreamSinkLog.DELETE_ACTION).map(_.path).toSet
    if (deletedFiles.isEmpty) {
      logs
    } else {
      logs.filter(f => !deletedFiles.contains(f.path))
    }
  }
}

object FileStreamSinkLog {
  val VERSION = 1
  val DELETE_ACTION = "delete"
  val ADD_ACTION = "add"
} 
Example 145
Source File: LibSVMTransformationLocalFunctionalTests.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation

import java.io.{File, FileWriter}

import collection.JavaConverters._
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
import org.scalatest.mock.MockitoSugar

import org.apache.spark.sql.SparkSession

import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.LibSVMResponseRowDeserializer
import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.LibSVMRequestRowSerializer

class LibSVMTransformationLocalFunctionalTests extends FlatSpec with Matchers with MockitoSugar
  with BeforeAndAfter {

  val spark = SparkSession.builder
    .master("local")
    .appName("spark session")
    .getOrCreate()

  var libsvmDataFile : File = _
  val libsvmdata =
    "1.0 1:1.5 2:3.0 28:-39.935 55:0.01\n" +
      "0.0 2:3.0 28:-39.935 55:0.01\n" +
      "-1.0 23:-39.935 55:0.01\n" +
      "3.0 1:1.5 2:3.0"
  before {
    libsvmDataFile = File.createTempFile("temp", "temp")
    val fw = new FileWriter(libsvmDataFile)
    fw.write(libsvmdata)
    fw.close()
  }

  "LibSVMSerialization" should "serialize Spark loaded libsvm file to same contents" in {
    import spark.implicits._

    val df = spark.read.format("libsvm").load(libsvmDataFile.getPath)
    val libsvmSerializer = new LibSVMRequestRowSerializer(Some(df.schema))
    val result = df.map(row => new String(libsvmSerializer.serializeRow(row))).collect().mkString
    assert (libsvmdata.trim == result.trim)
  }

  "LibSVMDeserialization" should "deserialize serialized lib svm records" in {

    val libsvmdata =
      "1.0 1:1.5 2:3.0 28:-39.935 55:0.01\n" +
        "0.0 2:3.0 28:-39.935 55:0.01\n" +
        "-1.0 23:-39.935 55:0.01\n" +
        "3.0 1:1.5 2:3.0"

    val libsvmDeserializer = new LibSVMResponseRowDeserializer (55)
    val rowList = libsvmDeserializer.deserializeResponse(libsvmdata.getBytes).toBuffer.asJava
    val deserializedDataFrame = spark.createDataFrame(rowList, libsvmDeserializer.schema)
    val sparkProducedDataFrame = spark.read.format("libsvm").load(libsvmDataFile.getPath)

    val deserializedRows = deserializedDataFrame.collectAsList()
    val sparkRows = sparkProducedDataFrame.collectAsList()

    assert (deserializedRows == sparkRows)
  }
} 
Example 146
Source File: LocalBasedStrategies.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.optimizer.StarryLocalRelationReplace


object LocalBasedStrategies {

  def register(sparkSession: SparkSession): Unit = {
    sparkSession.experimental.extraStrategies = Seq(
      StarryAggStrategy(),
      StarryJoinLocalStrategy(sparkSession.sessionState.conf),
      StarryUnionLocalStrategy(),
      StarryLimitLocalStrategy(),
      StarryLocalTableScanStrategies()
    ) ++: sparkSession.experimental.extraStrategies

    sparkSession.experimental.extraOptimizations = Seq(
      StarryLocalRelationReplace
    )
  }

  def unRegister(sparkSession: SparkSession): Unit = {
    sparkSession.experimental.extraStrategies =
      sparkSession.experimental.extraStrategies
        .filter(strategy => !strategy.isInstanceOf[StarryJoinLocalStrategy])
        .filter(strategy => !strategy.isInstanceOf[StarryUnionLocalStrategy])
        .filter(strategy => !strategy.isInstanceOf[StarryLimitLocalStrategy])
        .filter(strategy => !strategy.isInstanceOf[StarryAggStrategy])
        .filter(strategy => !strategy.isInstanceOf[StarryLocalTableScanStrategies])

    sparkSession.experimental.extraOptimizations = Seq()
  }

} 
Example 147
Source File: SparkPlanExecutor.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package com.github.passionke.starry

import org.apache.spark.{Partition, StarryTaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.{ReuseSubquery, SparkPlan}


object SparkPlanExecutor {

  def exec(plan: SparkPlan, sparkSession: SparkSession) = {
    val newPlan = Seq(
      ReuseSubquery(sparkSession.sessionState.conf))
      .foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
    doExec(newPlan)
  }

  def firstPartition(rdd: RDD[InternalRow]): Partition = {
    rdd.partitions.head
  }

  def doExec(sparkPlan: SparkPlan): List[InternalRow] = {
    val rdd = sparkPlan.execute().map(ite => ite.copy())
    val partition = firstPartition(rdd)
    rdd.compute(partition, new StarryTaskContext).toList
  }

  def rddCompute(rdd: RDD[InternalRow]): List[InternalRow] = {
    val partition = firstPartition(rdd)
    rdd.compute(partition, new StarryTaskContext).toList
  }

} 
Example 148
Source File: Spark.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import com.github.passionke.starry.StarrySparkContext
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.execution.LocalBasedStrategies


object Spark {

  val sparkConf = new SparkConf()
  sparkConf.setMaster("local[*]")
  sparkConf.setAppName("aloha")
  sparkConf
    .set("spark.default.parallelism", "1")
    .set("spark.sql.shuffle.partitions", "1")
    .set("spark.broadcast.manager", "rotary")
    .set("rotary.shuffer", "true")
    .set("spark.sql.codegen.wholeStage", "false")
    .set("spark.sql.extensions", "org.apache.spark.sql.StarrySparkSessionExtension")
    .set("spark.driver.allowMultipleContexts", "true") // for test only
  val sparkContext = new StarrySparkContext(sparkConf)
  val sparkSession: SparkSession =
    SparkSession.builder
      .sparkContext(sparkContext)
      .getOrCreate

  LocalBasedStrategies.register(sparkSession)
} 
Example 149
Source File: Database.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package com.github.passionke.student50

import org.apache.spark.Spark
import org.apache.spark.sql.SparkSession


object Database {

  def sparkSession(): SparkSession = {
    val sparkSession: SparkSession = Spark.sparkSession
    sparkSession.sparkContext.setLogLevel("WARN")
    sparkSession.createDataFrame(Student.students()).createOrReplaceTempView("student")
    sparkSession.createDataFrame(Teacher.teachers()).createOrReplaceTempView("techer")
    sparkSession.createDataFrame(Score.scores()).createOrReplaceTempView("score")
    sparkSession.createDataFrame(Course.courses()).createOrReplaceTempView("course")
    sparkSession
  }
} 
Example 150
Source File: KinesisWriter.scala    From kinesis-sql   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kinesis

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.util.Utils

private[kinesis] object KinesisWriter extends Logging {

  val DATA_ATTRIBUTE_NAME: String = "data"
  val PARTITION_KEY_ATTRIBUTE_NAME: String = "partitionKey"

  override def toString: String = "KinesisWriter"

  def write(sparkSession: SparkSession,
            queryExecution: QueryExecution,
            kinesisParameters: Map[String, String]): Unit = {
    val schema = queryExecution.analyzed.output

    SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
      queryExecution.toRdd.foreachPartition { iter =>
        val writeTask = new KinesisWriteTask(kinesisParameters, schema)
        Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
          finallyBlock = writeTask.close())
      }
    }
  }
} 
Example 151
Source File: SchemaJsonExample.scala    From spark-schema-registry   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.registry.examples

import java.util.UUID

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{from_json, struct, to_json}
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import org.apache.spark.sql.types._


object SchemaJsonExample {

  def main(args: Array[String]): Unit = {

    val bootstrapServers = if (args.length > 0) args(0) else "localhost:9092"
    val topic = if (args.length > 1) args(1) else "topic1"
    val outTopic = if (args.length > 2) args(2) else "topic1-out"
    val checkpointLocation =
      if (args.length > 3) args(3) else "/tmp/temporary-" + UUID.randomUUID.toString

    val spark = SparkSession
      .builder
      .appName("SchemaExample")
      .getOrCreate()

    val messages = spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServers)
      .option("subscribe", topic)
      .load()

    import spark.implicits._

    // the schema for truck events
    val schema = StructType(Seq(
      StructField("driverId", IntegerType, nullable = false),
      StructField("truckId", IntegerType, nullable = false),
      StructField("eventTime", StringType, nullable = false),
      StructField("eventType", StringType, nullable = false),
      StructField("longitude", DoubleType, nullable = false),
      StructField("latitude", DoubleType, nullable = false),
      StructField("eventKey", StringType, nullable = false),
      StructField("correlationId", StringType, nullable = false),
      StructField("driverName", StringType, nullable = false),
      StructField("routeId", IntegerType, nullable = false),
      StructField("routeName", StringType, nullable = false),
      StructField("eventDate", StringType, nullable = false),
      StructField("miles", IntegerType, nullable = false)
    ))

    // read messages from kafka and parse it using the above schema
    val df = messages
      .select(from_json($"value".cast("string"), schema).alias("value"))

    // project (driverId, truckId, miles) for the events where miles > 300
    val filtered = df.select($"value.driverId", $"value.truckId", $"value.miles")
      .where("value.miles > 300")

    // write the output to a kafka topic serialized as a JSON string.
    // should produce events like {"driverId":14,"truckId":25,"miles":373}
    val query = filtered
      .select(to_json(struct($"*")).alias("value"))
      .writeStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServers)
      .option("topic", outTopic)
      .option("checkpointLocation", checkpointLocation)
      .trigger(Trigger.ProcessingTime(10000))
      .outputMode(OutputMode.Append())
      .start()

    query.awaitTermination()
  }

} 
Example 152
Source File: SchemaRegistryAvroReader.scala    From spark-schema-registry   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.registry.examples

import java.util.UUID

import com.hortonworks.spark.registry.util._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.{OutputMode, Trigger}


object SchemaRegistryAvroReader {

  def main(args: Array[String]): Unit = {

    val schemaRegistryUrl = if (args.length > 0) args(0) else "http://localhost:9090/api/v1/"
    val bootstrapServers = if (args.length > 1) args(1) else "localhost:9092"
    val topic = if (args.length > 2) args(2) else "topic1-out"
    val checkpointLocation =
      if (args.length > 3) args(3) else "/tmp/temporary-" + UUID.randomUUID.toString
    val securityProtocol =
      if (args.length > 4) Option(args(4)) else None

    val spark = SparkSession
      .builder
      .appName("SchemaRegistryAvroReader")
      .getOrCreate()

    val reader = spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServers)
      .option("subscribe", topic)

    val messages = securityProtocol
      .map(p => reader.option("kafka.security.protocol", p).load())
      .getOrElse(reader.load())

    import spark.implicits._

    // the schema registry client config
    val config = Map[String, Object]("schema.registry.url" -> schemaRegistryUrl)

    // the schema registry config that will be implicitly passed
    implicit val srConfig: SchemaRegistryConfig = SchemaRegistryConfig(config)

    // Read messages from kafka and deserialize.
    // This uses the schema registry schema associated with the topic.
    val df = messages
      .select(from_sr($"value", topic).alias("message"))

    // write the output to console
    // should produce events like {"driverId":14,"truckId":25,"miles":373}
    val query = df
      .writeStream
      .format("console")
      .trigger(Trigger.ProcessingTime(10000))
      .outputMode(OutputMode.Append())
      .start()

    query.awaitTermination()
  }

} 
Example 153
Source File: LOFSuite.scala    From spark-lof   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.outlier

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import org.apache.spark.sql.functions._

object LOFSuite {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("LOFExample")
      .master("local[4]")
      .getOrCreate()

    val schema = new StructType(Array(
      new StructField("col1", DataTypes.DoubleType),
      new StructField("col2", DataTypes.DoubleType)))
    val df = spark.read.schema(schema).csv("data/outlier.csv")

    val assembler = new VectorAssembler()
      .setInputCols(df.columns)
      .setOutputCol("features")
    val data = assembler.transform(df).repartition(4)

    val startTime = System.currentTimeMillis()
    val result = new LOF()
      .setMinPts(5)
      .transform(data)
    val endTime = System.currentTimeMillis()
    result.count()

    // Outliers have much higher LOF value than normal data
    result.sort(desc(LOF.lof)).head(10).foreach { row =>
      println(row.get(0) + " | " + row.get(1) + " | " + row.get(2))
    }
    println("Total time = " + (endTime - startTime) / 1000.0 + "s")
  }
} 
Example 154
Source File: ElasticSink.scala    From Spark-Structured-Streaming-Examples   with Apache License 2.0 5 votes vote down vote up
package elastic

import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
import radio.{SimpleSongAggregation, Song}
import org.elasticsearch.spark.sql.streaming._
import org.elasticsearch.spark.sql._
import org.elasticsearch.spark.sql.streaming.EsSparkSqlStreamingSink

object ElasticSink {
  def writeStream(ds: Dataset[Song] ) : StreamingQuery = {
    ds   //Append output mode not supported when there are streaming aggregations on streaming DataFrames/DataSets without watermark
      .writeStream
      .outputMode(OutputMode.Append) //Only mode for ES
      .format("org.elasticsearch.spark.sql") //es
      .queryName("ElasticSink")
      .start("test/broadcast") //ES index
  }

} 
Example 155
Source File: SparkHelper.scala    From Spark-Structured-Streaming-Examples   with Apache License 2.0 5 votes vote down vote up
package spark

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession

object SparkHelper {
  def getAndConfigureSparkSession() = {
    val conf = new SparkConf()
      .setAppName("Structured Streaming from Parquet to Cassandra")
      .setMaster("local[2]")
      .set("spark.cassandra.connection.host", "127.0.0.1")
      .set("spark.sql.streaming.checkpointLocation", "checkpoint")
      .set("es.nodes", "localhost") // full config : https://www.elastic.co/guide/en/elasticsearch/hadoop/current/configuration.html
      .set("es.index.auto.create", "true") //https://www.elastic.co/guide/en/elasticsearch/hadoop/current/spark.html
      .set("es.nodes.wan.only", "true")

    val sc = new SparkContext(conf)
    sc.setLogLevel("WARN")

    SparkSession
      .builder()
      .getOrCreate()
  }

  def getSparkSession() = {
    SparkSession
      .builder()
      .getOrCreate()
  }
} 
Example 156
Source File: SparkSessionTestWrapper.scala    From spark-stringmetric   with MIT License 5 votes vote down vote up
package com.github.mrpowers.spark.stringmetric

import org.apache.spark.sql.SparkSession
import org.apache.log4j.{Logger, Level}

trait SparkSessionTestWrapper {

  lazy val spark: SparkSession = {
    Logger.getLogger("org").setLevel(Level.OFF)
    SparkSession
      .builder()
      .master("local")
      .appName("spark session")
      .config("spark.sql.shuffle.partitions", "1")
      .getOrCreate()
  }

} 
Example 157
Source File: IndexConf.scala    From parquet-index   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.internal

import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.sql.SparkSession

object IndexConf {
  import SQLConf.buildConf

  val METASTORE_LOCATION = buildConf("spark.sql.index.metastore").
    doc("Metastore location or root directory to store index information, will be created " +
      "if path does not exist").
    stringConf.
    createWithDefault("")

  val CREATE_IF_NOT_EXISTS = buildConf("spark.sql.index.createIfNotExists").
    doc("When set to true, creates index if one does not exist in metastore for the table").
    booleanConf.
    createWithDefault(false)

  val NUM_PARTITIONS = buildConf("spark.sql.index.partitions").
    doc("When creating index uses this number of partitions. If value is non-positive or not " +
      "provided then uses `sc.defaultParallelism * 3` or `spark.sql.shuffle.partitions` " +
      "configuration value, whichever is smaller").
    intConf.
    createWithDefault(0)

  val PARQUET_FILTER_STATISTICS_ENABLED =
    buildConf("spark.sql.index.parquet.filter.enabled").
    doc("When set to true, writes filter statistics for indexed columns when creating table " +
      "index, otherwise only min/max statistics are used. Filter statistics are always used " +
      "during filtering stage, if applicable").
    booleanConf.
    createWithDefault(true)

  val PARQUET_FILTER_STATISTICS_TYPE = buildConf("spark.sql.index.parquet.filter.type").
    doc("When filter statistics enabled, selects type of statistics to use when creating index. " +
      "Available options are `bloom`, `dict`").
    stringConf.
    createWithDefault("bloom")

  val PARQUET_FILTER_STATISTICS_EAGER_LOADING =
    buildConf("spark.sql.index.parquet.filter.eagerLoading").
    doc("When set to true, read and load all filter statistics in memory the first time catalog " +
      "is resolved, otherwise load them lazily as needed when evaluating predicate. " +
      "Eager loading removes IO of reading filter data from disk, but requires extra memory").
    booleanConf.
    createWithDefault(false)

  
  def unsetConf(entry: ConfigEntry[_]): Unit = {
    sqlConf.unsetConf(entry)
  }
} 
Example 158
Source File: SparkLocal.scala    From parquet-index   with Apache License 2.0 5 votes vote down vote up
package com.github.lightcopy.testutil

import org.apache.log4j.Level
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession


  private def localConf: SparkConf = {
    new SparkConf().
      setMaster("local[4]").
      setAppName("spark-local-test").
      set("spark.driver.memory", "1g").
      set("spark.executor.memory", "2g")
  }

  override def createSparkSession(): SparkSession = {
    SparkSession.builder().config(localConf).getOrCreate()
  }
} 
Example 159
Source File: SQLAggregationScala.scala    From infinispan-spark   with Apache License 2.0 5 votes vote down vote up
package org.infinispan.spark.examples.twitter

import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.infinispan.spark.examples.twitter.Sample.{getSparkConf, usage}
import org.infinispan.spark.rdd.InfinispanRDD


object SQLAggregationScala {

   def main(args: Array[String]) {
      if (args.length < 1) {
         usage("SQLAggregationScala")
      }

      Logger.getLogger("org").setLevel(Level.WARN)
      val infinispanHost = args(0)

      // Reduce the log level in the driver
      Logger.getLogger("org").setLevel(Level.WARN)

      // Create Spark Context
      val conf = getSparkConf("spark-infinispan-rdd-aggregation-scala")
      val sc = new SparkContext(conf)

      // Populate infinispan properties
      val config = Sample.getConnectorConf(infinispanHost)

      // Create RDD from infinispan data
      val infinispanRDD = new InfinispanRDD[Long, Tweet](sc, config)

      // Create a SQLContext, register a data frame and a temp table
      val valuesRDD = infinispanRDD.values
      val sparkSession = SparkSession.builder().config(conf).getOrCreate()
      val dataFrame = sparkSession.createDataFrame(valuesRDD, classOf[Tweet])
      dataFrame.createOrReplaceTempView("tweets")

      // Run the Query, collect and print results
      sparkSession.sql("SELECT country, count(*) as c from tweets WHERE country != 'N/A' GROUP BY country ORDER BY c desc")
        .collect().take(20).foreach(println)

   }

} 
Example 160
Source File: DataSetSuite.scala    From infinispan-spark   with Apache License 2.0 5 votes vote down vote up
package org.infinispan.spark.suites

import org.apache.spark.sql.SparkSession
import org.infinispan.spark.domain.Runner
import org.infinispan.spark.test._
import org.scalatest.{DoNotDiscover, FunSuite, Matchers}

@DoNotDiscover
class DataSetSuite extends FunSuite with RunnersCache with Spark with MultipleServers with Matchers
  with DatasetAssertions[Runner] {

   override protected def getNumEntries: Int = 100

   override def getConfiguration = {
      val config = super.getConfiguration
      config.addProtoAnnotatedClass(classOf[Runner])
      config.setAutoRegisterProto()
      config
   }

   test("read data using the DataFrame API") {
      val config = getConfiguration.toStringsMap
      val df = getSparkSession.read.format("infinispan").options(config).load()

      val filter = df.filter(df("age").gt(30)).filter(df("age").lt(40))

      assertDataset(filter, r => r.getAge > 30 && r.getAge < 40)
   }

   test("read using SQL single filter") {
      val config = getConfiguration.toStringsMap
      val session = getSparkSession

      val df = getSparkSession.read.format("infinispan").options(config).load()

      df.createOrReplaceTempView("runner")

      assertSql(session, "From runner where age > 30", _.getAge > 30)
      assertSql(session, "From runner where age >= 30", _.getAge >= 30)
      assertSql(session, "From runner where age < 50", _.getAge < 50)
      assertSql(session, "From runner where age <= 50", _.getAge <= 50)
      assertSql(session, "From runner where name LIKE 'runner1%'", _.getName.startsWith("runner1"))
      assertSql(session, "From runner where name LIKE '%unner2%'", _.getName.contains("unner2"))
   }

   test("read using SQL and projections") {
      val config = getConfiguration.toStringsMap
      val session: SparkSession = getSparkSession
      val df = session.read.format("infinispan").options(config).load()
      df.createOrReplaceTempView("runner")

      val rows = getSparkSession.sql("select name, finished from runner where age > 20").collect()

      assertRows(rows, _.getAge > 20)

      val firstRow = rows(0)

      firstRow.length shouldBe 2
      firstRow.get(0).getClass shouldBe classOf[String]
      firstRow.get(1).getClass shouldBe classOf[java.lang.Boolean]
   }

   test("read using SQL and combined predicates") {
      val config = getConfiguration.toStringsMap
      implicit val session = getSparkSession
      val df = session.read.format("infinispan").options(config).load()
      df.createOrReplaceTempView("runner")

      assertSql(session, "select * from runner where finished = false and (age > 30 or age < 50)", r => !r.getFinished && (r.getAge > 30 || r.getAge < 50))
   }

   override def row2String(e: Runner): String = e.getName
} 
Example 161
Source File: SQLSuite.scala    From infinispan-spark   with Apache License 2.0 5 votes vote down vote up
package org.infinispan.spark.suites

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.infinispan.spark.domain.Runner
import org.infinispan.spark.test._
import org.scalatest.{DoNotDiscover, FunSuite, Matchers}

@DoNotDiscover
class SQLSuite extends FunSuite with RunnersCache with Spark with MultipleServers with Matchers {

   override def getNumEntries: Int = 100

   test("SQL Group By") {
      withSession { (session, runnersRDD) =>
         val winners = session.sql(
            """
              |SELECT MIN(r.finishTimeSeconds) as time, first(r.name) as name, first(r.age) as age
              |FROM runners r WHERE
              |r.finished = true GROUP BY r.age
              |
            """.stripMargin).collect()

         
         winners.foreach { row =>
            val winnerTime = row.getAs[Int]("time")
            val age = row.getAs[Int]("age")
            val fasterOfAge = runnersRDD.filter(r => r.getAge == age && r.getFinished).sortBy(_.getFinishTimeSeconds).first()
            fasterOfAge.getFinishTimeSeconds shouldBe winnerTime
         }
      }
   }

   test("SQL Count") {
      withSession { (session, _) =>
         val count = session.sql("SELECT count(*) AS result from runners").collect().head.getAs[Long]("result")
         count shouldBe getNumEntries
      }
   }


   private def withSession(f: (SparkSession, RDD[Runner]) => Any) = {
      val runnersRDD = createInfinispanRDD[Integer, Runner].values
      val session = SparkSession.builder().config(getSparkConfig).getOrCreate()
      val dataFrame = session.createDataFrame(runnersRDD, classOf[Runner])
      dataFrame.createOrReplaceTempView("runners")
      f(session, runnersRDD)
   }

   override def getCacheConfig: Option[String] = Some("""{"replicated-cache":{"mode":"SYNC"}}""")
} 
Example 162
Source File: HiveContextSuite.scala    From infinispan-spark   with Apache License 2.0 5 votes vote down vote up
package org.infinispan.spark.suites

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.infinispan.spark.domain.Runner
import org.infinispan.spark.test.{RunnersCache, SingleStandardServer, Spark}
import org.scalatest.{DoNotDiscover, FunSuite, Matchers}

@DoNotDiscover
class HiveContextSuite extends FunSuite with RunnersCache with Spark with SingleStandardServer with Matchers {
   override protected def getNumEntries: Int = 200

   test("Hive SQL") {
      withHiveContext { (session: SparkSession, _) =>
         val sample = session.sql(
            """
             SELECT * FROM runners TABLESAMPLE(10 ROWS) s
            """.stripMargin).collect()

         sample.length shouldBe 10
      }
   }

   private def withHiveContext(f: (SparkSession, RDD[Runner]) => Any) = {
      val runnersRDD = createInfinispanRDD[Integer, Runner].values
      val sparkSession = SparkSession.builder().enableHiveSupport().config(getSparkConfig)getOrCreate()
      val dataFrame = sparkSession.createDataFrame(runnersRDD, classOf[Runner])
      dataFrame.createOrReplaceTempView("runners")
      f(sparkSession, runnersRDD)
   }

} 
Example 163
Source File: JavaSpark.scala    From infinispan-spark   with Apache License 2.0 5 votes vote down vote up
package org.infinispan.spark.test

import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.SparkSession
import org.infinispan.spark.serializer._
import org.scalatest.{BeforeAndAfterAll, Suite}


trait JavaSpark extends BeforeAndAfterAll {
   this: Suite with RemoteTest =>

   private lazy val config: SparkConf = new SparkConf().setMaster("local[4]")
     .setAppName(this.getClass.getName)
     .set("spark.serializer", classOf[JBossMarshallingSerializer].getName)
     .set("spark.driver.host", "127.0.0.1")

   protected var sparkSession: SparkSession = _
   protected var jsc: JavaSparkContext = _

   override protected def beforeAll(): Unit = {
      sparkSession = SparkSession.builder().config(config).getOrCreate()
      jsc = new JavaSparkContext(sparkSession.sparkContext)
      super.beforeAll()
   }

   override protected def afterAll(): Unit = {
      jsc.stop()
      sparkSession.stop()
      sparkSession.stop()
      super.afterAll()
   }
} 
Example 164
Source File: S3ADataFrames.scala    From cloud-integration   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.spark.cloud.s3

import com.cloudera.spark.cloud.common.CloudTestKeys
import com.cloudera.spark.cloud.operations.CloudDataFrames
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.sql.SparkSession


object S3ADataFrames extends CloudDataFrames with S3AExampleSetup {

  override def extraValidation(
      session: SparkSession,
      conf: Configuration,
      fs: FileSystem,
      results: Seq[(String, Path, Long, Long)]): Unit = {

    val operations = new S3AOperations(fs)
    if (conf.getBoolean(CloudTestKeys.S3A_COMMITTER_TEST_ENABLED, false)) {
      results.foreach((tuple: (String, Path, Long, Long)) => {
        operations.verifyS3Committer(tuple._2, None, None, "")
      })
    }

  }
} 
Example 165
Source File: HiveTestTrait.scala    From cloud-integration   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import java.io.File

import com.cloudera.spark.cloud.ObjectStoreConfigurations
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits}
import org.apache.spark.sql.hive.test.TestHiveContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils


trait HiveTestTrait extends SparkFunSuite with BeforeAndAfterAll {
//  override protected val enableAutoThreadAudit = false
  protected var hiveContext: HiveInstanceForTests = _
  protected var spark: SparkSession = _


  protected override def beforeAll(): Unit = {
    super.beforeAll()
    // set up spark and hive context
    hiveContext = new HiveInstanceForTests()
    spark = hiveContext.sparkSession
  }

  protected override def afterAll(): Unit = {
    try {
      SparkSession.clearActiveSession()

      if (hiveContext != null) {
        hiveContext.reset()
        hiveContext = null
      }
      if (spark != null) {
        spark.close()
        spark = null
      }
    } finally {
      super.afterAll()
    }
  }

}

class HiveInstanceForTests
  extends TestHiveContext(
    new SparkContext(
      System.getProperty("spark.sql.test.master", "local[1]"),
      "TestSQLContext",
      new SparkConf()
        .setAll(ObjectStoreConfigurations.RW_TEST_OPTIONS)
        .set("spark.sql.warehouse.dir",
          TestSetup.makeWarehouseDir().toURI.getPath)
    )
  ) {

}




object TestSetup {

  def makeWarehouseDir(): File = {
    val warehouseDir = Utils.createTempDir(namePrefix = "warehouse")
    warehouseDir.delete()
    warehouseDir
  }
} 
Example 166
Source File: AzureDataFrameSuite.scala    From cloud-integration   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.spark.cloud.azure

import com.cloudera.spark.cloud.common.DataFrameTests

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StringType


  def example(sparkConf: SparkConf): Unit = {
    val spark = SparkSession
        .builder
        .appName("DataFrames")
        .config(sparkConf)
        .getOrCreate()
    import spark.implicits._
    val numRows = 1000
    val sourceData = spark.range(0, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
    val dest = "wasb://[email protected]/dataframes"
    val orcFile = dest + "/data.orc"
    sourceData.write.format("orc").save(orcFile)
    // read it back
    val orcData = spark.read.format("orc").load(orcFile)
    // save it to parquet
    val parquetFile = dest + "/data.parquet"
    orcData.write.format("parquet").save(parquetFile)
    spark.stop()
  }
} 
Example 167
Source File: SequilaThriftServer.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver


import java.io.File

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.HiveThriftServer2Listener
import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2Seq.HiveThriftServer2ListenerSeq
import org.apache.spark.sql.hive.thriftserver._
import org.apache.spark.sql.{SQLContext, SequilaSession, SparkSession}
import org.biodatageeks.sequila.utils.{SequilaRegister, UDFRegister}
import org.apache.spark.sql.hive.thriftserver.ui.{ThriftServerTab, ThriftServerTabSeq}



object SequilaThriftServer extends Logging {
  var uiTab: Option[ThriftServerTabSeq] = None
  var listener: HiveThriftServer2ListenerSeq = _

  @DeveloperApi
  def startWithContext(ss: SequilaSession): Unit = {
    //System.setSecurityManager(null)
    val server = new HiveThriftServer2Seq(ss)

    val executionHive = HiveUtils.newClientForExecution(
      ss.sqlContext.sparkContext.conf,
      ss.sparkContext.hadoopConfiguration)

    server.init(executionHive.conf)
    server.start()
    listener = new HiveThriftServer2ListenerSeq(server, ss.sqlContext.conf)
    ss.sqlContext.sparkContext.addSparkListener(listener)
    uiTab = if (ss.sqlContext.sparkContext.getConf.getBoolean("spark.ui.enabled", true)) {
      Some(new ThriftServerTabSeq(ss.sqlContext.sparkContext,listener))
    } else {
      None
    }
  }

  def main(args: Array[String]): Unit = {
    //System.setSecurityManager(null)
    val spark = SparkSession
      .builder
        .config("spark.sql.hive.thriftServer.singleSession","true")
        .config("spark.sql.warehouse.dir",sys.env.getOrElse("SEQ_METASTORE_LOCATION",System.getProperty("user.dir")) )
//        .config("spark.hadoop.hive.metastore.uris","thrift://localhost:9083")
      .enableHiveSupport()
//       .master("local[1]")
      .getOrCreate
    val ss = new SequilaSession(spark)
    UDFRegister.register(ss)
    SequilaRegister.register(ss)


    HiveThriftServer2Seq.startWithContext(ss)
  }

} 
Example 168
Source File: BEDRelation.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.datasources.BED

import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Encoders, Row, SQLContext, SparkSession}
import org.apache.spark.sql.sources.{BaseRelation, Filter, PrunedFilteredScan}
import org.biodatageeks.sequila.utils.{Columns, DataQualityFuncs}

class BEDRelation(path: String)(@transient val sqlContext: SQLContext)
  extends BaseRelation
    with PrunedFilteredScan
    with Serializable {

  @transient val logger = Logger.getLogger(this.getClass.getCanonicalName)
  override def schema: org.apache.spark.sql.types.StructType = Encoders.product[org.biodatageeks.formats.BrowserExtensibleData].schema

  private def getValueFromColumn(colName:String, r:Array[String]): Any = {
    colName match {
      case Columns.CONTIG       =>  DataQualityFuncs.cleanContig(r(0) )
      case Columns.START        =>  r(1).toInt + 1 //Convert interval to 1-based
      case Columns.END          =>  r(2).toInt
      case Columns.NAME         =>  if (r.length > 3) Some (r(3)) else None
      case Columns.SCORE        =>  if (r.length > 4) Some (r(4).toInt) else None
      case Columns.STRAND       =>  if (r.length > 5) Some (r(5)) else None
      case Columns.THICK_START  =>  if (r.length > 6) Some (r(6).toInt) else None
      case Columns.THICK_END    =>  if (r.length > 7) Some (r(7).toInt) else None
      case Columns.ITEM_RGB     =>  if (r.length > 8) Some (r(8).split(",").map(_.toInt)) else None
      case Columns.BLOCK_COUNT  =>  if (r.length > 9) Some (r(9).toInt) else None
      case Columns.BLOCK_SIZES  =>  if (r.length > 10) Some (r(10).split(",").map(_.toInt)) else None
      case Columns.BLOCK_STARTS =>  if (r.length > 11) Some (r(11).split(",").map(_.toInt)) else None
      case _                    =>  throw new Exception(s"Unknown column found: ${colName}")
    }
  }
  override def buildScan(requiredColumns:Array[String], filters:Array[Filter]): RDD[Row] = {
    sqlContext
      .sparkContext
      .textFile(path)
      .filter(!_.toLowerCase.startsWith("track"))
      .filter(!_.toLowerCase.startsWith("browser"))
      .map(_.split("\t"))
      .map(r=>
            {
              val record = new Array[Any](requiredColumns.length)
              //requiredColumns.
              for (i <- 0 to requiredColumns.length - 1) {
                record(i) = getValueFromColumn(requiredColumns(i), r)
              }
              Row.fromSeq(record)
            }

    )




  }

} 
Example 169
Source File: VCFRelation.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.datasources.VCF

import io.projectglow.Glow
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
import org.apache.spark.sql.sources._
import org.biodatageeks.sequila.utils.{Columns, DataQualityFuncs}
import org.apache.spark.sql.functions._




class VCFRelation(path: String,
                  normalization_mode: Option[String] = None,
                  ref_genome_path : Option[String] = None )(@transient val sqlContext: SQLContext) extends BaseRelation
  with PrunedScan
  with Serializable
  with Logging {

  val spark: SparkSession = sqlContext.sparkSession

  val cleanContigUDF = udf[String, String](DataQualityFuncs.cleanContig)

  lazy val inputDf: DataFrame = spark
    .read
    .format("vcf")
    .option("splitToBiallelic", "true")
    .load(path)
  lazy val dfNormalized = {
    normalization_mode match {
    case Some(m) => {
      if (m.equalsIgnoreCase("normalize") || m.equalsIgnoreCase("split_and_normalize")
        && ref_genome_path == None) throw new Exception(s"Variant normalization mode specified but ref_genome_path is empty ")
      Glow.transform(m.toLowerCase(), inputDf, Map("reference_genome_path" -> ref_genome_path.get))
    }
    case _ => inputDf
    }
  }.withColumnRenamed("contigName", Columns.CONTIG)
    .withColumnRenamed("start", Columns.START)
    .withColumnRenamed("end", Columns.END)
    .withColumnRenamed("referenceAllele", Columns.REF)
    .withColumnRenamed("alternateAlleles", Columns.ALT)

  lazy val df = dfNormalized
    .withColumn(Columns.CONTIG, cleanContigUDF(dfNormalized(Columns.CONTIG)))

  override def schema: org.apache.spark.sql.types.StructType = {
   df.schema
  }

  override def buildScan(requiredColumns: Array[String] ): RDD[Row] = {

    {
      if (requiredColumns.length > 0)
        df.select(requiredColumns.head, requiredColumns.tail: _*)
      else
        df
    }.rdd


  }

} 
Example 170
Source File: TableFuncs.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.utils

import org.apache.hadoop.fs.FileSystem
import org.apache.spark.sql.SparkSession

object TableFuncs{

  def getTableMetadata(spark:SparkSession, tableName:String) = {
    val catalog = spark.sessionState.catalog
    val tId = spark.sessionState.sqlParser.parseTableIdentifier(tableName)
    catalog.getTableMetadata(tId)
  }

  def getTableDirectory(spark: SparkSession, tableName:String) ={
    getTableMetadata(spark,tableName)
      .location
      .toString
      .split('/')
      .dropRight(1)
      .mkString("/")
  }

  def getExactSamplePath(spark: SparkSession, path:String) = {
    val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration)
    val statuses = fs.globStatus(new org.apache.hadoop.fs.Path(path))
    statuses.head.getPath.toString
  }

  def getParentFolderPath(spark: SparkSession, path: String): String = {
    val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration)
    (new org.apache.hadoop.fs.Path(path)).getParent.toString
  }

  def getAllSamples(spark: SparkSession, path:String) = {
    val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration)
    val statuses = fs.globStatus(new org.apache.hadoop.fs.Path(path))
    //println(statuses.length)
    statuses
      .map(_.getPath.toString.split('/').takeRight(1).head.split('.').take(1).head)
  }
} 
Example 171
Source File: SequilaRegister.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.utils

import htsjdk.samtools.ValidationStringency
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.SequilaDataSourceStrategy
import org.biodatageeks.sequila.utvf.GenomicIntervalStrategy
import org.biodatageeks.sequila.coverage.CoverageStrategy
import org.biodatageeks.sequila.pileup.PileupStrategy
import org.biodatageeks.sequila.rangejoins.IntervalTree.IntervalTreeJoinStrategyOptim

object SequilaRegister {

  def register(spark : SparkSession) = {
    spark.experimental.extraStrategies =
      Seq(
        new SequilaDataSourceStrategy(spark),
        new IntervalTreeJoinStrategyOptim(spark),
        new CoverageStrategy(spark),
        new PileupStrategy(spark),
        new GenomicIntervalStrategy(spark)

      )
    
    spark
      .sparkContext
      .hadoopConfiguration
      .setInt("mapred.max.split.size", spark.sqlContext.getConf(InternalParams.InputSplitSize,"134217728").toInt)

    spark
      .sqlContext
      .setConf(InternalParams.IOReadAlignmentMethod,"hadoopBAM")

    spark
      .sqlContext
      .setConf(InternalParams.BAMValidationStringency, ValidationStringency.LENIENT.toString)

    spark
      .sqlContext
      .setConf(InternalParams.EnableInstrumentation, "false")
  }

} 
Example 172
Source File: GenomicIntervalStrategy.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.utvf

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{DataFrame, GenomicInterval, SparkSession, Strategy}
import org.apache.spark.unsafe.types.UTF8String

case class GIntervalRow(contigName: String, start: Int, end: Int)
class GenomicIntervalStrategy( spark: SparkSession) extends Strategy with Serializable  {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {

    case GenomicInterval(contigName, start, end,output) => GenomicIntervalPlan(plan,spark,GIntervalRow(contigName,start,end),output) :: Nil
    case _ => Nil

  }
}

case class GenomicIntervalPlan(plan: LogicalPlan, spark: SparkSession,interval:GIntervalRow, output: Seq[Attribute]) extends SparkPlan with Serializable {
  def doExecute(): org.apache.spark.rdd.RDD[InternalRow] = {
    import spark.implicits._

    lazy val genomicInterval = spark.createDataset(Seq(interval))
    genomicInterval
        .rdd
      .map(r=>{
        val proj =  UnsafeProjection.create(schema)
        proj.apply(InternalRow.fromSeq(Seq(UTF8String.fromString(r.contigName),r.start,r.end)))
        }
      )
  }
  def children: Seq[SparkPlan] = Nil
} 
Example 173
Source File: Pileup.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.pileup

import htsjdk.samtools.SAMRecord
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.biodatageeks.sequila.datasources.BAM.BDGAlignFileReaderWriter
import org.biodatageeks.sequila.datasources.InputDataType
import org.biodatageeks.sequila.inputformats.BDGAlignInputFormat
import org.biodatageeks.sequila.utils.{InternalParams, TableFuncs}
import org.seqdoop.hadoop_bam.CRAMBDGInputFormat
import org.slf4j.LoggerFactory

import scala.reflect.ClassTag


class Pileup[T<:BDGAlignInputFormat](spark:SparkSession)(implicit c: ClassTag[T]) extends BDGAlignFileReaderWriter[T] {
  val logger = LoggerFactory.getLogger(this.getClass.getCanonicalName)

  def handlePileup(tableName: String, sampleId: String, refPath:String, output: Seq[Attribute]): RDD[InternalRow] = {
    logger.info("Calculating pileup on table: {}", tableName)

    lazy val allAlignments = readTableFile(name=tableName, sampleId)

    if(logger.isDebugEnabled()) logger.debug("Processing {} reads in total", allAlignments.count() )

    val alignments = filterAlignments(allAlignments )


    PileupMethods.calculatePileup(alignments, spark ,refPath)

  }

  private def filterAlignments(alignments:RDD[SAMRecord]): RDD[SAMRecord] = {
    // any other filtering conditions should go here
    val filterFlag = spark.conf.get(InternalParams.filterReadsByFlag, "1796").toInt
    val cleaned = alignments.filter(read => read.getContig != null && (read.getFlags & filterFlag) == 0)
    if(logger.isDebugEnabled()) logger.debug("Processing {} cleaned reads in total", cleaned.count() )
    cleaned
  }

  private def readTableFile(name: String, sampleId: String): RDD[SAMRecord] = {
    val metadata = TableFuncs.getTableMetadata(spark, name)
    val path = metadata.location.toString

    val samplePathTemplate = (
      path
      .split('/')
      .dropRight(1) ++ Array(s"$sampleId*.{{fileExtension}}"))
      .mkString("/")

    metadata.provider match {
      case Some(f) =>
        if (f == InputDataType.BAMInputDataType)
           readBAMFile(spark.sqlContext, samplePathTemplate.replace("{{fileExtension}}", "bam"), refPath = None)
        else if (f == InputDataType.CRAMInputDataType) {
          val refPath = spark.sqlContext
            .sparkContext
            .hadoopConfiguration
            .get(CRAMBDGInputFormat.REFERENCE_SOURCE_PATH_PROPERTY)
           readBAMFile(spark.sqlContext, samplePathTemplate.replace("{{fileExtension}}", "cram"), Some(refPath))
        }
        else throw new Exception("Only BAM and CRAM file formats are supported in bdg_coverage.")
      case None => throw new Exception("Wrong file extension - only BAM and CRAM file formats are supported in bdg_coverage.")
    }
  }
} 
Example 174
Source File: PileupStrategy.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.pileup

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{PileupTemplate, SparkSession, Strategy}
import org.biodatageeks.sequila.datasources.BAM.BDGAlignFileReaderWriter
import org.biodatageeks.sequila.datasources.InputDataType
import org.biodatageeks.sequila.inputformats.BDGAlignInputFormat
import org.biodatageeks.sequila.utils.TableFuncs
import org.seqdoop.hadoop_bam.{BAMBDGInputFormat, CRAMBDGInputFormat}

import scala.reflect.ClassTag

class PileupStrategy (spark:SparkSession) extends Strategy with Serializable {
  override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
    plan match {
      case PileupTemplate(tableName, sampleId, refPath, output) =>
        val inputFormat = TableFuncs.getTableMetadata(spark, tableName).provider
        inputFormat match {
          case Some(f) =>
            if (f == InputDataType.BAMInputDataType)
              PileupPlan[BAMBDGInputFormat](plan, spark, tableName, sampleId, refPath, output) :: Nil
            else if (f == InputDataType.CRAMInputDataType)
              PileupPlan[CRAMBDGInputFormat](plan, spark, tableName, sampleId, refPath, output) :: Nil
            else Nil
          case None => throw new RuntimeException("Only BAM and CRAM file formats are supported in pileup function.")
        }
      case _ => Nil
    }
  }
}

case class PileupPlan [T<:BDGAlignInputFormat](plan:LogicalPlan, spark:SparkSession,
                                               tableName:String,
                                               sampleId:String,
                                               refPath: String,
                                               output:Seq[Attribute])(implicit c: ClassTag[T])
  extends SparkPlan with Serializable  with BDGAlignFileReaderWriter [T]{

  override def children: Seq[SparkPlan] = Nil

  override protected def doExecute(): RDD[InternalRow] = {
   new Pileup(spark).handlePileup(tableName, sampleId, refPath, output)
  }

} 
Example 175
Source File: PileupMethods.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.pileup

import htsjdk.samtools.SAMRecord
import org.apache.spark.rdd.MetricsContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
import org.biodatageeks.sequila.pileup.model.{Reference, _}
import org.biodatageeks.sequila.pileup.timers.PileupTimers._
import org.biodatageeks.sequila.utils.InternalParams
import org.slf4j.{Logger, LoggerFactory}

import AggregateRDDOperations.implicits._
import AlignmentsRDDOperations.implicits._

  def calculatePileup(alignments: RDD[SAMRecord], spark: SparkSession, refPath: String): RDD[InternalRow] = {

    Reference.init(refPath)
    val enableInstrumentation = spark
      .sqlContext
      .getConf(InternalParams.EnableInstrumentation).toBoolean
    val alignmentsInstr = if(enableInstrumentation) alignments.instrument() else alignments
    val aggregates = ContigAggrTimer.time {
      alignmentsInstr.assembleContigAggregates()
        .persist(StorageLevel.MEMORY_AND_DISK) //FIXME: Add automatic unpersist
    }
    val accumulator = AccumulatorTimer.time {aggregates.accumulateTails(spark)}

    val broadcast = BroadcastTimer.time{
      spark.sparkContext.broadcast(accumulator.value().prepareOverlaps())
    }
    val adjustedEvents = AdjustedEventsTimer.time {aggregates.adjustWithOverlaps(broadcast) }
    val pileup = EventsToPileupTimer.time {adjustedEvents.toPileup}
    pileup
  }
} 
Example 176
Source File: MDTagParser.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.pileup

import java.io.File

import htsjdk.samtools.reference.IndexedFastaSequenceFile
import htsjdk.samtools.{Cigar, CigarOperator, SAMRecord}
import org.apache.log4j.Logger
import org.apache.spark.sql.SparkSession
import org.biodatageeks.sequila.datasources.BAM.BDGAlignFileReaderWriter
import org.seqdoop.hadoop_bam.BAMBDGInputFormat

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

case class MDOperator(length: Int, base: Char) { //S means to skip n positions, not fix needed
  def isDeletion:Boolean = base.isLower
  def isNonDeletion:Boolean = base.isUpper
}
object MDTagParser{

  val logger: Logger = Logger.getLogger(this.getClass.getCanonicalName)
  val pattern = "([0-9]+)\\^?([A-Za-z]+)?".r

  def parseMDTag(t : String) = {

    if (isAllDigits(t)) {
      Array[MDOperator](MDOperator(t.toInt, 'S'))
    }
    else {
      val ab = new ArrayBuffer[MDOperator]()
      val matches = pattern
        .findAllIn(t)
      while (matches.hasNext) {
        val m = matches.next()
        if(m.last.isLetter && !m.contains('^') ){
          val skipPos = m.dropRight(1).toInt
          ab.append(MDOperator(skipPos, 'S') )
          ab.append(MDOperator(0, m.last.toUpper))
        }
        else if (m.last.isLetter && m.contains('^') ){ //encoding deletions as lowercase
          val arr =  m.split('^')
          val skipPos = arr.head.toInt
          ab.append(MDOperator(skipPos, 'S') )
          arr(1).foreach { b =>
            ab.append(MDOperator(0, b.toLower))
          }
        }
        else ab.append(MDOperator(m.toInt, 'S') )
      }
      ab.toArray
    }
  }


  private def isAllDigits(s: String) : Boolean = {
    val len = s.length
    var i = 0
      while(i < len){
        if(! s(i).isDigit ) return false
        i += 1
      }
    true
  }

} 
Example 177
Source File: NCListsJoin.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.NCList

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{SparkPlan, _}

@DeveloperApi
case class NCListsJoin(left: SparkPlan,
                     right: SparkPlan,
                     condition: Seq[Expression],
                     context: SparkSession) extends BinaryExecNode {
  def output = left.output ++ right.output

  lazy val (buildPlan, streamedPlan) = (left, right)

  lazy val (buildKeys, streamedKeys) = (List(condition(0), condition(1)),
    List(condition(2), condition(3)))

  @transient lazy val buildKeyGenerator = new InterpretedProjection(buildKeys, buildPlan.output)
  @transient lazy val streamKeyGenerator = new InterpretedProjection(streamedKeys,
    streamedPlan.output)

  protected override def doExecute(): RDD[InternalRow] = {
    val v1 = left.execute()
    val v2 = right.execute()

    val v1kv = v1.map(x => {
      val v1Key = buildKeyGenerator(x)

      (new Interval[Int](v1Key.getInt(0), v1Key.getInt(1)),
        x.copy())
    } )

    val v2kv = v2.map(x => {
      val v2Key = streamKeyGenerator(x)
      (new Interval[Int](v2Key.getInt(0), v2Key.getInt(1)),
        x.copy())
    } )
    
    if (v1.count <= v2.count) {


      val v3 = NCListsJoinImpl.overlapJoin(context.sparkContext, v1kv, v2kv).flatMap(l => l._2.map(r => (l._1, r)))
      v3.map {
        case (l: InternalRow, r: InternalRow) => {
          val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema);
          joiner.join(l.asInstanceOf[UnsafeRow], r.asInstanceOf[UnsafeRow]).asInstanceOf[InternalRow] //resultProj(joinedRow(l, r)) joiner.joiner
        }
      }
    } else {
      val v3 = NCListsJoinImpl.overlapJoin(context.sparkContext, v2kv, v1kv).flatMap(l => l._2.map(r => (l._1, r)))
      v3.map {
        case (r: InternalRow, l: InternalRow) => {
          val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema);
          joiner.join(l.asInstanceOf[UnsafeRow], r.asInstanceOf[UnsafeRow]).asInstanceOf[InternalRow] //resultProj(joinedRow(l, r)) joiner.joiner
        }
      }
    }
  }
} 
Example 178
Source File: NCListsJoinStrategy.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.NCList

import org.biodatageeks.sequila.rangejoins.common.{ExtractRangeJoinKeys, ExtractRangeJoinKeysWithEquality}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{SparkSession, Strategy}
import org.biodatageeks.sequila.rangejoins.methods.NCList.NCListsJoinChromosome

class NCListsJoinStrategy(spark: SparkSession) extends Strategy with Serializable with  PredicateHelper {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case ExtractRangeJoinKeys(joinType, rangeJoinKeys, left, right) =>
      NCListsJoin(planLater(left), planLater(right), rangeJoinKeys, spark) :: Nil
    case ExtractRangeJoinKeysWithEquality(joinType, rangeJoinKeys, left, right) =>
      NCListsJoinChromosome(planLater(left), planLater(right), rangeJoinKeys, spark) :: Nil
    case _ =>
      Nil
  }
} 
Example 179
Source File: NCListsJoinChromosome.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.methods.NCList

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{SparkPlan, _}
import org.biodatageeks.sequila.rangejoins.NCList.{Interval, NCListsJoinImpl}

@DeveloperApi
case class NCListsJoinChromosome(left: SparkPlan,
                     right: SparkPlan,
                     condition: Seq[Expression],
                     context: SparkSession) extends BinaryExecNode {
  def output = left.output ++ right.output

  lazy val (buildPlan, streamedPlan) = (left, right)

  lazy val (buildKeys, streamedKeys) = (List(condition(0), condition(1),condition(4)),
    List(condition(2), condition(3),condition(5)))


  @transient lazy val buildKeyGenerator = new InterpretedProjection(buildKeys, buildPlan.output)
  @transient lazy val streamKeyGenerator = new InterpretedProjection(streamedKeys,
    streamedPlan.output)

  protected override def doExecute(): RDD[InternalRow] = {
    val v1 = left.execute()
    val v2 = right.execute()

    val v1kv = v1.map(x => {
      val v1Key = buildKeyGenerator(x)

      ((v1Key.getString(2),new Interval[Int](v1Key.getInt(0), v1Key.getInt(1))),
        x.copy())
    } )

    val v2kv = v2.map(x => {
      val v2Key = streamKeyGenerator(x)
      ((v2Key.getString(2),new Interval[Int](v2Key.getInt(0), v2Key.getInt(1))),
        x.copy())
    } )
    
    if (v1.count <= v2.count) {


      val v3 = NCListsJoinChromosomeImpl.overlapJoin(context.sparkContext, v1kv, v2kv)
      v3.mapPartitions(
        p => {
          val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
          p.map(r => joiner.join(r._1.asInstanceOf[UnsafeRow], r._2.asInstanceOf[UnsafeRow]))
        }
      )
    } else {
      val v3 = NCListsJoinChromosomeImpl.overlapJoin(context.sparkContext, v2kv, v1kv)
      v3.mapPartitions(
        p => {
          val joiner = GenerateUnsafeRowJoiner.create(right.schema, left.schema)
          p.map(r=>joiner.join(r._2.asInstanceOf[UnsafeRow],r._1.asInstanceOf[UnsafeRow]))
        }

      )
    }
  }
} 
Example 180
Source File: IntervalTreeJoinOptim.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.IntervalTree

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, _}
import org.apache.spark.sql.internal.SQLConf

@DeveloperApi
case class IntervalTreeJoinOptim(left: SparkPlan,
                                 right: SparkPlan,
                                 condition: Seq[Expression],
                                 context: SparkSession,leftLogicalPlan: LogicalPlan, righLogicalPlan: LogicalPlan) extends BinaryExecNode {
  def output = left.output ++ right.output

  lazy val (buildPlan, streamedPlan) = (left, right)

  lazy val (buildKeys, streamedKeys) = (List(condition(0), condition(1)),
    List(condition(2), condition(3)))

  @transient lazy val buildKeyGenerator = new InterpretedProjection(buildKeys, buildPlan.output)
  @transient lazy val streamKeyGenerator = new InterpretedProjection(streamedKeys,
    streamedPlan.output)

  protected override def doExecute(): RDD[InternalRow] = {
    val v1 = left.execute()
    val v1kv = v1.map(x => {
      val v1Key = buildKeyGenerator(x)

      (new IntervalWithRow[Int](v1Key.getInt(0), v1Key.getInt(1),
        x) )

    })
    val v2 = right.execute()
    val v2kv = v2.map(x => {
      val v2Key = streamKeyGenerator(x)
      (new IntervalWithRow[Int](v2Key.getInt(0), v2Key.getInt(1),
        x) )
    })
    

    val conf = new SQLConf()
    val v1Size =
      if(leftLogicalPlan
      .stats
      .sizeInBytes >0) leftLogicalPlan.stats.sizeInBytes.toLong
      else
        v1.count

    val v2Size = if(righLogicalPlan
      .stats
      .sizeInBytes >0) righLogicalPlan.stats.sizeInBytes.toLong
    else
      v2.count
    if ( v1Size <= v2Size ) {
      val v3 = IntervalTreeJoinOptimImpl.overlapJoin(context.sparkContext, v1kv, v2kv,v1.count())
     v3.mapPartitions(
       p => {
         val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
         p.map(r=>joiner.join(r._1.asInstanceOf[UnsafeRow],r._2.asInstanceOf[UnsafeRow]))
       }



     )

    }
    else {
      val v3 = IntervalTreeJoinOptimImpl.overlapJoin(context.sparkContext, v2kv, v1kv, v2.count())
      v3.mapPartitions(
        p => {
          val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
          p.map(r=>joiner.join(r._2.asInstanceOf[UnsafeRow],r._1.asInstanceOf[UnsafeRow]))
        }

      )
    }

  }
} 
Example 181
Source File: IntervalTreeJoinChromosome.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.methods.genApp

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedProjection, UnsafeRow}
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.biodatageeks.sequila.rangejoins.genApp.Interval

@DeveloperApi
case class
IntervalTreeJoinChromosome(left: SparkPlan,
                             right: SparkPlan,
                             condition: Seq[Expression],
                             context: SparkSession) extends BinaryExecNode {
  def output = left.output ++ right.output

  lazy val (buildPlan, streamedPlan) = (left, right)

  lazy val (buildKeys, streamedKeys) = (List(condition(0), condition(1),condition(4)),
    List(condition(2), condition(3),condition(5)))

  @transient lazy val buildKeyGenerator = new InterpretedProjection(buildKeys, buildPlan.output)
  @transient lazy val streamKeyGenerator = new InterpretedProjection(streamedKeys,
    streamedPlan.output)

  protected override def doExecute(): RDD[InternalRow] = {
    val v1 = left.execute()
    val v1kv = v1.map(x => {
      val v1Key = buildKeyGenerator(x)

      ((v1Key.getString(2),new Interval[Int](v1Key.getInt(0), v1Key.getInt(1))),
        x.copy())
    })
    val v2 = right.execute()
    val v2kv = v2.map(x => {
      val v2Key = streamKeyGenerator(x)
      ((v2Key.getString(2),new Interval[Int](v2Key.getInt(0), v2Key.getInt(1))),
        x.copy())
    })
    
    if (v1.count <= v2.count) {
      val v3 = IntervalTreeJoinChromosomeImpl.overlapJoin(context.sparkContext, v1kv, v2kv)
        .flatMap(l => l._2
          .map(r => (l._1, r)))
      v3.mapPartitions(
        p => {
          val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
          p.map(r => joiner.join(r._1.asInstanceOf[UnsafeRow], r._2.asInstanceOf[UnsafeRow]))
        }
      )
    }
    else {
      val v3 = IntervalTreeJoinChromosomeImpl.overlapJoin(context.sparkContext, v2kv, v1kv).flatMap(l => l._2.map(r => (l._1, r)))
      v3.mapPartitions(
        p => {
          val joiner = GenerateUnsafeRowJoiner.create(right.schema, left.schema)
          p.map(r=>joiner.join(r._2.asInstanceOf[UnsafeRow],r._1.asInstanceOf[UnsafeRow]))
        }

      )
    }

  }
} 
Example 182
Source File: IntervalTreeJoin.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.genApp

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedProjection, UnsafeRow}
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}

@DeveloperApi
case class IntervalTreeJoin(left: SparkPlan,
                     right: SparkPlan,
                     condition: Seq[Expression],
                     context: SparkSession) extends BinaryExecNode {
  def output = left.output ++ right.output

  lazy val (buildPlan, streamedPlan) = (left, right)

  lazy val (buildKeys, streamedKeys) = (List(condition(0), condition(1)),
    List(condition(2), condition(3)))

  @transient lazy val buildKeyGenerator = new InterpretedProjection(buildKeys, buildPlan.output)
  @transient lazy val streamKeyGenerator = new InterpretedProjection(streamedKeys,
    streamedPlan.output)

  protected override def doExecute(): RDD[InternalRow] = {
    val v1 = left.execute()
    val v1kv = v1.map(x => {
      val v1Key = buildKeyGenerator(x)

      (new Interval[Int](v1Key.getInt(0), v1Key.getInt(1)),
        x.copy())
    })
    val v2 = right.execute()
    val v2kv = v2.map(x => {
      val v2Key = streamKeyGenerator(x)
      (new Interval[Int](v2Key.getInt(0), v2Key.getInt(1)),
        x.copy())
    })
    
    if (v1.count <= v2.count) {
      val v3 = IntervalTreeJoinImpl.overlapJoin(context.sparkContext, v1kv, v2kv)
        .flatMap(l => l._2
        .map(r => (l._1, r)))
      v3.map {
        case (l: InternalRow, r: InternalRow) => {
          val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema);
          joiner.join(l.asInstanceOf[UnsafeRow], r.asInstanceOf[UnsafeRow]).asInstanceOf[InternalRow] //resultProj(joinedRow(l, r)) joiner.joiner
        }
      }
    }
    else {
      val v3 = IntervalTreeJoinImpl.overlapJoin(context.sparkContext, v2kv, v1kv).flatMap(l => l._2.map(r => (l._1, r)))
      v3.map {
        case (r: InternalRow, l: InternalRow) => {
          val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema);
          joiner.join(l.asInstanceOf[UnsafeRow], r.asInstanceOf[UnsafeRow]).asInstanceOf[InternalRow] //resultProj(joinedRow(l, r)) joiner.joiner
        }
      }
    }

  }
} 
Example 183
Source File: IntervalTreeJoinStrategy.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.genApp

import org.biodatageeks.sequila.rangejoins.common.{ExtractRangeJoinKeys, ExtractRangeJoinKeysWithEquality}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{SparkSession, Strategy}
import org.biodatageeks.sequila.rangejoins.methods.genApp.IntervalTreeJoinChromosome

class IntervalTreeJoinStrategy(spark: SparkSession) extends Strategy with Serializable with  PredicateHelper {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case ExtractRangeJoinKeys(joinType, rangeJoinKeys, left, right) =>
      IntervalTreeJoin(planLater(left), planLater(right), rangeJoinKeys, spark) :: Nil
    case ExtractRangeJoinKeysWithEquality(joinType, rangeJoinKeys, left, right) =>
      IntervalTreeJoinChromosome(planLater(left), planLater(right), rangeJoinKeys, spark) :: Nil
    case _ =>
      Nil
  }
} 
Example 184
Source File: JoinOptimizerChromosome.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.optimizer

import jdk.nashorn.internal.ir.debug.ObjectSizeCalculator
import org.apache.log4j.Logger
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.util.SizeEstimator
import org.biodatageeks.sequila.rangejoins.IntervalTree.{Interval, IntervalWithRow}
import org.biodatageeks.sequila.rangejoins.optimizer.RangeJoinMethod.RangeJoinMethod


class JoinOptimizerChromosome(spark: SparkSession, rdd: RDD[(String,Interval[Int],InternalRow)], rddCount : Long) {

  val logger =  Logger.getLogger(this.getClass.getCanonicalName)
  val maxBroadcastSize = spark.sqlContext
    .getConf("spark.biodatageeks.rangejoin.maxBroadcastSize","0") match {
    case "0" => 0.1*scala.math.max((spark.sparkContext.getConf.getSizeAsBytes("spark.driver.memory","0")),1024*(1024*1024)) //defaults 128MB or 0.1 * Spark Driver's memory
    case _ => spark.sqlContext.getConf("spark.biodatageeks.rangejoin.maxBroadcastSize").toLong }
  val estBroadcastSize = estimateBroadcastSize(rdd,rddCount)


   private def estimateBroadcastSize(rdd: RDD[(String,Interval[Int],InternalRow)], rddCount: Long): Long = {
     try{
       (ObjectSizeCalculator.getObjectSize(rdd.first()) * rddCount) /10
     }
     catch {
       case e @ (_ : NoClassDefFoundError | _ : ExceptionInInitializerError ) => {
         logger.warn("Method ObjectSizeCalculator.getObjectSize not available falling back to Spark methods")
         SizeEstimator.estimate(rdd.first()) * rddCount
       }
     }
     //FIXME: Do not know why the size ~10x the actual size is- Spark row representation or getObject size in bits???
  }

  def debugInfo = {
    s"""
       |Broadcast structure size is ~ ${math.rint(100*estBroadcastSize/1024.0)/100} kb
       |spark.biodatageeks.rangejoin.maxBroadcastSize is set to ${(maxBroadcastSize/1024).toInt} kb"
       |Using ${getRangeJoinMethod.toString} join method
     """.stripMargin
  }

  private def estimateRDDSizeSpark(rdd: RDD[(String,Interval[Int],InternalRow)]): Long = {
    math.round(SizeEstimator.estimate(rdd)/1024.0)
  }

  
  def getRangeJoinMethod : RangeJoinMethod ={

    if (estimateBroadcastSize(rdd, rddCount) <= maxBroadcastSize)
      RangeJoinMethod.JoinWithRowBroadcast
    else
      RangeJoinMethod.TwoPhaseJoin

  }



} 
Example 185
Source File: PileupApp.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.apps

import java.io.{OutputStreamWriter, PrintWriter}

import org.apache.spark.sql.{SequilaSession, SparkSession}
import org.bdgenomics.utils.instrumentation.{Metrics, MetricsListener, RecordedMetrics}
import org.biodatageeks.sequila.utils.{InternalParams, SequilaRegister}

object PileupApp extends App{
  override def main(args: Array[String]): Unit = {

    System.setProperty("spark.kryo.registrator", "org.biodatageeks.sequila.pileup.serializers.CustomKryoRegistrator")
    val spark = SparkSession
      .builder()
      .master("local[1]")
      .config("spark.driver.memory","4g")
      .config( "spark.serializer", "org.apache.spark.serializer.KryoSerializer" )
      .enableHiveSupport()
      .getOrCreate()

    val ss = SequilaSession(spark)
    SequilaRegister.register(ss)
    spark.sparkContext.setLogLevel("INFO")

    val bamPath = "/Users/aga/NA12878.chr20.md.bam"
    val referencePath = "/Users/aga/Homo_sapiens_assembly18_chr20.fasta"

    //    val bamPath = "/Users/marek/data/NA12878.chrom20.ILLUMINA.bwa.CEU.low_coverage.20121211.md.bam"
    //    val referencePath = "/Users/marek/data/hs37d5.fa"

    val tableNameBAM = "reads"

    ss.sql(s"""DROP  TABLE IF  EXISTS $tableNameBAM""")
    ss.sql(s"""
              |CREATE TABLE $tableNameBAM
              |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource
              |OPTIONS(path "$bamPath")
              |
      """.stripMargin)

    val query =
      s"""
         |SELECT count(*)
         |FROM  pileup('$tableNameBAM', 'NA12878', '${referencePath}')
       """.stripMargin
    ss
      .sqlContext
      .setConf(InternalParams.EnableInstrumentation, "true")
    Metrics.initialize(ss.sparkContext)
    val metricsListener = new MetricsListener(new RecordedMetrics())
    ss
      .sparkContext
      .addSparkListener(metricsListener)
    val results = ss.sql(query)
    ss.time{
      results.show()
    }
    val writer = new PrintWriter(new OutputStreamWriter(System.out, "UTF-8"))
    Metrics.print(writer, Some(metricsListener.metrics.sparkMetrics.stageTimes))
    writer.close()
    ss.stop()
  }
} 
Example 186
Source File: FeatureCounts.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.apps

import htsjdk.samtools.ValidationStringency
import org.apache.hadoop.io.LongWritable
import org.apache.spark.sql.SparkSession
import org.biodatageeks.sequila.rangejoins.IntervalTree.IntervalTreeJoinStrategyOptim
import org.biodatageeks.sequila.utils.Columns
import org.rogach.scallop.ScallopConf
import org.seqdoop.hadoop_bam.{BAMInputFormat, SAMRecordWritable}
import org.seqdoop.hadoop_bam.util.SAMHeaderReader

object FeatureCounts {
  case class Region(contig:String, pos_start:Int, pos_end:Int)
  class RunConf(args:Array[String]) extends ScallopConf(args){

    val output = opt[String](required = true)
    val annotations = opt[String](required = true)
    val readsFile = trailArg[String](required = true)
    val Format = trailArg[String](required = false)
    verify()
  }

  def main(args: Array[String]): Unit = {
    val runConf = new RunConf(args)
    val spark = SparkSession
      .builder()
      .appName("SeQuiLa-FC")
      .getOrCreate()

    spark.sqlContext.setConf("spark.biodatageeks.rangejoin.useJoinOrder","true")
    //spark.sqlContext.setConf("spark.biodatageeks.rangejoin.maxBroadcastSize", (1024).toString)
    spark.experimental.extraStrategies = new IntervalTreeJoinStrategyOptim(spark) :: Nil



    val query ="""SELECT targets.GeneId AS GeneId,
                     targets.Chr AS Chr,
                     targets.Start AS Start,
                     targets.End AS End,
                     targets.Strand AS Strand,
                     CAST(targets.End AS INTEGER)-CAST(targets.Start AS INTEGER) + 1 AS Length,
                     count(*) AS Counts
            FROM reads JOIN targets
      |ON (
      |  targets.Chr=reads.contigName
      |  AND
      |  reads.end >= CAST(targets.Start AS INTEGER)
      |  AND
      |  reads.start <= CAST(targets.End AS INTEGER)
      |)
      |GROUP BY targets.GeneId,targets.Chr,targets.Start,targets.End,targets.Strand""".stripMargin
      spark
        .sparkContext
        .setLogLevel("ERROR")

      spark
        .sparkContext
        .hadoopConfiguration.set(SAMHeaderReader.VALIDATION_STRINGENCY_PROPERTY, ValidationStringency.SILENT.toString)

      val alignments = spark
        .sparkContext.newAPIHadoopFile[LongWritable, SAMRecordWritable, BAMInputFormat](runConf.readsFile())
        .map(_._2.get)
        .map(r => Region(r.getContig, r.getStart, r.getEnd))

      val readsTable = spark.sqlContext.createDataFrame(alignments)
      readsTable.createOrReplaceTempView("reads")

      val targets = spark
        .read
        .option("header", "true")
        .option("delimiter", "\t")
        .csv(runConf.annotations())
      targets
        .withColumnRenamed("contigName", Columns.CONTIG)
        .createOrReplaceTempView("targets")

     spark.sql(query)
       .orderBy("GeneId")
        .coalesce(1)
        .write
        .option("header", "true")
        .option("delimiter", "\t")
        .csv(runConf.output())
  }

} 
Example 187
Source File: PileupTestBase.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.pileup

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{IntegerType, ShortType, StringType, StructField, StructType}
import org.scalatest.{BeforeAndAfter, FunSuite}

class PileupTestBase extends FunSuite
  with DataFrameSuiteBase
  with BeforeAndAfter
  with SharedSparkContext{

  val sampleId = "NA12878.multichrom.md"
  val samResPath: String = getClass.getResource("/multichrom/mdbam/samtools.pileup").getPath
  val referencePath: String = getClass.getResource("/reference/Homo_sapiens_assembly18_chr1_chrM.small.fasta").getPath
  val bamPath: String = getClass.getResource(s"/multichrom/mdbam/${sampleId}.bam").getPath
  val cramPath : String = getClass.getResource(s"/multichrom/mdcram/${sampleId}.cram").getPath
  val tableName = "reads_bam"
  val tableNameCRAM = "reads_cram"

  val schema: StructType = StructType(
    List(
      StructField("contig", StringType, nullable = true),
      StructField("position", IntegerType, nullable = true),
      StructField("reference", StringType, nullable = true),
      StructField("coverage", ShortType, nullable = true),
      StructField("pileup", StringType, nullable = true),
      StructField("quality", StringType, nullable = true)
    )
  )
  before {
    System.setProperty("spark.kryo.registrator", "org.biodatageeks.sequila.pileup.serializers.CustomKryoRegistrator")
    spark
      .conf.set("spark.sql.shuffle.partitions",1) //FIXME: In order to get orderBy in Samtools tests working - related to exchange partitions stage
    spark.sql(s"DROP TABLE IF EXISTS $tableName")
    spark.sql(
      s"""
         |CREATE TABLE $tableName
         |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource
         |OPTIONS(path "$bamPath")
         |
      """.stripMargin)

    spark.sql(s"DROP TABLE IF EXISTS $tableNameCRAM")
    spark.sql(
      s"""
         |CREATE TABLE $tableNameCRAM
         |USING org.biodatageeks.sequila.datasources.BAM.CRAMDataSource
         |OPTIONS(path "$cramPath", refPath "$referencePath" )
         |
      """.stripMargin)

    val mapToString = (map: Map[Byte, Short]) => {
      if (map == null)
        "null"
      else
        map.map({
          case (k, v) => k.toChar -> v}).mkString.replace(" -> ", ":")
    }

    val byteToString = ((byte: Byte) => byte.toString)

    spark.udf.register("mapToString", mapToString)
    spark.udf.register("byteToString", byteToString)
  }

} 
Example 188
Source File: Writer.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.pileup

import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}

object Writer {

  val mapToString = (map: Map[Byte, Short]) => {
    if (map == null)
      "null"
    else
      map.map({
        case (k, v) => k.toChar -> v
      }).toSeq.sortBy(_._1).mkString.replace(" -> ", ":")
  }

  def saveToFile(spark: SparkSession, res: Dataset[Row], path: String) = {
    spark.udf.register("mapToString", mapToString)
    res
      .selectExpr("contig", "pos_start", "pos_end", "ref", "cast(coverage as int)", "mapToString(alts)")
      .coalesce(1)
      .write
      .mode(SaveMode.Overwrite)
      .csv(path)
  }
} 
Example 189
Source File: LongReadsTestSuite.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.coverage

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.{SequilaSession, SparkSession}
import org.biodatageeks.sequila.utils.{Columns, InternalParams, SequilaRegister}
import org.scalatest.{BeforeAndAfter, FunSuite}

class LongReadsTestSuite
    extends FunSuite
    with DataFrameSuiteBase
    with BeforeAndAfter
    with SharedSparkContext {

  val bamPath: String =
    getClass.getResource("/nanopore_guppy_slice.bam").getPath
  val splitSize = 30000
  val tableNameBAM = "reads"

  before {

    System.setSecurityManager(null)
    spark.sql(s"DROP TABLE IF EXISTS $tableNameBAM")
    spark.sql(s"""
         |CREATE TABLE $tableNameBAM
         |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource
         |OPTIONS(path "$bamPath")
         |
      """.stripMargin)

  }
  test("BAM - Nanopore with guppy basecaller") {

    val session: SparkSession = SequilaSession(spark)
    SequilaRegister.register(session)
    session.sparkContext
      .setLogLevel("WARN")
    val bdg = session.sql(s"SELECT * FROM ${tableNameBAM}")
    assert(bdg.count() === 150)
  }

  test("BAM - coverage - Nanopore with guppy basecaller") {
    spark.sqlContext.setConf(InternalParams.InputSplitSize,
                             (splitSize * 10).toString)
    val session2: SparkSession = SequilaSession(spark)
    SequilaRegister.register(session2)
    val query =
      s"""SELECT ${Columns.CONTIG}, ${Columns.START}, ${Columns.COVERAGE}
        FROM bdg_coverage('$tableNameBAM','nanopore_guppy_slice','bases')
        order by ${Columns.CONTIG},${Columns.START},${Columns.END}
        """.stripMargin
    val covMultiPartitionDF = session2.sql(query)

    //covMultiPartitionDF.coalesce(1).write.mode("overwrite").option("delimiter", "\t").csv("/Users/aga/workplace/multiPart.csv")
    assert(covMultiPartitionDF.count() == 45620) // total count check 45620<---> 45842

    assert(covMultiPartitionDF.filter(s"${Columns.COVERAGE}== 0").count == 0)

    assert(
      covMultiPartitionDF
        .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5010515")
        .first()
        .getShort(2) == 1) // value check [first element]
    assert(
      covMultiPartitionDF
        .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5022667")
        .first()
        .getShort(2) == 15) // value check [partition boundary]
    assert(
      covMultiPartitionDF
        .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5036398")
        .first()
        .getShort(2) == 14) // value check [partition boundary]
    assert(
      covMultiPartitionDF
        .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5056356")
        .first()
        .getShort(2) == 1) // value check [last element]

  }

} 
Example 190
Source File: 5_DataFrameAndSql.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.spark_streaming

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext}




}

object SparkSessionSingleton {
  @transient private var instance: SparkSession = _

  def getInstance(conf: SparkConf): SparkSession = {
    if (instance.==(null)) {
      instance = SparkSession
        .builder()
        .config(conf)
        .getOrCreate()
    }
    instance
  }
} 
Example 191
Source File: A_1_WindowOperation.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.structured_streaming

import java.sql.Timestamp

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.OutputMode

object A_1_WindowOperation {

  def main(args: Array[String]): Unit = {

    if (args.length < 3) {
      println(s" Usage: StructuredNetworkWordCountWindowed <hostname> <port>" +
        " <window duration in seconds> [<slide duration in seconds>]")
      System.exit(1)
    }

    val host = args(0)
    val port = args(1).toInt
    val windowSize = args(2).toInt
    val slideSize = if (args.length == 3) windowSize else args(3).toInt
    if (slideSize > windowSize) {
      System.err.println("<滑动间隔> 必须要小于或等于 <窗口间隔>")
    }

    val windowDuration = s"$windowSize seconds"
    val slideDuration = s"$slideSize seconds"

    val spark = SparkSession.builder()
      .master("local")
      .appName(A_1_WindowOperation.getClass.getName)
      .getOrCreate()
    val lines = spark.readStream
      .format("socket")
      .option("host", host)
      .option("port", port)
      .load()
    import spark.implicits._

    val words = lines.as[(String, Timestamp)]
      .flatMap(line => line._1.split(" ").map(word => (word, line._2))).toDF()

    val windowCount = words.groupBy(
      window($"timestamp", windowDuration, slideDuration)
      , $"word").count().orderBy("window")

    val query = windowCount.writeStream
      .outputMode(OutputMode.Complete())
      .format("console")
      .option("truncate", "false")
      .start()

    query.awaitTermination()


  }
} 
Example 192
Source File: A_5_StreamingWordWcount.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.structured_streaming

import com.sev7e0.wow.spark_streaming.StreamingLogger
import org.apache.spark.sql.SparkSession

object A_5_StreamingWordWcount {

  val MASTER = "local"
  val HOST = "localhost"
  val PORT = 9999

  
  def main(args: Array[String]): Unit = {

    //创建SparkSession对象
    val spark = SparkSession.builder()
      .appName(A_5_StreamingWordWcount.getClass.getName)
      .master(MASTER)
      .getOrCreate()

    StreamingLogger.setLoggerLevel()

    //输入表
    val line = spark
      .readStream
      .format("socket")
      .option("host", HOST)
      .option("port", PORT)
      .load()

    //打印结构
    line.printSchema()

    //DataFrame隐式转换为DataSet
    import spark.implicits._
    val word = line.as[String].flatMap(_.split(" "))

    //对流进行操作
    val count = word.groupBy("value").count()

    val query = count.writeStream
      .outputMode(outputMode = "complete")
      .format("console")
      .start()

    query.awaitTermination()
  }

} 
Example 193
Source File: A_1_BasicOperation.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.structured_streaming

import java.sql.Timestamp

import org.apache.spark.sql.types.{BooleanType, StringType, StructType, TimestampType}
import org.apache.spark.sql.{Dataset, SparkSession}

object A_1_BasicOperation {

  //DateTime要使用Timestamp  case类必须使用java.sql。在catalyst中作为TimestampType调用的时间戳
  case class DeviceData(device: String, deviceType: String, signal: Double, time: Timestamp)

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName(A_1_BasicOperation.getClass.getName)
      .master("local")
      .getOrCreate()
    val timeStructType = new StructType().add("device", StringType)
      .add("deviceType", StringType)
      .add("signal", BooleanType)
      .add("time", TimestampType)

    val dataFrame = spark.read.json("src/main/resources/sparkresource/device.json")
    import spark.implicits._
    val ds: Dataset[DeviceData] = dataFrame.as[DeviceData]

    //使用无类型方式查询,类sql
    dataFrame.select("device").where("signal>10").show()
    //使用有类型方式进行查询
    ds.filter(_.signal > 10).map(_.device).show()

    //使用无类型方式进行groupBy,并进行统计
    dataFrame.groupBy("deviceType").count().show()


    import org.apache.spark.sql.expressions.scalalang.typed
    //使用有类型方式进行 计算每种类型的设备的平均信号值
    ds.groupByKey(_.deviceType).agg(typed.avg(_.signal)).show()

    //也可以使用创建临时视图的形式,使用sql语句进行查询
    dataFrame.createOrReplaceTempView("device")
    spark.sql("select * from device").show()

    //可以使用isStreaming来判断是否有流数据
    println(dataFrame.isStreaming)
  }
} 
Example 194
Source File: A_3_StreamDeduplication.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.structured_streaming

import org.apache.spark.sql.SparkSession

object A_3_StreamDeduplication {
  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder()
      .master("local")
      .appName("StreamDeduplication")
      .getOrCreate()
    val streamDF = session.readStream.load()

    //不适用水印的情况下操作guid列
    streamDF.dropDuplicates("guid")

    
    streamDF.withWatermark("eventTime", "2 seconds")
      .dropDuplicates("guid")

  }
} 
Example 195
Source File: A_6_ContinuousProcessing.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.structured_streaming

import org.apache.spark.sql.SparkSession

object A_6_ContinuousProcessing {

  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder()
      .appName("ContinuousProcessing")
      .master("local")
      .getOrCreate()

    import org.apache.spark.sql.streaming.Trigger
    session.readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", "host1:port1,host2:port2")
      .option("subscribe", "topic1")
      .load()
      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
      .writeStream
      .format("kafka")
      .option("kafka.bootstrap.server", "host1:port1,host2:port2")
      .option("subscribe", "outPutTopic")
      .trigger(Trigger.Continuous("1 second"))
      .start()

    
  }

} 
Example 196
Source File: A_1_DataFrameTest.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.sql

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}


      //定义一个schema
    val schemaString = "name age"
    val fields = schemaString.split(" ")
      .map(filedName => StructField(filedName, StringType, nullable = true))
    val structType = StructType(fields)

    val personRDD = sparkSession.sparkContext.textFile("src/main/resources/sparkresource/people.txt")
      .map(_.split(","))
      //将RDD转换为行
      .map(attr => Row(attr(0), attr(1).trim))
    //将schema应用于RDD,并创建df
    sparkSession.createDataFrame(personRDD,structType).createOrReplaceTempView("people1")
    val dataFrameBySchema = sparkSession.sql("select name,age from people1 where age > 19 ")
    dataFrameBySchema.show()
  }

} 
Example 197
Source File: A_2_DataSetTest.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.sql

import org.apache.spark.sql.SparkSession



object A_2_DataSetTest {
  case class Person( name: String, age: Int)
  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession.builder().appName("DataSetTest")
      .master("local")
      .getOrCreate()
    import sparkSession.implicits._
    val dataSet= Seq(Person("lee",18)).toDS()
    dataSet.show()
    val frame = Seq(1,5,7).toDS()
    val inns = frame.map(_ + 1).collect()
    inns.foreach(print(_))
  }

} 
Example 198
Source File: A_9_MyAverageByAggregator.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.sql

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator



case class Employee(name:String, salary:Long)
case class Average(var sum:Long, var count:Long)
object A_9_MyAverageByAggregator extends Aggregator[Employee, Average, Double]{
  override def zero: Average = Average(0L,0L)

  override def reduce(b: Average, a: Employee): Average = {
    b.sum += a.salary
    b.count+=1
    b
  }

  override def merge(b1: Average, b2: Average): Average = {
    b1.count+=b2.count
    b1.sum+=b2.sum
    b1
  }

  override def finish(reduction: Average): Double = reduction.sum.toDouble/reduction.count

  override def bufferEncoder: Encoder[Average] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession.builder().master("local").appName("MyAverageByAggregator")
      .getOrCreate()
    //隐式转换
    import sparkSession.implicits._
    val dataFrame = sparkSession.read.json("src/main/resources/sparkresource/employees.json").as[Employee]
    dataFrame.show()

    val salary_average = A_9_MyAverageByAggregator.toColumn.name("salary_average")

    val frame = dataFrame.select(salary_average)
    frame.show()
  }
} 
Example 199
Source File: A_8_MyAverage.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.sql

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


object A_8_MyAverage extends UserDefinedAggregateFunction{
  override def inputSchema: StructType = StructType(StructField("inputColumn",LongType)::Nil)

  override def bufferSchema: StructType = {
    StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
  }

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)= 0l
    buffer(1)= 0l
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)){
      buffer(0) = buffer.getLong(0) + input.getLong(0)
      buffer(1) = buffer.getLong(1) + 1

    }
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1)

  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession.builder().appName("A_8_MyAverage")
      .master("local")
      .getOrCreate()
    sparkSession.udf.register("A_8_MyAverage",A_8_MyAverage)

    val dataFrame = sparkSession.read.json("src/main/resources/sparkresource/employees.json")
    dataFrame.createOrReplaceTempView("employees")

    val result = sparkSession.sql("select A_8_MyAverage(salary) as average_salary from employees")
    result.show()
  }
} 
Example 200
Source File: A_0_LoadSaveFunction.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.sql

import org.apache.spark.sql.SparkSession


    val orcDF = session.read.format("orc").load("/Users/sev7e0/workspace/idea-workspace/sparklearn/src/main/resources/spark resource")
    orcDF.show()

//    session.sqlContext
//    users.write.format("orc")
//      .option("orc.bloom.filter.columns","favorite_color")
//      .option("orc.dictionary.key.threshold","1.0")
//      .save("users_with_option.orc")

    //k可以指定格式名称,从任意数据源进行转换
    val peopleDS = session.read.json("src/main/resources/sparkresource/people.json")
    peopleDS.select("name","age").write.mode(saveMode = "overwrite").format("parquet").save("nameAndAge.parquet")

    //从指定格式转换为另一种格式
    val dataFrameCSV = session.read.option("sep",";").option("inferSchema","true")
      .option("header","true").csv("src/main/resources/sparkresource/people.csv")
    dataFrameCSV.select("name","age").write.mode(saveMode = "overwrite").format("json").save("csvToJson.json")

    //直接在文件中执行SQL
    val fromParquetFile = session.sql("SELECT * FROM parquet.`src/main/resources/sparkresource/users.parquet`")
//    fromParquetFile.show()

    //对于生成指定文件可以指定key进行分区
    users.write.partitionBy("favorite_color").format("parquet").mode(saveMode = "overwrite")
      .save("namesPartitionByColor.parquet")
  }

}